diff --git a/internal/tracerconfig/auth.go b/internal/tracerconfig/auth.go new file mode 100644 index 0000000..eed2ffe --- /dev/null +++ b/internal/tracerconfig/auth.go @@ -0,0 +1,52 @@ +package tracerconfig + +import ( + "context" + "net/http" +) + +// bearerCredentials implements gRPC PerRPCCredentials for bearer token authentication. +// It is safe for concurrent use as required by the gRPC PerRPCCredentials interface. +type bearerCredentials struct { + tokenFunc BearerTokenFunc +} + +// GetRequestMetadata returns authorization metadata for each RPC call. +// It calls the token function to retrieve the current token. +func (c *bearerCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + token, err := c.tokenFunc(ctx) + if err != nil { + return nil, err + } + if token == "" { + return nil, nil // Omit header for empty token + } + return map[string]string{"authorization": "Bearer " + token}, nil +} + +// RequireTransportSecurity returns true because bearer tokens require TLS. +func (c *bearerCredentials) RequireTransportSecurity() bool { + return true +} + +// bearerRoundTripper wraps an http.RoundTripper to add bearer token authentication. +// It is safe for concurrent use as required by the http.RoundTripper interface. +type bearerRoundTripper struct { + base http.RoundTripper + tokenFunc BearerTokenFunc +} + +// RoundTrip adds the Authorization header with the bearer token. +func (rt *bearerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + token, err := rt.tokenFunc(req.Context()) + if err != nil { + return nil, err + } + if token == "" { + return rt.base.RoundTrip(req) + } + // Clone only when adding a header to preserve the original request + req = req.Clone(req.Context()) + req.Header.Set("Authorization", "Bearer "+token) + return rt.base.RoundTrip(req) +} diff --git a/internal/tracerconfig/auth_test.go b/internal/tracerconfig/auth_test.go new file mode 100644 index 0000000..df03b70 --- /dev/null +++ b/internal/tracerconfig/auth_test.go @@ -0,0 +1,326 @@ +package tracerconfig + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func TestBearerCredentials_GetRequestMetadata(t *testing.T) { + tests := []struct { + name string + tokenFunc BearerTokenFunc + wantMeta map[string]string + wantErr bool + }{ + { + name: "valid token", + tokenFunc: func(ctx context.Context) (string, error) { + return "test-token-123", nil + }, + wantMeta: map[string]string{"authorization": "Bearer test-token-123"}, + wantErr: false, + }, + { + name: "empty token returns nil map", + tokenFunc: func(ctx context.Context) (string, error) { + return "", nil + }, + wantMeta: nil, + wantErr: false, + }, + { + name: "token function error", + tokenFunc: func(ctx context.Context) (string, error) { + return "", errors.New("token retrieval failed") + }, + wantMeta: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &bearerCredentials{tokenFunc: tt.tokenFunc} + meta, err := c.GetRequestMetadata(context.Background()) + + if (err != nil) != tt.wantErr { + t.Errorf("GetRequestMetadata() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantMeta == nil && meta != nil { + t.Errorf("GetRequestMetadata() = %v, want nil", meta) + return + } + + if tt.wantMeta != nil { + if meta == nil { + t.Errorf("GetRequestMetadata() = nil, want %v", tt.wantMeta) + return + } + for k, v := range tt.wantMeta { + if meta[k] != v { + t.Errorf("GetRequestMetadata()[%q] = %q, want %q", k, meta[k], v) + } + } + } + }) + } +} + +func TestBearerCredentials_GetRequestMetadata_ContextPassed(t *testing.T) { + type ctxKey string + key := ctxKey("test-key") + expectedValue := "test-value" + + var receivedCtx context.Context + c := &bearerCredentials{ + tokenFunc: func(ctx context.Context) (string, error) { + receivedCtx = ctx + return "token", nil + }, + } + + ctx := context.WithValue(context.Background(), key, expectedValue) + _, err := c.GetRequestMetadata(ctx) + if err != nil { + t.Fatalf("GetRequestMetadata() error = %v", err) + } + + if receivedCtx == nil { + t.Fatal("context was not passed to tokenFunc") + } + + if receivedCtx.Value(key) != expectedValue { + t.Errorf("context value = %v, want %v", receivedCtx.Value(key), expectedValue) + } +} + +func TestBearerCredentials_RequireTransportSecurity(t *testing.T) { + c := &bearerCredentials{tokenFunc: func(ctx context.Context) (string, error) { + return "", nil + }} + + if !c.RequireTransportSecurity() { + t.Error("RequireTransportSecurity() = false, want true") + } +} + +func TestBearerRoundTripper_RoundTrip(t *testing.T) { + tests := []struct { + name string + tokenFunc BearerTokenFunc + wantAuthHeader string + wantErr bool + serverShouldRun bool + }{ + { + name: "adds authorization header with valid token", + tokenFunc: func(ctx context.Context) (string, error) { + return "test-token-abc", nil + }, + wantAuthHeader: "Bearer test-token-abc", + wantErr: false, + serverShouldRun: true, + }, + { + name: "omits header for empty token", + tokenFunc: func(ctx context.Context) (string, error) { + return "", nil + }, + wantAuthHeader: "", + wantErr: false, + serverShouldRun: true, + }, + { + name: "propagates token function errors", + tokenFunc: func(ctx context.Context) (string, error) { + return "", errors.New("token error") + }, + wantAuthHeader: "", + wantErr: true, + serverShouldRun: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var receivedAuthHeader string + serverCalled := false + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + serverCalled = true + receivedAuthHeader = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + rt := &bearerRoundTripper{ + base: http.DefaultTransport, + tokenFunc: tt.tokenFunc, + } + + req, err := http.NewRequestWithContext(context.Background(), "GET", server.URL, nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := rt.RoundTrip(req) + + if (err != nil) != tt.wantErr { + t.Errorf("RoundTrip() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.serverShouldRun { + if !serverCalled { + t.Error("expected server to be called but it wasn't") + return + } + if resp != nil { + resp.Body.Close() + } + } + + if serverCalled && receivedAuthHeader != tt.wantAuthHeader { + t.Errorf("Authorization header = %q, want %q", receivedAuthHeader, tt.wantAuthHeader) + } + }) + } +} + +func TestBearerRoundTripper_PreservesOriginalRequest(t *testing.T) { + originalHeader := "original-value" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + rt := &bearerRoundTripper{ + base: http.DefaultTransport, + tokenFunc: func(ctx context.Context) (string, error) { + return "new-token", nil + }, + } + + req, err := http.NewRequestWithContext(context.Background(), "GET", server.URL, nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("X-Custom", originalHeader) + + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip() error = %v", err) + } + resp.Body.Close() + + // Original request should not be modified + if auth := req.Header.Get("Authorization"); auth != "" { + t.Errorf("original request Authorization header was modified to %q", auth) + } + + if custom := req.Header.Get("X-Custom"); custom != originalHeader { + t.Errorf("original request X-Custom header = %q, want %q", custom, originalHeader) + } +} + +func TestBearerRoundTripper_UsesRequestContext(t *testing.T) { + type ctxKey string + key := ctxKey("test-key") + expectedValue := "context-value" + + var receivedCtx context.Context + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + rt := &bearerRoundTripper{ + base: http.DefaultTransport, + tokenFunc: func(ctx context.Context) (string, error) { + receivedCtx = ctx + return "token", nil + }, + } + + ctx := context.WithValue(context.Background(), key, expectedValue) + req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip() error = %v", err) + } + resp.Body.Close() + + if receivedCtx == nil { + t.Fatal("context was not passed to tokenFunc") + } + + if receivedCtx.Value(key) != expectedValue { + t.Errorf("context value = %v, want %v", receivedCtx.Value(key), expectedValue) + } +} + +func TestBearerRoundTripper_PreservesRequestBody(t *testing.T) { + expectedBody := "request body content" + var receivedBody string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + receivedBody = string(body) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + rt := &bearerRoundTripper{ + base: http.DefaultTransport, + tokenFunc: func(ctx context.Context) (string, error) { + return "token", nil + }, + } + + req, err := http.NewRequestWithContext(context.Background(), "POST", server.URL, + io.NopCloser(newStringReader(expectedBody))) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.ContentLength = int64(len(expectedBody)) + + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip() error = %v", err) + } + resp.Body.Close() + + if receivedBody != expectedBody { + t.Errorf("body = %q, want %q", receivedBody, expectedBody) + } +} + +// stringReader is an io.Reader that reads from a string exactly once +type stringReader struct { + s string + done bool +} + +func newStringReader(s string) *stringReader { + return &stringReader{s: s} +} + +func (r *stringReader) Read(p []byte) (n int, err error) { + if r.done { + return 0, io.EOF + } + n = copy(p, r.s) + r.done = true + return n, io.EOF +} diff --git a/internal/tracerconfig/config.go b/internal/tracerconfig/config.go index 4318f7d..977b28f 100644 --- a/internal/tracerconfig/config.go +++ b/internal/tracerconfig/config.go @@ -25,6 +25,7 @@ import ( sdklog "go.opentelemetry.io/otel/sdk/log" sdkmetric "go.opentelemetry.io/otel/sdk/metric" sdktrace "go.opentelemetry.io/otel/sdk/trace" + "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) @@ -109,6 +110,22 @@ func ValidateAndStore(ctx context.Context, cfg *Config, logFactory LogExporterFa // client certificate authentication. type GetClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error) +// BearerTokenFunc retrieves a bearer token for OTLP authentication. +// It is called for each export request (traces, logs, metrics). +// The caller is responsible for caching and token renewal. +// Returns the token string (without "Bearer " prefix) or an error. +// An empty string with no error means skip the Authorization header. +// +// Thread safety: This function may be called concurrently from multiple +// goroutines. Implementations must be safe for concurrent use. +// +// Protocol support: Bearer authentication is fully supported for gRPC exporters +// via PerRPCCredentials. HTTP exporters do not currently support dynamic bearer +// tokens due to OpenTelemetry SDK limitations (no WithHTTPClient option). +// For HTTP, use OTEL_EXPORTER_OTLP_HEADERS environment variable for static tokens, +// or switch to gRPC protocol (OTEL_EXPORTER_OTLP_PROTOCOL=grpc) for dynamic tokens. +type BearerTokenFunc func(ctx context.Context) (string, error) + // Config provides configuration options for OpenTelemetry tracing setup. // It supplements standard OpenTelemetry environment variables with additional // NTP Pool-specific configuration including TLS settings for secure OTLP export. @@ -119,6 +136,7 @@ type Config struct { EndpointURL string // Complete OTLP endpoint URL (e.g., "https://otlp.example.com:4317/v1/traces") CertificateProvider GetClientCertificate // Client certificate provider for mutual TLS RootCAs *x509.CertPool // CA certificate pool for server verification + BearerTokenFunc BearerTokenFunc // Token provider for bearer authentication } // LogExporterFactory creates an OTLP log exporter using the provided configuration. @@ -242,6 +260,10 @@ func CreateOTLPLogExporter(ctx context.Context, cfg *Config) (sdklog.Exporter, e if tlsConfig != nil { opts = append(opts, otlploggrpc.WithTLSCredentials(credentials.NewTLS(tlsConfig))) } + if cfg.BearerTokenFunc != nil { + creds := &bearerCredentials{tokenFunc: cfg.BearerTokenFunc} + opts = append(opts, otlploggrpc.WithDialOption(grpc.WithPerRPCCredentials(creds))) + } if len(cfg.Endpoint) > 0 { opts = append(opts, otlploggrpc.WithEndpoint(cfg.Endpoint)) } @@ -290,6 +312,10 @@ func CreateOTLPMetricExporter(ctx context.Context, cfg *Config) (sdkmetric.Expor if tlsConfig != nil { opts = append(opts, otlpmetricgrpc.WithTLSCredentials(credentials.NewTLS(tlsConfig))) } + if cfg.BearerTokenFunc != nil { + creds := &bearerCredentials{tokenFunc: cfg.BearerTokenFunc} + opts = append(opts, otlpmetricgrpc.WithDialOption(grpc.WithPerRPCCredentials(creds))) + } if len(cfg.Endpoint) > 0 { opts = append(opts, otlpmetricgrpc.WithEndpoint(cfg.Endpoint)) } @@ -340,6 +366,10 @@ func CreateOTLPTraceExporter(ctx context.Context, cfg *Config) (sdktrace.SpanExp if tlsConfig != nil { opts = append(opts, otlptracegrpc.WithTLSCredentials(credentials.NewTLS(tlsConfig))) } + if cfg.BearerTokenFunc != nil { + creds := &bearerCredentials{tokenFunc: cfg.BearerTokenFunc} + opts = append(opts, otlptracegrpc.WithDialOption(grpc.WithPerRPCCredentials(creds))) + } if len(cfg.Endpoint) > 0 { opts = append(opts, otlptracegrpc.WithEndpoint(cfg.Endpoint)) } diff --git a/logger/buffering_exporter.go b/logger/buffering_exporter.go index 3774f54..51e839e 100644 --- a/logger/buffering_exporter.go +++ b/logger/buffering_exporter.go @@ -23,6 +23,9 @@ type bufferingExporter struct { // Real exporter (created when tracing is configured) exporter otellog.Exporter + // Track whether buffer has been flushed (separate from exporter creation) + bufferFlushed bool + // Thread-safe initialization state (managed only by checkReadiness) initErr error @@ -71,22 +74,63 @@ func (e *bufferingExporter) initialize() error { initCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - exporter, err := factory(initCtx, cfg) - if err != nil { - return fmt.Errorf("failed to create OTLP exporter: %w", err) + e.mu.RLock() + hasExporter := e.exporter != nil + e.mu.RUnlock() + + // Create exporter if not already created + if !hasExporter { + exporter, err := factory(initCtx, cfg) + if err != nil { + return fmt.Errorf("failed to create OTLP exporter: %w", err) + } + e.mu.Lock() + // Double-check: another goroutine may have created it while we were waiting + if e.exporter == nil { + e.exporter = exporter + } else { + // Another goroutine beat us, close the one we created + _ = exporter.Shutdown(context.Background()) + } + e.mu.Unlock() + } + + // Check if we can flush (token verification if configured) + if !e.canFlush(initCtx, cfg, false) { + return errors.New("waiting for token authentication") } e.mu.Lock() - e.exporter = exporter - flushErr := e.flushBuffer(initCtx) + if !e.bufferFlushed { + flushErr := e.flushBuffer(initCtx) + if flushErr != nil { + e.mu.Unlock() + // Log but don't fail initialization + Setup().Warn("buffer flush failed during initialization", "error", flushErr) + return nil + } + e.bufferFlushed = true + } e.mu.Unlock() - if flushErr != nil { - // Log but don't fail initialization - Setup().Warn("buffer flush failed during initialization", "error", flushErr) + return nil +} + +// canFlush checks if we're ready to flush buffered logs. +// If BearerTokenFunc is configured, it must return without error. +// If forceFlush is true (during shutdown with cancelled context), skip token check. +func (e *bufferingExporter) canFlush(ctx context.Context, cfg *tracerconfig.Config, forceFlush bool) bool { + if cfg.BearerTokenFunc == nil { + return true // No token auth configured, can flush immediately } - return nil + if forceFlush { + return true // During shutdown, proceed with best-effort flush + } + + // Check if token is available (call returns without error) + _, err := cfg.BearerTokenFunc(ctx) + return err == nil } // bufferRecords adds records to the buffer for later processing @@ -119,16 +163,16 @@ func (e *bufferingExporter) checkReadiness() { for { select { case <-ticker.C: - // Check if we already have a working exporter + // Check if we're fully ready (exporter created AND buffer flushed) e.mu.RLock() - hasExporter := e.exporter != nil + fullyReady := e.exporter != nil && e.bufferFlushed e.mu.RUnlock() - if hasExporter { - return // Exporter ready, checker no longer needed + if fullyReady { + return // Fully initialized, checker no longer needed } - // Try to initialize + // Try to initialize (creates exporter and flushes if token ready) err := e.initialize() e.mu.Lock() e.initErr = err @@ -182,18 +226,42 @@ func (e *bufferingExporter) Shutdown(ctx context.Context) error { // Wait for readiness checker goroutine to complete <-e.checkerDone + cfg, _, _ := tracerconfig.Get() + + // Check if context is cancelled for best-effort flush + forceFlush := ctx.Err() != nil + // Give one final chance for TLS/tracing to become ready for buffer flushing e.mu.RLock() hasExporter := e.exporter != nil + bufferFlushed := e.bufferFlushed e.mu.RUnlock() if !hasExporter { err := e.initialize() e.mu.Lock() e.initErr = err + hasExporter = e.exporter != nil + bufferFlushed = e.bufferFlushed e.mu.Unlock() } + // If exporter exists but buffer not flushed, try to flush now + if hasExporter && !bufferFlushed { + canFlushNow := cfg == nil || e.canFlush(ctx, cfg, forceFlush) + if canFlushNow { + e.mu.Lock() + if !e.bufferFlushed { + flushErr := e.flushBuffer(ctx) + if flushErr != nil { + Setup().Warn("buffer flush failed during shutdown", "error", flushErr) + } + e.bufferFlushed = true + } + e.mu.Unlock() + } + } + e.mu.Lock() defer e.mu.Unlock() diff --git a/tracing/tracing.go b/tracing/tracing.go index acf987d..c055eca 100644 --- a/tracing/tracing.go +++ b/tracing/tracing.go @@ -112,6 +112,10 @@ func Start(ctx context.Context, spanName string, opts ...trace.SpanStartOption) // This maintains backward compatibility for existing code. type GetClientCertificate = tracerconfig.GetClientCertificate +// BearerTokenFunc is an alias for the type defined in tracerconfig. +// It retrieves a bearer token for OTLP authentication. +type BearerTokenFunc = tracerconfig.BearerTokenFunc + // TracerConfig provides configuration options for OpenTelemetry tracing setup. // It supplements standard OpenTelemetry environment variables with additional // NTP Pool-specific configuration including TLS settings for secure OTLP export. @@ -123,6 +127,7 @@ type TracerConfig struct { CertificateProvider GetClientCertificate // Client certificate provider for mutual TLS RootCAs *x509.CertPool // CA certificate pool for server verification + BearerTokenFunc BearerTokenFunc // Token provider for bearer authentication } // InitTracer initializes the OpenTelemetry SDK with the provided configuration. @@ -160,6 +165,7 @@ func SetupSDK(ctx context.Context, cfg *TracerConfig) (shutdown TpShutdownFunc, EndpointURL: cfg.EndpointURL, CertificateProvider: cfg.CertificateProvider, RootCAs: cfg.RootCAs, + BearerTokenFunc: cfg.BearerTokenFunc, } tracerconfig.Store(ctx, bridgeConfig, createOTLPLogExporter, createOTLPMetricExporter, createOTLPTraceExporter)