package metricsserver import ( "context" "fmt" "io" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/prometheus/client_golang/prometheus" ) func TestNew(t *testing.T) { metrics := New() if metrics == nil { t.Fatal("New returned nil") } if metrics.r == nil { t.Error("metrics registry is nil") } } func TestRegistry(t *testing.T) { metrics := New() registry := metrics.Registry() if registry == nil { t.Fatal("Registry() returned nil") } if registry != metrics.r { t.Error("Registry() did not return the metrics registry") } // Test that we can register a metric counter := prometheus.NewCounter(prometheus.CounterOpts{ Name: "test_counter", Help: "A test counter", }) err := registry.Register(counter) if err != nil { t.Errorf("failed to register metric: %v", err) } // Test that the metric is registered metricFamilies, err := registry.Gather() if err != nil { t.Errorf("failed to gather metrics: %v", err) } found := false for _, mf := range metricFamilies { if mf.GetName() == "test_counter" { found = true break } } if !found { t.Error("registered metric not found in registry") } } func TestGatherer(t *testing.T) { metrics := New() gatherer := metrics.Gatherer() if gatherer == nil { t.Fatal("Gatherer() returned nil") } // Register a test metric counter := prometheus.NewCounter(prometheus.CounterOpts{ Name: "test_gatherer_counter", Help: "A test counter for gatherer", }) metrics.Registry().MustRegister(counter) counter.Inc() // Test that the gatherer collects our custom metric metricFamilies, err := gatherer.Gather() if err != nil { t.Errorf("failed to gather metrics: %v", err) } found := false for _, mf := range metricFamilies { if mf.GetName() == "test_gatherer_counter" { found = true break } } if !found { t.Error("registered metric not found via Gatherer()") } // Verify gatherer is the same as registry if gatherer != metrics.r { t.Error("Gatherer() should return the same object as the registry for custom registry mode") } } func TestHandler(t *testing.T) { metrics := New() // Register a test metric counter := prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "test_requests_total", Help: "Total number of test requests", }, []string{"method"}, ) metrics.Registry().MustRegister(counter) counter.WithLabelValues("GET").Inc() // Test the handler handler := metrics.Handler() if handler == nil { t.Fatal("Handler() returned nil") } // Create a test request req := httptest.NewRequest("GET", "/metrics", nil) recorder := httptest.NewRecorder() // Call the handler handler.ServeHTTP(recorder, req) // Check response resp := recorder.Result() defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("failed to read response body: %v", err) } bodyStr := string(body) // Check for our test metric if !strings.Contains(bodyStr, "test_requests_total") { t.Error("test metric not found in metrics output") } // Check for OpenMetrics format indicators if !strings.Contains(bodyStr, "# TYPE") { t.Error("metrics output missing TYPE comments") } } func TestListenAndServe(t *testing.T) { metrics := New() // Register a test metric counter := prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "test_requests_total", Help: "Total number of test requests", }, []string{"method"}, ) metrics.Registry().MustRegister(counter) counter.WithLabelValues("GET").Inc() ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Start server in a goroutine errCh := make(chan error, 1) go func() { // Use a high port number to avoid conflicts errCh <- metrics.ListenAndServe(ctx, 9999) }() // Give the server a moment to start time.Sleep(100 * time.Millisecond) // Test metrics endpoint resp, err := http.Get("http://localhost:9999/metrics") if err != nil { t.Fatalf("failed to GET /metrics: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("failed to read response body: %v", err) } bodyStr := string(body) // Check for our test metric if !strings.Contains(bodyStr, "test_requests_total") { t.Error("test metric not found in metrics output") } // Cancel context to stop server cancel() // Wait for server to stop select { case err := <-errCh: if err != nil { t.Errorf("server returned error: %v", err) } case <-time.After(5 * time.Second): t.Error("server did not stop within timeout") } } func TestListenAndServeContextCancellation(t *testing.T) { metrics := New() ctx, cancel := context.WithCancel(context.Background()) // Start server errCh := make(chan error, 1) go func() { errCh <- metrics.ListenAndServe(ctx, 9998) }() // Give server time to start time.Sleep(100 * time.Millisecond) // Cancel context cancel() // Server should stop gracefully select { case err := <-errCh: if err != nil { t.Errorf("server returned error on graceful shutdown: %v", err) } case <-time.After(5 * time.Second): t.Error("server did not stop within timeout after context cancellation") } } func TestNewWithDefaultGatherer(t *testing.T) { metrics := NewWithDefaultGatherer() if metrics == nil { t.Fatal("NewWithDefaultGatherer returned nil") } if !metrics.useDefaultGatherer { t.Error("useDefaultGatherer should be true") } gatherer := metrics.Gatherer() if gatherer == nil { t.Fatal("Gatherer() returned nil") } // Verify it returns the default gatherer if gatherer != prometheus.DefaultGatherer { t.Error("Gatherer() should return prometheus.DefaultGatherer when useDefaultGatherer is true") } // Verify the custom registry is still available and separate if metrics.Registry() == nil { t.Error("Registry() should still return a custom registry") } // Test that registering in custom registry doesn't affect default gatherer check counter := prometheus.NewCounter(prometheus.CounterOpts{ Name: "test_default_gatherer_counter", Help: "A test counter", }) metrics.Registry().MustRegister(counter) // The gatherer should still be the default one, not our custom registry if metrics.Gatherer() != prometheus.DefaultGatherer { t.Error("Gatherer() should continue to return prometheus.DefaultGatherer") } } // Benchmark the metrics handler response time func BenchmarkMetricsHandler(b *testing.B) { metrics := New() // Register some test metrics for i := 0; i < 10; i++ { counter := prometheus.NewCounter(prometheus.CounterOpts{ Name: fmt.Sprintf("bench_counter_%d", i), Help: "A benchmark counter", }) metrics.Registry().MustRegister(counter) counter.Add(float64(i * 100)) } handler := metrics.Handler() b.ResetTimer() for i := 0; i < b.N; i++ { req := httptest.NewRequest("GET", "/metrics", nil) recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, req) if recorder.Code != http.StatusOK { b.Fatalf("unexpected status code: %d", recorder.Code) } } }