refactor(xff): split into generic, echo, and fastly packages
Extract generic trusted proxy handling into xff/ (stdlib only), Echo framework adapter into xff/echo/, and slim xff/fastlyxff/ down to Fastly JSON loading. Key changes: - xff/ uses netip.Prefix for efficient IP matching - Fix XFF extraction to walk right-to-left per MDN spec - Remove echo dependency from core xff package - fastlyxff.New() now returns *xff.TrustedProxies
This commit is contained in:
@@ -1,356 +1,44 @@
|
||||
package fastlyxff
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFastlyIPRanges(t *testing.T) {
|
||||
fastlyxff, err := New("fastly.json")
|
||||
func TestNew(t *testing.T) {
|
||||
tp, err := New("fastly.json")
|
||||
if err != nil {
|
||||
t.Fatalf("could not load test data: %s", err)
|
||||
}
|
||||
|
||||
data, err := fastlyxff.EchoTrustOption()
|
||||
prefixes := tp.Prefixes()
|
||||
if len(prefixes) < 10 {
|
||||
t.Errorf("only got %d prefixes, expected more", len(prefixes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFileNotFound(t *testing.T) {
|
||||
_, err := New("nonexistent.json")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewInvalidJSON(t *testing.T) {
|
||||
// Create a temp file with invalid JSON
|
||||
f, err := os.CreateTemp("", "fastlyxff-test-*.json")
|
||||
if err != nil {
|
||||
t.Fatalf("could not parse test data: %s", err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
|
||||
if len(data) < 10 {
|
||||
t.Logf("only got %d prefixes, expected more", len(data))
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPMiddleware(t *testing.T) {
|
||||
// Create a test FastlyXFF instance with known IP ranges
|
||||
xff := &FastlyXFF{
|
||||
IPv4: []string{"192.0.2.0/24", "203.0.113.0/24"},
|
||||
IPv6: []string{"2001:db8::/32"},
|
||||
}
|
||||
|
||||
middleware := xff.HTTPMiddleware()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
xForwardedFor string
|
||||
expectedRealIP string
|
||||
}{
|
||||
{
|
||||
name: "direct connection",
|
||||
remoteAddr: "198.51.100.1:12345",
|
||||
xForwardedFor: "",
|
||||
expectedRealIP: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "trusted proxy with XFF",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expectedRealIP: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "trusted proxy with multiple XFF",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "198.51.100.1, 203.0.113.1",
|
||||
expectedRealIP: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "untrusted proxy ignored",
|
||||
remoteAddr: "198.51.100.2:80",
|
||||
xForwardedFor: "10.0.0.1",
|
||||
expectedRealIP: "198.51.100.2",
|
||||
},
|
||||
{
|
||||
name: "IPv6 trusted proxy",
|
||||
remoteAddr: "[2001:db8::1]:80",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expectedRealIP: "198.51.100.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create test handler that captures both GetRealIP and r.RemoteAddr
|
||||
var capturedRealIP, capturedRemoteAddr string
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedRealIP = GetRealIP(r)
|
||||
capturedRemoteAddr = r.RemoteAddr
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with middleware
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
if tt.xForwardedFor != "" {
|
||||
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
middleware(handler).ServeHTTP(rr, req)
|
||||
|
||||
// Test GetRealIP function
|
||||
if capturedRealIP != tt.expectedRealIP {
|
||||
t.Errorf("GetRealIP: expected %s, got %s", tt.expectedRealIP, capturedRealIP)
|
||||
}
|
||||
|
||||
// Test that r.RemoteAddr is updated with real IP and port 0
|
||||
// (since the original port is from the proxy, not the real client)
|
||||
expectedRemoteAddr := net.JoinHostPort(tt.expectedRealIP, "0")
|
||||
if capturedRemoteAddr != expectedRemoteAddr {
|
||||
t.Errorf("RemoteAddr: expected %s, got %s", expectedRemoteAddr, capturedRemoteAddr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTrustedProxy(t *testing.T) {
|
||||
xff := &FastlyXFF{
|
||||
IPv4: []string{"192.0.2.0/24", "203.0.113.0/24"},
|
||||
IPv6: []string{"2001:db8::/32"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
{"192.0.2.1", true},
|
||||
{"192.0.2.255", true},
|
||||
{"203.0.113.1", true},
|
||||
{"192.0.3.1", false},
|
||||
{"198.51.100.1", false},
|
||||
{"2001:db8::1", true},
|
||||
{"2001:db8:ffff::1", true},
|
||||
{"2001:db9::1", false},
|
||||
{"invalid-ip", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := xff.isTrustedProxy(tt.ip)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isTrustedProxy(%s) = %v, expected %v", tt.ip, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRealIP(t *testing.T) {
|
||||
xff := &FastlyXFF{
|
||||
IPv4: []string{"192.0.2.0/24"},
|
||||
IPv6: []string{"2001:db8::/32"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
xForwardedFor string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no XFF header",
|
||||
remoteAddr: "198.51.100.1:12345",
|
||||
xForwardedFor: "",
|
||||
expected: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "trusted proxy with single IP",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expected: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "trusted proxy with multiple IPs",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "198.51.100.1, 203.0.113.5",
|
||||
expected: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "untrusted proxy",
|
||||
remoteAddr: "198.51.100.1:80",
|
||||
xForwardedFor: "10.0.0.1",
|
||||
expected: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "empty XFF",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "",
|
||||
expected: "192.0.2.1",
|
||||
},
|
||||
{
|
||||
name: "malformed remote addr",
|
||||
remoteAddr: "192.0.2.1",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expected: "198.51.100.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
if tt.xForwardedFor != "" {
|
||||
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||
}
|
||||
|
||||
result := xff.extractRealIP(req)
|
||||
if result != tt.expected {
|
||||
t.Errorf("extractRealIP() = %s, expected %s", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRealIPWithoutMiddleware(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = "198.51.100.1:12345"
|
||||
|
||||
realIP := GetRealIP(req)
|
||||
expected := "198.51.100.1"
|
||||
if realIP != expected {
|
||||
t.Errorf("GetRealIP() = %s, expected %s", realIP, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddTrustedCIDR(t *testing.T) {
|
||||
xff := &FastlyXFF{
|
||||
IPv4: []string{"192.0.2.0/24"},
|
||||
IPv6: []string{"2001:db8::/32"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cidr string
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid IPv4 range", "10.0.0.0/8", false},
|
||||
{"valid IPv6 range", "fc00::/7", false},
|
||||
{"valid single IP", "203.0.113.1/32", false},
|
||||
{"invalid CIDR", "not-a-cidr", true},
|
||||
{"invalid format", "10.0.0.0/99", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := xff.AddTrustedCIDR(tt.cidr)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("AddTrustedCIDR(%s) error = %v, wantErr %v", tt.cidr, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomTrustedCIDRs(t *testing.T) {
|
||||
xff := &FastlyXFF{
|
||||
IPv4: []string{"192.0.2.0/24"},
|
||||
IPv6: []string{"2001:db8::/32"},
|
||||
}
|
||||
|
||||
// Add custom trusted CIDRs
|
||||
err := xff.AddTrustedCIDR("10.0.0.0/8")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add trusted CIDR: %v", err)
|
||||
}
|
||||
|
||||
err = xff.AddTrustedCIDR("172.16.0.0/12")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add trusted CIDR: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
// Original Fastly ranges
|
||||
{"192.0.2.1", true},
|
||||
{"2001:db8::1", true},
|
||||
// Custom CIDRs
|
||||
{"10.1.2.3", true},
|
||||
{"172.16.1.1", true},
|
||||
// Not trusted
|
||||
{"198.51.100.1", false},
|
||||
{"172.15.1.1", false},
|
||||
{"10.0.0.0", true}, // Network address should still match
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := xff.isTrustedProxy(tt.ip)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isTrustedProxy(%s) = %v, expected %v", tt.ip, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPMiddlewareWithCustomCIDRs(t *testing.T) {
|
||||
xff := &FastlyXFF{
|
||||
IPv4: []string{"192.0.2.0/24"},
|
||||
IPv6: []string{"2001:db8::/32"},
|
||||
}
|
||||
|
||||
// Add custom trusted CIDR for internal proxies
|
||||
err := xff.AddTrustedCIDR("10.0.0.0/8")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add trusted CIDR: %v", err)
|
||||
}
|
||||
|
||||
middleware := xff.HTTPMiddleware()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
xForwardedFor string
|
||||
expectedRealIP string
|
||||
}{
|
||||
{
|
||||
name: "custom trusted proxy with XFF",
|
||||
remoteAddr: "10.1.2.3:80",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expectedRealIP: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "fastly proxy with XFF",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expectedRealIP: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "untrusted proxy ignored",
|
||||
remoteAddr: "172.16.1.1:80",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expectedRealIP: "172.16.1.1",
|
||||
},
|
||||
{
|
||||
name: "chain through custom and fastly",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "198.51.100.1, 10.1.2.3",
|
||||
expectedRealIP: "198.51.100.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedIP string
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedIP = GetRealIP(r)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
if tt.xForwardedFor != "" {
|
||||
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
middleware(handler).ServeHTTP(rr, req)
|
||||
|
||||
if capturedIP != tt.expectedRealIP {
|
||||
t.Errorf("expected real IP %s, got %s", tt.expectedRealIP, capturedIP)
|
||||
}
|
||||
})
|
||||
if _, err := f.WriteString("{invalid"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f.Close()
|
||||
|
||||
_, err = New(f.Name())
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user