package metricsserver import ( "context" "fmt" "io" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/prometheus/client_golang/prometheus" ) func TestNew(t *testing.T) { metrics := New() if metrics == nil { t.Fatal("New returned nil") } if metrics.r == nil { t.Error("metrics registry is nil") } } func TestRegistry(t *testing.T) { metrics := New() registry := metrics.Registry() if registry == nil { t.Fatal("Registry() returned nil") } if registry != metrics.r { t.Error("Registry() did not return the metrics registry") } // Test that we can register a metric counter := prometheus.NewCounter(prometheus.CounterOpts{ Name: "test_counter", Help: "A test counter", }) err := registry.Register(counter) if err != nil { t.Errorf("failed to register metric: %v", err) } // Test that the metric is registered metricFamilies, err := registry.Gather() if err != nil { t.Errorf("failed to gather metrics: %v", err) } found := false for _, mf := range metricFamilies { if mf.GetName() == "test_counter" { found = true break } } if !found { t.Error("registered metric not found in registry") } } func TestHandler(t *testing.T) { metrics := New() // Register a test metric counter := prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "test_requests_total", Help: "Total number of test requests", }, []string{"method"}, ) metrics.Registry().MustRegister(counter) counter.WithLabelValues("GET").Inc() // Test the handler handler := metrics.Handler() if handler == nil { t.Fatal("Handler() returned nil") } // Create a test request req := httptest.NewRequest("GET", "/metrics", nil) recorder := httptest.NewRecorder() // Call the handler handler.ServeHTTP(recorder, req) // Check response resp := recorder.Result() defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("failed to read response body: %v", err) } bodyStr := string(body) // Check for our test metric if !strings.Contains(bodyStr, "test_requests_total") { t.Error("test metric not found in metrics output") } // Check for OpenMetrics format indicators if !strings.Contains(bodyStr, "# TYPE") { t.Error("metrics output missing TYPE comments") } } func TestListenAndServe(t *testing.T) { metrics := New() // Register a test metric counter := prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "test_requests_total", Help: "Total number of test requests", }, []string{"method"}, ) metrics.Registry().MustRegister(counter) counter.WithLabelValues("GET").Inc() ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Start server in a goroutine errCh := make(chan error, 1) go func() { // Use a high port number to avoid conflicts errCh <- metrics.ListenAndServe(ctx, 9999) }() // Give the server a moment to start time.Sleep(100 * time.Millisecond) // Test metrics endpoint resp, err := http.Get("http://localhost:9999/metrics") if err != nil { t.Fatalf("failed to GET /metrics: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("failed to read response body: %v", err) } bodyStr := string(body) // Check for our test metric if !strings.Contains(bodyStr, "test_requests_total") { t.Error("test metric not found in metrics output") } // Cancel context to stop server cancel() // Wait for server to stop select { case err := <-errCh: if err != nil { t.Errorf("server returned error: %v", err) } case <-time.After(5 * time.Second): t.Error("server did not stop within timeout") } } func TestListenAndServeContextCancellation(t *testing.T) { metrics := New() ctx, cancel := context.WithCancel(context.Background()) // Start server errCh := make(chan error, 1) go func() { errCh <- metrics.ListenAndServe(ctx, 9998) }() // Give server time to start time.Sleep(100 * time.Millisecond) // Cancel context cancel() // Server should stop gracefully select { case err := <-errCh: if err != nil { t.Errorf("server returned error on graceful shutdown: %v", err) } case <-time.After(5 * time.Second): t.Error("server did not stop within timeout after context cancellation") } } // Benchmark the metrics handler response time func BenchmarkMetricsHandler(b *testing.B) { metrics := New() // Register some test metrics for i := 0; i < 10; i++ { counter := prometheus.NewCounter(prometheus.CounterOpts{ Name: fmt.Sprintf("bench_counter_%d", i), Help: "A benchmark counter", }) metrics.Registry().MustRegister(counter) counter.Add(float64(i * 100)) } handler := metrics.Handler() b.ResetTimer() for i := 0; i < b.N; i++ { req := httptest.NewRequest("GET", "/metrics", nil) recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, req) if recorder.Code != http.StatusOK { b.Fatalf("unexpected status code: %d", recorder.Code) } } }