refactor(xff): split into generic, echo, and fastly packages
Extract generic trusted proxy handling into xff/ (stdlib only), Echo framework adapter into xff/echo/, and slim xff/fastlyxff/ down to Fastly JSON loading. Key changes: - xff/ uses netip.Prefix for efficient IP matching - Fix XFF extraction to walk right-to-left per MDN spec - Remove echo dependency from core xff package - fastlyxff.New() now returns *xff.TrustedProxies
This commit is contained in:
48
xff/echo/echo.go
Normal file
48
xff/echo/echo.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
// Package xffecho adapts [xff.TrustedProxies] for use with the Echo web
|
||||||
|
// framework's X-Forwarded-For IP extraction.
|
||||||
|
//
|
||||||
|
// # Usage
|
||||||
|
//
|
||||||
|
// tp, err := fastlyxff.New("fastly.json")
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// trustOpts := xffecho.TrustOptions(tp)
|
||||||
|
// e.IPExtractor = echo.ExtractIPFromXFFHeader(trustOpts...)
|
||||||
|
package xffecho
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
|
||||||
|
"go.ntppool.org/common/xff"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TrustOptions converts a [xff.TrustedProxies] into Echo trust options
|
||||||
|
// for use with [echo.ExtractIPFromXFFHeader].
|
||||||
|
func TrustOptions(tp *xff.TrustedProxies) []echo.TrustOption {
|
||||||
|
prefixes := tp.Prefixes()
|
||||||
|
opts := make([]echo.TrustOption, 0, len(prefixes))
|
||||||
|
for _, p := range prefixes {
|
||||||
|
opts = append(opts, echo.TrustIPRange(prefixToIPNet(p)))
|
||||||
|
}
|
||||||
|
return opts
|
||||||
|
}
|
||||||
|
|
||||||
|
// prefixToIPNet bridges netip.Prefix (used by xff) to net.IPNet (used by Echo).
|
||||||
|
func prefixToIPNet(p netip.Prefix) *net.IPNet {
|
||||||
|
addr := p.Masked().Addr()
|
||||||
|
bits := p.Bits()
|
||||||
|
|
||||||
|
ipLen := 128
|
||||||
|
if addr.Is4() {
|
||||||
|
ipLen = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
return &net.IPNet{
|
||||||
|
IP: net.IP(addr.AsSlice()),
|
||||||
|
Mask: net.CIDRMask(bits, ipLen),
|
||||||
|
}
|
||||||
|
}
|
||||||
31
xff/echo/echo_test.go
Normal file
31
xff/echo/echo_test.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package xffecho
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"go.ntppool.org/common/xff"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTrustOptions(t *testing.T) {
|
||||||
|
tp, err := xff.NewFromCIDRs([]string{
|
||||||
|
"192.0.2.0/24",
|
||||||
|
"203.0.113.0/24",
|
||||||
|
"2001:db8::/32",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := TrustOptions(tp)
|
||||||
|
if len(opts) != 3 {
|
||||||
|
t.Errorf("expected 3 trust options, got %d", len(opts))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTrustOptionsEmpty(t *testing.T) {
|
||||||
|
tp := xff.New()
|
||||||
|
opts := TrustOptions(tp)
|
||||||
|
if len(opts) != 0 {
|
||||||
|
t.Errorf("expected 0 trust options, got %d", len(opts))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,270 +1,53 @@
|
|||||||
// Package fastlyxff provides Fastly CDN IP range management for trusted proxy handling.
|
// Package fastlyxff loads Fastly CDN IP ranges and returns a generic
|
||||||
|
// [xff.TrustedProxies] for trusted proxy handling.
|
||||||
//
|
//
|
||||||
// This package parses Fastly's public IP ranges JSON file and provides middleware
|
// Fastly publishes their edge server IP ranges in a JSON format:
|
||||||
// for both Echo framework and standard net/http for proper client IP extraction
|
|
||||||
// from X-Forwarded-For headers. It's designed specifically for services deployed
|
|
||||||
// behind Fastly's CDN that need to identify real client IPs for logging, rate
|
|
||||||
// limiting, and security purposes.
|
|
||||||
//
|
|
||||||
// Fastly publishes their edge server IP ranges in a JSON format that this package
|
|
||||||
// consumes to automatically configure trusted proxy ranges. This ensures that
|
|
||||||
// X-Forwarded-For headers are only trusted when they originate from legitimate
|
|
||||||
// Fastly edge servers.
|
|
||||||
//
|
|
||||||
// Key features:
|
|
||||||
// - Automatic parsing of Fastly's IP ranges JSON format
|
|
||||||
// - Support for both IPv4 and IPv6 address ranges
|
|
||||||
// - Echo framework integration via TrustOption generation
|
|
||||||
// - Standard net/http middleware support
|
|
||||||
// - CIDR notation parsing and validation
|
|
||||||
//
|
|
||||||
// # Echo Framework Usage
|
|
||||||
//
|
|
||||||
// fastlyRanges, err := fastlyxff.New("fastly.json")
|
|
||||||
// if err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
// options, err := fastlyRanges.EchoTrustOption()
|
|
||||||
// if err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
// e.IPExtractor = echo.ExtractIPFromXFFHeader(options...)
|
|
||||||
//
|
|
||||||
// # Net/HTTP Usage
|
|
||||||
//
|
|
||||||
// fastlyRanges, err := fastlyxff.New("fastly.json")
|
|
||||||
// if err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
// middleware := fastlyRanges.HTTPMiddleware()
|
|
||||||
//
|
|
||||||
// handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// // Both methods work - middleware updates r.RemoteAddr (with port 0) and stores in context
|
|
||||||
// realIP := fastlyxff.GetRealIP(r) // Preferred method
|
|
||||||
// // OR: host, _, _ := net.SplitHostPort(r.RemoteAddr) // Direct access (port will be "0")
|
|
||||||
// fmt.Fprintf(w, "Real IP: %s\n", realIP)
|
|
||||||
// })
|
|
||||||
//
|
|
||||||
// http.ListenAndServe(":8080", middleware(handler))
|
|
||||||
//
|
|
||||||
// # Net/HTTP with Additional Trusted Ranges
|
|
||||||
//
|
|
||||||
// fastlyRanges, err := fastlyxff.New("fastly.json")
|
|
||||||
// if err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // Add custom trusted CIDRs (e.g., internal load balancers)
|
|
||||||
// // Note: For Echo framework, use the ekko package for additional ranges
|
|
||||||
// err = fastlyRanges.AddTrustedCIDR("10.0.0.0/8")
|
|
||||||
// if err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// middleware := fastlyRanges.HTTPMiddleware()
|
|
||||||
// handler := middleware(yourHandler)
|
|
||||||
//
|
|
||||||
// The JSON file typically contains IP ranges in this format:
|
|
||||||
//
|
//
|
||||||
// {
|
// {
|
||||||
// "addresses": ["23.235.32.0/20", "43.249.72.0/22", ...],
|
// "addresses": ["23.235.32.0/20", "43.249.72.0/22", ...],
|
||||||
// "ipv6_addresses": ["2a04:4e40::/32", "2a04:4e42::/32", ...]
|
// "ipv6_addresses": ["2a04:4e40::/32", "2a04:4e42::/32", ...]
|
||||||
// }
|
// }
|
||||||
|
//
|
||||||
|
// # Usage
|
||||||
|
//
|
||||||
|
// tp, err := fastlyxff.New("fastly.json")
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// // Use tp.HTTPMiddleware(), tp.ExtractRealIP(r), etc.
|
||||||
|
// // For Echo framework, use the xff/echo package:
|
||||||
|
// // opts, err := xffecho.TrustOptions(tp)
|
||||||
package fastlyxff
|
package fastlyxff
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"go.ntppool.org/common/xff"
|
||||||
)
|
)
|
||||||
|
|
||||||
// FastlyXFF represents Fastly's published IP ranges for their CDN edge servers.
|
// fastlyIPRanges matches the JSON format published by Fastly for their
|
||||||
// This structure matches the JSON format provided by Fastly for their public IP ranges.
|
// edge server IP ranges.
|
||||||
// It contains separate lists for IPv4 and IPv6 CIDR ranges, plus additional trusted CIDRs.
|
type fastlyIPRanges struct {
|
||||||
type FastlyXFF struct {
|
IPv4 []string `json:"addresses"`
|
||||||
IPv4 []string `json:"addresses"` // IPv4 CIDR ranges (e.g., "23.235.32.0/20")
|
IPv6 []string `json:"ipv6_addresses"`
|
||||||
IPv6 []string `json:"ipv6_addresses"` // IPv6 CIDR ranges (e.g., "2a04:4e40::/32")
|
|
||||||
extraCIDRs []string // Additional trusted CIDRs added via AddTrustedCIDR
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrustedNets holds parsed network prefixes for efficient IP range checking.
|
// New loads Fastly IP ranges from a JSON file and returns a [xff.TrustedProxies].
|
||||||
type TrustedNets struct {
|
func New(fileName string) (*xff.TrustedProxies, error) {
|
||||||
prefixes []netip.Prefix // Parsed network prefixes for efficient lookups
|
|
||||||
}
|
|
||||||
|
|
||||||
// contextKey is used for storing the real client IP in request context
|
|
||||||
type contextKey string
|
|
||||||
|
|
||||||
const realIPKey contextKey = "fastly-real-ip"
|
|
||||||
|
|
||||||
// New loads and parses Fastly IP ranges from a JSON file.
|
|
||||||
// The file should contain Fastly's published IP ranges in their standard JSON format.
|
|
||||||
//
|
|
||||||
// Parameters:
|
|
||||||
// - fileName: Path to the Fastly IP ranges JSON file
|
|
||||||
//
|
|
||||||
// Returns the parsed FastlyXFF structure or an error if the file cannot be
|
|
||||||
// read or the JSON format is invalid.
|
|
||||||
func New(fileName string) (*FastlyXFF, error) {
|
|
||||||
b, err := os.ReadFile(fileName)
|
b, err := os.ReadFile(fileName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
d := FastlyXFF{}
|
var ranges fastlyIPRanges
|
||||||
|
if err := json.Unmarshal(b, &ranges); err != nil {
|
||||||
err = json.Unmarshal(b, &d)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &d, nil
|
cidrs := make([]string, 0, len(ranges.IPv4)+len(ranges.IPv6))
|
||||||
}
|
cidrs = append(cidrs, ranges.IPv4...)
|
||||||
|
cidrs = append(cidrs, ranges.IPv6...)
|
||||||
// EchoTrustOption converts Fastly IP ranges into Echo framework trust options.
|
|
||||||
// This method generates trust configurations that tell Echo to accept X-Forwarded-For
|
return xff.NewFromCIDRs(cidrs)
|
||||||
// headers only from Fastly's edge servers, ensuring accurate client IP extraction.
|
|
||||||
//
|
|
||||||
// The generated trust options should be used with Echo's IP extractor:
|
|
||||||
//
|
|
||||||
// options, err := fastlyRanges.EchoTrustOption()
|
|
||||||
// if err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
// e.IPExtractor = echo.ExtractIPFromXFFHeader(options...)
|
|
||||||
//
|
|
||||||
// Returns a slice of Echo trust options or an error if any CIDR range cannot be parsed.
|
|
||||||
func (xff *FastlyXFF) EchoTrustOption() ([]echo.TrustOption, error) {
|
|
||||||
ranges := []echo.TrustOption{}
|
|
||||||
|
|
||||||
for _, s := range append(xff.IPv4, xff.IPv6...) {
|
|
||||||
_, cidr, err := net.ParseCIDR(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
trust := echo.TrustIPRange(cidr)
|
|
||||||
ranges = append(ranges, trust)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ranges, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddTrustedCIDR adds an additional CIDR to the list of trusted proxies.
|
|
||||||
// This allows trusting proxies beyond Fastly's published ranges.
|
|
||||||
// The cidr parameter must be a valid CIDR notation (e.g., "10.0.0.0/8", "192.168.1.0/24").
|
|
||||||
// Returns an error if the CIDR format is invalid.
|
|
||||||
func (xff *FastlyXFF) AddTrustedCIDR(cidr string) error {
|
|
||||||
// Validate CIDR format
|
|
||||||
_, _, err := net.ParseCIDR(cidr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add to extra CIDRs
|
|
||||||
xff.extraCIDRs = append(xff.extraCIDRs, cidr)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// isTrustedProxy checks if the given IP address belongs to Fastly's trusted IP ranges
|
|
||||||
// or any additional CIDRs added via AddTrustedCIDR.
|
|
||||||
func (xff *FastlyXFF) isTrustedProxy(ip string) bool {
|
|
||||||
addr, err := netip.ParseAddr(ip)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check all IPv4 and IPv6 ranges (Fastly + additional)
|
|
||||||
allRanges := append(append(xff.IPv4, xff.IPv6...), xff.extraCIDRs...)
|
|
||||||
for _, s := range allRanges {
|
|
||||||
_, cidr, err := net.ParseCIDR(s)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if cidr.Contains(net.IP(addr.AsSlice())) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractRealIP extracts the real client IP from X-Forwarded-For header.
|
|
||||||
// It returns the rightmost IP that is not from a trusted Fastly proxy.
|
|
||||||
func (xff *FastlyXFF) extractRealIP(r *http.Request) string {
|
|
||||||
// Get the immediate peer IP
|
|
||||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
host = r.RemoteAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the immediate peer is not a trusted Fastly proxy, return it
|
|
||||||
if !xff.isTrustedProxy(host) {
|
|
||||||
return host
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check X-Forwarded-For header
|
|
||||||
xff_header := r.Header.Get("X-Forwarded-For")
|
|
||||||
if xff_header == "" {
|
|
||||||
return host
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse comma-separated IP list
|
|
||||||
ips := strings.Split(xff_header, ",")
|
|
||||||
if len(ips) == 0 {
|
|
||||||
return host
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find the leftmost IP that is not from a trusted proxy
|
|
||||||
// This represents the original client IP
|
|
||||||
for i := 0; i < len(ips); i++ {
|
|
||||||
ip := strings.TrimSpace(ips[i])
|
|
||||||
if ip != "" && !xff.isTrustedProxy(ip) {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback to the immediate peer
|
|
||||||
return host
|
|
||||||
}
|
|
||||||
|
|
||||||
// HTTPMiddleware returns a net/http middleware that extracts real client IP
|
|
||||||
// from X-Forwarded-For headers when the request comes from trusted Fastly proxies.
|
|
||||||
// The real IP is stored in the request context and also updates r.RemoteAddr
|
|
||||||
// with port 0 (since the original port is from the proxy, not the real client).
|
|
||||||
func (xff *FastlyXFF) HTTPMiddleware() func(http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
realIP := xff.extractRealIP(r)
|
|
||||||
|
|
||||||
// Store in context for GetRealIP function
|
|
||||||
ctx := context.WithValue(r.Context(), realIPKey, realIP)
|
|
||||||
r = r.WithContext(ctx)
|
|
||||||
|
|
||||||
// Update RemoteAddr to be consistent with extracted IP
|
|
||||||
// Use port 0 since the original port is from the proxy, not the real client
|
|
||||||
r.RemoteAddr = net.JoinHostPort(realIP, "0")
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRealIP retrieves the real client IP from the request context.
|
|
||||||
// This should be used after the HTTPMiddleware has processed the request.
|
|
||||||
// Returns the remote address if no real IP was extracted.
|
|
||||||
func GetRealIP(r *http.Request) string {
|
|
||||||
if ip, ok := r.Context().Value(realIPKey).(string); ok {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
return r.RemoteAddr
|
|
||||||
}
|
|
||||||
return host
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,356 +1,44 @@
|
|||||||
package fastlyxff
|
package fastlyxff
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"os"
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestFastlyIPRanges(t *testing.T) {
|
func TestNew(t *testing.T) {
|
||||||
fastlyxff, err := New("fastly.json")
|
tp, err := New("fastly.json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("could not load test data: %s", err)
|
t.Fatalf("could not load test data: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := fastlyxff.EchoTrustOption()
|
prefixes := tp.Prefixes()
|
||||||
|
if len(prefixes) < 10 {
|
||||||
|
t.Errorf("only got %d prefixes, expected more", len(prefixes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewFileNotFound(t *testing.T) {
|
||||||
|
_, err := New("nonexistent.json")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for missing file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewInvalidJSON(t *testing.T) {
|
||||||
|
// Create a temp file with invalid JSON
|
||||||
|
f, err := os.CreateTemp("", "fastlyxff-test-*.json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("could not parse test data: %s", err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
defer os.Remove(f.Name())
|
||||||
|
|
||||||
if len(data) < 10 {
|
if _, err := f.WriteString("{invalid"); err != nil {
|
||||||
t.Logf("only got %d prefixes, expected more", len(data))
|
t.Fatal(err)
|
||||||
t.Fail()
|
}
|
||||||
}
|
f.Close()
|
||||||
}
|
|
||||||
|
_, err = New(f.Name())
|
||||||
func TestHTTPMiddleware(t *testing.T) {
|
if err == nil {
|
||||||
// Create a test FastlyXFF instance with known IP ranges
|
t.Fatal("expected error for invalid JSON")
|
||||||
xff := &FastlyXFF{
|
|
||||||
IPv4: []string{"192.0.2.0/24", "203.0.113.0/24"},
|
|
||||||
IPv6: []string{"2001:db8::/32"},
|
|
||||||
}
|
|
||||||
|
|
||||||
middleware := xff.HTTPMiddleware()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
remoteAddr string
|
|
||||||
xForwardedFor string
|
|
||||||
expectedRealIP string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "direct connection",
|
|
||||||
remoteAddr: "198.51.100.1:12345",
|
|
||||||
xForwardedFor: "",
|
|
||||||
expectedRealIP: "198.51.100.1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "trusted proxy with XFF",
|
|
||||||
remoteAddr: "192.0.2.1:80",
|
|
||||||
xForwardedFor: "198.51.100.1",
|
|
||||||
expectedRealIP: "198.51.100.1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "trusted proxy with multiple XFF",
|
|
||||||
remoteAddr: "192.0.2.1:80",
|
|
||||||
xForwardedFor: "198.51.100.1, 203.0.113.1",
|
|
||||||
expectedRealIP: "198.51.100.1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "untrusted proxy ignored",
|
|
||||||
remoteAddr: "198.51.100.2:80",
|
|
||||||
xForwardedFor: "10.0.0.1",
|
|
||||||
expectedRealIP: "198.51.100.2",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6 trusted proxy",
|
|
||||||
remoteAddr: "[2001:db8::1]:80",
|
|
||||||
xForwardedFor: "198.51.100.1",
|
|
||||||
expectedRealIP: "198.51.100.1",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// Create test handler that captures both GetRealIP and r.RemoteAddr
|
|
||||||
var capturedRealIP, capturedRemoteAddr string
|
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
capturedRealIP = GetRealIP(r)
|
|
||||||
capturedRemoteAddr = r.RemoteAddr
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create request with middleware
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
req.RemoteAddr = tt.remoteAddr
|
|
||||||
if tt.xForwardedFor != "" {
|
|
||||||
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
|
||||||
}
|
|
||||||
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
middleware(handler).ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
// Test GetRealIP function
|
|
||||||
if capturedRealIP != tt.expectedRealIP {
|
|
||||||
t.Errorf("GetRealIP: expected %s, got %s", tt.expectedRealIP, capturedRealIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test that r.RemoteAddr is updated with real IP and port 0
|
|
||||||
// (since the original port is from the proxy, not the real client)
|
|
||||||
expectedRemoteAddr := net.JoinHostPort(tt.expectedRealIP, "0")
|
|
||||||
if capturedRemoteAddr != expectedRemoteAddr {
|
|
||||||
t.Errorf("RemoteAddr: expected %s, got %s", expectedRemoteAddr, capturedRemoteAddr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsTrustedProxy(t *testing.T) {
|
|
||||||
xff := &FastlyXFF{
|
|
||||||
IPv4: []string{"192.0.2.0/24", "203.0.113.0/24"},
|
|
||||||
IPv6: []string{"2001:db8::/32"},
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
ip string
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
{"192.0.2.1", true},
|
|
||||||
{"192.0.2.255", true},
|
|
||||||
{"203.0.113.1", true},
|
|
||||||
{"192.0.3.1", false},
|
|
||||||
{"198.51.100.1", false},
|
|
||||||
{"2001:db8::1", true},
|
|
||||||
{"2001:db8:ffff::1", true},
|
|
||||||
{"2001:db9::1", false},
|
|
||||||
{"invalid-ip", false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.ip, func(t *testing.T) {
|
|
||||||
result := xff.isTrustedProxy(tt.ip)
|
|
||||||
if result != tt.expected {
|
|
||||||
t.Errorf("isTrustedProxy(%s) = %v, expected %v", tt.ip, result, tt.expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractRealIP(t *testing.T) {
|
|
||||||
xff := &FastlyXFF{
|
|
||||||
IPv4: []string{"192.0.2.0/24"},
|
|
||||||
IPv6: []string{"2001:db8::/32"},
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
remoteAddr string
|
|
||||||
xForwardedFor string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "no XFF header",
|
|
||||||
remoteAddr: "198.51.100.1:12345",
|
|
||||||
xForwardedFor: "",
|
|
||||||
expected: "198.51.100.1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "trusted proxy with single IP",
|
|
||||||
remoteAddr: "192.0.2.1:80",
|
|
||||||
xForwardedFor: "198.51.100.1",
|
|
||||||
expected: "198.51.100.1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "trusted proxy with multiple IPs",
|
|
||||||
remoteAddr: "192.0.2.1:80",
|
|
||||||
xForwardedFor: "198.51.100.1, 203.0.113.5",
|
|
||||||
expected: "198.51.100.1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "untrusted proxy",
|
|
||||||
remoteAddr: "198.51.100.1:80",
|
|
||||||
xForwardedFor: "10.0.0.1",
|
|
||||||
expected: "198.51.100.1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty XFF",
|
|
||||||
remoteAddr: "192.0.2.1:80",
|
|
||||||
xForwardedFor: "",
|
|
||||||
expected: "192.0.2.1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "malformed remote addr",
|
|
||||||
remoteAddr: "192.0.2.1",
|
|
||||||
xForwardedFor: "198.51.100.1",
|
|
||||||
expected: "198.51.100.1",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
req.RemoteAddr = tt.remoteAddr
|
|
||||||
if tt.xForwardedFor != "" {
|
|
||||||
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
|
||||||
}
|
|
||||||
|
|
||||||
result := xff.extractRealIP(req)
|
|
||||||
if result != tt.expected {
|
|
||||||
t.Errorf("extractRealIP() = %s, expected %s", result, tt.expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetRealIPWithoutMiddleware(t *testing.T) {
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
req.RemoteAddr = "198.51.100.1:12345"
|
|
||||||
|
|
||||||
realIP := GetRealIP(req)
|
|
||||||
expected := "198.51.100.1"
|
|
||||||
if realIP != expected {
|
|
||||||
t.Errorf("GetRealIP() = %s, expected %s", realIP, expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAddTrustedCIDR(t *testing.T) {
|
|
||||||
xff := &FastlyXFF{
|
|
||||||
IPv4: []string{"192.0.2.0/24"},
|
|
||||||
IPv6: []string{"2001:db8::/32"},
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
cidr string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"valid IPv4 range", "10.0.0.0/8", false},
|
|
||||||
{"valid IPv6 range", "fc00::/7", false},
|
|
||||||
{"valid single IP", "203.0.113.1/32", false},
|
|
||||||
{"invalid CIDR", "not-a-cidr", true},
|
|
||||||
{"invalid format", "10.0.0.0/99", true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
err := xff.AddTrustedCIDR(tt.cidr)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("AddTrustedCIDR(%s) error = %v, wantErr %v", tt.cidr, err, tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCustomTrustedCIDRs(t *testing.T) {
|
|
||||||
xff := &FastlyXFF{
|
|
||||||
IPv4: []string{"192.0.2.0/24"},
|
|
||||||
IPv6: []string{"2001:db8::/32"},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add custom trusted CIDRs
|
|
||||||
err := xff.AddTrustedCIDR("10.0.0.0/8")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to add trusted CIDR: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = xff.AddTrustedCIDR("172.16.0.0/12")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to add trusted CIDR: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
ip string
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
// Original Fastly ranges
|
|
||||||
{"192.0.2.1", true},
|
|
||||||
{"2001:db8::1", true},
|
|
||||||
// Custom CIDRs
|
|
||||||
{"10.1.2.3", true},
|
|
||||||
{"172.16.1.1", true},
|
|
||||||
// Not trusted
|
|
||||||
{"198.51.100.1", false},
|
|
||||||
{"172.15.1.1", false},
|
|
||||||
{"10.0.0.0", true}, // Network address should still match
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.ip, func(t *testing.T) {
|
|
||||||
result := xff.isTrustedProxy(tt.ip)
|
|
||||||
if result != tt.expected {
|
|
||||||
t.Errorf("isTrustedProxy(%s) = %v, expected %v", tt.ip, result, tt.expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPMiddlewareWithCustomCIDRs(t *testing.T) {
|
|
||||||
xff := &FastlyXFF{
|
|
||||||
IPv4: []string{"192.0.2.0/24"},
|
|
||||||
IPv6: []string{"2001:db8::/32"},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add custom trusted CIDR for internal proxies
|
|
||||||
err := xff.AddTrustedCIDR("10.0.0.0/8")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to add trusted CIDR: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
middleware := xff.HTTPMiddleware()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
remoteAddr string
|
|
||||||
xForwardedFor string
|
|
||||||
expectedRealIP string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "custom trusted proxy with XFF",
|
|
||||||
remoteAddr: "10.1.2.3:80",
|
|
||||||
xForwardedFor: "198.51.100.1",
|
|
||||||
expectedRealIP: "198.51.100.1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "fastly proxy with XFF",
|
|
||||||
remoteAddr: "192.0.2.1:80",
|
|
||||||
xForwardedFor: "198.51.100.1",
|
|
||||||
expectedRealIP: "198.51.100.1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "untrusted proxy ignored",
|
|
||||||
remoteAddr: "172.16.1.1:80",
|
|
||||||
xForwardedFor: "198.51.100.1",
|
|
||||||
expectedRealIP: "172.16.1.1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "chain through custom and fastly",
|
|
||||||
remoteAddr: "192.0.2.1:80",
|
|
||||||
xForwardedFor: "198.51.100.1, 10.1.2.3",
|
|
||||||
expectedRealIP: "198.51.100.1",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
var capturedIP string
|
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
capturedIP = GetRealIP(r)
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
req.RemoteAddr = tt.remoteAddr
|
|
||||||
if tt.xForwardedFor != "" {
|
|
||||||
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
|
||||||
}
|
|
||||||
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
middleware(handler).ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if capturedIP != tt.expectedRealIP {
|
|
||||||
t.Errorf("expected real IP %s, got %s", tt.expectedRealIP, capturedIP)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
163
xff/xff.go
Normal file
163
xff/xff.go
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
// Package xff provides trusted proxy handling and real client IP extraction
|
||||||
|
// from X-Forwarded-For headers.
|
||||||
|
//
|
||||||
|
// This package has no external dependencies — it uses only the Go standard library.
|
||||||
|
//
|
||||||
|
// The XFF extraction algorithm walks right-to-left through the X-Forwarded-For
|
||||||
|
// header, skipping trusted proxy IPs, and returns the first untrusted IP as the
|
||||||
|
// real client address. This follows the MDN-recommended approach for secure
|
||||||
|
// client IP extraction.
|
||||||
|
//
|
||||||
|
// # Usage with net/http middleware
|
||||||
|
//
|
||||||
|
// tp, err := xff.NewFromCIDRs([]string{"10.0.0.0/8", "192.168.0.0/16"})
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// handler := tp.HTTPMiddleware()(yourHandler)
|
||||||
|
//
|
||||||
|
// # Direct extraction
|
||||||
|
//
|
||||||
|
// realIP := tp.ExtractRealIP(r)
|
||||||
|
package xff
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TrustedProxies holds a set of trusted proxy network prefixes and provides
|
||||||
|
// methods for extracting the real client IP from X-Forwarded-For headers.
|
||||||
|
type TrustedProxies struct {
|
||||||
|
prefixes []netip.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
const realIPKey contextKey = "xff-real-ip"
|
||||||
|
|
||||||
|
// New creates a TrustedProxies from already-parsed prefixes.
|
||||||
|
func New(prefixes ...netip.Prefix) *TrustedProxies {
|
||||||
|
return &TrustedProxies{prefixes: prefixes}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFromCIDRs creates a TrustedProxies from CIDR strings (e.g., "10.0.0.0/8").
|
||||||
|
func NewFromCIDRs(cidrs []string) (*TrustedProxies, error) {
|
||||||
|
prefixes := make([]netip.Prefix, 0, len(cidrs))
|
||||||
|
for _, s := range cidrs {
|
||||||
|
p, err := netip.ParsePrefix(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
prefixes = append(prefixes, p)
|
||||||
|
}
|
||||||
|
return &TrustedProxies{prefixes: prefixes}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddCIDR adds a CIDR string to the trusted proxy list.
|
||||||
|
func (tp *TrustedProxies) AddCIDR(cidr string) error {
|
||||||
|
p, err := netip.ParsePrefix(cidr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
tp.prefixes = append(tp.prefixes, p)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPrefix adds a parsed prefix to the trusted proxy list.
|
||||||
|
func (tp *TrustedProxies) AddPrefix(prefix netip.Prefix) {
|
||||||
|
tp.prefixes = append(tp.prefixes, prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefixes returns a copy of the trusted proxy prefixes.
|
||||||
|
func (tp *TrustedProxies) Prefixes() []netip.Prefix {
|
||||||
|
out := make([]netip.Prefix, len(tp.prefixes))
|
||||||
|
copy(out, tp.prefixes)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTrusted reports whether ip belongs to any of the trusted proxy ranges.
|
||||||
|
func (tp *TrustedProxies) IsTrusted(ip string) bool {
|
||||||
|
addr, err := netip.ParseAddr(ip)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return tp.isTrustedAddr(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tp *TrustedProxies) isTrustedAddr(addr netip.Addr) bool {
|
||||||
|
for _, p := range tp.prefixes {
|
||||||
|
if p.Contains(addr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractRealIP extracts the real client IP from a request by walking the
|
||||||
|
// X-Forwarded-For header right-to-left, skipping trusted proxy IPs.
|
||||||
|
// If the immediate peer is not a trusted proxy, its IP is returned.
|
||||||
|
func (tp *TrustedProxies) ExtractRealIP(r *http.Request) string {
|
||||||
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
host = r.RemoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
hostAddr, err := netip.ParseAddr(host)
|
||||||
|
if err != nil || !tp.isTrustedAddr(hostAddr) {
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
xffHeader := r.Header.Get("X-Forwarded-For")
|
||||||
|
if xffHeader == "" {
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
ips := strings.Split(xffHeader, ",")
|
||||||
|
|
||||||
|
// Walk right-to-left: skip trusted proxies, return first untrusted IP.
|
||||||
|
for i := len(ips) - 1; i >= 0; i-- {
|
||||||
|
ip := strings.TrimSpace(ips[i])
|
||||||
|
if ip == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addr, err := netip.ParseAddr(ip)
|
||||||
|
if err != nil || !tp.isTrustedAddr(addr) {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPMiddleware returns a net/http middleware that extracts the real client IP
|
||||||
|
// from X-Forwarded-For headers and stores it in the request context and
|
||||||
|
// RemoteAddr. The port in RemoteAddr is set to 0 because the original port
|
||||||
|
// belongs to the proxy connection, not the real client.
|
||||||
|
func (tp *TrustedProxies) HTTPMiddleware() func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
realIP := tp.ExtractRealIP(r)
|
||||||
|
ctx := context.WithValue(r.Context(), realIPKey, realIP)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
r.RemoteAddr = net.JoinHostPort(realIP, "0")
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRealIP retrieves the real client IP from the request context.
|
||||||
|
// Returns the remote address host if no real IP was extracted by middleware.
|
||||||
|
func GetRealIP(r *http.Request) string {
|
||||||
|
if ip, ok := r.Context().Value(realIPKey).(string); ok {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return r.RemoteAddr
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
225
xff/xff_test.go
Normal file
225
xff/xff_test.go
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
package xff
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testProxies(t *testing.T) *TrustedProxies {
|
||||||
|
t.Helper()
|
||||||
|
tp, err := NewFromCIDRs([]string{"192.0.2.0/24", "203.0.113.0/24", "2001:db8::/32"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return tp
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
p := netip.MustParsePrefix("10.0.0.0/8")
|
||||||
|
tp := New(p)
|
||||||
|
if len(tp.Prefixes()) != 1 {
|
||||||
|
t.Fatalf("expected 1 prefix, got %d", len(tp.Prefixes()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewFromCIDRs(t *testing.T) {
|
||||||
|
_, err := NewFromCIDRs([]string{"not-a-cidr"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid CIDR")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsTrusted(t *testing.T) {
|
||||||
|
tp := testProxies(t)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"192.0.2.1", true},
|
||||||
|
{"192.0.2.255", true},
|
||||||
|
{"203.0.113.1", true},
|
||||||
|
{"192.0.3.1", false},
|
||||||
|
{"198.51.100.1", false},
|
||||||
|
{"2001:db8::1", true},
|
||||||
|
{"2001:db8:ffff::1", true},
|
||||||
|
{"2001:db9::1", false},
|
||||||
|
{"invalid-ip", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.ip, func(t *testing.T) {
|
||||||
|
if got := tp.IsTrusted(tt.ip); got != tt.expected {
|
||||||
|
t.Errorf("IsTrusted(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddCIDR(t *testing.T) {
|
||||||
|
tp := testProxies(t)
|
||||||
|
|
||||||
|
if err := tp.AddCIDR("10.0.0.0/8"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !tp.IsTrusted("10.1.2.3") {
|
||||||
|
t.Error("expected 10.1.2.3 to be trusted after AddCIDR")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tp.AddCIDR("bad"); err == nil {
|
||||||
|
t.Error("expected error for invalid CIDR")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPrefix(t *testing.T) {
|
||||||
|
tp := testProxies(t)
|
||||||
|
tp.AddPrefix(netip.MustParsePrefix("172.16.0.0/12"))
|
||||||
|
if !tp.IsTrusted("172.16.1.1") {
|
||||||
|
t.Error("expected 172.16.1.1 to be trusted after AddPrefix")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractRealIP(t *testing.T) {
|
||||||
|
tp := testProxies(t)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
remoteAddr string
|
||||||
|
xForwardedFor string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no XFF, untrusted peer",
|
||||||
|
remoteAddr: "198.51.100.1:12345",
|
||||||
|
expected: "198.51.100.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trusted proxy, single XFF",
|
||||||
|
remoteAddr: "192.0.2.1:80",
|
||||||
|
xForwardedFor: "198.51.100.1",
|
||||||
|
expected: "198.51.100.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trusted proxy, empty XFF",
|
||||||
|
remoteAddr: "192.0.2.1:80",
|
||||||
|
xForwardedFor: "",
|
||||||
|
expected: "192.0.2.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "untrusted peer ignores XFF",
|
||||||
|
remoteAddr: "198.51.100.1:80",
|
||||||
|
xForwardedFor: "10.0.0.1",
|
||||||
|
expected: "198.51.100.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "malformed remote addr",
|
||||||
|
remoteAddr: "192.0.2.1",
|
||||||
|
xForwardedFor: "198.51.100.1",
|
||||||
|
expected: "198.51.100.1",
|
||||||
|
},
|
||||||
|
// Right-to-left: "client, proxy1(trusted)" -> skip proxy1, return client
|
||||||
|
{
|
||||||
|
name: "right-to-left skips trusted proxies in XFF",
|
||||||
|
remoteAddr: "192.0.2.1:80",
|
||||||
|
xForwardedFor: "198.51.100.1, 203.0.113.1",
|
||||||
|
expected: "198.51.100.1",
|
||||||
|
},
|
||||||
|
// Right-to-left: "spoofed, real-client, trusted-proxy"
|
||||||
|
// should return real-client (first untrusted from right)
|
||||||
|
{
|
||||||
|
name: "right-to-left stops at first untrusted from right",
|
||||||
|
remoteAddr: "192.0.2.1:80",
|
||||||
|
xForwardedFor: "198.51.100.50, 198.51.100.99, 203.0.113.1",
|
||||||
|
expected: "198.51.100.99",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 trusted proxy",
|
||||||
|
remoteAddr: "[2001:db8::1]:80",
|
||||||
|
xForwardedFor: "198.51.100.1",
|
||||||
|
expected: "198.51.100.1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req.RemoteAddr = tt.remoteAddr
|
||||||
|
if tt.xForwardedFor != "" {
|
||||||
|
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||||
|
}
|
||||||
|
if got := tp.ExtractRealIP(req); got != tt.expected {
|
||||||
|
t.Errorf("ExtractRealIP() = %s, want %s", got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPMiddleware(t *testing.T) {
|
||||||
|
tp := testProxies(t)
|
||||||
|
mw := tp.HTTPMiddleware()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
remoteAddr string
|
||||||
|
xForwardedFor string
|
||||||
|
expectedRealIP string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "direct connection",
|
||||||
|
remoteAddr: "198.51.100.1:12345",
|
||||||
|
expectedRealIP: "198.51.100.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trusted proxy with XFF",
|
||||||
|
remoteAddr: "192.0.2.1:80",
|
||||||
|
xForwardedFor: "198.51.100.1",
|
||||||
|
expectedRealIP: "198.51.100.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "untrusted proxy ignored",
|
||||||
|
remoteAddr: "198.51.100.2:80",
|
||||||
|
xForwardedFor: "10.0.0.1",
|
||||||
|
expectedRealIP: "198.51.100.2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var capturedRealIP, capturedRemoteAddr string
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
capturedRealIP = GetRealIP(r)
|
||||||
|
capturedRemoteAddr = r.RemoteAddr
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req.RemoteAddr = tt.remoteAddr
|
||||||
|
if tt.xForwardedFor != "" {
|
||||||
|
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||||
|
}
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
mw(handler).ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if capturedRealIP != tt.expectedRealIP {
|
||||||
|
t.Errorf("GetRealIP: got %s, want %s", capturedRealIP, tt.expectedRealIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedAddr := net.JoinHostPort(tt.expectedRealIP, "0")
|
||||||
|
if capturedRemoteAddr != expectedAddr {
|
||||||
|
t.Errorf("RemoteAddr: got %s, want %s", capturedRemoteAddr, expectedAddr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRealIPWithoutMiddleware(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req.RemoteAddr = "198.51.100.1:12345"
|
||||||
|
|
||||||
|
if got := GetRealIP(req); got != "198.51.100.1" {
|
||||||
|
t.Errorf("GetRealIP() = %s, want 198.51.100.1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user