Files
common/metricsserver/metrics_test.go
Ask Bjørn Hansen 7291f00f48 feat(metricsserver): add Gatherer method
Add explicit Gatherer() method to improve API discoverability
and prevent users from accidentally using prometheus.DefaultGatherer
instead of the custom registry.

Changes:
- Add Gatherer() method returning prometheus.Gatherer interface
- Add NewWithDefaultGatherer() constructor for opt-in default usage
- Update package docs with usage examples
- Add tests for both gatherer modes
2025-10-12 16:13:19 -07:00

323 lines
7.1 KiB
Go

package metricsserver
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
)
func TestNew(t *testing.T) {
metrics := New()
if metrics == nil {
t.Fatal("New returned nil")
}
if metrics.r == nil {
t.Error("metrics registry is nil")
}
}
func TestRegistry(t *testing.T) {
metrics := New()
registry := metrics.Registry()
if registry == nil {
t.Fatal("Registry() returned nil")
}
if registry != metrics.r {
t.Error("Registry() did not return the metrics registry")
}
// Test that we can register a metric
counter := prometheus.NewCounter(prometheus.CounterOpts{
Name: "test_counter",
Help: "A test counter",
})
err := registry.Register(counter)
if err != nil {
t.Errorf("failed to register metric: %v", err)
}
// Test that the metric is registered
metricFamilies, err := registry.Gather()
if err != nil {
t.Errorf("failed to gather metrics: %v", err)
}
found := false
for _, mf := range metricFamilies {
if mf.GetName() == "test_counter" {
found = true
break
}
}
if !found {
t.Error("registered metric not found in registry")
}
}
func TestGatherer(t *testing.T) {
metrics := New()
gatherer := metrics.Gatherer()
if gatherer == nil {
t.Fatal("Gatherer() returned nil")
}
// Register a test metric
counter := prometheus.NewCounter(prometheus.CounterOpts{
Name: "test_gatherer_counter",
Help: "A test counter for gatherer",
})
metrics.Registry().MustRegister(counter)
counter.Inc()
// Test that the gatherer collects our custom metric
metricFamilies, err := gatherer.Gather()
if err != nil {
t.Errorf("failed to gather metrics: %v", err)
}
found := false
for _, mf := range metricFamilies {
if mf.GetName() == "test_gatherer_counter" {
found = true
break
}
}
if !found {
t.Error("registered metric not found via Gatherer()")
}
// Verify gatherer is the same as registry
if gatherer != metrics.r {
t.Error("Gatherer() should return the same object as the registry for custom registry mode")
}
}
func TestHandler(t *testing.T) {
metrics := New()
// Register a test metric
counter := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "test_requests_total",
Help: "Total number of test requests",
},
[]string{"method"},
)
metrics.Registry().MustRegister(counter)
counter.WithLabelValues("GET").Inc()
// Test the handler
handler := metrics.Handler()
if handler == nil {
t.Fatal("Handler() returned nil")
}
// Create a test request
req := httptest.NewRequest("GET", "/metrics", nil)
recorder := httptest.NewRecorder()
// Call the handler
handler.ServeHTTP(recorder, req)
// Check response
resp := recorder.Result()
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
bodyStr := string(body)
// Check for our test metric
if !strings.Contains(bodyStr, "test_requests_total") {
t.Error("test metric not found in metrics output")
}
// Check for OpenMetrics format indicators
if !strings.Contains(bodyStr, "# TYPE") {
t.Error("metrics output missing TYPE comments")
}
}
func TestListenAndServe(t *testing.T) {
metrics := New()
// Register a test metric
counter := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "test_requests_total",
Help: "Total number of test requests",
},
[]string{"method"},
)
metrics.Registry().MustRegister(counter)
counter.WithLabelValues("GET").Inc()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start server in a goroutine
errCh := make(chan error, 1)
go func() {
// Use a high port number to avoid conflicts
errCh <- metrics.ListenAndServe(ctx, 9999)
}()
// Give the server a moment to start
time.Sleep(100 * time.Millisecond)
// Test metrics endpoint
resp, err := http.Get("http://localhost:9999/metrics")
if err != nil {
t.Fatalf("failed to GET /metrics: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
bodyStr := string(body)
// Check for our test metric
if !strings.Contains(bodyStr, "test_requests_total") {
t.Error("test metric not found in metrics output")
}
// Cancel context to stop server
cancel()
// Wait for server to stop
select {
case err := <-errCh:
if err != nil {
t.Errorf("server returned error: %v", err)
}
case <-time.After(5 * time.Second):
t.Error("server did not stop within timeout")
}
}
func TestListenAndServeContextCancellation(t *testing.T) {
metrics := New()
ctx, cancel := context.WithCancel(context.Background())
// Start server
errCh := make(chan error, 1)
go func() {
errCh <- metrics.ListenAndServe(ctx, 9998)
}()
// Give server time to start
time.Sleep(100 * time.Millisecond)
// Cancel context
cancel()
// Server should stop gracefully
select {
case err := <-errCh:
if err != nil {
t.Errorf("server returned error on graceful shutdown: %v", err)
}
case <-time.After(5 * time.Second):
t.Error("server did not stop within timeout after context cancellation")
}
}
func TestNewWithDefaultGatherer(t *testing.T) {
metrics := NewWithDefaultGatherer()
if metrics == nil {
t.Fatal("NewWithDefaultGatherer returned nil")
}
if !metrics.useDefaultGatherer {
t.Error("useDefaultGatherer should be true")
}
gatherer := metrics.Gatherer()
if gatherer == nil {
t.Fatal("Gatherer() returned nil")
}
// Verify it returns the default gatherer
if gatherer != prometheus.DefaultGatherer {
t.Error("Gatherer() should return prometheus.DefaultGatherer when useDefaultGatherer is true")
}
// Verify the custom registry is still available and separate
if metrics.Registry() == nil {
t.Error("Registry() should still return a custom registry")
}
// Test that registering in custom registry doesn't affect default gatherer check
counter := prometheus.NewCounter(prometheus.CounterOpts{
Name: "test_default_gatherer_counter",
Help: "A test counter",
})
metrics.Registry().MustRegister(counter)
// The gatherer should still be the default one, not our custom registry
if metrics.Gatherer() != prometheus.DefaultGatherer {
t.Error("Gatherer() should continue to return prometheus.DefaultGatherer")
}
}
// Benchmark the metrics handler response time
func BenchmarkMetricsHandler(b *testing.B) {
metrics := New()
// Register some test metrics
for i := 0; i < 10; i++ {
counter := prometheus.NewCounter(prometheus.CounterOpts{
Name: fmt.Sprintf("bench_counter_%d", i),
Help: "A benchmark counter",
})
metrics.Registry().MustRegister(counter)
counter.Add(float64(i * 100))
}
handler := metrics.Handler()
b.ResetTimer()
for i := 0; i < b.N; i++ {
req := httptest.NewRequest("GET", "/metrics", nil)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
b.Fatalf("unexpected status code: %d", recorder.Code)
}
}
}