package tracing

// todo, review:
// https://github.com/ttys3/tracing-go/blob/main/tracing.go#L136

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"errors"
	"os"
	"slices"
	"time"

	"go.ntppool.org/common/logger"
	"go.ntppool.org/common/version"
	"google.golang.org/grpc/credentials"

	"go.opentelemetry.io/contrib/exporters/autoexport"
	"go.opentelemetry.io/otel"
	"go.opentelemetry.io/otel/attribute"
	"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
	"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
	"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
	logglobal "go.opentelemetry.io/otel/log/global"
	"go.opentelemetry.io/otel/propagation"
	sdklog "go.opentelemetry.io/otel/sdk/log"
	"go.opentelemetry.io/otel/sdk/resource"
	sdktrace "go.opentelemetry.io/otel/sdk/trace"
	semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
	"go.opentelemetry.io/otel/trace"
)

const (
	// svcNameKey is the environment variable name that Service Name information will be read from.
	svcNameKey = "OTEL_SERVICE_NAME"

	otelExporterOTLPProtoEnvKey       = "OTEL_EXPORTER_OTLP_PROTOCOL"
	otelExporterOTLPTracesProtoEnvKey = "OTEL_EXPORTER_OTLP_TRACES_PROTOCOL"
)

var errInvalidOTLPProtocol = errors.New("invalid OTLP protocol - should be one of ['grpc', 'http/protobuf']")

// https://github.com/open-telemetry/opentelemetry-go/blob/main/exporters/otlp/otlptrace/otlptracehttp/example_test.go

type TpShutdownFunc func(ctx context.Context) error

func Tracer() trace.Tracer {
	traceProvider := otel.GetTracerProvider()
	return traceProvider.Tracer("ntppool-tracer")
}

func Start(ctx context.Context, spanName string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
	return Tracer().Start(ctx, spanName, opts...)
}

type GetClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error)

type TracerConfig struct {
	ServiceName string
	Environment string
	Endpoint    string
	EndpointURL string

	CertificateProvider GetClientCertificate
	RootCAs             *x509.CertPool
}

func InitTracer(ctx context.Context, cfg *TracerConfig) (TpShutdownFunc, error) {
	// todo: setup environment from cfg
	return SetupSDK(ctx, cfg)
}

func SetupSDK(ctx context.Context, cfg *TracerConfig) (shutdown TpShutdownFunc, err error) {
	if cfg == nil {
		cfg = &TracerConfig{}
	}

	log := logger.Setup()

	if serviceName := os.Getenv(svcNameKey); len(serviceName) == 0 {
		if len(cfg.ServiceName) > 0 {
			os.Setenv(svcNameKey, cfg.ServiceName)
		}
	}

	resources := []resource.Option{
		resource.WithFromEnv(),      // Discover and provide attributes from OTEL_RESOURCE_ATTRIBUTES and OTEL_SERVICE_NAME environment variables.
		resource.WithTelemetrySDK(), // Discover and provide information about the OpenTelemetry SDK used.
		resource.WithProcess(),      // Discover and provide process information.
		resource.WithOS(),           // Discover and provide OS information.
		resource.WithContainer(),    // Discover and provide container information.
		resource.WithHost(),         // Discover and provide host information.

		// set above via os.Setenv() for WithFromEnv to find
		// resource.WithAttributes(semconv.ServiceNameKey.String(cfg.ServiceName)),

		resource.WithAttributes(semconv.ServiceVersionKey.String(version.Version())),
	}

	if len(cfg.Environment) > 0 {
		resources = append(resources,
			resource.WithAttributes(attribute.String("environment", cfg.Environment)),
		)
	}

	res, err := resource.New(
		context.Background(),
		resources...,
	)
	if errors.Is(err, resource.ErrPartialResource) || errors.Is(err, resource.ErrSchemaURLConflict) {
		log.Warn("otel resource setup", "err", err) // Log non-fatal issues.
	} else if err != nil {
		log.Error("otel resource setup", "err", err)
		return
	}

	var shutdownFuncs []func(context.Context) error
	shutdown = func(ctx context.Context) error {
		var err error
		// need to shutdown the providers first,
		// exporters after which is the opposite
		// order they are setup.
		slices.Reverse(shutdownFuncs)
		for _, fn := range shutdownFuncs {
			// log.Warn("shutting down", "fn", fn)
			err = errors.Join(err, fn(ctx))
		}
		shutdownFuncs = nil
		if err != nil {
			log.Warn("shutdown returned errors", "err", err)
		}
		return err
	}

	// handleErr calls shutdown for cleanup and makes sure that all errors are returned.
	handleErr := func(inErr error) {
		err = errors.Join(inErr, shutdown(ctx))
	}

	prop := newPropagator()
	otel.SetTextMapPropagator(prop)

	var spanExporter sdktrace.SpanExporter

	switch os.Getenv("OTEL_TRACES_EXPORTER") {
	case "":
		spanExporter, err = newOLTPExporter(ctx, cfg)
	case "otlp":
		spanExporter, err = newOLTPExporter(ctx, cfg)
	default:
		// log.Debug("OTEL_TRACES_EXPORTER", "fallback", os.Getenv("OTEL_TRACES_EXPORTER"))
		spanExporter, err = autoexport.NewSpanExporter(ctx)
	}
	if err != nil {
		handleErr(err)
		return
	}
	shutdownFuncs = append(shutdownFuncs, spanExporter.Shutdown)

	logExporter, err := autoexport.NewLogExporter(ctx)
	if err != nil {
		handleErr(err)
		return
	}
	shutdownFuncs = append(shutdownFuncs, logExporter.Shutdown)

	// Set up trace provider.
	tracerProvider, err := newTraceProvider(spanExporter, res)
	if err != nil {
		handleErr(err)
		return
	}
	shutdownFuncs = append(shutdownFuncs, tracerProvider.Shutdown)
	otel.SetTracerProvider(tracerProvider)

	logProvider := sdklog.NewLoggerProvider(sdklog.WithResource(res),
		sdklog.WithProcessor(
			sdklog.NewBatchProcessor(logExporter, sdklog.WithExportBufferSize(10)),
		),
	)

	logglobal.SetLoggerProvider(logProvider)
	shutdownFuncs = append(shutdownFuncs, func(ctx context.Context) error {
		logProvider.ForceFlush(ctx)
		return logProvider.Shutdown(ctx)
	},
	)

	if err != nil {
		handleErr(err)
		return
	}

	return
}

