package xff import ( "net" "net/http" "net/http/httptest" "net/netip" "testing" ) func testProxies(t *testing.T) *TrustedProxies { t.Helper() tp, err := NewFromCIDRs([]string{"192.0.2.0/24", "203.0.113.0/24", "2001:db8::/32"}) if err != nil { t.Fatal(err) } return tp } func TestNew(t *testing.T) { p := netip.MustParsePrefix("10.0.0.0/8") tp := New(p) if len(tp.Prefixes()) != 1 { t.Fatalf("expected 1 prefix, got %d", len(tp.Prefixes())) } } func TestNewFromCIDRs(t *testing.T) { _, err := NewFromCIDRs([]string{"not-a-cidr"}) if err == nil { t.Fatal("expected error for invalid CIDR") } } func TestIsTrusted(t *testing.T) { tp := testProxies(t) tests := []struct { ip string expected bool }{ {"192.0.2.1", true}, {"192.0.2.255", true}, {"203.0.113.1", true}, {"192.0.3.1", false}, {"198.51.100.1", false}, {"2001:db8::1", true}, {"2001:db8:ffff::1", true}, {"2001:db9::1", false}, {"invalid-ip", false}, } for _, tt := range tests { t.Run(tt.ip, func(t *testing.T) { if got := tp.IsTrusted(tt.ip); got != tt.expected { t.Errorf("IsTrusted(%s) = %v, want %v", tt.ip, got, tt.expected) } }) } } func TestAddCIDR(t *testing.T) { tp := testProxies(t) if err := tp.AddCIDR("10.0.0.0/8"); err != nil { t.Fatal(err) } if !tp.IsTrusted("10.1.2.3") { t.Error("expected 10.1.2.3 to be trusted after AddCIDR") } if err := tp.AddCIDR("bad"); err == nil { t.Error("expected error for invalid CIDR") } } func TestAddPrefix(t *testing.T) { tp := testProxies(t) tp.AddPrefix(netip.MustParsePrefix("172.16.0.0/12")) if !tp.IsTrusted("172.16.1.1") { t.Error("expected 172.16.1.1 to be trusted after AddPrefix") } } func TestExtractRealIP(t *testing.T) { tp := testProxies(t) tests := []struct { name string remoteAddr string xForwardedFor string expected string }{ { name: "no XFF, untrusted peer", remoteAddr: "198.51.100.1:12345", expected: "198.51.100.1", }, { name: "trusted proxy, single XFF", remoteAddr: "192.0.2.1:80", xForwardedFor: "198.51.100.1", expected: "198.51.100.1", }, { name: "trusted proxy, empty XFF", remoteAddr: "192.0.2.1:80", xForwardedFor: "", expected: "192.0.2.1", }, { name: "untrusted peer ignores XFF", remoteAddr: "198.51.100.1:80", xForwardedFor: "10.0.0.1", expected: "198.51.100.1", }, { name: "malformed remote addr", remoteAddr: "192.0.2.1", xForwardedFor: "198.51.100.1", expected: "198.51.100.1", }, // Right-to-left: "client, proxy1(trusted)" -> skip proxy1, return client { name: "right-to-left skips trusted proxies in XFF", remoteAddr: "192.0.2.1:80", xForwardedFor: "198.51.100.1, 203.0.113.1", expected: "198.51.100.1", }, // Right-to-left: "spoofed, real-client, trusted-proxy" // should return real-client (first untrusted from right) { name: "right-to-left stops at first untrusted from right", remoteAddr: "192.0.2.1:80", xForwardedFor: "198.51.100.50, 198.51.100.99, 203.0.113.1", expected: "198.51.100.99", }, { name: "IPv6 trusted proxy", remoteAddr: "[2001:db8::1]:80", xForwardedFor: "198.51.100.1", expected: "198.51.100.1", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req.RemoteAddr = tt.remoteAddr if tt.xForwardedFor != "" { req.Header.Set("X-Forwarded-For", tt.xForwardedFor) } if got := tp.ExtractRealIP(req); got != tt.expected { t.Errorf("ExtractRealIP() = %s, want %s", got, tt.expected) } }) } } func TestHTTPMiddleware(t *testing.T) { tp := testProxies(t) mw := tp.HTTPMiddleware() tests := []struct { name string remoteAddr string xForwardedFor string expectedRealIP string }{ { name: "direct connection", remoteAddr: "198.51.100.1:12345", expectedRealIP: "198.51.100.1", }, { name: "trusted proxy with XFF", remoteAddr: "192.0.2.1:80", xForwardedFor: "198.51.100.1", expectedRealIP: "198.51.100.1", }, { name: "untrusted proxy ignored", remoteAddr: "198.51.100.2:80", xForwardedFor: "10.0.0.1", expectedRealIP: "198.51.100.2", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var capturedRealIP, capturedRemoteAddr string handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedRealIP = GetRealIP(r) capturedRemoteAddr = r.RemoteAddr }) req := httptest.NewRequest("GET", "/", nil) req.RemoteAddr = tt.remoteAddr if tt.xForwardedFor != "" { req.Header.Set("X-Forwarded-For", tt.xForwardedFor) } rr := httptest.NewRecorder() mw(handler).ServeHTTP(rr, req) if capturedRealIP != tt.expectedRealIP { t.Errorf("GetRealIP: got %s, want %s", capturedRealIP, tt.expectedRealIP) } expectedAddr := net.JoinHostPort(tt.expectedRealIP, "0") if capturedRemoteAddr != expectedAddr { t.Errorf("RemoteAddr: got %s, want %s", capturedRemoteAddr, expectedAddr) } }) } } func TestGetRealIPWithoutMiddleware(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req.RemoteAddr = "198.51.100.1:12345" if got := GetRealIP(req); got != "198.51.100.1" { t.Errorf("GetRealIP() = %s, want 198.51.100.1", got) } }