9 Commits

Author SHA1 Message Date
6420d0b174 feat(config): add pool_domain and upgrade dependencies
- Add configurable pool_domain with pool.ntp.org default
- Update Go from 1.23.5 to 1.24.0
- Update golang.org/x/* dependencies
- Add enumer and accessory as tool dependencies
- Update goreleaser to v2.12.3
2025-10-04 08:25:54 -07:00
45308cd4bf feat(database): add PostgreSQL support with native pgx pool
Add PostgreSQL support to database package alongside existing MySQL support.
Both databases share common infrastructure (pool management, metrics,
transactions) while using database-specific connectors.

database/ changes:
- Add PostgresConfig struct and PostgreSQL connector using pgx/stdlib
- Change MySQL config from DBConfig to *MySQLConfig (pointer)
- Add Config.Validate() to prevent multiple database configs
- Add PostgreSQL connector with secure config building (no password in DSN)
- Add field validation and secure defaults (SSLMode="prefer")
- Support legacy flat PostgreSQL config format for backward compatibility
- Add tests for PostgreSQL configs and validation

New database/pgdb/ package:
- Native pgx connection pool support (*pgxpool.Pool)
- OpenPool() and OpenPoolWithConfigFile() APIs
- CreatePoolConfig() for secure config conversion
- PoolOptions for fine-grained pool control
- Full test coverage and documentation

Security:
- Passwords never exposed in DSN strings
- Set passwords separately in pgx config objects
- Validate all configuration before connection

Architecture:
- Shared code in database/ for both MySQL and PostgreSQL (sql.DB)
- database/pgdb/ for PostgreSQL-specific native pool support
2025-09-27 16:55:54 -07:00
4767caf7b8 feat(xff): add AddTrustedCIDR for custom proxies
- Add AddTrustedCIDR() method to support non-Fastly proxies
- Enable trusting custom CIDR ranges (e.g., 10.0.0.0/8)
- Validate CIDR format before adding to trusted list
- Maintain backward compatibility with Fastly-only usage

Allows mixed proxy environments where requests pass through
both Fastly CDN and custom internal proxies/load balancers.
Uses precise CIDR terminology instead of generic "range".
2025-09-27 14:46:02 -07:00
f90281f472 feat(xff): add net/http middleware support
- Add HTTPMiddleware() method for standard net/http handlers
- Add GetRealIP() helper to extract client IP from context
- Update r.RemoteAddr with real IP and port 0 (proxy port invalid)
- Support both IPv4 and IPv6 Fastly IP range validation
- Maintain backward compatibility with existing Echo support

The middleware extracts real client IPs from X-Forwarded-For
headers when requests come from trusted Fastly proxy ranges.
2025-09-27 13:41:12 -07:00
ca190b0085 docs: add v0.5.2 release notes
Add changelog entries for recent commits:
- Health package: Kubernetes health probes
- Logger package: runtime level control and fixes
- Database package: config file override support
2025-09-21 12:10:05 -07:00
10864363e2 feat(health): enhance server with probe-specific handlers
- Add separate handlers for liveness (/healthz), readiness (/readyz),
  and startup (/startupz) probes
- Implement WithLivenessHandler, WithReadinessHandler, WithStartupHandler,
  and WithServiceName options
- Add probe-specific JSON response formats
- Add comprehensive package documentation with usage examples
- Maintain backward compatibility for /__health and / endpoints
- Add tests for all probe types and fallback scenarios

Enables proper Kubernetes health monitoring with different probe types.
2025-09-21 10:52:29 -07:00
66b51df2af feat(logger): add runtime log level control API
Add independent log level control for stderr and OTLP loggers.
Both can be configured via environment variables or programmatically
at runtime.

- Add SetLevel() and SetOTLPLevel() for runtime control
- Add ParseLevel() to convert strings to slog.Level
- Support LOG_LEVEL and OTLP_LOG_LEVEL env vars
- Maintain backward compatibility with DEBUG env var
- Add comprehensive test coverage
2025-09-06 05:21:33 -07:00
28d05d1d0e feat(database): add DATABASE_CONFIG_FILE env override
Allow overriding default database.yaml paths via DATABASE_CONFIG_FILE
environment variable. When set, uses single specified file instead of
default ["database.yaml", "/vault/secrets/database.yaml"] search paths.

Maintains backward compatibility when env var not set.
2025-08-03 12:20:35 -07:00
a774f92bf7 fix(logger): prevent mutex crash in bufferingExporter
Remove sync.Once reset that caused "unlock of unlocked mutex" panic.
Redesign initialization to use only checkReadiness goroutine for
retry attempts, eliminating race condition while preserving retry
functionality for TLS/tracing setup delays.
2025-08-02 22:55:57 -07:00
24 changed files with 2257 additions and 118 deletions

1
.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
.aider*

28
.mcp.json Normal file
View File

@@ -0,0 +1,28 @@
{
"mcpServers": {
"context7": {
"type": "stdio",
"command": "npx",
"args": [
"-y",
"@upstash/context7-mcp@1.0.0"
],
"env": {}
},
"serena": {
"type": "stdio",
"command": "uvx",
"args": [
"--from",
"git+https://github.com/oraios/serena@v0.1.4",
"serena",
"start-mcp-server",
"--context",
"ide-assistant",
"--project",
"."
],
"env": {}
}
}
}

20
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,20 @@
---
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-added-large-files
args: ["--maxkb=400"]
- id: check-case-conflict
- id: check-executables-have-shebangs
- id: check-shebang-scripts-are-executable
- id: check-merge-conflict
- id: check-symlinks
- repo: https://github.com/adrienverge/yamllint
rev: v1.35.1
hooks:
- id: yamllint
args: [-c=.yamllint]

14
.yamllint Normal file
View File

@@ -0,0 +1,14 @@
---
extends: relaxed
rules:
braces:
level: error
brackets:
level: error
truthy:
level: warning
#ignore: |
# - ...

View File

@@ -1,3 +1,22 @@
# Release Notes - v0.5.2
## Health Package
- **Kubernetes-native health probes** - Added dedicated handlers for liveness (`/healthz`), readiness (`/readyz`), and startup (`/startupz`) probes
- **Flexible configuration options** - New `WithLivenessHandler`, `WithReadinessHandler`, `WithStartupHandler`, and `WithServiceName` options
- **JSON response formats** - Structured probe responses with service identification
- **Backward compatibility** - Maintains existing `/__health` and `/` endpoints
## Logger Package
- **Runtime log level control** - Independent level management for stderr and OTLP loggers via `SetLevel()` and `SetOTLPLevel()`
- **Environment variable support** - Configure levels with `LOG_LEVEL` and `OTLP_LOG_LEVEL` env vars
- **String parsing utility** - New `ParseLevel()` function for converting string levels to `slog.Level`
- **Buffering exporter fix** - Resolved "unlock of unlocked mutex" panic in `bufferingExporter`
- **Initialization redesign** - Eliminated race conditions in TLS/tracing setup retry logic
## Database Package
- **Configuration file override** - Added `DATABASE_CONFIG_FILE` environment variable to specify custom database configuration file paths
- **Flexible path configuration** - Override default `["database.yaml", "/vault/secrets/database.yaml"]` search paths when needed
# Release Notes - v0.5.1
## Observability Enhancements

View File

@@ -39,6 +39,8 @@ type Config struct {
webHostnames []string
webTLS bool
poolDomain string `accessor:"getter"`
valid bool `accessor:"getter"`
}
@@ -52,6 +54,7 @@ type Config struct {
// - web_hostname: Comma-separated web hostnames (first becomes primary)
// - manage_tls: Management interface TLS setting
// - web_tls: Web interface TLS setting
// - pool_domain: NTP pool domain (default: pool.ntp.org)
func New() *Config {
c := Config{}
c.deploymentMode = os.Getenv("deployment_mode")
@@ -69,6 +72,11 @@ func New() *Config {
c.manageTLS = parseBool(os.Getenv("manage_tls"))
c.webTLS = parseBool(os.Getenv("web_tls"))
c.poolDomain = os.Getenv("pool_domain")
if c.poolDomain == "" {
c.poolDomain = "pool.ntp.org"
}
return &c
}

View File

@@ -1,6 +1,8 @@
package database
import (
"fmt"
"os"
"time"
"github.com/prometheus/client_golang/prometheus"
@@ -8,15 +10,67 @@ import (
// Config represents the database configuration structure
type Config struct {
MySQL DBConfig `yaml:"mysql"`
// MySQL configuration (use this OR Postgres, not both)
MySQL *MySQLConfig `yaml:"mysql,omitempty"`
// Postgres configuration (use this OR MySQL, not both)
Postgres *PostgresConfig `yaml:"postgres,omitempty"`
// Legacy flat PostgreSQL format (deprecated, for backward compatibility only)
// If neither MySQL nor Postgres is set, these fields will be used for PostgreSQL
User string `yaml:"user,omitempty"`
Pass string `yaml:"pass,omitempty"`
Host string `yaml:"host,omitempty"`
Port uint16 `yaml:"port,omitempty"`
Name string `yaml:"name,omitempty"`
SSLMode string `yaml:"sslmode,omitempty"`
}
// DBConfig represents the MySQL database configuration
type DBConfig struct {
DSN string `default:"" flag:"dsn" usage:"Database DSN"`
User string `default:"" flag:"user"`
Pass string `default:"" flag:"pass"`
DBName string // Optional database name override
// MySQLConfig represents the MySQL database configuration
type MySQLConfig struct {
DSN string `yaml:"dsn" default:"" flag:"dsn" usage:"Database DSN"`
User string `yaml:"user" default:"" flag:"user"`
Pass string `yaml:"pass" default:"" flag:"pass"`
DBName string `yaml:"name,omitempty"` // Optional database name override
}
// PostgresConfig represents the PostgreSQL database configuration
type PostgresConfig struct {
User string `yaml:"user"`
Pass string `yaml:"pass"`
Host string `yaml:"host"`
Port uint16 `yaml:"port"`
Name string `yaml:"name"`
SSLMode string `yaml:"sslmode"`
}
// DBConfig is a legacy alias for MySQLConfig
type DBConfig = MySQLConfig
// Validate ensures the configuration is valid and unambiguous
func (c *Config) Validate() error {
hasMySQL := c.MySQL != nil
hasPostgres := c.Postgres != nil
hasLegacy := c.User != "" || c.Host != "" || c.Port != 0 || c.Name != ""
count := 0
if hasMySQL {
count++
}
if hasPostgres {
count++
}
if hasLegacy {
count++
}
if count == 0 {
return fmt.Errorf("no database configuration provided")
}
if count > 1 {
return fmt.Errorf("multiple database configurations provided (only one allowed)")
}
return nil
}
// ConfigOptions allows customization of database opening behavior
@@ -36,10 +90,20 @@ type ConfigOptions struct {
ConnMaxLifetime time.Duration
}
// getConfigFiles returns the list of config files to search for database configuration.
// If DATABASE_CONFIG_FILE environment variable is set, it returns that single file.
// Otherwise, it returns the default paths.
func getConfigFiles() []string {
if configFile := os.Getenv("DATABASE_CONFIG_FILE"); configFile != "" {
return []string{configFile}
}
return []string{"database.yaml", "/vault/secrets/database.yaml"}
}
// DefaultConfigOptions returns the standard configuration options used by API package
func DefaultConfigOptions() ConfigOptions {
return ConfigOptions{
ConfigFiles: []string{"database.yaml", "/vault/secrets/database.yaml"},
ConfigFiles: getConfigFiles(),
EnablePoolMonitoring: true,
PrometheusRegisterer: prometheus.DefaultRegisterer,
MaxOpenConns: 25,
@@ -51,7 +115,7 @@ func DefaultConfigOptions() ConfigOptions {
// MonitorConfigOptions returns configuration options optimized for Monitor package
func MonitorConfigOptions() ConfigOptions {
return ConfigOptions{
ConfigFiles: []string{"database.yaml", "/vault/secrets/database.yaml"},
ConfigFiles: getConfigFiles(),
EnablePoolMonitoring: false, // Monitor doesn't need metrics
PrometheusRegisterer: nil, // No Prometheus dependency
MaxOpenConns: 10,

View File

@@ -56,9 +56,9 @@ func TestMonitorConfigOptions(t *testing.T) {
}
func TestConfigStructures(t *testing.T) {
// Test that configuration structures can be created and populated
// Test that MySQL configuration structures can be created and populated
config := Config{
MySQL: DBConfig{
MySQL: &MySQLConfig{
DSN: "user:pass@tcp(localhost:3306)/dbname",
User: "testuser",
Pass: "testpass",
@@ -79,3 +79,118 @@ func TestConfigStructures(t *testing.T) {
t.Errorf("Expected DBName='testdb', got '%s'", config.MySQL.DBName)
}
}
func TestPostgresConfigStructures(t *testing.T) {
// Test that PostgreSQL configuration structures can be created and populated
config := Config{
Postgres: &PostgresConfig{
Host: "localhost",
Port: 5432,
User: "testuser",
Pass: "testpass",
Name: "testdb",
SSLMode: "require",
},
}
if config.Postgres.Host != "localhost" {
t.Errorf("Expected Host='localhost', got '%s'", config.Postgres.Host)
}
if config.Postgres.Port != 5432 {
t.Errorf("Expected Port=5432, got %d", config.Postgres.Port)
}
if config.Postgres.User != "testuser" {
t.Errorf("Expected User='testuser', got '%s'", config.Postgres.User)
}
if config.Postgres.Pass != "testpass" {
t.Errorf("Expected Pass='testpass', got '%s'", config.Postgres.Pass)
}
if config.Postgres.Name != "testdb" {
t.Errorf("Expected Name='testdb', got '%s'", config.Postgres.Name)
}
if config.Postgres.SSLMode != "require" {
t.Errorf("Expected SSLMode='require', got '%s'", config.Postgres.SSLMode)
}
}
func TestLegacyPostgresConfig(t *testing.T) {
// Test that legacy flat PostgreSQL format can be created
config := Config{
User: "testuser",
Pass: "testpass",
Host: "localhost",
Port: 5432,
Name: "testdb",
SSLMode: "require",
}
if config.User != "testuser" {
t.Errorf("Expected User='testuser', got '%s'", config.User)
}
if config.Name != "testdb" {
t.Errorf("Expected Name='testdb', got '%s'", config.Name)
}
}
func TestConfigValidation(t *testing.T) {
tests := []struct {
name string
config Config
wantErr bool
}{
{
name: "valid mysql config",
config: Config{
MySQL: &MySQLConfig{DSN: "test"},
},
wantErr: false,
},
{
name: "valid postgres config",
config: Config{
Postgres: &PostgresConfig{User: "test", Host: "localhost", Name: "test"},
},
wantErr: false,
},
{
name: "valid legacy postgres config",
config: Config{
User: "test",
Host: "localhost",
Name: "testdb",
},
wantErr: false,
},
{
name: "both mysql and postgres set",
config: Config{
MySQL: &MySQLConfig{DSN: "test"},
Postgres: &PostgresConfig{User: "test"},
},
wantErr: true,
},
{
name: "mysql and legacy postgres set",
config: Config{
MySQL: &MySQLConfig{DSN: "test"},
User: "test",
Name: "testdb",
},
wantErr: true,
},
{
name: "no config set",
config: Config{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -8,6 +8,8 @@ import (
"os"
"github.com/go-sql-driver/mysql"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/stdlib"
"gopkg.in/yaml.v3"
)
@@ -58,31 +60,128 @@ func createConnector(configFile string) CreateConnectorFunc {
return nil, err
}
dsn := cfg.MySQL.DSN
if len(dsn) == 0 {
dsn = os.Getenv("DATABASE_DSN")
if len(dsn) == 0 {
return nil, fmt.Errorf("dsn config in database.yaml or DATABASE_DSN environment variable required")
}
// Validate configuration
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("invalid configuration: %w", err)
}
dbcfg, err := mysql.ParseDSN(dsn)
if err != nil {
return nil, err
// Determine database type and create appropriate connector
if cfg.MySQL != nil {
return createMySQLConnector(cfg.MySQL)
} else if cfg.Postgres != nil {
return createPostgresConnector(cfg.Postgres)
} else if cfg.User != "" && cfg.Name != "" {
// Legacy flat PostgreSQL format (requires at minimum user and dbname)
return createPostgresConnectorFromFlat(&cfg)
}
if user := cfg.MySQL.User; len(user) > 0 {
dbcfg.User = user
}
if pass := cfg.MySQL.Pass; len(pass) > 0 {
dbcfg.Passwd = pass
}
if name := cfg.MySQL.DBName; len(name) > 0 {
dbcfg.DBName = name
}
return mysql.NewConnector(dbcfg)
return nil, fmt.Errorf("no valid database configuration found (mysql or postgres section required)")
}
}
// createMySQLConnector creates a MySQL connector from configuration
func createMySQLConnector(cfg *MySQLConfig) (driver.Connector, error) {
dsn := cfg.DSN
if len(dsn) == 0 {
dsn = os.Getenv("DATABASE_DSN")
if len(dsn) == 0 {
return nil, fmt.Errorf("dsn config in database.yaml or DATABASE_DSN environment variable required")
}
}
dbcfg, err := mysql.ParseDSN(dsn)
if err != nil {
return nil, err
}
if user := cfg.User; len(user) > 0 {
dbcfg.User = user
}
if pass := cfg.Pass; len(pass) > 0 {
dbcfg.Passwd = pass
}
if name := cfg.DBName; len(name) > 0 {
dbcfg.DBName = name
}
return mysql.NewConnector(dbcfg)
}
// createPostgresConnector creates a PostgreSQL connector from configuration
func createPostgresConnector(cfg *PostgresConfig) (driver.Connector, error) {
// Validate required fields
if cfg.Host == "" {
return nil, fmt.Errorf("postgres: host is required")
}
if cfg.User == "" {
return nil, fmt.Errorf("postgres: user is required")
}
if cfg.Name == "" {
return nil, fmt.Errorf("postgres: database name is required")
}
// Validate SSLMode
validSSLModes := map[string]bool{
"disable": true, "allow": true, "prefer": true,
"require": true, "verify-ca": true, "verify-full": true,
}
if cfg.SSLMode != "" && !validSSLModes[cfg.SSLMode] {
return nil, fmt.Errorf("postgres: invalid sslmode: %s", cfg.SSLMode)
}
// Build config directly (security: no DSN string with password)
connConfig, err := pgx.ParseConfig("")
if err != nil {
return nil, fmt.Errorf("postgres: failed to create pgx config: %w", err)
}
connConfig.Host = cfg.Host
connConfig.Port = cfg.Port
connConfig.User = cfg.User
connConfig.Password = cfg.Pass
connConfig.Database = cfg.Name
// Map SSLMode to pgx configuration
// Note: pgx uses different SSL handling than libpq
// For now, we'll construct a minimal DSN with sslmode for ParseConfig
if cfg.SSLMode != "" {
// Reconstruct with sslmode only (no password in DSN)
dsnWithoutPassword := fmt.Sprintf("host=%s port=%d user=%s dbname=%s sslmode=%s",
cfg.Host, cfg.Port, cfg.User, cfg.Name, cfg.SSLMode)
connConfig, err = pgx.ParseConfig(dsnWithoutPassword)
if err != nil {
return nil, fmt.Errorf("postgres: failed to parse config with sslmode: %w", err)
}
// Set password separately after parsing
connConfig.Password = cfg.Pass
}
return stdlib.GetConnector(*connConfig), nil
}
// createPostgresConnectorFromFlat creates a PostgreSQL connector from flat config format
func createPostgresConnectorFromFlat(cfg *Config) (driver.Connector, error) {
pgCfg := &PostgresConfig{
User: cfg.User,
Pass: cfg.Pass,
Host: cfg.Host,
Port: cfg.Port,
Name: cfg.Name,
SSLMode: cfg.SSLMode,
}
// Set defaults for PostgreSQL
if pgCfg.Host == "" {
pgCfg.Host = "localhost"
}
if pgCfg.Port == 0 {
pgCfg.Port = 5432
}
if pgCfg.SSLMode == "" {
pgCfg.SSLMode = "prefer"
}
return createPostgresConnector(pgCfg)
}

120
database/pgdb/CLAUDE.md Normal file
View File

@@ -0,0 +1,120 @@
# pgdb - Native PostgreSQL Connection Pool
Primary package for PostgreSQL connections using native pgx pool (`*pgxpool.Pool`). Provides better performance and PostgreSQL-specific features compared to `database/sql`.
## Usage
### Basic Example
```go
import (
"context"
"go.ntppool.org/common/database/pgdb"
)
func main() {
ctx := context.Background()
// Open pool with default options
pool, err := pgdb.OpenPool(ctx, pgdb.DefaultPoolOptions())
if err != nil {
log.Fatal(err)
}
defer pool.Close()
// Use the pool for queries
row := pool.QueryRow(ctx, "SELECT version()")
var version string
row.Scan(&version)
}
```
### With Custom Config File
```go
pool, err := pgdb.OpenPoolWithConfigFile(ctx, "/path/to/database.yaml")
```
### With Custom Pool Settings
```go
opts := pgdb.DefaultPoolOptions()
opts.MaxConns = 50
opts.MinConns = 5
opts.MaxConnLifetime = 2 * time.Hour
pool, err := pgdb.OpenPool(ctx, opts)
```
## Configuration Format
### Recommended: Nested Format (database.yaml)
```yaml
postgres:
host: localhost
port: 5432
user: myuser
pass: mypassword
name: mydb
sslmode: prefer
```
### Legacy: Flat Format (backward compatible)
```yaml
host: localhost
port: 5432
user: myuser
pass: mypassword
name: mydb
sslmode: prefer
```
## Configuration Options
### PoolOptions
- `ConfigFiles` - List of config file paths to search (default: `database.yaml`, `/vault/secrets/database.yaml`)
- `MinConns` - Minimum connections (default: 0)
- `MaxConns` - Maximum connections (default: 25)
- `MaxConnLifetime` - Connection lifetime (default: 1 hour)
- `MaxConnIdleTime` - Idle timeout (default: 30 minutes)
- `HealthCheckPeriod` - Health check interval (default: 1 minute)
### PostgreSQL Config Fields
- `host` - Database host (required)
- `user` - Database user (required)
- `pass` - Database password
- `name` - Database name (required)
- `port` - Port number (default: 5432)
- `sslmode` - SSL mode: `disable`, `allow`, `prefer`, `require`, `verify-ca`, `verify-full` (default: `prefer`)
## Environment Variables
- `DATABASE_CONFIG_FILE` - Override config file location
## When to Use
**Use `pgdb.OpenPool()`** (this package) when:
- You need native PostgreSQL features (LISTEN/NOTIFY, COPY, etc.)
- You want better performance
- You're writing new PostgreSQL code
**Use `database.OpenDB()`** (sql.DB) when:
- You need database-agnostic code
- You're using SQLC or other tools that expect `database/sql`
- You need to support both MySQL and PostgreSQL
## Security
This package avoids password exposure by:
1. Never constructing DSN strings with passwords
2. Setting passwords separately in pgx config objects
3. Validating all configuration before connection
## See Also
- `database/` - Generic sql.DB support for MySQL and PostgreSQL
- pgx documentation: https://github.com/jackc/pgx

64
database/pgdb/config.go Normal file
View File

@@ -0,0 +1,64 @@
package pgdb
import (
"fmt"
"github.com/jackc/pgx/v5/pgxpool"
"go.ntppool.org/common/database"
)
// CreatePoolConfig converts database.PostgresConfig to pgxpool.Config
// This is the secure way to create a config without exposing passwords in DSN strings
func CreatePoolConfig(cfg *database.PostgresConfig) (*pgxpool.Config, error) {
// Validate required fields
if cfg.Host == "" {
return nil, fmt.Errorf("postgres: host is required")
}
if cfg.User == "" {
return nil, fmt.Errorf("postgres: user is required")
}
if cfg.Name == "" {
return nil, fmt.Errorf("postgres: database name is required")
}
// Validate SSLMode
validSSLModes := map[string]bool{
"disable": true, "allow": true, "prefer": true,
"require": true, "verify-ca": true, "verify-full": true,
}
if cfg.SSLMode != "" && !validSSLModes[cfg.SSLMode] {
return nil, fmt.Errorf("postgres: invalid sslmode: %s", cfg.SSLMode)
}
// Set defaults
host := cfg.Host
if host == "" {
host = "localhost"
}
port := cfg.Port
if port == 0 {
port = 5432
}
sslmode := cfg.SSLMode
if sslmode == "" {
sslmode = "prefer"
}
// Build connection string WITHOUT password
// We'll set the password separately in the config
connString := fmt.Sprintf("host=%s port=%d user=%s dbname=%s sslmode=%s",
host, port, cfg.User, cfg.Name, sslmode)
// Parse the connection string
poolConfig, err := pgxpool.ParseConfig(connString)
if err != nil {
return nil, fmt.Errorf("postgres: failed to parse connection config: %w", err)
}
// Set password separately (security: never put password in the connection string)
poolConfig.ConnConfig.Password = cfg.Pass
return poolConfig, nil
}

173
database/pgdb/pool.go Normal file
View File

@@ -0,0 +1,173 @@
package pgdb
import (
"context"
"fmt"
"os"
"time"
"github.com/jackc/pgx/v5/pgxpool"
"go.ntppool.org/common/database"
"gopkg.in/yaml.v3"
)
// PoolOptions configures pgxpool connection behavior
type PoolOptions struct {
// ConfigFiles is a list of config file paths to search for database configuration
ConfigFiles []string
// MinConns is the minimum number of connections in the pool
// Default: 0 (no minimum)
MinConns int32
// MaxConns is the maximum number of connections in the pool
// Default: 25
MaxConns int32
// MaxConnLifetime is the maximum lifetime of a connection
// Default: 1 hour
MaxConnLifetime time.Duration
// MaxConnIdleTime is the maximum idle time of a connection
// Default: 30 minutes
MaxConnIdleTime time.Duration
// HealthCheckPeriod is how often to check connection health
// Default: 1 minute
HealthCheckPeriod time.Duration
}
// DefaultPoolOptions returns sensible defaults for pgxpool
func DefaultPoolOptions() PoolOptions {
return PoolOptions{
ConfigFiles: getConfigFiles(),
MinConns: 0,
MaxConns: 25,
MaxConnLifetime: time.Hour,
MaxConnIdleTime: 30 * time.Minute,
HealthCheckPeriod: time.Minute,
}
}
// OpenPool opens a native pgx connection pool with the specified configuration
// This is the primary and recommended way to connect to PostgreSQL
func OpenPool(ctx context.Context, options PoolOptions) (*pgxpool.Pool, error) {
// Find and read config file
pgCfg, err := findAndParseConfig(options.ConfigFiles)
if err != nil {
return nil, err
}
// Create pool config from PostgreSQL config
poolConfig, err := CreatePoolConfig(pgCfg)
if err != nil {
return nil, err
}
// Apply pool-specific settings
poolConfig.MinConns = options.MinConns
poolConfig.MaxConns = options.MaxConns
poolConfig.MaxConnLifetime = options.MaxConnLifetime
poolConfig.MaxConnIdleTime = options.MaxConnIdleTime
poolConfig.HealthCheckPeriod = options.HealthCheckPeriod
// Create the pool
pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %w", err)
}
// Test the connection
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return pool, nil
}
// OpenPoolWithConfigFile opens a connection pool using an explicit config file path
// This is a convenience function for when you have a specific config file
func OpenPoolWithConfigFile(ctx context.Context, configFile string) (*pgxpool.Pool, error) {
options := DefaultPoolOptions()
options.ConfigFiles = []string{configFile}
return OpenPool(ctx, options)
}
// findAndParseConfig searches for and parses the first existing config file
func findAndParseConfig(configFiles []string) (*database.PostgresConfig, error) {
var firstErr error
for _, configFile := range configFiles {
if configFile == "" {
continue
}
// Check if file exists
if _, err := os.Stat(configFile); err != nil {
if firstErr == nil {
firstErr = err
}
continue
}
// Try to read and parse the file
pgCfg, err := parseConfigFile(configFile)
if err != nil {
if firstErr == nil {
firstErr = err
}
continue
}
return pgCfg, nil
}
if firstErr != nil {
return nil, fmt.Errorf("no config file found: %w", firstErr)
}
return nil, fmt.Errorf("no valid config files provided")
}
// parseConfigFile reads and parses a YAML config file
func parseConfigFile(configFile string) (*database.PostgresConfig, error) {
file, err := os.Open(configFile)
if err != nil {
return nil, fmt.Errorf("failed to open config file: %w", err)
}
defer file.Close()
dec := yaml.NewDecoder(file)
cfg := database.Config{}
if err := dec.Decode(&cfg); err != nil {
return nil, fmt.Errorf("failed to decode config: %w", err)
}
// Extract PostgreSQL config
if cfg.Postgres != nil {
return cfg.Postgres, nil
}
// Check for legacy flat format
if cfg.User != "" && cfg.Name != "" {
return &database.PostgresConfig{
User: cfg.User,
Pass: cfg.Pass,
Host: cfg.Host,
Port: cfg.Port,
Name: cfg.Name,
SSLMode: cfg.SSLMode,
}, nil
}
return nil, fmt.Errorf("no PostgreSQL configuration found in %s", configFile)
}
// getConfigFiles returns the list of config files to search
func getConfigFiles() []string {
if configFile := os.Getenv("DATABASE_CONFIG_FILE"); configFile != "" {
return []string{configFile}
}
return []string{"database.yaml", "/vault/secrets/database.yaml"}
}

151
database/pgdb/pool_test.go Normal file
View File

@@ -0,0 +1,151 @@
package pgdb
import (
"testing"
"time"
"go.ntppool.org/common/database"
)
func TestCreatePoolConfig(t *testing.T) {
tests := []struct {
name string
cfg *database.PostgresConfig
wantErr bool
}{
{
name: "valid config",
cfg: &database.PostgresConfig{
Host: "localhost",
Port: 5432,
User: "testuser",
Pass: "testpass",
Name: "testdb",
SSLMode: "require",
},
wantErr: false,
},
{
name: "valid config with defaults",
cfg: &database.PostgresConfig{
Host: "localhost",
User: "testuser",
Pass: "testpass",
Name: "testdb",
// Port and SSLMode will use defaults
},
wantErr: false,
},
{
name: "missing host",
cfg: &database.PostgresConfig{
User: "testuser",
Pass: "testpass",
Name: "testdb",
},
wantErr: true,
},
{
name: "missing user",
cfg: &database.PostgresConfig{
Host: "localhost",
Pass: "testpass",
Name: "testdb",
},
wantErr: true,
},
{
name: "missing database name",
cfg: &database.PostgresConfig{
Host: "localhost",
User: "testuser",
Pass: "testpass",
},
wantErr: true,
},
{
name: "invalid sslmode",
cfg: &database.PostgresConfig{
Host: "localhost",
User: "testuser",
Pass: "testpass",
Name: "testdb",
SSLMode: "invalid",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
poolCfg, err := CreatePoolConfig(tt.cfg)
if (err != nil) != tt.wantErr {
t.Errorf("CreatePoolConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && poolCfg == nil {
t.Error("CreatePoolConfig() returned nil config without error")
}
if !tt.wantErr && poolCfg != nil {
// Verify config fields are set correctly
if poolCfg.ConnConfig.Host != tt.cfg.Host && tt.cfg.Host != "" {
t.Errorf("Expected Host=%s, got %s", tt.cfg.Host, poolCfg.ConnConfig.Host)
}
if poolCfg.ConnConfig.User != tt.cfg.User {
t.Errorf("Expected User=%s, got %s", tt.cfg.User, poolCfg.ConnConfig.User)
}
if poolCfg.ConnConfig.Password != tt.cfg.Pass {
t.Errorf("Expected Password to be set correctly")
}
if poolCfg.ConnConfig.Database != tt.cfg.Name {
t.Errorf("Expected Database=%s, got %s", tt.cfg.Name, poolCfg.ConnConfig.Database)
}
}
})
}
}
func TestDefaultPoolOptions(t *testing.T) {
opts := DefaultPoolOptions()
// Verify expected defaults
if opts.MinConns != 0 {
t.Errorf("Expected MinConns=0, got %d", opts.MinConns)
}
if opts.MaxConns != 25 {
t.Errorf("Expected MaxConns=25, got %d", opts.MaxConns)
}
if opts.MaxConnLifetime != time.Hour {
t.Errorf("Expected MaxConnLifetime=1h, got %v", opts.MaxConnLifetime)
}
if opts.MaxConnIdleTime != 30*time.Minute {
t.Errorf("Expected MaxConnIdleTime=30m, got %v", opts.MaxConnIdleTime)
}
if opts.HealthCheckPeriod != time.Minute {
t.Errorf("Expected HealthCheckPeriod=1m, got %v", opts.HealthCheckPeriod)
}
if len(opts.ConfigFiles) == 0 {
t.Error("Expected ConfigFiles to be non-empty")
}
}
func TestCreatePoolConfigDefaults(t *testing.T) {
// Test that defaults are applied correctly
cfg := &database.PostgresConfig{
Host: "localhost",
User: "testuser",
Pass: "testpass",
Name: "testdb",
// Port and SSLMode not set
}
poolCfg, err := CreatePoolConfig(cfg)
if err != nil {
t.Fatalf("CreatePoolConfig() failed: %v", err)
}
// Verify defaults were applied
if poolCfg.ConnConfig.Port != 5432 {
t.Errorf("Expected default Port=5432, got %d", poolCfg.ConnConfig.Port)
}
}

28
go.mod
View File

@@ -1,10 +1,11 @@
module go.ntppool.org/common
go 1.23.5
go 1.24.0
require (
github.com/abh/certman v0.4.0
github.com/go-sql-driver/mysql v1.9.3
github.com/jackc/pgx/v5 v5.7.6
github.com/labstack/echo-contrib v0.17.2
github.com/labstack/echo/v4 v4.13.3
github.com/oklog/ulid/v2 v2.1.0
@@ -32,9 +33,9 @@ require (
go.opentelemetry.io/otel/sdk/log v0.9.0
go.opentelemetry.io/otel/sdk/metric v1.33.0
go.opentelemetry.io/otel/trace v1.33.0
golang.org/x/mod v0.22.0
golang.org/x/net v0.33.0
golang.org/x/sync v0.10.0
golang.org/x/mod v0.28.0
golang.org/x/net v0.44.0
golang.org/x/sync v0.17.0
google.golang.org/grpc v1.69.2
gopkg.in/yaml.v3 v3.0.1
)
@@ -44,22 +45,29 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dmarkham/enumer v1.6.1 // indirect
github.com/fsnotify/fsnotify v1.8.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/labstack/gommon v0.4.2 // indirect
github.com/masaushi/accessory v0.6.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pascaldekloe/name v1.0.0 // indirect
github.com/pierrec/lz4/v4 v4.1.22 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/prometheus/common v0.61.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/samber/lo v1.47.0 // indirect
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.2 // indirect
@@ -70,11 +78,17 @@ require (
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.33.0 // indirect
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.33.0 // indirect
go.opentelemetry.io/proto/otlp v1.4.0 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/crypto v0.42.0 // indirect
golang.org/x/sys v0.36.0 // indirect
golang.org/x/text v0.29.0 // indirect
golang.org/x/time v0.8.0 // indirect
golang.org/x/tools v0.37.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 // indirect
google.golang.org/protobuf v1.36.1 // indirect
)
tool (
github.com/dmarkham/enumer
github.com/masaushi/accessory
)

46
go.sum
View File

@@ -4,6 +4,8 @@ github.com/abh/certman v0.4.0 h1:XHoDtb0YyRQPclaHMrBDlKTVZpNjTK6vhB0S3Bd/Sbs=
github.com/abh/certman v0.4.0/go.mod h1:x8QhpKVZifmV1Hdiwdg9gLo2GMPAxezz1s3zrVnPs+I=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bradleyjkemp/cupaloy/v2 v2.8.0 h1:any4BmKE+jGIaMpnU8YgH/I2LPiLBufr6oMMlVBbn9M=
github.com/bradleyjkemp/cupaloy/v2 v2.8.0/go.mod h1:bm7JXdkRd4BHJk9HpwqAI8BoAY1lps46Enkdqw6aRX0=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
@@ -12,6 +14,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dmarkham/enumer v1.6.1 h1:aSc9awYtZL07TUueWs40QcHtxTvHTAwG0EqrNsK45w4=
github.com/dmarkham/enumer v1.6.1/go.mod h1:yixql+kDDQRYqcuBM2n9Vlt7NoT9ixgXhaXry8vmRg8=
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
@@ -31,6 +35,14 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 h1:VNqngBF40hVlDloBruUehVYC3Ar
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1/go.mod h1:RBRO7fro65R6tjKzYgLAFo0t1QEXY1Dp+i/bvpRiqiQ=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk=
github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
@@ -46,6 +58,8 @@ github.com/labstack/echo/v4 v4.13.3 h1:pwhpCPrTl5qry5HRdM5FwdXnhXSLSY+WE+YQSeCaa
github.com/labstack/echo/v4 v4.13.3/go.mod h1:o90YNEeQWjDozo584l7AwhJMHN0bOC4tAfg+Xox9q5g=
github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0=
github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU=
github.com/masaushi/accessory v0.6.0 h1:HYAzxkuhfvlbaQwinxXTxsSPbFabAnwHt8K6I/DvNBU=
github.com/masaushi/accessory v0.6.0/go.mod h1:8GZMgq3wcIapVZWt7VVQCh5+onPc/8gJeHb8WRXezvQ=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
@@ -55,6 +69,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/oklog/ulid/v2 v2.1.0 h1:+9lhoxAP56we25tyYETBBY1YLA2SaoLvUFgrP2miPJU=
github.com/oklog/ulid/v2 v2.1.0/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ=
github.com/pascaldekloe/name v1.0.0 h1:n7LKFgHixETzxpRv2R77YgPUFo85QHGZKrdaYm7eY5U=
github.com/pascaldekloe/name v1.0.0/go.mod h1:Z//MfYJnH4jVpQ9wkclwu2I2MkHmXTlT9wR5UZScttM=
github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o=
github.com/pierrec/lz4/v4 v4.1.15/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU=
@@ -84,12 +100,16 @@ github.com/samber/slog-multi v1.2.4 h1:k9x3JAWKJFPKffx+oXZ8TasaNuorIW4tG+TXxkt6R
github.com/samber/slog-multi v1.2.4/go.mod h1:ACuZ5B6heK57TfMVkVknN2UZHoFfjCwRxR0Q2OXKHlo=
github.com/segmentio/kafka-go v0.4.47 h1:IqziR4pA3vrZq7YdRxaT3w1/5fvIH5qpCwstUanQQB0=
github.com/segmentio/kafka-go v0.4.47/go.mod h1:HjF6XbOKh0Pjlkr5GVZxt6CsjjwnmhVOfURM5KMd8qg=
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
@@ -160,25 +180,25 @@ go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4=
golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/mod v0.28.0 h1:gQBtGhjxykdjY9YhZpSlZIsbnaE2+PgjfLWUQTnoZ1U=
golang.org/x/mod v0.28.0/go.mod h1:yfB/L0NOf/kmEbXjzCPOx1iK1fRutOydrCMsqRhEBxI=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I=
golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -189,8 +209,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -203,14 +223,16 @@ golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg=
golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.37.0 h1:DVSRzp7FwePZW356yEAChSdNcQo6Nsp+fex1SUW09lE=
golang.org/x/tools v0.37.0/go.mod h1:MBN5QPQtLMHVdvsbtarmTNukZDdgwdwlO5qGacAzF0w=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8 h1:st3LcW/BPi75W4q1jJTEor/QWwbNlPlDG0JTn6XhZu0=
google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8/go.mod h1:klhJGKFyG8Tn50enBn7gizg4nXGXJ+jqEREdCWaPcV4=

View File

@@ -1,13 +1,71 @@
// Package health provides a standalone HTTP server for health checks.
//
// This package implements a simple health check server that can be used
// to expose health status endpoints for monitoring and load balancing.
// It supports custom health check handlers and provides structured logging
// with graceful shutdown capabilities.
// This package implements a flexible health check server that supports
// different handlers for Kubernetes probe types (liveness, readiness, startup).
// It provides structured logging, graceful shutdown, and standard HTTP endpoints
// for monitoring and load balancing.
//
// # Kubernetes Probe Types
//
// Liveness Probe: Detects when a container is "dead" and needs restarting.
// Should be a lightweight check that verifies the process is still running
// and not in an unrecoverable state.
//
// Readiness Probe: Determines when a container is ready to accept traffic.
// Controls which Pods are used as backends for Services. Should verify
// the application can handle requests properly.
//
// Startup Probe: Verifies when a container application has successfully started.
// Delays liveness and readiness probes until startup succeeds. Useful for
// slow-starting applications.
//
// # Usage Examples
//
// Basic usage with a single handler for all probes:
//
// srv := health.NewServer(myHealthHandler)
// srv.Listen(ctx, 9091)
//
// Advanced usage with separate handlers for each probe type:
//
// srv := health.NewServer(nil,
// health.WithLivenessHandler(func(w http.ResponseWriter, r *http.Request) {
// // Simple alive check
// w.WriteHeader(http.StatusOK)
// }),
// health.WithReadinessHandler(func(w http.ResponseWriter, r *http.Request) {
// // Check if ready to serve traffic
// if err := checkDatabase(); err != nil {
// w.WriteHeader(http.StatusServiceUnavailable)
// return
// }
// w.WriteHeader(http.StatusOK)
// }),
// health.WithStartupHandler(func(w http.ResponseWriter, r *http.Request) {
// // Check if startup is complete
// if !applicationReady() {
// w.WriteHeader(http.StatusServiceUnavailable)
// return
// }
// w.WriteHeader(http.StatusOK)
// }),
// health.WithServiceName("my-service"),
// )
// srv.Listen(ctx, 9091)
//
// # Standard Endpoints
//
// The server exposes these endpoints:
// - /healthz - liveness probe (or general health if no specific handler)
// - /readyz - readiness probe (or general health if no specific handler)
// - /startupz - startup probe (or general health if no specific handler)
// - /__health - general health endpoint (backward compatibility)
// - / - general health endpoint (root path)
package health
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"strconv"
@@ -21,23 +79,74 @@ import (
// It runs separately from the main application server to ensure health
// checks remain available even if the main server is experiencing issues.
//
// The server includes built-in timeouts, graceful shutdown, and structured
// logging for monitoring and debugging health check behavior.
// The server supports separate handlers for different Kubernetes probe types
// (liveness, readiness, startup) and includes built-in timeouts, graceful
// shutdown, and structured logging.
type Server struct {
log *slog.Logger
healthFn http.HandlerFunc
log *slog.Logger
livenessHandler http.HandlerFunc
readinessHandler http.HandlerFunc
startupHandler http.HandlerFunc
generalHandler http.HandlerFunc // fallback for /__health and / paths
serviceName string
}
// NewServer creates a new health check server with the specified health handler.
// If healthFn is nil, a default handler that returns HTTP 200 "ok" is used.
func NewServer(healthFn http.HandlerFunc) *Server {
// Option represents a configuration option for the health server.
type Option func(*Server)
// WithLivenessHandler sets a specific handler for the /healthz endpoint.
// Liveness probes determine if a container should be restarted.
func WithLivenessHandler(handler http.HandlerFunc) Option {
return func(s *Server) {
s.livenessHandler = handler
}
}
// WithReadinessHandler sets a specific handler for the /readyz endpoint.
// Readiness probes determine if a container can receive traffic.
func WithReadinessHandler(handler http.HandlerFunc) Option {
return func(s *Server) {
s.readinessHandler = handler
}
}
// WithStartupHandler sets a specific handler for the /startupz endpoint.
// Startup probes determine if a container has finished initializing.
func WithStartupHandler(handler http.HandlerFunc) Option {
return func(s *Server) {
s.startupHandler = handler
}
}
// WithServiceName sets the service name for JSON responses and logging.
func WithServiceName(serviceName string) Option {
return func(s *Server) {
s.serviceName = serviceName
}
}
// NewServer creates a new health check server with optional probe-specific handlers.
//
// If healthFn is provided, it will be used as a fallback for any probe endpoints
// that don't have specific handlers configured. If healthFn is nil, a default
// handler that returns HTTP 200 "ok" is used as the fallback.
//
// Use the With* option functions to configure specific handlers for different
// probe types (liveness, readiness, startup).
func NewServer(healthFn http.HandlerFunc, opts ...Option) *Server {
if healthFn == nil {
healthFn = basicHealth
}
srv := &Server{
log: logger.Setup(),
healthFn: healthFn,
log: logger.Setup(),
generalHandler: healthFn,
}
for _, opt := range opts {
opt(srv)
}
return srv
}
@@ -47,13 +156,27 @@ func (srv *Server) SetLogger(log *slog.Logger) {
}
// Listen starts the health server on the specified port and blocks until ctx is cancelled.
// The server exposes the health handler at "/__health" with graceful shutdown support.
// The server exposes health check endpoints with graceful shutdown support.
//
// Standard endpoints exposed:
// - /healthz - liveness probe (uses livenessHandler or falls back to generalHandler)
// - /readyz - readiness probe (uses readinessHandler or falls back to generalHandler)
// - /startupz - startup probe (uses startupHandler or falls back to generalHandler)
// - /__health - general health endpoint (uses generalHandler)
// - / - root health endpoint (uses generalHandler)
func (srv *Server) Listen(ctx context.Context, port int) error {
srv.log.Info("starting health listener", "port", port)
serveMux := http.NewServeMux()
serveMux.HandleFunc("/__health", srv.healthFn)
// Register probe-specific handlers
serveMux.HandleFunc("/healthz", srv.createProbeHandler("liveness"))
serveMux.HandleFunc("/readyz", srv.createProbeHandler("readiness"))
serveMux.HandleFunc("/startupz", srv.createProbeHandler("startup"))
// Register general health endpoints for backward compatibility
serveMux.HandleFunc("/__health", srv.createGeneralHandler())
serveMux.HandleFunc("/", srv.createGeneralHandler())
hsrv := &http.Server{
Addr: ":" + strconv.Itoa(port),
@@ -89,6 +212,121 @@ func (srv *Server) Listen(ctx context.Context, port int) error {
return g.Wait()
}
// createProbeHandler creates a handler for a specific probe type that provides
// appropriate JSON responses and falls back to the general handler if no specific
// handler is configured.
func (srv *Server) createProbeHandler(probeType string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var handler http.HandlerFunc
// Select the appropriate handler
switch probeType {
case "liveness":
handler = srv.livenessHandler
case "readiness":
handler = srv.readinessHandler
case "startup":
handler = srv.startupHandler
}
// Fall back to general handler if no specific handler is configured
if handler == nil {
handler = srv.generalHandler
}
// Create a response recorder to capture the handler's status code
recorder := &statusRecorder{ResponseWriter: w, statusCode: 200}
handler(recorder, r)
// If the handler already wrote a response, we're done
if recorder.written {
return
}
// Otherwise, provide a standard JSON response based on the status code
w.Header().Set("Content-Type", "application/json")
if recorder.statusCode >= 400 {
// Handler indicated unhealthy
switch probeType {
case "liveness":
json.NewEncoder(w).Encode(map[string]string{"status": "unhealthy"})
case "readiness":
json.NewEncoder(w).Encode(map[string]bool{"ready": false})
case "startup":
json.NewEncoder(w).Encode(map[string]bool{"started": false})
}
} else {
// Handler indicated healthy
switch probeType {
case "liveness":
json.NewEncoder(w).Encode(map[string]string{"status": "alive"})
case "readiness":
json.NewEncoder(w).Encode(map[string]bool{"ready": true})
case "startup":
json.NewEncoder(w).Encode(map[string]bool{"started": true})
}
}
}
}
// createGeneralHandler creates a handler for general health endpoints that provides
// comprehensive health information.
func (srv *Server) createGeneralHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Create a response recorder to capture the handler's status code
// Use a buffer to prevent the handler from writing to the actual response
recorder := &statusRecorder{ResponseWriter: &discardWriter{}, statusCode: 200}
srv.generalHandler(recorder, r)
// Always provide a comprehensive JSON response for general endpoints
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(recorder.statusCode)
response := map[string]interface{}{
"status": map[bool]string{true: "healthy", false: "unhealthy"}[recorder.statusCode < 400],
}
if srv.serviceName != "" {
response["service"] = srv.serviceName
}
json.NewEncoder(w).Encode(response)
}
}
// statusRecorder captures the response status code from handlers while allowing
// them to write their own response content if needed.
type statusRecorder struct {
http.ResponseWriter
statusCode int
written bool
}
func (r *statusRecorder) WriteHeader(code int) {
r.statusCode = code
r.ResponseWriter.WriteHeader(code)
}
func (r *statusRecorder) Write(data []byte) (int, error) {
r.written = true
return r.ResponseWriter.Write(data)
}
// discardWriter implements http.ResponseWriter but discards all writes.
// Used to capture status codes without writing response content.
type discardWriter struct{}
func (d *discardWriter) Header() http.Header {
return make(http.Header)
}
func (d *discardWriter) Write([]byte) (int, error) {
return 0, nil
}
func (d *discardWriter) WriteHeader(int) {}
// HealthCheckListener runs a simple HTTP server on the specified port for health check probes.
func HealthCheckListener(ctx context.Context, port int, log *slog.Logger) error {
srv := NewServer(nil)

View File

@@ -1,13 +1,14 @@
package health
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
)
func TestHealthHandler(t *testing.T) {
func TestBasicHealthHandler(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/__health", nil)
w := httptest.NewRecorder()
@@ -24,3 +25,129 @@ func TestHealthHandler(t *testing.T) {
t.Errorf("expected ok got %q", string(data))
}
}
func TestProbeHandlers(t *testing.T) {
// Test with separate handlers for each probe type
srv := NewServer(nil,
WithLivenessHandler(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
WithReadinessHandler(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
WithStartupHandler(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
WithServiceName("test-service"),
)
tests := []struct {
handler http.HandlerFunc
expectedStatus int
expectedBody string
}{
{srv.createProbeHandler("liveness"), 200, `{"status":"alive"}`},
{srv.createProbeHandler("readiness"), 200, `{"ready":true}`},
{srv.createProbeHandler("startup"), 200, `{"started":true}`},
{srv.createGeneralHandler(), 200, `{"service":"test-service","status":"healthy"}`},
}
for i, tt := range tests {
t.Run(fmt.Sprintf("test_%d", i), func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
tt.handler(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
}
body := w.Body.String()
if body != tt.expectedBody+"\n" { // json.Encoder adds newline
t.Errorf("expected body %q, got %q", tt.expectedBody, body)
}
})
}
}
func TestProbeHandlerFallback(t *testing.T) {
// Test fallback to general handler when no specific handler is configured
generalHandler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
srv := NewServer(generalHandler, WithServiceName("test-service"))
tests := []struct {
handler http.HandlerFunc
expectedStatus int
expectedBody string
}{
{srv.createProbeHandler("liveness"), 200, `{"status":"alive"}`},
{srv.createProbeHandler("readiness"), 200, `{"ready":true}`},
{srv.createProbeHandler("startup"), 200, `{"started":true}`},
}
for i, tt := range tests {
t.Run(fmt.Sprintf("fallback_%d", i), func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
tt.handler(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
}
body := w.Body.String()
if body != tt.expectedBody+"\n" { // json.Encoder adds newline
t.Errorf("expected body %q, got %q", tt.expectedBody, body)
}
})
}
}
func TestUnhealthyProbeHandlers(t *testing.T) {
// Test with handlers that return unhealthy status
srv := NewServer(nil,
WithLivenessHandler(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}),
WithReadinessHandler(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}),
WithStartupHandler(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}),
WithServiceName("test-service"),
)
tests := []struct {
handler http.HandlerFunc
expectedStatus int
expectedBody string
}{
{srv.createProbeHandler("liveness"), 503, `{"status":"unhealthy"}`},
{srv.createProbeHandler("readiness"), 503, `{"ready":false}`},
{srv.createProbeHandler("startup"), 503, `{"started":false}`},
}
for i, tt := range tests {
t.Run(fmt.Sprintf("unhealthy_%d", i), func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
tt.handler(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
}
body := w.Body.String()
if body != tt.expectedBody+"\n" { // json.Encoder adds newline
t.Errorf("expected body %q, got %q", tt.expectedBody, body)
}
})
}
}

View File

@@ -23,9 +23,8 @@ type bufferingExporter struct {
// Real exporter (created when tracing is configured)
exporter otellog.Exporter
// Thread-safe initialization
initOnce sync.Once
initErr error
// Thread-safe initialization state (managed only by checkReadiness)
initErr error
// Background checker
stopChecker chan struct{}
@@ -48,20 +47,13 @@ func newBufferingExporter() *bufferingExporter {
// Export implements otellog.Exporter
func (e *bufferingExporter) Export(ctx context.Context, records []otellog.Record) error {
// Try initialization once
e.initOnce.Do(func() {
e.initErr = e.initialize()
})
// Check if exporter is ready (initialization handled by checkReadiness goroutine)
e.mu.RLock()
exporter := e.exporter
e.mu.RUnlock()
// If initialization succeeded, use the exporter
if e.initErr == nil {
e.mu.RLock()
exporter := e.exporter
e.mu.RUnlock()
if exporter != nil {
return exporter.Export(ctx, records)
}
if exporter != nil {
return exporter.Export(ctx, records)
}
// Not ready yet, buffer the records
@@ -117,24 +109,31 @@ func (e *bufferingExporter) bufferRecords(records []otellog.Record) error {
return nil
}
// checkReadiness periodically checks if tracing is configured
// checkReadiness periodically attempts initialization until successful
func (e *bufferingExporter) checkReadiness() {
defer close(e.checkerDone)
ticker := time.NewTicker(1 * time.Second) // Reduced frequency since OTLP handles retries
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// If initialization failed, reset sync.Once to allow retry
// The OTLP exporter will handle its own retry logic
if e.initErr != nil {
e.initOnce = sync.Once{}
} else if e.exporter != nil {
// Check if we already have a working exporter
e.mu.RLock()
hasExporter := e.exporter != nil
e.mu.RUnlock()
if hasExporter {
return // Exporter ready, checker no longer needed
}
// Try to initialize
err := e.initialize()
e.mu.Lock()
e.initErr = err
e.mu.Unlock()
case <-e.stopChecker:
return
}
@@ -180,14 +179,21 @@ func (e *bufferingExporter) Shutdown(ctx context.Context) error {
// Stop the readiness checker from continuing
close(e.stopChecker)
// Give one final chance for TLS/tracing to become ready before fully shutting down
e.initOnce.Do(func() {
e.initErr = e.initialize()
})
// Wait for readiness checker goroutine to complete
<-e.checkerDone
// Give one final chance for TLS/tracing to become ready for buffer flushing
e.mu.RLock()
hasExporter := e.exporter != nil
e.mu.RUnlock()
if !hasExporter {
err := e.initialize()
e.mu.Lock()
e.initErr = err
e.mu.Unlock()
}
e.mu.Lock()
defer e.mu.Unlock()

235
logger/level_test.go Normal file
View File

@@ -0,0 +1,235 @@
package logger
import (
"context"
"log/slog"
"os"
"testing"
"time"
)
func TestParseLevel(t *testing.T) {
tests := []struct {
name string
input string
expected slog.Level
expectError bool
}{
{"empty string", "", slog.LevelInfo, false},
{"DEBUG upper", "DEBUG", slog.LevelDebug, false},
{"debug lower", "debug", slog.LevelDebug, false},
{"INFO upper", "INFO", slog.LevelInfo, false},
{"info lower", "info", slog.LevelInfo, false},
{"WARN upper", "WARN", slog.LevelWarn, false},
{"warn lower", "warn", slog.LevelWarn, false},
{"ERROR upper", "ERROR", slog.LevelError, false},
{"error lower", "error", slog.LevelError, false},
{"invalid level", "invalid", slog.LevelInfo, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
level, err := ParseLevel(tt.input)
if tt.expectError {
if err == nil {
t.Errorf("expected error for input %q, got nil", tt.input)
}
} else {
if err != nil {
t.Errorf("unexpected error for input %q: %v", tt.input, err)
}
if level != tt.expected {
t.Errorf("expected level %v for input %q, got %v", tt.expected, tt.input, level)
}
}
})
}
}
func TestSetLevel(t *testing.T) {
// Store original level to restore later
originalLevel := Level.Level()
defer Level.Set(originalLevel)
SetLevel(slog.LevelDebug)
if Level.Level() != slog.LevelDebug {
t.Errorf("expected Level to be Debug, got %v", Level.Level())
}
SetLevel(slog.LevelError)
if Level.Level() != slog.LevelError {
t.Errorf("expected Level to be Error, got %v", Level.Level())
}
}
func TestSetOTLPLevel(t *testing.T) {
// Store original level to restore later
originalLevel := OTLPLevel.Level()
defer OTLPLevel.Set(originalLevel)
SetOTLPLevel(slog.LevelWarn)
if OTLPLevel.Level() != slog.LevelWarn {
t.Errorf("expected OTLPLevel to be Warn, got %v", OTLPLevel.Level())
}
SetOTLPLevel(slog.LevelDebug)
if OTLPLevel.Level() != slog.LevelDebug {
t.Errorf("expected OTLPLevel to be Debug, got %v", OTLPLevel.Level())
}
}
func TestOTLPLevelHandler(t *testing.T) {
// Create a mock handler that counts calls
callCount := 0
mockHandler := &mockHandler{
handleFunc: func(ctx context.Context, r slog.Record) error {
callCount++
return nil
},
}
// Set OTLP level to Warn
originalLevel := OTLPLevel.Level()
defer OTLPLevel.Set(originalLevel)
OTLPLevel.Set(slog.LevelWarn)
// Create OTLP level handler
handler := newOTLPLevelHandler(mockHandler)
ctx := context.Background()
// Test that Debug and Info are filtered out
if handler.Enabled(ctx, slog.LevelDebug) {
t.Error("Debug level should be disabled when OTLP level is Warn")
}
if handler.Enabled(ctx, slog.LevelInfo) {
t.Error("Info level should be disabled when OTLP level is Warn")
}
// Test that Warn and Error are enabled
if !handler.Enabled(ctx, slog.LevelWarn) {
t.Error("Warn level should be enabled when OTLP level is Warn")
}
if !handler.Enabled(ctx, slog.LevelError) {
t.Error("Error level should be enabled when OTLP level is Warn")
}
// Test that Handle respects level filtering
now := time.Now()
debugRecord := slog.NewRecord(now, slog.LevelDebug, "debug message", 0)
warnRecord := slog.NewRecord(now, slog.LevelWarn, "warn message", 0)
handler.Handle(ctx, debugRecord)
if callCount != 0 {
t.Error("Debug record should not be passed to underlying handler")
}
handler.Handle(ctx, warnRecord)
if callCount != 1 {
t.Error("Warn record should be passed to underlying handler")
}
}
func TestEnvironmentVariables(t *testing.T) {
tests := []struct {
name string
envVar string
envValue string
configPrefix string
testFunc func(t *testing.T)
}{
{
name: "LOG_LEVEL sets stderr level",
envVar: "LOG_LEVEL",
envValue: "ERROR",
testFunc: func(t *testing.T) {
// Reset the setup state
resetLoggerSetup()
// Call setupStdErrHandler which should read the env var
handler := setupStdErrHandler()
if handler == nil {
t.Fatal("setupStdErrHandler returned nil")
}
if Level.Level() != slog.LevelError {
t.Errorf("expected Level to be Error after setting LOG_LEVEL=ERROR, got %v", Level.Level())
}
},
},
{
name: "Prefixed LOG_LEVEL",
envVar: "TEST_LOG_LEVEL",
envValue: "DEBUG",
configPrefix: "TEST",
testFunc: func(t *testing.T) {
ConfigPrefix = "TEST"
defer func() { ConfigPrefix = "" }()
resetLoggerSetup()
handler := setupStdErrHandler()
if handler == nil {
t.Fatal("setupStdErrHandler returned nil")
}
if Level.Level() != slog.LevelDebug {
t.Errorf("expected Level to be Debug after setting TEST_LOG_LEVEL=DEBUG, got %v", Level.Level())
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Store original env value and level
originalEnv := os.Getenv(tt.envVar)
originalLevel := Level.Level()
defer func() {
os.Setenv(tt.envVar, originalEnv)
Level.Set(originalLevel)
}()
// Set test environment variable
os.Setenv(tt.envVar, tt.envValue)
// Run the test
tt.testFunc(t)
})
}
}
// mockHandler is a simple mock implementation of slog.Handler for testing
type mockHandler struct {
handleFunc func(ctx context.Context, r slog.Record) error
}
func (m *mockHandler) Enabled(ctx context.Context, level slog.Level) bool {
return true
}
func (m *mockHandler) Handle(ctx context.Context, r slog.Record) error {
if m.handleFunc != nil {
return m.handleFunc(ctx, r)
}
return nil
}
func (m *mockHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return m
}
func (m *mockHandler) WithGroup(name string) slog.Handler {
return m
}
// resetLoggerSetup resets the sync.Once instances for testing
func resetLoggerSetup() {
// Reset package-level variables
textLogger = nil
otlpLogger = nil
multiLogger = nil
// Note: We can't easily reset sync.Once instances in tests,
// but for the specific test we're doing (environment variable parsing)
// we can test the setupStdErrHandler function directly
}

View File

@@ -18,12 +18,15 @@
// - Context propagation for request-scoped logging
//
// Environment variables:
// - DEBUG: Enable debug level logging (configurable prefix via ConfigPrefix)
// - LOG_LEVEL: Set stderr log level (DEBUG, INFO, WARN, ERROR) (configurable prefix via ConfigPrefix)
// - OTLP_LOG_LEVEL: Set OTLP log level independently (configurable prefix via ConfigPrefix)
// - DEBUG: Enable debug level logging for backward compatibility (configurable prefix via ConfigPrefix)
// - INVOCATION_ID: Systemd detection for timestamp handling
package logger
import (
"context"
"fmt"
"log"
"log/slog"
"os"
@@ -43,6 +46,16 @@ import (
// This enables multiple services to have independent logging configuration.
var ConfigPrefix = ""
var (
// Level controls the log level for the default stderr logger.
// Can be changed at runtime to adjust logging verbosity.
Level = new(slog.LevelVar) // Info by default
// OTLPLevel controls the log level for OTLP output.
// Can be changed independently from the stderr logger level.
OTLPLevel = new(slog.LevelVar) // Info by default
)
var (
textLogger *slog.Logger
otlpLogger *slog.Logger
@@ -56,21 +69,64 @@ var (
mu sync.Mutex
)
func setupStdErrHandler() slog.Handler {
programLevel := new(slog.LevelVar) // Info by default
// SetLevel sets the log level for the default stderr logger.
// This affects the primary application logger returned by Setup().
func SetLevel(level slog.Level) {
Level.Set(level)
}
envVar := "DEBUG"
// SetOTLPLevel sets the log level for OTLP output.
// This affects the logger returned by SetupOLTP() and the OTLP portion of SetupMultiLogger().
func SetOTLPLevel(level slog.Level) {
OTLPLevel.Set(level)
}
// ParseLevel converts a string log level to slog.Level.
// Supported levels: "DEBUG", "INFO", "WARN", "ERROR" (case insensitive).
// Returns an error for unrecognized level strings.
func ParseLevel(level string) (slog.Level, error) {
switch {
case level == "":
return slog.LevelInfo, nil
case level == "DEBUG" || level == "debug":
return slog.LevelDebug, nil
case level == "INFO" || level == "info":
return slog.LevelInfo, nil
case level == "WARN" || level == "warn":
return slog.LevelWarn, nil
case level == "ERROR" || level == "error":
return slog.LevelError, nil
default:
return slog.LevelInfo, fmt.Errorf("unknown log level: %s", level)
}
}
func setupStdErrHandler() slog.Handler {
// Parse LOG_LEVEL environment variable
logLevelVar := "LOG_LEVEL"
if len(ConfigPrefix) > 0 {
envVar = ConfigPrefix + "_" + envVar
logLevelVar = ConfigPrefix + "_" + logLevelVar
}
if opt := os.Getenv(envVar); len(opt) > 0 {
if debug, _ := strconv.ParseBool(opt); debug {
programLevel.Set(slog.LevelDebug)
if levelStr := os.Getenv(logLevelVar); levelStr != "" {
if level, err := ParseLevel(levelStr); err == nil {
Level.Set(level)
}
}
logOptions := &slog.HandlerOptions{Level: programLevel}
// Maintain backward compatibility with DEBUG environment variable
debugVar := "DEBUG"
if len(ConfigPrefix) > 0 {
debugVar = ConfigPrefix + "_" + debugVar
}
if opt := os.Getenv(debugVar); len(opt) > 0 {
if debug, _ := strconv.ParseBool(opt); debug {
Level.Set(slog.LevelDebug)
}
}
logOptions := &slog.HandlerOptions{Level: Level}
if len(os.Getenv("INVOCATION_ID")) > 0 {
// don't add timestamps when running under systemd
@@ -88,6 +144,18 @@ func setupStdErrHandler() slog.Handler {
func setupOtlpLogger() *slog.Logger {
setupOtlp.Do(func() {
// Parse OTLP_LOG_LEVEL environment variable
otlpLevelVar := "OTLP_LOG_LEVEL"
if len(ConfigPrefix) > 0 {
otlpLevelVar = ConfigPrefix + "_" + otlpLevelVar
}
if levelStr := os.Getenv(otlpLevelVar); levelStr != "" {
if level, err := ParseLevel(levelStr); err == nil {
OTLPLevel.Set(level)
}
}
// Create our buffering exporter
// It will buffer until tracing is configured
bufferingExp := newBufferingExporter()
@@ -107,8 +175,9 @@ func setupOtlpLogger() *slog.Logger {
// Set global provider
global.SetLoggerProvider(provider)
// Create slog handler
handler := newLogFmtHandler(otelslog.NewHandler("common"))
// Create slog handler with level control
baseHandler := newLogFmtHandler(otelslog.NewHandler("common"))
handler := newOTLPLevelHandler(baseHandler)
otlpLogger = slog.New(handler)
})
return otlpLogger

48
logger/otlp_handler.go Normal file
View File

@@ -0,0 +1,48 @@
package logger
import (
"context"
"log/slog"
)
// otlpLevelHandler is a wrapper that enforces level checking for OTLP handlers.
// This allows independent level control for OTLP output separate from stderr logging.
type otlpLevelHandler struct {
next slog.Handler
}
// newOTLPLevelHandler creates a new OTLP level wrapper handler.
func newOTLPLevelHandler(next slog.Handler) slog.Handler {
return &otlpLevelHandler{
next: next,
}
}
// Enabled checks if the log level should be processed by the OTLP handler.
// It uses the OTLPLevel variable to determine if the record should be processed.
func (h *otlpLevelHandler) Enabled(ctx context.Context, level slog.Level) bool {
return level >= OTLPLevel.Level()
}
// Handle processes the log record if the level is enabled.
// If disabled by level checking, the record is silently dropped.
func (h *otlpLevelHandler) Handle(ctx context.Context, r slog.Record) error {
if !h.Enabled(ctx, r.Level) {
return nil
}
return h.next.Handle(ctx, r)
}
// WithAttrs returns a new handler with the specified attributes added.
func (h *otlpLevelHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return &otlpLevelHandler{
next: h.next.WithAttrs(attrs),
}
}
// WithGroup returns a new handler with the specified group name.
func (h *otlpLevelHandler) WithGroup(name string) slog.Handler {
return &otlpLevelHandler{
next: h.next.WithGroup(name),
}
}

View File

@@ -2,7 +2,7 @@
set -euo pipefail
go install github.com/goreleaser/goreleaser/v2@v2.11.0
go install github.com/goreleaser/goreleaser/v2@v2.12.3
if [ ! -z "${harbor_username:-}" ]; then
DOCKER_FILE=~/.docker/config.json

View File

@@ -1,9 +1,10 @@
// Package fastlyxff provides Fastly CDN IP range management for trusted proxy handling.
//
// This package parses Fastly's public IP ranges JSON file and generates Echo framework
// trust options for proper client IP extraction from X-Forwarded-For headers.
// It's designed specifically for services deployed behind Fastly's CDN that need
// to identify real client IPs for logging, rate limiting, and security purposes.
// This package parses Fastly's public IP ranges JSON file and provides middleware
// for both Echo framework and standard net/http for proper client IP extraction
// from X-Forwarded-For headers. It's designed specifically for services deployed
// behind Fastly's CDN that need to identify real client IPs for logging, rate
// limiting, and security purposes.
//
// Fastly publishes their edge server IP ranges in a JSON format that this package
// consumes to automatically configure trusted proxy ranges. This ensures that
@@ -14,8 +15,55 @@
// - Automatic parsing of Fastly's IP ranges JSON format
// - Support for both IPv4 and IPv6 address ranges
// - Echo framework integration via TrustOption generation
// - Standard net/http middleware support
// - CIDR notation parsing and validation
//
// # Echo Framework Usage
//
// fastlyRanges, err := fastlyxff.New("fastly.json")
// if err != nil {
// return err
// }
// options, err := fastlyRanges.EchoTrustOption()
// if err != nil {
// return err
// }
// e.IPExtractor = echo.ExtractIPFromXFFHeader(options...)
//
// # Net/HTTP Usage
//
// fastlyRanges, err := fastlyxff.New("fastly.json")
// if err != nil {
// return err
// }
// middleware := fastlyRanges.HTTPMiddleware()
//
// handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// // Both methods work - middleware updates r.RemoteAddr (with port 0) and stores in context
// realIP := fastlyxff.GetRealIP(r) // Preferred method
// // OR: host, _, _ := net.SplitHostPort(r.RemoteAddr) // Direct access (port will be "0")
// fmt.Fprintf(w, "Real IP: %s\n", realIP)
// })
//
// http.ListenAndServe(":8080", middleware(handler))
//
// # Net/HTTP with Additional Trusted Ranges
//
// fastlyRanges, err := fastlyxff.New("fastly.json")
// if err != nil {
// return err
// }
//
// // Add custom trusted CIDRs (e.g., internal load balancers)
// // Note: For Echo framework, use the ekko package for additional ranges
// err = fastlyRanges.AddTrustedCIDR("10.0.0.0/8")
// if err != nil {
// return err
// }
//
// middleware := fastlyRanges.HTTPMiddleware()
// handler := middleware(yourHandler)
//
// The JSON file typically contains IP ranges in this format:
//
// {
@@ -25,29 +73,36 @@
package fastlyxff
import (
"context"
"encoding/json"
"net"
"net/http"
"net/netip"
"os"
"strings"
"github.com/labstack/echo/v4"
)
// FastlyXFF represents Fastly's published IP ranges for their CDN edge servers.
// This structure matches the JSON format provided by Fastly for their public IP ranges.
// It contains separate lists for IPv4 and IPv6 CIDR ranges.
// It contains separate lists for IPv4 and IPv6 CIDR ranges, plus additional trusted CIDRs.
type FastlyXFF struct {
IPv4 []string `json:"addresses"` // IPv4 CIDR ranges (e.g., "23.235.32.0/20")
IPv6 []string `json:"ipv6_addresses"` // IPv6 CIDR ranges (e.g., "2a04:4e40::/32")
IPv4 []string `json:"addresses"` // IPv4 CIDR ranges (e.g., "23.235.32.0/20")
IPv6 []string `json:"ipv6_addresses"` // IPv6 CIDR ranges (e.g., "2a04:4e40::/32")
extraCIDRs []string // Additional trusted CIDRs added via AddTrustedCIDR
}
// TrustedNets holds parsed network prefixes for efficient IP range checking.
// This type is currently unused but reserved for future optimizations
// where frequent IP range lookups might benefit from pre-parsed prefixes.
type TrustedNets struct {
prefixes []netip.Prefix // Parsed network prefixes for efficient lookups
}
// contextKey is used for storing the real client IP in request context
type contextKey string
const realIPKey contextKey = "fastly-real-ip"
// New loads and parses Fastly IP ranges from a JSON file.
// The file should contain Fastly's published IP ranges in their standard JSON format.
//
@@ -100,3 +155,116 @@ func (xff *FastlyXFF) EchoTrustOption() ([]echo.TrustOption, error) {
return ranges, nil
}
// AddTrustedCIDR adds an additional CIDR to the list of trusted proxies.
// This allows trusting proxies beyond Fastly's published ranges.
// The cidr parameter must be a valid CIDR notation (e.g., "10.0.0.0/8", "192.168.1.0/24").
// Returns an error if the CIDR format is invalid.
func (xff *FastlyXFF) AddTrustedCIDR(cidr string) error {
// Validate CIDR format
_, _, err := net.ParseCIDR(cidr)
if err != nil {
return err
}
// Add to extra CIDRs
xff.extraCIDRs = append(xff.extraCIDRs, cidr)
return nil
}
// isTrustedProxy checks if the given IP address belongs to Fastly's trusted IP ranges
// or any additional CIDRs added via AddTrustedCIDR.
func (xff *FastlyXFF) isTrustedProxy(ip string) bool {
addr, err := netip.ParseAddr(ip)
if err != nil {
return false
}
// Check all IPv4 and IPv6 ranges (Fastly + additional)
allRanges := append(append(xff.IPv4, xff.IPv6...), xff.extraCIDRs...)
for _, s := range allRanges {
_, cidr, err := net.ParseCIDR(s)
if err != nil {
continue
}
if cidr.Contains(net.IP(addr.AsSlice())) {
return true
}
}
return false
}
// extractRealIP extracts the real client IP from X-Forwarded-For header.
// It returns the rightmost IP that is not from a trusted Fastly proxy.
func (xff *FastlyXFF) extractRealIP(r *http.Request) string {
// Get the immediate peer IP
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
host = r.RemoteAddr
}
// If the immediate peer is not a trusted Fastly proxy, return it
if !xff.isTrustedProxy(host) {
return host
}
// Check X-Forwarded-For header
xff_header := r.Header.Get("X-Forwarded-For")
if xff_header == "" {
return host
}
// Parse comma-separated IP list
ips := strings.Split(xff_header, ",")
if len(ips) == 0 {
return host
}
// Find the leftmost IP that is not from a trusted proxy
// This represents the original client IP
for i := 0; i < len(ips); i++ {
ip := strings.TrimSpace(ips[i])
if ip != "" && !xff.isTrustedProxy(ip) {
return ip
}
}
// Fallback to the immediate peer
return host
}
// HTTPMiddleware returns a net/http middleware that extracts real client IP
// from X-Forwarded-For headers when the request comes from trusted Fastly proxies.
// The real IP is stored in the request context and also updates r.RemoteAddr
// with port 0 (since the original port is from the proxy, not the real client).
func (xff *FastlyXFF) HTTPMiddleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
realIP := xff.extractRealIP(r)
// Store in context for GetRealIP function
ctx := context.WithValue(r.Context(), realIPKey, realIP)
r = r.WithContext(ctx)
// Update RemoteAddr to be consistent with extracted IP
// Use port 0 since the original port is from the proxy, not the real client
r.RemoteAddr = net.JoinHostPort(realIP, "0")
next.ServeHTTP(w, r)
})
}
}
// GetRealIP retrieves the real client IP from the request context.
// This should be used after the HTTPMiddleware has processed the request.
// Returns the remote address if no real IP was extracted.
func GetRealIP(r *http.Request) string {
if ip, ok := r.Context().Value(realIPKey).(string); ok {
return ip
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}

View File

@@ -1,6 +1,11 @@
package fastlyxff
import "testing"
import (
"net"
"net/http"
"net/http/httptest"
"testing"
)
func TestFastlyIPRanges(t *testing.T) {
fastlyxff, err := New("fastly.json")
@@ -18,3 +23,334 @@ func TestFastlyIPRanges(t *testing.T) {
t.Fail()
}
}
func TestHTTPMiddleware(t *testing.T) {
// Create a test FastlyXFF instance with known IP ranges
xff := &FastlyXFF{
IPv4: []string{"192.0.2.0/24", "203.0.113.0/24"},
IPv6: []string{"2001:db8::/32"},
}
middleware := xff.HTTPMiddleware()
tests := []struct {
name string
remoteAddr string
xForwardedFor string
expectedRealIP string
}{
{
name: "direct connection",
remoteAddr: "198.51.100.1:12345",
xForwardedFor: "",
expectedRealIP: "198.51.100.1",
},
{
name: "trusted proxy with XFF",
remoteAddr: "192.0.2.1:80",
xForwardedFor: "198.51.100.1",
expectedRealIP: "198.51.100.1",
},
{
name: "trusted proxy with multiple XFF",
remoteAddr: "192.0.2.1:80",
xForwardedFor: "198.51.100.1, 203.0.113.1",
expectedRealIP: "198.51.100.1",
},
{
name: "untrusted proxy ignored",
remoteAddr: "198.51.100.2:80",
xForwardedFor: "10.0.0.1",
expectedRealIP: "198.51.100.2",
},
{
name: "IPv6 trusted proxy",
remoteAddr: "[2001:db8::1]:80",
xForwardedFor: "198.51.100.1",
expectedRealIP: "198.51.100.1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create test handler that captures both GetRealIP and r.RemoteAddr
var capturedRealIP, capturedRemoteAddr string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedRealIP = GetRealIP(r)
capturedRemoteAddr = r.RemoteAddr
w.WriteHeader(http.StatusOK)
})
// Create request with middleware
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = tt.remoteAddr
if tt.xForwardedFor != "" {
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
}
rr := httptest.NewRecorder()
middleware(handler).ServeHTTP(rr, req)
// Test GetRealIP function
if capturedRealIP != tt.expectedRealIP {
t.Errorf("GetRealIP: expected %s, got %s", tt.expectedRealIP, capturedRealIP)
}
// Test that r.RemoteAddr is updated with real IP and port 0
// (since the original port is from the proxy, not the real client)
expectedRemoteAddr := net.JoinHostPort(tt.expectedRealIP, "0")
if capturedRemoteAddr != expectedRemoteAddr {
t.Errorf("RemoteAddr: expected %s, got %s", expectedRemoteAddr, capturedRemoteAddr)
}
})
}
}
func TestIsTrustedProxy(t *testing.T) {
xff := &FastlyXFF{
IPv4: []string{"192.0.2.0/24", "203.0.113.0/24"},
IPv6: []string{"2001:db8::/32"},
}
tests := []struct {
ip string
expected bool
}{
{"192.0.2.1", true},
{"192.0.2.255", true},
{"203.0.113.1", true},
{"192.0.3.1", false},
{"198.51.100.1", false},
{"2001:db8::1", true},
{"2001:db8:ffff::1", true},
{"2001:db9::1", false},
{"invalid-ip", false},
}
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
result := xff.isTrustedProxy(tt.ip)
if result != tt.expected {
t.Errorf("isTrustedProxy(%s) = %v, expected %v", tt.ip, result, tt.expected)
}
})
}
}
func TestExtractRealIP(t *testing.T) {
xff := &FastlyXFF{
IPv4: []string{"192.0.2.0/24"},
IPv6: []string{"2001:db8::/32"},
}
tests := []struct {
name string
remoteAddr string
xForwardedFor string
expected string
}{
{
name: "no XFF header",
remoteAddr: "198.51.100.1:12345",
xForwardedFor: "",
expected: "198.51.100.1",
},
{
name: "trusted proxy with single IP",
remoteAddr: "192.0.2.1:80",
xForwardedFor: "198.51.100.1",
expected: "198.51.100.1",
},
{
name: "trusted proxy with multiple IPs",
remoteAddr: "192.0.2.1:80",
xForwardedFor: "198.51.100.1, 203.0.113.5",
expected: "198.51.100.1",
},
{
name: "untrusted proxy",
remoteAddr: "198.51.100.1:80",
xForwardedFor: "10.0.0.1",
expected: "198.51.100.1",
},
{
name: "empty XFF",
remoteAddr: "192.0.2.1:80",
xForwardedFor: "",
expected: "192.0.2.1",
},
{
name: "malformed remote addr",
remoteAddr: "192.0.2.1",
xForwardedFor: "198.51.100.1",
expected: "198.51.100.1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = tt.remoteAddr
if tt.xForwardedFor != "" {
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
}
result := xff.extractRealIP(req)
if result != tt.expected {
t.Errorf("extractRealIP() = %s, expected %s", result, tt.expected)
}
})
}
}
func TestGetRealIPWithoutMiddleware(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = "198.51.100.1:12345"
realIP := GetRealIP(req)
expected := "198.51.100.1"
if realIP != expected {
t.Errorf("GetRealIP() = %s, expected %s", realIP, expected)
}
}
func TestAddTrustedCIDR(t *testing.T) {
xff := &FastlyXFF{
IPv4: []string{"192.0.2.0/24"},
IPv6: []string{"2001:db8::/32"},
}
tests := []struct {
name string
cidr string
wantErr bool
}{
{"valid IPv4 range", "10.0.0.0/8", false},
{"valid IPv6 range", "fc00::/7", false},
{"valid single IP", "203.0.113.1/32", false},
{"invalid CIDR", "not-a-cidr", true},
{"invalid format", "10.0.0.0/99", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := xff.AddTrustedCIDR(tt.cidr)
if (err != nil) != tt.wantErr {
t.Errorf("AddTrustedCIDR(%s) error = %v, wantErr %v", tt.cidr, err, tt.wantErr)
}
})
}
}
func TestCustomTrustedCIDRs(t *testing.T) {
xff := &FastlyXFF{
IPv4: []string{"192.0.2.0/24"},
IPv6: []string{"2001:db8::/32"},
}
// Add custom trusted CIDRs
err := xff.AddTrustedCIDR("10.0.0.0/8")
if err != nil {
t.Fatalf("Failed to add trusted CIDR: %v", err)
}
err = xff.AddTrustedCIDR("172.16.0.0/12")
if err != nil {
t.Fatalf("Failed to add trusted CIDR: %v", err)
}
tests := []struct {
ip string
expected bool
}{
// Original Fastly ranges
{"192.0.2.1", true},
{"2001:db8::1", true},
// Custom CIDRs
{"10.1.2.3", true},
{"172.16.1.1", true},
// Not trusted
{"198.51.100.1", false},
{"172.15.1.1", false},
{"10.0.0.0", true}, // Network address should still match
}
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
result := xff.isTrustedProxy(tt.ip)
if result != tt.expected {
t.Errorf("isTrustedProxy(%s) = %v, expected %v", tt.ip, result, tt.expected)
}
})
}
}
func TestHTTPMiddlewareWithCustomCIDRs(t *testing.T) {
xff := &FastlyXFF{
IPv4: []string{"192.0.2.0/24"},
IPv6: []string{"2001:db8::/32"},
}
// Add custom trusted CIDR for internal proxies
err := xff.AddTrustedCIDR("10.0.0.0/8")
if err != nil {
t.Fatalf("Failed to add trusted CIDR: %v", err)
}
middleware := xff.HTTPMiddleware()
tests := []struct {
name string
remoteAddr string
xForwardedFor string
expectedRealIP string
}{
{
name: "custom trusted proxy with XFF",
remoteAddr: "10.1.2.3:80",
xForwardedFor: "198.51.100.1",
expectedRealIP: "198.51.100.1",
},
{
name: "fastly proxy with XFF",
remoteAddr: "192.0.2.1:80",
xForwardedFor: "198.51.100.1",
expectedRealIP: "198.51.100.1",
},
{
name: "untrusted proxy ignored",
remoteAddr: "172.16.1.1:80",
xForwardedFor: "198.51.100.1",
expectedRealIP: "172.16.1.1",
},
{
name: "chain through custom and fastly",
remoteAddr: "192.0.2.1:80",
xForwardedFor: "198.51.100.1, 10.1.2.3",
expectedRealIP: "198.51.100.1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedIP string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedIP = GetRealIP(r)
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = tt.remoteAddr
if tt.xForwardedFor != "" {
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
}
rr := httptest.NewRecorder()
middleware(handler).ServeHTTP(rr, req)
if capturedIP != tt.expectedRealIP {
t.Errorf("expected real IP %s, got %s", tt.expectedRealIP, capturedIP)
}
})
}
}