refactor(xff): split into generic, echo, and fastly packages
Extract generic trusted proxy handling into xff/ (stdlib only), Echo framework adapter into xff/echo/, and slim xff/fastlyxff/ down to Fastly JSON loading. Key changes: - xff/ uses netip.Prefix for efficient IP matching - Fix XFF extraction to walk right-to-left per MDN spec - Remove echo dependency from core xff package - fastlyxff.New() now returns *xff.TrustedProxies
This commit is contained in:
225
xff/xff_test.go
Normal file
225
xff/xff_test.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package xff
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func testProxies(t *testing.T) *TrustedProxies {
|
||||
t.Helper()
|
||||
tp, err := NewFromCIDRs([]string{"192.0.2.0/24", "203.0.113.0/24", "2001:db8::/32"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return tp
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
p := netip.MustParsePrefix("10.0.0.0/8")
|
||||
tp := New(p)
|
||||
if len(tp.Prefixes()) != 1 {
|
||||
t.Fatalf("expected 1 prefix, got %d", len(tp.Prefixes()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFromCIDRs(t *testing.T) {
|
||||
_, err := NewFromCIDRs([]string{"not-a-cidr"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid CIDR")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTrusted(t *testing.T) {
|
||||
tp := testProxies(t)
|
||||
|
||||
tests := []struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
{"192.0.2.1", true},
|
||||
{"192.0.2.255", true},
|
||||
{"203.0.113.1", true},
|
||||
{"192.0.3.1", false},
|
||||
{"198.51.100.1", false},
|
||||
{"2001:db8::1", true},
|
||||
{"2001:db8:ffff::1", true},
|
||||
{"2001:db9::1", false},
|
||||
{"invalid-ip", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
if got := tp.IsTrusted(tt.ip); got != tt.expected {
|
||||
t.Errorf("IsTrusted(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddCIDR(t *testing.T) {
|
||||
tp := testProxies(t)
|
||||
|
||||
if err := tp.AddCIDR("10.0.0.0/8"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !tp.IsTrusted("10.1.2.3") {
|
||||
t.Error("expected 10.1.2.3 to be trusted after AddCIDR")
|
||||
}
|
||||
|
||||
if err := tp.AddCIDR("bad"); err == nil {
|
||||
t.Error("expected error for invalid CIDR")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddPrefix(t *testing.T) {
|
||||
tp := testProxies(t)
|
||||
tp.AddPrefix(netip.MustParsePrefix("172.16.0.0/12"))
|
||||
if !tp.IsTrusted("172.16.1.1") {
|
||||
t.Error("expected 172.16.1.1 to be trusted after AddPrefix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRealIP(t *testing.T) {
|
||||
tp := testProxies(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
xForwardedFor string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no XFF, untrusted peer",
|
||||
remoteAddr: "198.51.100.1:12345",
|
||||
expected: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "trusted proxy, single XFF",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expected: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "trusted proxy, empty XFF",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "",
|
||||
expected: "192.0.2.1",
|
||||
},
|
||||
{
|
||||
name: "untrusted peer ignores XFF",
|
||||
remoteAddr: "198.51.100.1:80",
|
||||
xForwardedFor: "10.0.0.1",
|
||||
expected: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "malformed remote addr",
|
||||
remoteAddr: "192.0.2.1",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expected: "198.51.100.1",
|
||||
},
|
||||
// Right-to-left: "client, proxy1(trusted)" -> skip proxy1, return client
|
||||
{
|
||||
name: "right-to-left skips trusted proxies in XFF",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "198.51.100.1, 203.0.113.1",
|
||||
expected: "198.51.100.1",
|
||||
},
|
||||
// Right-to-left: "spoofed, real-client, trusted-proxy"
|
||||
// should return real-client (first untrusted from right)
|
||||
{
|
||||
name: "right-to-left stops at first untrusted from right",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "198.51.100.50, 198.51.100.99, 203.0.113.1",
|
||||
expected: "198.51.100.99",
|
||||
},
|
||||
{
|
||||
name: "IPv6 trusted proxy",
|
||||
remoteAddr: "[2001:db8::1]:80",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expected: "198.51.100.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
if tt.xForwardedFor != "" {
|
||||
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||
}
|
||||
if got := tp.ExtractRealIP(req); got != tt.expected {
|
||||
t.Errorf("ExtractRealIP() = %s, want %s", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPMiddleware(t *testing.T) {
|
||||
tp := testProxies(t)
|
||||
mw := tp.HTTPMiddleware()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
xForwardedFor string
|
||||
expectedRealIP string
|
||||
}{
|
||||
{
|
||||
name: "direct connection",
|
||||
remoteAddr: "198.51.100.1:12345",
|
||||
expectedRealIP: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "trusted proxy with XFF",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expectedRealIP: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "untrusted proxy ignored",
|
||||
remoteAddr: "198.51.100.2:80",
|
||||
xForwardedFor: "10.0.0.1",
|
||||
expectedRealIP: "198.51.100.2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedRealIP, capturedRemoteAddr string
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedRealIP = GetRealIP(r)
|
||||
capturedRemoteAddr = r.RemoteAddr
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
if tt.xForwardedFor != "" {
|
||||
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
mw(handler).ServeHTTP(rr, req)
|
||||
|
||||
if capturedRealIP != tt.expectedRealIP {
|
||||
t.Errorf("GetRealIP: got %s, want %s", capturedRealIP, tt.expectedRealIP)
|
||||
}
|
||||
|
||||
expectedAddr := net.JoinHostPort(tt.expectedRealIP, "0")
|
||||
if capturedRemoteAddr != expectedAddr {
|
||||
t.Errorf("RemoteAddr: got %s, want %s", capturedRemoteAddr, expectedAddr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRealIPWithoutMiddleware(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = "198.51.100.1:12345"
|
||||
|
||||
if got := GetRealIP(req); got != "198.51.100.1" {
|
||||
t.Errorf("GetRealIP() = %s, want 198.51.100.1", got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user