func newOLTPExporter(ctx context.Context, cfg *TracerConfig) (sdktrace.SpanExporter, error) {
	log := logger.Setup()

	var tlsConfig *tls.Config

	if cfg.CertificateProvider != nil {
		tlsConfig = &tls.Config{
			GetClientCertificate: cfg.CertificateProvider,
			RootCAs:              cfg.RootCAs,
		}
	}

	proto := os.Getenv(otelExporterOTLPTracesProtoEnvKey)
	if proto == "" {
		proto = os.Getenv(otelExporterOTLPProtoEnvKey)
	}

	// Fallback to default, http/protobuf.
	if proto == "" {
		proto = "http/protobuf"
	}

	var client otlptrace.Client

	switch proto {
	case "grpc":
		opts := []otlptracegrpc.Option{
			otlptracegrpc.WithCompressor("gzip"),
		}
		if tlsConfig != nil {
			opts = append(opts, otlptracegrpc.WithTLSCredentials(credentials.NewTLS(tlsConfig)))
		}
		if len(cfg.Endpoint) > 0 {
			log.Info("adding option", "Endpoint", cfg.Endpoint)
			opts = append(opts, otlptracegrpc.WithEndpoint(cfg.Endpoint))
		}
		if len(cfg.EndpointURL) > 0 {
			log.Info("adding option", "EndpointURL", cfg.EndpointURL)
			opts = append(opts, otlptracegrpc.WithEndpointURL(cfg.EndpointURL))
		}

		client = otlptracegrpc.NewClient(opts...)
	case "http/protobuf", "http/json":
		opts := []otlptracehttp.Option{
			otlptracehttp.WithCompression(otlptracehttp.GzipCompression),
		}
		if tlsConfig != nil {
			opts = append(opts, otlptracehttp.WithTLSClientConfig(tlsConfig))
		}
		if len(cfg.Endpoint) > 0 {
			opts = append(opts, otlptracehttp.WithEndpoint(cfg.Endpoint))
		}
		if len(cfg.EndpointURL) > 0 {
			opts = append(opts, otlptracehttp.WithEndpointURL(cfg.EndpointURL))
		}

		client = otlptracehttp.NewClient(opts...)
	default:
		return nil, errInvalidOTLPProtocol
	}

	exporter, err := otlptrace.New(ctx, client)
	if err != nil {
		log.ErrorContext(ctx, "creating OTLP trace exporter", "err", err)
	}
	return exporter, err
}

func newTraceProvider(traceExporter sdktrace.SpanExporter, res *resource.Resource) (*sdktrace.TracerProvider, error) {
	traceProvider := sdktrace.NewTracerProvider(
		sdktrace.WithResource(res),
		sdktrace.WithBatcher(traceExporter,
			sdktrace.WithBatchTimeout(time.Second*3),
		),
	)
	return traceProvider, nil
}

func newPropagator() propagation.TextMapPropagator {
	return propagation.NewCompositeTextMapPropagator(
		propagation.TraceContext{},
		propagation.Baggage{},
	)
}