package tracerconfig import ( "context" "errors" "io" "net/http" "net/http/httptest" "testing" ) func TestBearerCredentials_GetRequestMetadata(t *testing.T) { tests := []struct { name string tokenFunc BearerTokenFunc wantMeta map[string]string wantErr bool }{ { name: "valid token", tokenFunc: func(ctx context.Context) (string, error) { return "test-token-123", nil }, wantMeta: map[string]string{"authorization": "Bearer test-token-123"}, wantErr: false, }, { name: "empty token returns nil map", tokenFunc: func(ctx context.Context) (string, error) { return "", nil }, wantMeta: nil, wantErr: false, }, { name: "token function error", tokenFunc: func(ctx context.Context) (string, error) { return "", errors.New("token retrieval failed") }, wantMeta: nil, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &bearerCredentials{tokenFunc: tt.tokenFunc} meta, err := c.GetRequestMetadata(context.Background()) if (err != nil) != tt.wantErr { t.Errorf("GetRequestMetadata() error = %v, wantErr %v", err, tt.wantErr) return } if tt.wantMeta == nil && meta != nil { t.Errorf("GetRequestMetadata() = %v, want nil", meta) return } if tt.wantMeta != nil { if meta == nil { t.Errorf("GetRequestMetadata() = nil, want %v", tt.wantMeta) return } for k, v := range tt.wantMeta { if meta[k] != v { t.Errorf("GetRequestMetadata()[%q] = %q, want %q", k, meta[k], v) } } } }) } } func TestBearerCredentials_GetRequestMetadata_ContextPassed(t *testing.T) { type ctxKey string key := ctxKey("test-key") expectedValue := "test-value" var receivedCtx context.Context c := &bearerCredentials{ tokenFunc: func(ctx context.Context) (string, error) { receivedCtx = ctx return "token", nil }, } ctx := context.WithValue(context.Background(), key, expectedValue) _, err := c.GetRequestMetadata(ctx) if err != nil { t.Fatalf("GetRequestMetadata() error = %v", err) } if receivedCtx == nil { t.Fatal("context was not passed to tokenFunc") } if receivedCtx.Value(key) != expectedValue { t.Errorf("context value = %v, want %v", receivedCtx.Value(key), expectedValue) } } func TestBearerCredentials_RequireTransportSecurity(t *testing.T) { c := &bearerCredentials{tokenFunc: func(ctx context.Context) (string, error) { return "", nil }} if !c.RequireTransportSecurity() { t.Error("RequireTransportSecurity() = false, want true") } } func TestBearerRoundTripper_RoundTrip(t *testing.T) { tests := []struct { name string tokenFunc BearerTokenFunc wantAuthHeader string wantErr bool serverShouldRun bool }{ { name: "adds authorization header with valid token", tokenFunc: func(ctx context.Context) (string, error) { return "test-token-abc", nil }, wantAuthHeader: "Bearer test-token-abc", wantErr: false, serverShouldRun: true, }, { name: "omits header for empty token", tokenFunc: func(ctx context.Context) (string, error) { return "", nil }, wantAuthHeader: "", wantErr: false, serverShouldRun: true, }, { name: "propagates token function errors", tokenFunc: func(ctx context.Context) (string, error) { return "", errors.New("token error") }, wantAuthHeader: "", wantErr: true, serverShouldRun: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var receivedAuthHeader string serverCalled := false server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { serverCalled = true receivedAuthHeader = r.Header.Get("Authorization") w.WriteHeader(http.StatusOK) })) defer server.Close() rt := &bearerRoundTripper{ base: http.DefaultTransport, tokenFunc: tt.tokenFunc, } req, err := http.NewRequestWithContext(context.Background(), "GET", server.URL, nil) if err != nil { t.Fatalf("failed to create request: %v", err) } resp, err := rt.RoundTrip(req) if (err != nil) != tt.wantErr { t.Errorf("RoundTrip() error = %v, wantErr %v", err, tt.wantErr) return } if tt.serverShouldRun { if !serverCalled { t.Error("expected server to be called but it wasn't") return } if resp != nil { resp.Body.Close() } } if serverCalled && receivedAuthHeader != tt.wantAuthHeader { t.Errorf("Authorization header = %q, want %q", receivedAuthHeader, tt.wantAuthHeader) } }) } } func TestBearerRoundTripper_PreservesOriginalRequest(t *testing.T) { originalHeader := "original-value" server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) defer server.Close() rt := &bearerRoundTripper{ base: http.DefaultTransport, tokenFunc: func(ctx context.Context) (string, error) { return "new-token", nil }, } req, err := http.NewRequestWithContext(context.Background(), "GET", server.URL, nil) if err != nil { t.Fatalf("failed to create request: %v", err) } req.Header.Set("X-Custom", originalHeader) resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("RoundTrip() error = %v", err) } resp.Body.Close() // Original request should not be modified if auth := req.Header.Get("Authorization"); auth != "" { t.Errorf("original request Authorization header was modified to %q", auth) } if custom := req.Header.Get("X-Custom"); custom != originalHeader { t.Errorf("original request X-Custom header = %q, want %q", custom, originalHeader) } } func TestBearerRoundTripper_UsesRequestContext(t *testing.T) { type ctxKey string key := ctxKey("test-key") expectedValue := "context-value" var receivedCtx context.Context server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) defer server.Close() rt := &bearerRoundTripper{ base: http.DefaultTransport, tokenFunc: func(ctx context.Context) (string, error) { receivedCtx = ctx return "token", nil }, } ctx := context.WithValue(context.Background(), key, expectedValue) req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil) if err != nil { t.Fatalf("failed to create request: %v", err) } resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("RoundTrip() error = %v", err) } resp.Body.Close() if receivedCtx == nil { t.Fatal("context was not passed to tokenFunc") } if receivedCtx.Value(key) != expectedValue { t.Errorf("context value = %v, want %v", receivedCtx.Value(key), expectedValue) } } func TestBearerRoundTripper_PreservesRequestBody(t *testing.T) { expectedBody := "request body content" var receivedBody string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body, _ := io.ReadAll(r.Body) receivedBody = string(body) w.WriteHeader(http.StatusOK) })) defer server.Close() rt := &bearerRoundTripper{ base: http.DefaultTransport, tokenFunc: func(ctx context.Context) (string, error) { return "token", nil }, } req, err := http.NewRequestWithContext(context.Background(), "POST", server.URL, io.NopCloser(newStringReader(expectedBody))) if err != nil { t.Fatalf("failed to create request: %v", err) } req.ContentLength = int64(len(expectedBody)) resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("RoundTrip() error = %v", err) } resp.Body.Close() if receivedBody != expectedBody { t.Errorf("body = %q, want %q", receivedBody, expectedBody) } } // stringReader is an io.Reader that reads from a string exactly once type stringReader struct { s string done bool } func newStringReader(s string) *stringReader { return &stringReader{s: s} } func (r *stringReader) Read(p []byte) (n int, err error) { if r.done { return 0, io.EOF } n = copy(p, r.s) r.done = true return n, io.EOF }