Extract generic trusted proxy handling into xff/ (stdlib only), Echo framework adapter into xff/echo/, and slim xff/fastlyxff/ down to Fastly JSON loading. Key changes: - xff/ uses netip.Prefix for efficient IP matching - Fix XFF extraction to walk right-to-left per MDN spec - Remove echo dependency from core xff package - fastlyxff.New() now returns *xff.TrustedProxies
226 lines
5.4 KiB
Go
226 lines
5.4 KiB
Go
package xff
|
|
|
|
import (
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/netip"
|
|
"testing"
|
|
)
|
|
|
|
func testProxies(t *testing.T) *TrustedProxies {
|
|
t.Helper()
|
|
tp, err := NewFromCIDRs([]string{"192.0.2.0/24", "203.0.113.0/24", "2001:db8::/32"})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return tp
|
|
}
|
|
|
|
func TestNew(t *testing.T) {
|
|
p := netip.MustParsePrefix("10.0.0.0/8")
|
|
tp := New(p)
|
|
if len(tp.Prefixes()) != 1 {
|
|
t.Fatalf("expected 1 prefix, got %d", len(tp.Prefixes()))
|
|
}
|
|
}
|
|
|
|
func TestNewFromCIDRs(t *testing.T) {
|
|
_, err := NewFromCIDRs([]string{"not-a-cidr"})
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid CIDR")
|
|
}
|
|
}
|
|
|
|
func TestIsTrusted(t *testing.T) {
|
|
tp := testProxies(t)
|
|
|
|
tests := []struct {
|
|
ip string
|
|
expected bool
|
|
}{
|
|
{"192.0.2.1", true},
|
|
{"192.0.2.255", true},
|
|
{"203.0.113.1", true},
|
|
{"192.0.3.1", false},
|
|
{"198.51.100.1", false},
|
|
{"2001:db8::1", true},
|
|
{"2001:db8:ffff::1", true},
|
|
{"2001:db9::1", false},
|
|
{"invalid-ip", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.ip, func(t *testing.T) {
|
|
if got := tp.IsTrusted(tt.ip); got != tt.expected {
|
|
t.Errorf("IsTrusted(%s) = %v, want %v", tt.ip, got, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAddCIDR(t *testing.T) {
|
|
tp := testProxies(t)
|
|
|
|
if err := tp.AddCIDR("10.0.0.0/8"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !tp.IsTrusted("10.1.2.3") {
|
|
t.Error("expected 10.1.2.3 to be trusted after AddCIDR")
|
|
}
|
|
|
|
if err := tp.AddCIDR("bad"); err == nil {
|
|
t.Error("expected error for invalid CIDR")
|
|
}
|
|
}
|
|
|
|
func TestAddPrefix(t *testing.T) {
|
|
tp := testProxies(t)
|
|
tp.AddPrefix(netip.MustParsePrefix("172.16.0.0/12"))
|
|
if !tp.IsTrusted("172.16.1.1") {
|
|
t.Error("expected 172.16.1.1 to be trusted after AddPrefix")
|
|
}
|
|
}
|
|
|
|
func TestExtractRealIP(t *testing.T) {
|
|
tp := testProxies(t)
|
|
|
|
tests := []struct {
|
|
name string
|
|
remoteAddr string
|
|
xForwardedFor string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "no XFF, untrusted peer",
|
|
remoteAddr: "198.51.100.1:12345",
|
|
expected: "198.51.100.1",
|
|
},
|
|
{
|
|
name: "trusted proxy, single XFF",
|
|
remoteAddr: "192.0.2.1:80",
|
|
xForwardedFor: "198.51.100.1",
|
|
expected: "198.51.100.1",
|
|
},
|
|
{
|
|
name: "trusted proxy, empty XFF",
|
|
remoteAddr: "192.0.2.1:80",
|
|
xForwardedFor: "",
|
|
expected: "192.0.2.1",
|
|
},
|
|
{
|
|
name: "untrusted peer ignores XFF",
|
|
remoteAddr: "198.51.100.1:80",
|
|
xForwardedFor: "10.0.0.1",
|
|
expected: "198.51.100.1",
|
|
},
|
|
{
|
|
name: "malformed remote addr",
|
|
remoteAddr: "192.0.2.1",
|
|
xForwardedFor: "198.51.100.1",
|
|
expected: "198.51.100.1",
|
|
},
|
|
// Right-to-left: "client, proxy1(trusted)" -> skip proxy1, return client
|
|
{
|
|
name: "right-to-left skips trusted proxies in XFF",
|
|
remoteAddr: "192.0.2.1:80",
|
|
xForwardedFor: "198.51.100.1, 203.0.113.1",
|
|
expected: "198.51.100.1",
|
|
},
|
|
// Right-to-left: "spoofed, real-client, trusted-proxy"
|
|
// should return real-client (first untrusted from right)
|
|
{
|
|
name: "right-to-left stops at first untrusted from right",
|
|
remoteAddr: "192.0.2.1:80",
|
|
xForwardedFor: "198.51.100.50, 198.51.100.99, 203.0.113.1",
|
|
expected: "198.51.100.99",
|
|
},
|
|
{
|
|
name: "IPv6 trusted proxy",
|
|
remoteAddr: "[2001:db8::1]:80",
|
|
xForwardedFor: "198.51.100.1",
|
|
expected: "198.51.100.1",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = tt.remoteAddr
|
|
if tt.xForwardedFor != "" {
|
|
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
|
}
|
|
if got := tp.ExtractRealIP(req); got != tt.expected {
|
|
t.Errorf("ExtractRealIP() = %s, want %s", got, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHTTPMiddleware(t *testing.T) {
|
|
tp := testProxies(t)
|
|
mw := tp.HTTPMiddleware()
|
|
|
|
tests := []struct {
|
|
name string
|
|
remoteAddr string
|
|
xForwardedFor string
|
|
expectedRealIP string
|
|
}{
|
|
{
|
|
name: "direct connection",
|
|
remoteAddr: "198.51.100.1:12345",
|
|
expectedRealIP: "198.51.100.1",
|
|
},
|
|
{
|
|
name: "trusted proxy with XFF",
|
|
remoteAddr: "192.0.2.1:80",
|
|
xForwardedFor: "198.51.100.1",
|
|
expectedRealIP: "198.51.100.1",
|
|
},
|
|
{
|
|
name: "untrusted proxy ignored",
|
|
remoteAddr: "198.51.100.2:80",
|
|
xForwardedFor: "10.0.0.1",
|
|
expectedRealIP: "198.51.100.2",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
var capturedRealIP, capturedRemoteAddr string
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
capturedRealIP = GetRealIP(r)
|
|
capturedRemoteAddr = r.RemoteAddr
|
|
})
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = tt.remoteAddr
|
|
if tt.xForwardedFor != "" {
|
|
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
|
}
|
|
|
|
rr := httptest.NewRecorder()
|
|
mw(handler).ServeHTTP(rr, req)
|
|
|
|
if capturedRealIP != tt.expectedRealIP {
|
|
t.Errorf("GetRealIP: got %s, want %s", capturedRealIP, tt.expectedRealIP)
|
|
}
|
|
|
|
expectedAddr := net.JoinHostPort(tt.expectedRealIP, "0")
|
|
if capturedRemoteAddr != expectedAddr {
|
|
t.Errorf("RemoteAddr: got %s, want %s", capturedRemoteAddr, expectedAddr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetRealIPWithoutMiddleware(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "198.51.100.1:12345"
|
|
|
|
if got := GetRealIP(req); got != "198.51.100.1" {
|
|
t.Errorf("GetRealIP() = %s, want 198.51.100.1", got)
|
|
}
|
|
}
|