Compare commits
4 Commits
Author | SHA1 | Date | |
---|---|---|---|
2670d25b52 | |||
45308cd4bf | |||
4767caf7b8 | |||
f90281f472 |
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
.aider*
|
28
.mcp.json
Normal file
28
.mcp.json
Normal file
@@ -0,0 +1,28 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"context7": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@upstash/context7-mcp@1.0.0"
|
||||
],
|
||||
"env": {}
|
||||
},
|
||||
"serena": {
|
||||
"type": "stdio",
|
||||
"command": "uvx",
|
||||
"args": [
|
||||
"--from",
|
||||
"git+https://github.com/oraios/serena@v0.1.4",
|
||||
"serena",
|
||||
"start-mcp-server",
|
||||
"--context",
|
||||
"ide-assistant",
|
||||
"--project",
|
||||
"."
|
||||
],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
20
.pre-commit-config.yaml
Normal file
20
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,20 @@
|
||||
---
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-added-large-files
|
||||
args: ["--maxkb=400"]
|
||||
- id: check-case-conflict
|
||||
- id: check-executables-have-shebangs
|
||||
- id: check-shebang-scripts-are-executable
|
||||
- id: check-merge-conflict
|
||||
- id: check-symlinks
|
||||
|
||||
- repo: https://github.com/adrienverge/yamllint
|
||||
rev: v1.35.1
|
||||
hooks:
|
||||
- id: yamllint
|
||||
args: [-c=.yamllint]
|
14
.yamllint
Normal file
14
.yamllint
Normal file
@@ -0,0 +1,14 @@
|
||||
---
|
||||
extends: relaxed
|
||||
|
||||
rules:
|
||||
braces:
|
||||
level: error
|
||||
brackets:
|
||||
level: error
|
||||
|
||||
truthy:
|
||||
level: warning
|
||||
|
||||
#ignore: |
|
||||
# - ...
|
@@ -39,6 +39,8 @@ type Config struct {
|
||||
webHostnames []string
|
||||
webTLS bool
|
||||
|
||||
poolDomain string `accessor:"getter"`
|
||||
|
||||
valid bool `accessor:"getter"`
|
||||
}
|
||||
|
||||
@@ -52,6 +54,7 @@ type Config struct {
|
||||
// - web_hostname: Comma-separated web hostnames (first becomes primary)
|
||||
// - manage_tls: Management interface TLS setting
|
||||
// - web_tls: Web interface TLS setting
|
||||
// - pool_domain: NTP pool domain (default: pool.ntp.org)
|
||||
func New() *Config {
|
||||
c := Config{}
|
||||
c.deploymentMode = os.Getenv("deployment_mode")
|
||||
@@ -69,6 +72,11 @@ func New() *Config {
|
||||
c.manageTLS = parseBool(os.Getenv("manage_tls"))
|
||||
c.webTLS = parseBool(os.Getenv("web_tls"))
|
||||
|
||||
c.poolDomain = os.Getenv("pool_domain")
|
||||
if c.poolDomain == "" {
|
||||
c.poolDomain = "pool.ntp.org"
|
||||
}
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
|
@@ -23,6 +23,13 @@ func (c *Config) WebHostname() string {
|
||||
return c.webHostname
|
||||
}
|
||||
|
||||
func (c *Config) PoolDomain() string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
return c.poolDomain
|
||||
}
|
||||
|
||||
func (c *Config) Valid() bool {
|
||||
if c == nil {
|
||||
return false
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
@@ -9,15 +10,67 @@ import (
|
||||
|
||||
// Config represents the database configuration structure
|
||||
type Config struct {
|
||||
MySQL DBConfig `yaml:"mysql"`
|
||||
// MySQL configuration (use this OR Postgres, not both)
|
||||
MySQL *MySQLConfig `yaml:"mysql,omitempty"`
|
||||
// Postgres configuration (use this OR MySQL, not both)
|
||||
Postgres *PostgresConfig `yaml:"postgres,omitempty"`
|
||||
|
||||
// Legacy flat PostgreSQL format (deprecated, for backward compatibility only)
|
||||
// If neither MySQL nor Postgres is set, these fields will be used for PostgreSQL
|
||||
User string `yaml:"user,omitempty"`
|
||||
Pass string `yaml:"pass,omitempty"`
|
||||
Host string `yaml:"host,omitempty"`
|
||||
Port uint16 `yaml:"port,omitempty"`
|
||||
Name string `yaml:"name,omitempty"`
|
||||
SSLMode string `yaml:"sslmode,omitempty"`
|
||||
}
|
||||
|
||||
// DBConfig represents the MySQL database configuration
|
||||
type DBConfig struct {
|
||||
DSN string `default:"" flag:"dsn" usage:"Database DSN"`
|
||||
User string `default:"" flag:"user"`
|
||||
Pass string `default:"" flag:"pass"`
|
||||
DBName string // Optional database name override
|
||||
// MySQLConfig represents the MySQL database configuration
|
||||
type MySQLConfig struct {
|
||||
DSN string `yaml:"dsn" default:"" flag:"dsn" usage:"Database DSN"`
|
||||
User string `yaml:"user" default:"" flag:"user"`
|
||||
Pass string `yaml:"pass" default:"" flag:"pass"`
|
||||
DBName string `yaml:"name,omitempty"` // Optional database name override
|
||||
}
|
||||
|
||||
// PostgresConfig represents the PostgreSQL database configuration
|
||||
type PostgresConfig struct {
|
||||
User string `yaml:"user"`
|
||||
Pass string `yaml:"pass"`
|
||||
Host string `yaml:"host"`
|
||||
Port uint16 `yaml:"port"`
|
||||
Name string `yaml:"name"`
|
||||
SSLMode string `yaml:"sslmode"`
|
||||
}
|
||||
|
||||
// DBConfig is a legacy alias for MySQLConfig
|
||||
type DBConfig = MySQLConfig
|
||||
|
||||
// Validate ensures the configuration is valid and unambiguous
|
||||
func (c *Config) Validate() error {
|
||||
hasMySQL := c.MySQL != nil
|
||||
hasPostgres := c.Postgres != nil
|
||||
hasLegacy := c.User != "" || c.Host != "" || c.Port != 0 || c.Name != ""
|
||||
|
||||
count := 0
|
||||
if hasMySQL {
|
||||
count++
|
||||
}
|
||||
if hasPostgres {
|
||||
count++
|
||||
}
|
||||
if hasLegacy {
|
||||
count++
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
return fmt.Errorf("no database configuration provided")
|
||||
}
|
||||
if count > 1 {
|
||||
return fmt.Errorf("multiple database configurations provided (only one allowed)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigOptions allows customization of database opening behavior
|
||||
|
@@ -56,9 +56,9 @@ func TestMonitorConfigOptions(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfigStructures(t *testing.T) {
|
||||
// Test that configuration structures can be created and populated
|
||||
// Test that MySQL configuration structures can be created and populated
|
||||
config := Config{
|
||||
MySQL: DBConfig{
|
||||
MySQL: &MySQLConfig{
|
||||
DSN: "user:pass@tcp(localhost:3306)/dbname",
|
||||
User: "testuser",
|
||||
Pass: "testpass",
|
||||
@@ -79,3 +79,118 @@ func TestConfigStructures(t *testing.T) {
|
||||
t.Errorf("Expected DBName='testdb', got '%s'", config.MySQL.DBName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresConfigStructures(t *testing.T) {
|
||||
// Test that PostgreSQL configuration structures can be created and populated
|
||||
config := Config{
|
||||
Postgres: &PostgresConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Pass: "testpass",
|
||||
Name: "testdb",
|
||||
SSLMode: "require",
|
||||
},
|
||||
}
|
||||
|
||||
if config.Postgres.Host != "localhost" {
|
||||
t.Errorf("Expected Host='localhost', got '%s'", config.Postgres.Host)
|
||||
}
|
||||
if config.Postgres.Port != 5432 {
|
||||
t.Errorf("Expected Port=5432, got %d", config.Postgres.Port)
|
||||
}
|
||||
if config.Postgres.User != "testuser" {
|
||||
t.Errorf("Expected User='testuser', got '%s'", config.Postgres.User)
|
||||
}
|
||||
if config.Postgres.Pass != "testpass" {
|
||||
t.Errorf("Expected Pass='testpass', got '%s'", config.Postgres.Pass)
|
||||
}
|
||||
if config.Postgres.Name != "testdb" {
|
||||
t.Errorf("Expected Name='testdb', got '%s'", config.Postgres.Name)
|
||||
}
|
||||
if config.Postgres.SSLMode != "require" {
|
||||
t.Errorf("Expected SSLMode='require', got '%s'", config.Postgres.SSLMode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyPostgresConfig(t *testing.T) {
|
||||
// Test that legacy flat PostgreSQL format can be created
|
||||
config := Config{
|
||||
User: "testuser",
|
||||
Pass: "testpass",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Name: "testdb",
|
||||
SSLMode: "require",
|
||||
}
|
||||
|
||||
if config.User != "testuser" {
|
||||
t.Errorf("Expected User='testuser', got '%s'", config.User)
|
||||
}
|
||||
if config.Name != "testdb" {
|
||||
t.Errorf("Expected Name='testdb', got '%s'", config.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config Config
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid mysql config",
|
||||
config: Config{
|
||||
MySQL: &MySQLConfig{DSN: "test"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid postgres config",
|
||||
config: Config{
|
||||
Postgres: &PostgresConfig{User: "test", Host: "localhost", Name: "test"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid legacy postgres config",
|
||||
config: Config{
|
||||
User: "test",
|
||||
Host: "localhost",
|
||||
Name: "testdb",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "both mysql and postgres set",
|
||||
config: Config{
|
||||
MySQL: &MySQLConfig{DSN: "test"},
|
||||
Postgres: &PostgresConfig{User: "test"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "mysql and legacy postgres set",
|
||||
config: Config{
|
||||
MySQL: &MySQLConfig{DSN: "test"},
|
||||
User: "test",
|
||||
Name: "testdb",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no config set",
|
||||
config: Config{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -8,6 +8,8 @@ import (
|
||||
"os"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -58,31 +60,128 @@ func createConnector(configFile string) CreateConnectorFunc {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dsn := cfg.MySQL.DSN
|
||||
if len(dsn) == 0 {
|
||||
dsn = os.Getenv("DATABASE_DSN")
|
||||
if len(dsn) == 0 {
|
||||
return nil, fmt.Errorf("dsn config in database.yaml or DATABASE_DSN environment variable required")
|
||||
}
|
||||
// Validate configuration
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
|
||||
dbcfg, err := mysql.ParseDSN(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Determine database type and create appropriate connector
|
||||
if cfg.MySQL != nil {
|
||||
return createMySQLConnector(cfg.MySQL)
|
||||
} else if cfg.Postgres != nil {
|
||||
return createPostgresConnector(cfg.Postgres)
|
||||
} else if cfg.User != "" && cfg.Name != "" {
|
||||
// Legacy flat PostgreSQL format (requires at minimum user and dbname)
|
||||
return createPostgresConnectorFromFlat(&cfg)
|
||||
}
|
||||
|
||||
if user := cfg.MySQL.User; len(user) > 0 {
|
||||
dbcfg.User = user
|
||||
}
|
||||
|
||||
if pass := cfg.MySQL.Pass; len(pass) > 0 {
|
||||
dbcfg.Passwd = pass
|
||||
}
|
||||
|
||||
if name := cfg.MySQL.DBName; len(name) > 0 {
|
||||
dbcfg.DBName = name
|
||||
}
|
||||
|
||||
return mysql.NewConnector(dbcfg)
|
||||
return nil, fmt.Errorf("no valid database configuration found (mysql or postgres section required)")
|
||||
}
|
||||
}
|
||||
|
||||
// createMySQLConnector creates a MySQL connector from configuration
|
||||
func createMySQLConnector(cfg *MySQLConfig) (driver.Connector, error) {
|
||||
dsn := cfg.DSN
|
||||
if len(dsn) == 0 {
|
||||
dsn = os.Getenv("DATABASE_DSN")
|
||||
if len(dsn) == 0 {
|
||||
return nil, fmt.Errorf("dsn config in database.yaml or DATABASE_DSN environment variable required")
|
||||
}
|
||||
}
|
||||
|
||||
dbcfg, err := mysql.ParseDSN(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user := cfg.User; len(user) > 0 {
|
||||
dbcfg.User = user
|
||||
}
|
||||
|
||||
if pass := cfg.Pass; len(pass) > 0 {
|
||||
dbcfg.Passwd = pass
|
||||
}
|
||||
|
||||
if name := cfg.DBName; len(name) > 0 {
|
||||
dbcfg.DBName = name
|
||||
}
|
||||
|
||||
return mysql.NewConnector(dbcfg)
|
||||
}
|
||||
|
||||
// createPostgresConnector creates a PostgreSQL connector from configuration
|
||||
func createPostgresConnector(cfg *PostgresConfig) (driver.Connector, error) {
|
||||
// Validate required fields
|
||||
if cfg.Host == "" {
|
||||
return nil, fmt.Errorf("postgres: host is required")
|
||||
}
|
||||
if cfg.User == "" {
|
||||
return nil, fmt.Errorf("postgres: user is required")
|
||||
}
|
||||
if cfg.Name == "" {
|
||||
return nil, fmt.Errorf("postgres: database name is required")
|
||||
}
|
||||
|
||||
// Validate SSLMode
|
||||
validSSLModes := map[string]bool{
|
||||
"disable": true, "allow": true, "prefer": true,
|
||||
"require": true, "verify-ca": true, "verify-full": true,
|
||||
}
|
||||
if cfg.SSLMode != "" && !validSSLModes[cfg.SSLMode] {
|
||||
return nil, fmt.Errorf("postgres: invalid sslmode: %s", cfg.SSLMode)
|
||||
}
|
||||
|
||||
// Build config directly (security: no DSN string with password)
|
||||
connConfig, err := pgx.ParseConfig("")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: failed to create pgx config: %w", err)
|
||||
}
|
||||
|
||||
connConfig.Host = cfg.Host
|
||||
connConfig.Port = cfg.Port
|
||||
connConfig.User = cfg.User
|
||||
connConfig.Password = cfg.Pass
|
||||
connConfig.Database = cfg.Name
|
||||
|
||||
// Map SSLMode to pgx configuration
|
||||
// Note: pgx uses different SSL handling than libpq
|
||||
// For now, we'll construct a minimal DSN with sslmode for ParseConfig
|
||||
if cfg.SSLMode != "" {
|
||||
// Reconstruct with sslmode only (no password in DSN)
|
||||
dsnWithoutPassword := fmt.Sprintf("host=%s port=%d user=%s dbname=%s sslmode=%s",
|
||||
cfg.Host, cfg.Port, cfg.User, cfg.Name, cfg.SSLMode)
|
||||
connConfig, err = pgx.ParseConfig(dsnWithoutPassword)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: failed to parse config with sslmode: %w", err)
|
||||
}
|
||||
// Set password separately after parsing
|
||||
connConfig.Password = cfg.Pass
|
||||
}
|
||||
|
||||
return stdlib.GetConnector(*connConfig), nil
|
||||
}
|
||||
|
||||
// createPostgresConnectorFromFlat creates a PostgreSQL connector from flat config format
|
||||
func createPostgresConnectorFromFlat(cfg *Config) (driver.Connector, error) {
|
||||
pgCfg := &PostgresConfig{
|
||||
User: cfg.User,
|
||||
Pass: cfg.Pass,
|
||||
Host: cfg.Host,
|
||||
Port: cfg.Port,
|
||||
Name: cfg.Name,
|
||||
SSLMode: cfg.SSLMode,
|
||||
}
|
||||
|
||||
// Set defaults for PostgreSQL
|
||||
if pgCfg.Host == "" {
|
||||
pgCfg.Host = "localhost"
|
||||
}
|
||||
if pgCfg.Port == 0 {
|
||||
pgCfg.Port = 5432
|
||||
}
|
||||
if pgCfg.SSLMode == "" {
|
||||
pgCfg.SSLMode = "prefer"
|
||||
}
|
||||
|
||||
return createPostgresConnector(pgCfg)
|
||||
}
|
||||
|
120
database/pgdb/CLAUDE.md
Normal file
120
database/pgdb/CLAUDE.md
Normal file
@@ -0,0 +1,120 @@
|
||||
# pgdb - Native PostgreSQL Connection Pool
|
||||
|
||||
Primary package for PostgreSQL connections using native pgx pool (`*pgxpool.Pool`). Provides better performance and PostgreSQL-specific features compared to `database/sql`.
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Example
|
||||
|
||||
```go
|
||||
import (
|
||||
"context"
|
||||
"go.ntppool.org/common/database/pgdb"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Open pool with default options
|
||||
pool, err := pgdb.OpenPool(ctx, pgdb.DefaultPoolOptions())
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
// Use the pool for queries
|
||||
row := pool.QueryRow(ctx, "SELECT version()")
|
||||
var version string
|
||||
row.Scan(&version)
|
||||
}
|
||||
```
|
||||
|
||||
### With Custom Config File
|
||||
|
||||
```go
|
||||
pool, err := pgdb.OpenPoolWithConfigFile(ctx, "/path/to/database.yaml")
|
||||
```
|
||||
|
||||
### With Custom Pool Settings
|
||||
|
||||
```go
|
||||
opts := pgdb.DefaultPoolOptions()
|
||||
opts.MaxConns = 50
|
||||
opts.MinConns = 5
|
||||
opts.MaxConnLifetime = 2 * time.Hour
|
||||
|
||||
pool, err := pgdb.OpenPool(ctx, opts)
|
||||
```
|
||||
|
||||
## Configuration Format
|
||||
|
||||
### Recommended: Nested Format (database.yaml)
|
||||
|
||||
```yaml
|
||||
postgres:
|
||||
host: localhost
|
||||
port: 5432
|
||||
user: myuser
|
||||
pass: mypassword
|
||||
name: mydb
|
||||
sslmode: prefer
|
||||
```
|
||||
|
||||
### Legacy: Flat Format (backward compatible)
|
||||
|
||||
```yaml
|
||||
host: localhost
|
||||
port: 5432
|
||||
user: myuser
|
||||
pass: mypassword
|
||||
name: mydb
|
||||
sslmode: prefer
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### PoolOptions
|
||||
|
||||
- `ConfigFiles` - List of config file paths to search (default: `database.yaml`, `/vault/secrets/database.yaml`)
|
||||
- `MinConns` - Minimum connections (default: 0)
|
||||
- `MaxConns` - Maximum connections (default: 25)
|
||||
- `MaxConnLifetime` - Connection lifetime (default: 1 hour)
|
||||
- `MaxConnIdleTime` - Idle timeout (default: 30 minutes)
|
||||
- `HealthCheckPeriod` - Health check interval (default: 1 minute)
|
||||
|
||||
### PostgreSQL Config Fields
|
||||
|
||||
- `host` - Database host (required)
|
||||
- `user` - Database user (required)
|
||||
- `pass` - Database password
|
||||
- `name` - Database name (required)
|
||||
- `port` - Port number (default: 5432)
|
||||
- `sslmode` - SSL mode: `disable`, `allow`, `prefer`, `require`, `verify-ca`, `verify-full` (default: `prefer`)
|
||||
|
||||
## Environment Variables
|
||||
|
||||
- `DATABASE_CONFIG_FILE` - Override config file location
|
||||
|
||||
## When to Use
|
||||
|
||||
**Use `pgdb.OpenPool()`** (this package) when:
|
||||
- You need native PostgreSQL features (LISTEN/NOTIFY, COPY, etc.)
|
||||
- You want better performance
|
||||
- You're writing new PostgreSQL code
|
||||
|
||||
**Use `database.OpenDB()`** (sql.DB) when:
|
||||
- You need database-agnostic code
|
||||
- You're using SQLC or other tools that expect `database/sql`
|
||||
- You need to support both MySQL and PostgreSQL
|
||||
|
||||
## Security
|
||||
|
||||
This package avoids password exposure by:
|
||||
1. Never constructing DSN strings with passwords
|
||||
2. Setting passwords separately in pgx config objects
|
||||
3. Validating all configuration before connection
|
||||
|
||||
## See Also
|
||||
|
||||
- `database/` - Generic sql.DB support for MySQL and PostgreSQL
|
||||
- pgx documentation: https://github.com/jackc/pgx
|
64
database/pgdb/config.go
Normal file
64
database/pgdb/config.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package pgdb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"go.ntppool.org/common/database"
|
||||
)
|
||||
|
||||
// CreatePoolConfig converts database.PostgresConfig to pgxpool.Config
|
||||
// This is the secure way to create a config without exposing passwords in DSN strings
|
||||
func CreatePoolConfig(cfg *database.PostgresConfig) (*pgxpool.Config, error) {
|
||||
// Validate required fields
|
||||
if cfg.Host == "" {
|
||||
return nil, fmt.Errorf("postgres: host is required")
|
||||
}
|
||||
if cfg.User == "" {
|
||||
return nil, fmt.Errorf("postgres: user is required")
|
||||
}
|
||||
if cfg.Name == "" {
|
||||
return nil, fmt.Errorf("postgres: database name is required")
|
||||
}
|
||||
|
||||
// Validate SSLMode
|
||||
validSSLModes := map[string]bool{
|
||||
"disable": true, "allow": true, "prefer": true,
|
||||
"require": true, "verify-ca": true, "verify-full": true,
|
||||
}
|
||||
if cfg.SSLMode != "" && !validSSLModes[cfg.SSLMode] {
|
||||
return nil, fmt.Errorf("postgres: invalid sslmode: %s", cfg.SSLMode)
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
host := cfg.Host
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
}
|
||||
|
||||
port := cfg.Port
|
||||
if port == 0 {
|
||||
port = 5432
|
||||
}
|
||||
|
||||
sslmode := cfg.SSLMode
|
||||
if sslmode == "" {
|
||||
sslmode = "prefer"
|
||||
}
|
||||
|
||||
// Build connection string WITHOUT password
|
||||
// We'll set the password separately in the config
|
||||
connString := fmt.Sprintf("host=%s port=%d user=%s dbname=%s sslmode=%s",
|
||||
host, port, cfg.User, cfg.Name, sslmode)
|
||||
|
||||
// Parse the connection string
|
||||
poolConfig, err := pgxpool.ParseConfig(connString)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: failed to parse connection config: %w", err)
|
||||
}
|
||||
|
||||
// Set password separately (security: never put password in the connection string)
|
||||
poolConfig.ConnConfig.Password = cfg.Pass
|
||||
|
||||
return poolConfig, nil
|
||||
}
|
173
database/pgdb/pool.go
Normal file
173
database/pgdb/pool.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package pgdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"go.ntppool.org/common/database"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// PoolOptions configures pgxpool connection behavior
|
||||
type PoolOptions struct {
|
||||
// ConfigFiles is a list of config file paths to search for database configuration
|
||||
ConfigFiles []string
|
||||
|
||||
// MinConns is the minimum number of connections in the pool
|
||||
// Default: 0 (no minimum)
|
||||
MinConns int32
|
||||
|
||||
// MaxConns is the maximum number of connections in the pool
|
||||
// Default: 25
|
||||
MaxConns int32
|
||||
|
||||
// MaxConnLifetime is the maximum lifetime of a connection
|
||||
// Default: 1 hour
|
||||
MaxConnLifetime time.Duration
|
||||
|
||||
// MaxConnIdleTime is the maximum idle time of a connection
|
||||
// Default: 30 minutes
|
||||
MaxConnIdleTime time.Duration
|
||||
|
||||
// HealthCheckPeriod is how often to check connection health
|
||||
// Default: 1 minute
|
||||
HealthCheckPeriod time.Duration
|
||||
}
|
||||
|
||||
// DefaultPoolOptions returns sensible defaults for pgxpool
|
||||
func DefaultPoolOptions() PoolOptions {
|
||||
return PoolOptions{
|
||||
ConfigFiles: getConfigFiles(),
|
||||
MinConns: 0,
|
||||
MaxConns: 25,
|
||||
MaxConnLifetime: time.Hour,
|
||||
MaxConnIdleTime: 30 * time.Minute,
|
||||
HealthCheckPeriod: time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// OpenPool opens a native pgx connection pool with the specified configuration
|
||||
// This is the primary and recommended way to connect to PostgreSQL
|
||||
func OpenPool(ctx context.Context, options PoolOptions) (*pgxpool.Pool, error) {
|
||||
// Find and read config file
|
||||
pgCfg, err := findAndParseConfig(options.ConfigFiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create pool config from PostgreSQL config
|
||||
poolConfig, err := CreatePoolConfig(pgCfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Apply pool-specific settings
|
||||
poolConfig.MinConns = options.MinConns
|
||||
poolConfig.MaxConns = options.MaxConns
|
||||
poolConfig.MaxConnLifetime = options.MaxConnLifetime
|
||||
poolConfig.MaxConnIdleTime = options.MaxConnIdleTime
|
||||
poolConfig.HealthCheckPeriod = options.HealthCheckPeriod
|
||||
|
||||
// Create the pool
|
||||
pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create connection pool: %w", err)
|
||||
}
|
||||
|
||||
// Test the connection
|
||||
if err := pool.Ping(ctx); err != nil {
|
||||
pool.Close()
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// OpenPoolWithConfigFile opens a connection pool using an explicit config file path
|
||||
// This is a convenience function for when you have a specific config file
|
||||
func OpenPoolWithConfigFile(ctx context.Context, configFile string) (*pgxpool.Pool, error) {
|
||||
options := DefaultPoolOptions()
|
||||
options.ConfigFiles = []string{configFile}
|
||||
return OpenPool(ctx, options)
|
||||
}
|
||||
|
||||
// findAndParseConfig searches for and parses the first existing config file
|
||||
func findAndParseConfig(configFiles []string) (*database.PostgresConfig, error) {
|
||||
var firstErr error
|
||||
|
||||
for _, configFile := range configFiles {
|
||||
if configFile == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(configFile); err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Try to read and parse the file
|
||||
pgCfg, err := parseConfigFile(configFile)
|
||||
if err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
return pgCfg, nil
|
||||
}
|
||||
|
||||
if firstErr != nil {
|
||||
return nil, fmt.Errorf("no config file found: %w", firstErr)
|
||||
}
|
||||
return nil, fmt.Errorf("no valid config files provided")
|
||||
}
|
||||
|
||||
// parseConfigFile reads and parses a YAML config file
|
||||
func parseConfigFile(configFile string) (*database.PostgresConfig, error) {
|
||||
file, err := os.Open(configFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open config file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
dec := yaml.NewDecoder(file)
|
||||
cfg := database.Config{}
|
||||
|
||||
if err := dec.Decode(&cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode config: %w", err)
|
||||
}
|
||||
|
||||
// Extract PostgreSQL config
|
||||
if cfg.Postgres != nil {
|
||||
return cfg.Postgres, nil
|
||||
}
|
||||
|
||||
// Check for legacy flat format
|
||||
if cfg.User != "" && cfg.Name != "" {
|
||||
return &database.PostgresConfig{
|
||||
User: cfg.User,
|
||||
Pass: cfg.Pass,
|
||||
Host: cfg.Host,
|
||||
Port: cfg.Port,
|
||||
Name: cfg.Name,
|
||||
SSLMode: cfg.SSLMode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no PostgreSQL configuration found in %s", configFile)
|
||||
}
|
||||
|
||||
// getConfigFiles returns the list of config files to search
|
||||
func getConfigFiles() []string {
|
||||
if configFile := os.Getenv("DATABASE_CONFIG_FILE"); configFile != "" {
|
||||
return []string{configFile}
|
||||
}
|
||||
return []string{"database.yaml", "/vault/secrets/database.yaml"}
|
||||
}
|
151
database/pgdb/pool_test.go
Normal file
151
database/pgdb/pool_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package pgdb
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.ntppool.org/common/database"
|
||||
)
|
||||
|
||||
func TestCreatePoolConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *database.PostgresConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: &database.PostgresConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Pass: "testpass",
|
||||
Name: "testdb",
|
||||
SSLMode: "require",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid config with defaults",
|
||||
cfg: &database.PostgresConfig{
|
||||
Host: "localhost",
|
||||
User: "testuser",
|
||||
Pass: "testpass",
|
||||
Name: "testdb",
|
||||
// Port and SSLMode will use defaults
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing host",
|
||||
cfg: &database.PostgresConfig{
|
||||
User: "testuser",
|
||||
Pass: "testpass",
|
||||
Name: "testdb",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing user",
|
||||
cfg: &database.PostgresConfig{
|
||||
Host: "localhost",
|
||||
Pass: "testpass",
|
||||
Name: "testdb",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing database name",
|
||||
cfg: &database.PostgresConfig{
|
||||
Host: "localhost",
|
||||
User: "testuser",
|
||||
Pass: "testpass",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid sslmode",
|
||||
cfg: &database.PostgresConfig{
|
||||
Host: "localhost",
|
||||
User: "testuser",
|
||||
Pass: "testpass",
|
||||
Name: "testdb",
|
||||
SSLMode: "invalid",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
poolCfg, err := CreatePoolConfig(tt.cfg)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CreatePoolConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && poolCfg == nil {
|
||||
t.Error("CreatePoolConfig() returned nil config without error")
|
||||
}
|
||||
if !tt.wantErr && poolCfg != nil {
|
||||
// Verify config fields are set correctly
|
||||
if poolCfg.ConnConfig.Host != tt.cfg.Host && tt.cfg.Host != "" {
|
||||
t.Errorf("Expected Host=%s, got %s", tt.cfg.Host, poolCfg.ConnConfig.Host)
|
||||
}
|
||||
if poolCfg.ConnConfig.User != tt.cfg.User {
|
||||
t.Errorf("Expected User=%s, got %s", tt.cfg.User, poolCfg.ConnConfig.User)
|
||||
}
|
||||
if poolCfg.ConnConfig.Password != tt.cfg.Pass {
|
||||
t.Errorf("Expected Password to be set correctly")
|
||||
}
|
||||
if poolCfg.ConnConfig.Database != tt.cfg.Name {
|
||||
t.Errorf("Expected Database=%s, got %s", tt.cfg.Name, poolCfg.ConnConfig.Database)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultPoolOptions(t *testing.T) {
|
||||
opts := DefaultPoolOptions()
|
||||
|
||||
// Verify expected defaults
|
||||
if opts.MinConns != 0 {
|
||||
t.Errorf("Expected MinConns=0, got %d", opts.MinConns)
|
||||
}
|
||||
if opts.MaxConns != 25 {
|
||||
t.Errorf("Expected MaxConns=25, got %d", opts.MaxConns)
|
||||
}
|
||||
if opts.MaxConnLifetime != time.Hour {
|
||||
t.Errorf("Expected MaxConnLifetime=1h, got %v", opts.MaxConnLifetime)
|
||||
}
|
||||
if opts.MaxConnIdleTime != 30*time.Minute {
|
||||
t.Errorf("Expected MaxConnIdleTime=30m, got %v", opts.MaxConnIdleTime)
|
||||
}
|
||||
if opts.HealthCheckPeriod != time.Minute {
|
||||
t.Errorf("Expected HealthCheckPeriod=1m, got %v", opts.HealthCheckPeriod)
|
||||
}
|
||||
if len(opts.ConfigFiles) == 0 {
|
||||
t.Error("Expected ConfigFiles to be non-empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreatePoolConfigDefaults(t *testing.T) {
|
||||
// Test that defaults are applied correctly
|
||||
cfg := &database.PostgresConfig{
|
||||
Host: "localhost",
|
||||
User: "testuser",
|
||||
Pass: "testpass",
|
||||
Name: "testdb",
|
||||
// Port and SSLMode not set
|
||||
}
|
||||
|
||||
poolCfg, err := CreatePoolConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreatePoolConfig() failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify defaults were applied
|
||||
if poolCfg.ConnConfig.Port != 5432 {
|
||||
t.Errorf("Expected default Port=5432, got %d", poolCfg.ConnConfig.Port)
|
||||
}
|
||||
}
|
28
go.mod
28
go.mod
@@ -1,10 +1,11 @@
|
||||
module go.ntppool.org/common
|
||||
|
||||
go 1.23.5
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/abh/certman v0.4.0
|
||||
github.com/go-sql-driver/mysql v1.9.3
|
||||
github.com/jackc/pgx/v5 v5.7.6
|
||||
github.com/labstack/echo-contrib v0.17.2
|
||||
github.com/labstack/echo/v4 v4.13.3
|
||||
github.com/oklog/ulid/v2 v2.1.0
|
||||
@@ -32,9 +33,9 @@ require (
|
||||
go.opentelemetry.io/otel/sdk/log v0.9.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.33.0
|
||||
go.opentelemetry.io/otel/trace v1.33.0
|
||||
golang.org/x/mod v0.22.0
|
||||
golang.org/x/net v0.33.0
|
||||
golang.org/x/sync v0.10.0
|
||||
golang.org/x/mod v0.28.0
|
||||
golang.org/x/net v0.44.0
|
||||
golang.org/x/sync v0.17.0
|
||||
google.golang.org/grpc v1.69.2
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
@@ -44,22 +45,29 @@ require (
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/dmarkham/enumer v1.6.1 // indirect
|
||||
github.com/fsnotify/fsnotify v1.8.0 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/klauspost/compress v1.17.11 // indirect
|
||||
github.com/labstack/gommon v0.4.2 // indirect
|
||||
github.com/masaushi/accessory v0.6.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/pascaldekloe/name v1.0.0 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.22 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/prometheus/common v0.61.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/samber/lo v1.47.0 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasttemplate v1.2.2 // indirect
|
||||
@@ -70,11 +78,17 @@ require (
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.33.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.33.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.4.0 // indirect
|
||||
golang.org/x/crypto v0.31.0 // indirect
|
||||
golang.org/x/sys v0.28.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
golang.org/x/crypto v0.42.0 // indirect
|
||||
golang.org/x/sys v0.36.0 // indirect
|
||||
golang.org/x/text v0.29.0 // indirect
|
||||
golang.org/x/time v0.8.0 // indirect
|
||||
golang.org/x/tools v0.37.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 // indirect
|
||||
google.golang.org/protobuf v1.36.1 // indirect
|
||||
)
|
||||
|
||||
tool (
|
||||
github.com/dmarkham/enumer
|
||||
github.com/masaushi/accessory
|
||||
)
|
||||
|
46
go.sum
46
go.sum
@@ -4,6 +4,8 @@ github.com/abh/certman v0.4.0 h1:XHoDtb0YyRQPclaHMrBDlKTVZpNjTK6vhB0S3Bd/Sbs=
|
||||
github.com/abh/certman v0.4.0/go.mod h1:x8QhpKVZifmV1Hdiwdg9gLo2GMPAxezz1s3zrVnPs+I=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/bradleyjkemp/cupaloy/v2 v2.8.0 h1:any4BmKE+jGIaMpnU8YgH/I2LPiLBufr6oMMlVBbn9M=
|
||||
github.com/bradleyjkemp/cupaloy/v2 v2.8.0/go.mod h1:bm7JXdkRd4BHJk9HpwqAI8BoAY1lps46Enkdqw6aRX0=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
@@ -12,6 +14,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dmarkham/enumer v1.6.1 h1:aSc9awYtZL07TUueWs40QcHtxTvHTAwG0EqrNsK45w4=
|
||||
github.com/dmarkham/enumer v1.6.1/go.mod h1:yixql+kDDQRYqcuBM2n9Vlt7NoT9ixgXhaXry8vmRg8=
|
||||
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
|
||||
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
@@ -31,6 +35,14 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 h1:VNqngBF40hVlDloBruUehVYC3Ar
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1/go.mod h1:RBRO7fro65R6tjKzYgLAFo0t1QEXY1Dp+i/bvpRiqiQ=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk=
|
||||
github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU=
|
||||
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
|
||||
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
|
||||
@@ -46,6 +58,8 @@ github.com/labstack/echo/v4 v4.13.3 h1:pwhpCPrTl5qry5HRdM5FwdXnhXSLSY+WE+YQSeCaa
|
||||
github.com/labstack/echo/v4 v4.13.3/go.mod h1:o90YNEeQWjDozo584l7AwhJMHN0bOC4tAfg+Xox9q5g=
|
||||
github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0=
|
||||
github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU=
|
||||
github.com/masaushi/accessory v0.6.0 h1:HYAzxkuhfvlbaQwinxXTxsSPbFabAnwHt8K6I/DvNBU=
|
||||
github.com/masaushi/accessory v0.6.0/go.mod h1:8GZMgq3wcIapVZWt7VVQCh5+onPc/8gJeHb8WRXezvQ=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
@@ -55,6 +69,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/oklog/ulid/v2 v2.1.0 h1:+9lhoxAP56we25tyYETBBY1YLA2SaoLvUFgrP2miPJU=
|
||||
github.com/oklog/ulid/v2 v2.1.0/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ=
|
||||
github.com/pascaldekloe/name v1.0.0 h1:n7LKFgHixETzxpRv2R77YgPUFo85QHGZKrdaYm7eY5U=
|
||||
github.com/pascaldekloe/name v1.0.0/go.mod h1:Z//MfYJnH4jVpQ9wkclwu2I2MkHmXTlT9wR5UZScttM=
|
||||
github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o=
|
||||
github.com/pierrec/lz4/v4 v4.1.15/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU=
|
||||
@@ -84,12 +100,16 @@ github.com/samber/slog-multi v1.2.4 h1:k9x3JAWKJFPKffx+oXZ8TasaNuorIW4tG+TXxkt6R
|
||||
github.com/samber/slog-multi v1.2.4/go.mod h1:ACuZ5B6heK57TfMVkVknN2UZHoFfjCwRxR0Q2OXKHlo=
|
||||
github.com/segmentio/kafka-go v0.4.47 h1:IqziR4pA3vrZq7YdRxaT3w1/5fvIH5qpCwstUanQQB0=
|
||||
github.com/segmentio/kafka-go v0.4.47/go.mod h1:HjF6XbOKh0Pjlkr5GVZxt6CsjjwnmhVOfURM5KMd8qg=
|
||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
||||
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
|
||||
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
@@ -160,25 +180,25 @@ go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
|
||||
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4=
|
||||
golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/mod v0.28.0 h1:gQBtGhjxykdjY9YhZpSlZIsbnaE2+PgjfLWUQTnoZ1U=
|
||||
golang.org/x/mod v0.28.0/go.mod h1:yfB/L0NOf/kmEbXjzCPOx1iK1fRutOydrCMsqRhEBxI=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I=
|
||||
golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -189,8 +209,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
@@ -203,14 +223,16 @@ golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
|
||||
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
||||
golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg=
|
||||
golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.37.0 h1:DVSRzp7FwePZW356yEAChSdNcQo6Nsp+fex1SUW09lE=
|
||||
golang.org/x/tools v0.37.0/go.mod h1:MBN5QPQtLMHVdvsbtarmTNukZDdgwdwlO5qGacAzF0w=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8 h1:st3LcW/BPi75W4q1jJTEor/QWwbNlPlDG0JTn6XhZu0=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8/go.mod h1:klhJGKFyG8Tn50enBn7gizg4nXGXJ+jqEREdCWaPcV4=
|
||||
|
@@ -2,7 +2,7 @@
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
go install github.com/goreleaser/goreleaser/v2@v2.11.0
|
||||
go install github.com/goreleaser/goreleaser/v2@v2.12.3
|
||||
|
||||
if [ ! -z "${harbor_username:-}" ]; then
|
||||
DOCKER_FILE=~/.docker/config.json
|
||||
|
@@ -1,9 +1,10 @@
|
||||
// Package fastlyxff provides Fastly CDN IP range management for trusted proxy handling.
|
||||
//
|
||||
// This package parses Fastly's public IP ranges JSON file and generates Echo framework
|
||||
// trust options for proper client IP extraction from X-Forwarded-For headers.
|
||||
// It's designed specifically for services deployed behind Fastly's CDN that need
|
||||
// to identify real client IPs for logging, rate limiting, and security purposes.
|
||||
// This package parses Fastly's public IP ranges JSON file and provides middleware
|
||||
// for both Echo framework and standard net/http for proper client IP extraction
|
||||
// from X-Forwarded-For headers. It's designed specifically for services deployed
|
||||
// behind Fastly's CDN that need to identify real client IPs for logging, rate
|
||||
// limiting, and security purposes.
|
||||
//
|
||||
// Fastly publishes their edge server IP ranges in a JSON format that this package
|
||||
// consumes to automatically configure trusted proxy ranges. This ensures that
|
||||
@@ -14,8 +15,55 @@
|
||||
// - Automatic parsing of Fastly's IP ranges JSON format
|
||||
// - Support for both IPv4 and IPv6 address ranges
|
||||
// - Echo framework integration via TrustOption generation
|
||||
// - Standard net/http middleware support
|
||||
// - CIDR notation parsing and validation
|
||||
//
|
||||
// # Echo Framework Usage
|
||||
//
|
||||
// fastlyRanges, err := fastlyxff.New("fastly.json")
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// options, err := fastlyRanges.EchoTrustOption()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// e.IPExtractor = echo.ExtractIPFromXFFHeader(options...)
|
||||
//
|
||||
// # Net/HTTP Usage
|
||||
//
|
||||
// fastlyRanges, err := fastlyxff.New("fastly.json")
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// middleware := fastlyRanges.HTTPMiddleware()
|
||||
//
|
||||
// handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// // Both methods work - middleware updates r.RemoteAddr (with port 0) and stores in context
|
||||
// realIP := fastlyxff.GetRealIP(r) // Preferred method
|
||||
// // OR: host, _, _ := net.SplitHostPort(r.RemoteAddr) // Direct access (port will be "0")
|
||||
// fmt.Fprintf(w, "Real IP: %s\n", realIP)
|
||||
// })
|
||||
//
|
||||
// http.ListenAndServe(":8080", middleware(handler))
|
||||
//
|
||||
// # Net/HTTP with Additional Trusted Ranges
|
||||
//
|
||||
// fastlyRanges, err := fastlyxff.New("fastly.json")
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
//
|
||||
// // Add custom trusted CIDRs (e.g., internal load balancers)
|
||||
// // Note: For Echo framework, use the ekko package for additional ranges
|
||||
// err = fastlyRanges.AddTrustedCIDR("10.0.0.0/8")
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
//
|
||||
// middleware := fastlyRanges.HTTPMiddleware()
|
||||
// handler := middleware(yourHandler)
|
||||
//
|
||||
// The JSON file typically contains IP ranges in this format:
|
||||
//
|
||||
// {
|
||||
@@ -25,29 +73,36 @@
|
||||
package fastlyxff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// FastlyXFF represents Fastly's published IP ranges for their CDN edge servers.
|
||||
// This structure matches the JSON format provided by Fastly for their public IP ranges.
|
||||
// It contains separate lists for IPv4 and IPv6 CIDR ranges.
|
||||
// It contains separate lists for IPv4 and IPv6 CIDR ranges, plus additional trusted CIDRs.
|
||||
type FastlyXFF struct {
|
||||
IPv4 []string `json:"addresses"` // IPv4 CIDR ranges (e.g., "23.235.32.0/20")
|
||||
IPv6 []string `json:"ipv6_addresses"` // IPv6 CIDR ranges (e.g., "2a04:4e40::/32")
|
||||
IPv4 []string `json:"addresses"` // IPv4 CIDR ranges (e.g., "23.235.32.0/20")
|
||||
IPv6 []string `json:"ipv6_addresses"` // IPv6 CIDR ranges (e.g., "2a04:4e40::/32")
|
||||
extraCIDRs []string // Additional trusted CIDRs added via AddTrustedCIDR
|
||||
}
|
||||
|
||||
// TrustedNets holds parsed network prefixes for efficient IP range checking.
|
||||
// This type is currently unused but reserved for future optimizations
|
||||
// where frequent IP range lookups might benefit from pre-parsed prefixes.
|
||||
type TrustedNets struct {
|
||||
prefixes []netip.Prefix // Parsed network prefixes for efficient lookups
|
||||
}
|
||||
|
||||
// contextKey is used for storing the real client IP in request context
|
||||
type contextKey string
|
||||
|
||||
const realIPKey contextKey = "fastly-real-ip"
|
||||
|
||||
// New loads and parses Fastly IP ranges from a JSON file.
|
||||
// The file should contain Fastly's published IP ranges in their standard JSON format.
|
||||
//
|
||||
@@ -100,3 +155,116 @@ func (xff *FastlyXFF) EchoTrustOption() ([]echo.TrustOption, error) {
|
||||
|
||||
return ranges, nil
|
||||
}
|
||||
|
||||
// AddTrustedCIDR adds an additional CIDR to the list of trusted proxies.
|
||||
// This allows trusting proxies beyond Fastly's published ranges.
|
||||
// The cidr parameter must be a valid CIDR notation (e.g., "10.0.0.0/8", "192.168.1.0/24").
|
||||
// Returns an error if the CIDR format is invalid.
|
||||
func (xff *FastlyXFF) AddTrustedCIDR(cidr string) error {
|
||||
// Validate CIDR format
|
||||
_, _, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add to extra CIDRs
|
||||
xff.extraCIDRs = append(xff.extraCIDRs, cidr)
|
||||
return nil
|
||||
}
|
||||
|
||||
// isTrustedProxy checks if the given IP address belongs to Fastly's trusted IP ranges
|
||||
// or any additional CIDRs added via AddTrustedCIDR.
|
||||
func (xff *FastlyXFF) isTrustedProxy(ip string) bool {
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check all IPv4 and IPv6 ranges (Fastly + additional)
|
||||
allRanges := append(append(xff.IPv4, xff.IPv6...), xff.extraCIDRs...)
|
||||
for _, s := range allRanges {
|
||||
_, cidr, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if cidr.Contains(net.IP(addr.AsSlice())) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// extractRealIP extracts the real client IP from X-Forwarded-For header.
|
||||
// It returns the rightmost IP that is not from a trusted Fastly proxy.
|
||||
func (xff *FastlyXFF) extractRealIP(r *http.Request) string {
|
||||
// Get the immediate peer IP
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
host = r.RemoteAddr
|
||||
}
|
||||
|
||||
// If the immediate peer is not a trusted Fastly proxy, return it
|
||||
if !xff.isTrustedProxy(host) {
|
||||
return host
|
||||
}
|
||||
|
||||
// Check X-Forwarded-For header
|
||||
xff_header := r.Header.Get("X-Forwarded-For")
|
||||
if xff_header == "" {
|
||||
return host
|
||||
}
|
||||
|
||||
// Parse comma-separated IP list
|
||||
ips := strings.Split(xff_header, ",")
|
||||
if len(ips) == 0 {
|
||||
return host
|
||||
}
|
||||
|
||||
// Find the leftmost IP that is not from a trusted proxy
|
||||
// This represents the original client IP
|
||||
for i := 0; i < len(ips); i++ {
|
||||
ip := strings.TrimSpace(ips[i])
|
||||
if ip != "" && !xff.isTrustedProxy(ip) {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to the immediate peer
|
||||
return host
|
||||
}
|
||||
|
||||
// HTTPMiddleware returns a net/http middleware that extracts real client IP
|
||||
// from X-Forwarded-For headers when the request comes from trusted Fastly proxies.
|
||||
// The real IP is stored in the request context and also updates r.RemoteAddr
|
||||
// with port 0 (since the original port is from the proxy, not the real client).
|
||||
func (xff *FastlyXFF) HTTPMiddleware() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
realIP := xff.extractRealIP(r)
|
||||
|
||||
// Store in context for GetRealIP function
|
||||
ctx := context.WithValue(r.Context(), realIPKey, realIP)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// Update RemoteAddr to be consistent with extracted IP
|
||||
// Use port 0 since the original port is from the proxy, not the real client
|
||||
r.RemoteAddr = net.JoinHostPort(realIP, "0")
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// GetRealIP retrieves the real client IP from the request context.
|
||||
// This should be used after the HTTPMiddleware has processed the request.
|
||||
// Returns the remote address if no real IP was extracted.
|
||||
func GetRealIP(r *http.Request) string {
|
||||
if ip, ok := r.Context().Value(realIPKey).(string); ok {
|
||||
return ip
|
||||
}
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
@@ -1,6 +1,11 @@
|
||||
package fastlyxff
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFastlyIPRanges(t *testing.T) {
|
||||
fastlyxff, err := New("fastly.json")
|
||||
@@ -18,3 +23,334 @@ func TestFastlyIPRanges(t *testing.T) {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPMiddleware(t *testing.T) {
|
||||
// Create a test FastlyXFF instance with known IP ranges
|
||||
xff := &FastlyXFF{
|
||||
IPv4: []string{"192.0.2.0/24", "203.0.113.0/24"},
|
||||
IPv6: []string{"2001:db8::/32"},
|
||||
}
|
||||
|
||||
middleware := xff.HTTPMiddleware()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
xForwardedFor string
|
||||
expectedRealIP string
|
||||
}{
|
||||
{
|
||||
name: "direct connection",
|
||||
remoteAddr: "198.51.100.1:12345",
|
||||
xForwardedFor: "",
|
||||
expectedRealIP: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "trusted proxy with XFF",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expectedRealIP: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "trusted proxy with multiple XFF",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "198.51.100.1, 203.0.113.1",
|
||||
expectedRealIP: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "untrusted proxy ignored",
|
||||
remoteAddr: "198.51.100.2:80",
|
||||
xForwardedFor: "10.0.0.1",
|
||||
expectedRealIP: "198.51.100.2",
|
||||
},
|
||||
{
|
||||
name: "IPv6 trusted proxy",
|
||||
remoteAddr: "[2001:db8::1]:80",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expectedRealIP: "198.51.100.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create test handler that captures both GetRealIP and r.RemoteAddr
|
||||
var capturedRealIP, capturedRemoteAddr string
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedRealIP = GetRealIP(r)
|
||||
capturedRemoteAddr = r.RemoteAddr
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with middleware
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
if tt.xForwardedFor != "" {
|
||||
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
middleware(handler).ServeHTTP(rr, req)
|
||||
|
||||
// Test GetRealIP function
|
||||
if capturedRealIP != tt.expectedRealIP {
|
||||
t.Errorf("GetRealIP: expected %s, got %s", tt.expectedRealIP, capturedRealIP)
|
||||
}
|
||||
|
||||
// Test that r.RemoteAddr is updated with real IP and port 0
|
||||
// (since the original port is from the proxy, not the real client)
|
||||
expectedRemoteAddr := net.JoinHostPort(tt.expectedRealIP, "0")
|
||||
if capturedRemoteAddr != expectedRemoteAddr {
|
||||
t.Errorf("RemoteAddr: expected %s, got %s", expectedRemoteAddr, capturedRemoteAddr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTrustedProxy(t *testing.T) {
|
||||
xff := &FastlyXFF{
|
||||
IPv4: []string{"192.0.2.0/24", "203.0.113.0/24"},
|
||||
IPv6: []string{"2001:db8::/32"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
{"192.0.2.1", true},
|
||||
{"192.0.2.255", true},
|
||||
{"203.0.113.1", true},
|
||||
{"192.0.3.1", false},
|
||||
{"198.51.100.1", false},
|
||||
{"2001:db8::1", true},
|
||||
{"2001:db8:ffff::1", true},
|
||||
{"2001:db9::1", false},
|
||||
{"invalid-ip", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := xff.isTrustedProxy(tt.ip)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isTrustedProxy(%s) = %v, expected %v", tt.ip, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRealIP(t *testing.T) {
|
||||
xff := &FastlyXFF{
|
||||
IPv4: []string{"192.0.2.0/24"},
|
||||
IPv6: []string{"2001:db8::/32"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
xForwardedFor string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no XFF header",
|
||||
remoteAddr: "198.51.100.1:12345",
|
||||
xForwardedFor: "",
|
||||
expected: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "trusted proxy with single IP",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expected: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "trusted proxy with multiple IPs",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "198.51.100.1, 203.0.113.5",
|
||||
expected: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "untrusted proxy",
|
||||
remoteAddr: "198.51.100.1:80",
|
||||
xForwardedFor: "10.0.0.1",
|
||||
expected: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "empty XFF",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "",
|
||||
expected: "192.0.2.1",
|
||||
},
|
||||
{
|
||||
name: "malformed remote addr",
|
||||
remoteAddr: "192.0.2.1",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expected: "198.51.100.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
if tt.xForwardedFor != "" {
|
||||
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||
}
|
||||
|
||||
result := xff.extractRealIP(req)
|
||||
if result != tt.expected {
|
||||
t.Errorf("extractRealIP() = %s, expected %s", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRealIPWithoutMiddleware(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = "198.51.100.1:12345"
|
||||
|
||||
realIP := GetRealIP(req)
|
||||
expected := "198.51.100.1"
|
||||
if realIP != expected {
|
||||
t.Errorf("GetRealIP() = %s, expected %s", realIP, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddTrustedCIDR(t *testing.T) {
|
||||
xff := &FastlyXFF{
|
||||
IPv4: []string{"192.0.2.0/24"},
|
||||
IPv6: []string{"2001:db8::/32"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cidr string
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid IPv4 range", "10.0.0.0/8", false},
|
||||
{"valid IPv6 range", "fc00::/7", false},
|
||||
{"valid single IP", "203.0.113.1/32", false},
|
||||
{"invalid CIDR", "not-a-cidr", true},
|
||||
{"invalid format", "10.0.0.0/99", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := xff.AddTrustedCIDR(tt.cidr)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("AddTrustedCIDR(%s) error = %v, wantErr %v", tt.cidr, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomTrustedCIDRs(t *testing.T) {
|
||||
xff := &FastlyXFF{
|
||||
IPv4: []string{"192.0.2.0/24"},
|
||||
IPv6: []string{"2001:db8::/32"},
|
||||
}
|
||||
|
||||
// Add custom trusted CIDRs
|
||||
err := xff.AddTrustedCIDR("10.0.0.0/8")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add trusted CIDR: %v", err)
|
||||
}
|
||||
|
||||
err = xff.AddTrustedCIDR("172.16.0.0/12")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add trusted CIDR: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
// Original Fastly ranges
|
||||
{"192.0.2.1", true},
|
||||
{"2001:db8::1", true},
|
||||
// Custom CIDRs
|
||||
{"10.1.2.3", true},
|
||||
{"172.16.1.1", true},
|
||||
// Not trusted
|
||||
{"198.51.100.1", false},
|
||||
{"172.15.1.1", false},
|
||||
{"10.0.0.0", true}, // Network address should still match
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := xff.isTrustedProxy(tt.ip)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isTrustedProxy(%s) = %v, expected %v", tt.ip, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPMiddlewareWithCustomCIDRs(t *testing.T) {
|
||||
xff := &FastlyXFF{
|
||||
IPv4: []string{"192.0.2.0/24"},
|
||||
IPv6: []string{"2001:db8::/32"},
|
||||
}
|
||||
|
||||
// Add custom trusted CIDR for internal proxies
|
||||
err := xff.AddTrustedCIDR("10.0.0.0/8")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add trusted CIDR: %v", err)
|
||||
}
|
||||
|
||||
middleware := xff.HTTPMiddleware()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
xForwardedFor string
|
||||
expectedRealIP string
|
||||
}{
|
||||
{
|
||||
name: "custom trusted proxy with XFF",
|
||||
remoteAddr: "10.1.2.3:80",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expectedRealIP: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "fastly proxy with XFF",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expectedRealIP: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "untrusted proxy ignored",
|
||||
remoteAddr: "172.16.1.1:80",
|
||||
xForwardedFor: "198.51.100.1",
|
||||
expectedRealIP: "172.16.1.1",
|
||||
},
|
||||
{
|
||||
name: "chain through custom and fastly",
|
||||
remoteAddr: "192.0.2.1:80",
|
||||
xForwardedFor: "198.51.100.1, 10.1.2.3",
|
||||
expectedRealIP: "198.51.100.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedIP string
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedIP = GetRealIP(r)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
if tt.xForwardedFor != "" {
|
||||
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
middleware(handler).ServeHTTP(rr, req)
|
||||
|
||||
if capturedIP != tt.expectedRealIP {
|
||||
t.Errorf("expected real IP %s, got %s", tt.expectedRealIP, capturedIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user