package main import ( "encoding/json" "io" "net/http" "net/http/httptest" "net/url" "strings" "testing" "time" "github.com/oschwald/geoip2-golang" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel/trace" ) // Integration tests that test the HTTP API endpoints with a running server func TestHTTPIntegration(t *testing.T) { // Skip integration tests if no database is available if testing.Short() { t.Skip("Skipping integration tests in short mode") } // Create test server server := createTestServer(t) defer server.Close() baseURL := server.URL t.Run("Health check endpoint", func(t *testing.T) { resp, err := http.Get(baseURL + "/healthz") require.NoError(t, err) defer resp.Body.Close() // Health check might fail without a real database, but should respond assert.Contains(t, []int{200, 500}, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) assert.NotEmpty(t, body) }) t.Run("Country API endpoint - invalid IP", func(t *testing.T) { resp, err := http.Get(baseURL + "/api/country?ip=invalid") require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) assert.Contains(t, string(body), "data error") }) t.Run("JSON API endpoint - invalid IP", func(t *testing.T) { resp, err := http.Get(baseURL + "/api/json?ip=invalid") require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) assert.Contains(t, string(body), "data error") }) t.Run("Country API endpoint - missing IP", func(t *testing.T) { resp, err := http.Get(baseURL + "/api/country") require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) }) t.Run("JSON API endpoint - missing IP", func(t *testing.T) { resp, err := http.Get(baseURL + "/api/json") require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) }) t.Run("Non-existent endpoint", func(t *testing.T) { resp, err := http.Get(baseURL + "/nonexistent") require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusNotFound, resp.StatusCode) }) t.Run("HTTP headers verification", func(t *testing.T) { resp, err := http.Get(baseURL + "/healthz") require.NoError(t, err) defer resp.Body.Close() // Check for version header serverHeader := resp.Header.Get("Server") assert.Contains(t, serverHeader, "geoipapi/") // Check for traceparent header (from OpenTelemetry) traceparent := resp.Header.Get("Traceparent") if traceparent != "" { assert.NotEmpty(t, traceparent) } }) t.Run("Multiple different IP formats", func(t *testing.T) { testIPs := []string{ "8.8.8.8", "127.0.0.1", "192.168.1.1", } for _, ip := range testIPs { t.Run("IP_"+ip, func(t *testing.T) { // Test country endpoint resp, err := http.Get(baseURL + "/api/country?ip=" + url.QueryEscape(ip)) require.NoError(t, err) defer resp.Body.Close() // Should get some response (might be error due to no DB) assert.Contains(t, []int{200, 500}, resp.StatusCode) // Test JSON endpoint resp2, err := http.Get(baseURL + "/api/json?ip=" + url.QueryEscape(ip)) require.NoError(t, err) defer resp2.Body.Close() assert.Contains(t, []int{200, 500}, resp2.StatusCode) }) } }) t.Run("Concurrent requests", func(t *testing.T) { numRequests := 10 results := make(chan int, numRequests) for i := 0; i < numRequests; i++ { go func() { resp, err := http.Get(baseURL + "/healthz") if err != nil { results <- 0 return } defer resp.Body.Close() results <- resp.StatusCode }() } // Collect all results for i := 0; i < numRequests; i++ { statusCode := <-results assert.Contains(t, []int{200, 500}, statusCode) } }) t.Run("Request timeout handling", func(t *testing.T) { client := &http.Client{ Timeout: 1 * time.Millisecond, // Very short timeout } // This might timeout or succeed depending on timing resp, err := client.Get(baseURL + "/healthz") if err == nil { resp.Body.Close() } // We just want to ensure the server handles timeouts gracefully }) } func TestHTTPMethodSupport(t *testing.T) { server := createTestServer(t) defer server.Close() baseURL := server.URL methods := []string{"GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS"} for _, method := range methods { t.Run("Method_"+method, func(t *testing.T) { req, err := http.NewRequest(method, baseURL+"/api/country?ip=8.8.8.8", nil) require.NoError(t, err) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) defer resp.Body.Close() // Our handlers should accept all methods assert.NotEqual(t, http.StatusMethodNotAllowed, resp.StatusCode) }) } } func TestOpenTelemetryIntegration(t *testing.T) { server := createTestServer(t) defer server.Close() baseURL := server.URL t.Run("Tracing headers", func(t *testing.T) { req, err := http.NewRequest("GET", baseURL+"/api/country?ip=8.8.8.8", nil) require.NoError(t, err) // Add tracing headers req.Header.Set("traceparent", "00-12345678901234567890123456789012-1234567890123456-01") client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) defer resp.Body.Close() // Server should handle tracing headers gracefully assert.Contains(t, []int{200, 500}, resp.StatusCode) }) t.Run("Health check filtering", func(t *testing.T) { // Health checks should be filtered from tracing resp, err := http.Get(baseURL + "/healthz") require.NoError(t, err) defer resp.Body.Close() // This tests that health check requests don't cause tracing issues assert.Contains(t, []int{200, 500}, resp.StatusCode) }) } func TestResponseFormats(t *testing.T) { server := createTestServer(t) defer server.Close() baseURL := server.URL t.Run("Country endpoint response format", func(t *testing.T) { resp, err := http.Get(baseURL + "/api/country?ip=invalid") require.NoError(t, err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) require.NoError(t, err) // Should be plain text error or country code if resp.StatusCode == 200 { // Should be plain text country code (lowercase) assert.True(t, len(body) >= 2 && len(body) <= 3) } else { // Should be error message assert.Contains(t, string(body), "error") } }) t.Run("JSON endpoint response format", func(t *testing.T) { resp, err := http.Get(baseURL + "/api/json?ip=invalid") require.NoError(t, err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) require.NoError(t, err) if resp.StatusCode == 200 { // Should be valid JSON var city geoip2.City err := json.Unmarshal(body, &city) assert.NoError(t, err) } else { // Should be error message assert.Contains(t, string(body), "error") } }) } func TestQueryParameterHandling(t *testing.T) { server := createTestServer(t) defer server.Close() baseURL := server.URL t.Run("Multiple query parameters", func(t *testing.T) { resp, err := http.Get(baseURL + "/api/country?ip=8.8.8.8&extra=value&another=param") require.NoError(t, err) defer resp.Body.Close() // Should handle extra parameters gracefully assert.Contains(t, []int{200, 500}, resp.StatusCode) }) t.Run("URL encoded parameters", func(t *testing.T) { // Test with IPv6 address that needs encoding ip := "2001:4860:4860::8888" encodedIP := url.QueryEscape(ip) resp, err := http.Get(baseURL + "/api/country?ip=" + encodedIP) require.NoError(t, err) defer resp.Body.Close() assert.Contains(t, []int{200, 500}, resp.StatusCode) }) t.Run("Duplicate parameters", func(t *testing.T) { resp, err := http.Get(baseURL + "/api/country?ip=8.8.8.8&ip=1.1.1.1") require.NoError(t, err) defer resp.Body.Close() // Should handle duplicate parameters (typically uses first value) assert.Contains(t, []int{200, 500}, resp.StatusCode) }) } func TestErrorHandling(t *testing.T) { server := createTestServer(t) defer server.Close() baseURL := server.URL t.Run("Malformed requests", func(t *testing.T) { // Test various malformed requests malformedRequests := []string{ "/api/country?ip=", "/api/country?ip=%", "/api/json?ip=256.256.256.256", "/api/json?ip=not.an.ip.address", } for _, reqURL := range malformedRequests { t.Run("Request_"+reqURL, func(t *testing.T) { resp, err := http.Get(baseURL + reqURL) require.NoError(t, err) defer resp.Body.Close() // Should handle malformed requests gracefully assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) }) } }) t.Run("Large request handling", func(t *testing.T) { // Test with very long IP parameter longParam := strings.Repeat("1", 1000) resp, err := http.Get(baseURL + "/api/country?ip=" + longParam) require.NoError(t, err) defer resp.Body.Close() // Should handle gracefully assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) }) } // createTestServer creates a test HTTP server with the same configuration as the main server func createTestServer(t *testing.T) *httptest.Server { t.Helper() mux := http.NewServeMux() mux.HandleFunc("/api/country", handleCountry) mux.HandleFunc("/api/json", handleJSON) mux.HandleFunc("/healthz", handleHealth) // Add version handler (simplified version for testing) versionHandler := func(next http.Handler) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Server", "geoipapi/test") span := trace.SpanFromContext(r.Context()) if span.SpanContext().IsValid() { w.Header().Set("Traceparent", span.SpanContext().TraceID().String()) } next.ServeHTTP(w, r) }) } // Use OTel HTTP handler with health check filtering handler := otelhttp.NewHandler( versionHandler(mux), "geoipapi-test", otelhttp.WithFilter(func(r *http.Request) bool { return r.URL.Path != "/healthz" }), ) return httptest.NewServer(handler) }