Compare commits
	
		
			29 Commits
		
	
	
		
			62a7605869
			...
			v0.6.2
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 7291f00f48 | |||
| 2670d25b52 | |||
| 45308cd4bf | |||
| 4767caf7b8 | |||
| f90281f472 | |||
| ca190b0085 | |||
| 10864363e2 | |||
| 66b51df2af | |||
| 28d05d1d0e | |||
| a774f92bf7 | |||
| 0b9769dc39 | |||
| 9dadd9edc3 | |||
| c6230be91e | |||
| 796b2a8412 | |||
| 6a3bc7bab3 | |||
| da13a371b4 | |||
| a1a5a6b8be | |||
| 96afb77844 | |||
| c372d79d1d | |||
| b5141d6a70 | |||
| 694f8ba1d3 | |||
| 09b52f92d7 | |||
| 785abdec8d | |||
| ce203a4618 | |||
| 3c994a7343 | |||
| f69c3e9c3c | |||
| fac5b1f275 | |||
| a37559b93e | |||
| faac09ac0c | 
							
								
								
									
										1
									
								
								.github/copilot-instructions.md
									
									
									
									
										vendored
									
									
										Symbolic link
									
								
							
							
						
						
									
										1
									
								
								.github/copilot-instructions.md
									
									
									
									
										vendored
									
									
										Symbolic link
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
../CLAUDE.md
 | 
			
		||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
.aider*
 | 
			
		||||
							
								
								
									
										28
									
								
								.mcp.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								.mcp.json
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,28 @@
 | 
			
		||||
{
 | 
			
		||||
  "mcpServers": {
 | 
			
		||||
    "context7": {
 | 
			
		||||
      "type": "stdio",
 | 
			
		||||
      "command": "npx",
 | 
			
		||||
      "args": [
 | 
			
		||||
        "-y",
 | 
			
		||||
        "@upstash/context7-mcp@1.0.0"
 | 
			
		||||
      ],
 | 
			
		||||
      "env": {}
 | 
			
		||||
    },
 | 
			
		||||
    "serena": {
 | 
			
		||||
      "type": "stdio",
 | 
			
		||||
      "command": "uvx",
 | 
			
		||||
      "args": [
 | 
			
		||||
        "--from",
 | 
			
		||||
        "git+https://github.com/oraios/serena@v0.1.4",
 | 
			
		||||
        "serena",
 | 
			
		||||
        "start-mcp-server",
 | 
			
		||||
        "--context",
 | 
			
		||||
        "ide-assistant",
 | 
			
		||||
        "--project",
 | 
			
		||||
        "."
 | 
			
		||||
      ],
 | 
			
		||||
      "env": {}
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										20
									
								
								.pre-commit-config.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								.pre-commit-config.yaml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,20 @@
 | 
			
		||||
---
 | 
			
		||||
repos:
 | 
			
		||||
  - repo: https://github.com/pre-commit/pre-commit-hooks
 | 
			
		||||
    rev: v5.0.0
 | 
			
		||||
    hooks:
 | 
			
		||||
      - id: trailing-whitespace
 | 
			
		||||
      - id: end-of-file-fixer
 | 
			
		||||
      - id: check-added-large-files
 | 
			
		||||
        args: ["--maxkb=400"]
 | 
			
		||||
      - id: check-case-conflict
 | 
			
		||||
      - id: check-executables-have-shebangs
 | 
			
		||||
      - id: check-shebang-scripts-are-executable
 | 
			
		||||
      - id: check-merge-conflict
 | 
			
		||||
      - id: check-symlinks
 | 
			
		||||
 | 
			
		||||
  - repo: https://github.com/adrienverge/yamllint
 | 
			
		||||
    rev: v1.35.1
 | 
			
		||||
    hooks:
 | 
			
		||||
      - id: yamllint
 | 
			
		||||
        args: [-c=.yamllint]
 | 
			
		||||
							
								
								
									
										14
									
								
								.yamllint
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								.yamllint
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,14 @@
 | 
			
		||||
---
 | 
			
		||||
extends: relaxed
 | 
			
		||||
 | 
			
		||||
rules:
 | 
			
		||||
  braces:
 | 
			
		||||
    level: error
 | 
			
		||||
  brackets:
 | 
			
		||||
    level: error
 | 
			
		||||
 | 
			
		||||
  truthy:
 | 
			
		||||
    level: warning
 | 
			
		||||
 | 
			
		||||
#ignore: |
 | 
			
		||||
#  - ...
 | 
			
		||||
							
								
								
									
										53
									
								
								CHANGELOG.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								CHANGELOG.md
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,53 @@
 | 
			
		||||
# Release Notes - v0.5.2
 | 
			
		||||
 | 
			
		||||
## Health Package
 | 
			
		||||
- **Kubernetes-native health probes** - Added dedicated handlers for liveness (`/healthz`), readiness (`/readyz`), and startup (`/startupz`) probes
 | 
			
		||||
- **Flexible configuration options** - New `WithLivenessHandler`, `WithReadinessHandler`, `WithStartupHandler`, and `WithServiceName` options
 | 
			
		||||
- **JSON response formats** - Structured probe responses with service identification
 | 
			
		||||
- **Backward compatibility** - Maintains existing `/__health` and `/` endpoints
 | 
			
		||||
 | 
			
		||||
## Logger Package
 | 
			
		||||
- **Runtime log level control** - Independent level management for stderr and OTLP loggers via `SetLevel()` and `SetOTLPLevel()`
 | 
			
		||||
- **Environment variable support** - Configure levels with `LOG_LEVEL` and `OTLP_LOG_LEVEL` env vars
 | 
			
		||||
- **String parsing utility** - New `ParseLevel()` function for converting string levels to `slog.Level`
 | 
			
		||||
- **Buffering exporter fix** - Resolved "unlock of unlocked mutex" panic in `bufferingExporter`
 | 
			
		||||
- **Initialization redesign** - Eliminated race conditions in TLS/tracing setup retry logic
 | 
			
		||||
 | 
			
		||||
## Database Package
 | 
			
		||||
- **Configuration file override** - Added `DATABASE_CONFIG_FILE` environment variable to specify custom database configuration file paths
 | 
			
		||||
- **Flexible path configuration** - Override default `["database.yaml", "/vault/secrets/database.yaml"]` search paths when needed
 | 
			
		||||
 | 
			
		||||
# Release Notes - v0.5.1
 | 
			
		||||
 | 
			
		||||
## Observability Enhancements
 | 
			
		||||
 | 
			
		||||
### OTLP Metrics Support
 | 
			
		||||
- **New `metrics/` package** - OpenTelemetry-native metrics with OTLP export support for structured metrics collection
 | 
			
		||||
- **Centralized OTLP configuration** - Refactored configuration to `internal/tracerconfig/` to eliminate code duplication across tracing, logging, and metrics
 | 
			
		||||
- **HTTP retry support** - Added consistent retry configuration for all HTTP OTLP exporters to improve reliability
 | 
			
		||||
 | 
			
		||||
### Enhanced Logging
 | 
			
		||||
- **Buffering exporter** - Added OTLP log buffering to queue logs until tracing configuration is available
 | 
			
		||||
- **TLS support for logs** - Client certificate authentication support for secure OTLP log export
 | 
			
		||||
- **Improved logfmt formatting** - Better structured output for log messages
 | 
			
		||||
 | 
			
		||||
### Tracing Improvements
 | 
			
		||||
- **HTTP retry support** - OTLP trace requests now automatically retry on failure when using HTTP transport
 | 
			
		||||
 | 
			
		||||
## Build System
 | 
			
		||||
 | 
			
		||||
### Version Package Enhancements
 | 
			
		||||
- **Unix epoch build time support** - Build time can now be injected as Unix timestamps (`$(date +%s)`) in addition to RFC3339 format
 | 
			
		||||
- **Simplified build commands** - Reduces complexity of ldflags injection while maintaining backward compatibility
 | 
			
		||||
- **Consistent output format** - All build times normalize to RFC3339 format regardless of input
 | 
			
		||||
 | 
			
		||||
## API Changes
 | 
			
		||||
 | 
			
		||||
### New Public Interfaces
 | 
			
		||||
- `metrics.NewMeterProvider()` - Create OTLP metrics provider with centralized configuration
 | 
			
		||||
- `metrics.Shutdown()` - Graceful shutdown for metrics exporters
 | 
			
		||||
- `internal/tracerconfig` - Shared OTLP configuration utilities (internal package)
 | 
			
		||||
 | 
			
		||||
### Dependencies
 | 
			
		||||
- Added explicit OpenTelemetry metrics dependencies to `go.mod`
 | 
			
		||||
- Updated tracing dependencies for retry support
 | 
			
		||||
							
								
								
									
										163
									
								
								CLAUDE.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										163
									
								
								CLAUDE.md
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,163 @@
 | 
			
		||||
# CLAUDE.md
 | 
			
		||||
 | 
			
		||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
 | 
			
		||||
 | 
			
		||||
## Commands
 | 
			
		||||
 | 
			
		||||
### Testing
 | 
			
		||||
- Run all tests: `go test ./...`
 | 
			
		||||
- Run tests with verbose output: `go test -v ./...`
 | 
			
		||||
- Run tests for specific package: `go test ./config`
 | 
			
		||||
- Run specific test: `go test -run TestConfigBool ./config`
 | 
			
		||||
 | 
			
		||||
### Building
 | 
			
		||||
- Build all packages: `go build ./...`
 | 
			
		||||
- Check module dependencies: `go mod tidy`
 | 
			
		||||
- Verify dependencies: `go mod verify`
 | 
			
		||||
 | 
			
		||||
### Code Quality
 | 
			
		||||
- Format code: `go fmt ./...`
 | 
			
		||||
- Vet code: `go vet ./...`
 | 
			
		||||
- Run static analysis: `staticcheck ./...` (if available)
 | 
			
		||||
 | 
			
		||||
## Architecture
 | 
			
		||||
 | 
			
		||||
This is a common library (`go.ntppool.org/common`) providing shared infrastructure for the NTP Pool project. The codebase emphasizes observability, security, and modern Go practices.
 | 
			
		||||
 | 
			
		||||
### Core Components
 | 
			
		||||
 | 
			
		||||
**Web Service Foundation:**
 | 
			
		||||
- `ekko/` - Enhanced Echo web framework with pre-configured middleware (OpenTelemetry, Prometheus, logging, security headers)
 | 
			
		||||
- `health/` - Standalone health check HTTP server with `/__health` endpoint
 | 
			
		||||
- `metricsserver/` - Prometheus metrics exposure via `/metrics` endpoint
 | 
			
		||||
 | 
			
		||||
**Observability Stack:**
 | 
			
		||||
- `logger/` - Structured logging with OpenTelemetry trace integration and multiple output formats
 | 
			
		||||
- `tracing/` - OpenTelemetry distributed tracing with OTLP export support
 | 
			
		||||
- `metricsserver/` - Prometheus metrics with custom registry
 | 
			
		||||
 | 
			
		||||
**Configuration & Environment:**
 | 
			
		||||
- `config/` - Environment-based configuration with code-generated accessors (`config_accessor.go`)
 | 
			
		||||
- `version/` - Build metadata and version information with Cobra CLI integration
 | 
			
		||||
 | 
			
		||||
**Security & Communication:**
 | 
			
		||||
- `apitls/` - TLS certificate management with automatic renewal via certman
 | 
			
		||||
- `kafka/` - Kafka client wrapper with TLS support for log streaming
 | 
			
		||||
- `xff/fastlyxff/` - Fastly CDN IP range management for trusted proxy handling
 | 
			
		||||
 | 
			
		||||
**Utilities:**
 | 
			
		||||
- `ulid/` - Thread-safe ULID generation with monotonic ordering
 | 
			
		||||
- `timeutil/` - JSON-serializable duration types
 | 
			
		||||
- `types/` - Shared data structures (LogScoreAttributes for NTP server scoring)
 | 
			
		||||
 | 
			
		||||
### Key Patterns
 | 
			
		||||
 | 
			
		||||
**Functional Options:** Used extensively in `ekko/` for flexible service configuration
 | 
			
		||||
**Interface-Based Design:** `CertificateProvider` in `apitls/` for pluggable certificate management
 | 
			
		||||
**Context Propagation:** Throughout the codebase for cancellation and tracing
 | 
			
		||||
**Graceful Shutdown:** Implemented in web servers and background services
 | 
			
		||||
 | 
			
		||||
### Dependencies
 | 
			
		||||
 | 
			
		||||
The codebase heavily uses:
 | 
			
		||||
- Echo web framework with custom middleware stack
 | 
			
		||||
- OpenTelemetry for observability (traces, metrics, logs)
 | 
			
		||||
- Prometheus for metrics collection
 | 
			
		||||
- Kafka for message streaming
 | 
			
		||||
- Cobra for CLI applications
 | 
			
		||||
 | 
			
		||||
### Code Generation
 | 
			
		||||
 | 
			
		||||
`config/config_accessor.go` is generated - modify `config.go` and regenerate accessors when adding new configuration options.
 | 
			
		||||
 | 
			
		||||
## Package Overview
 | 
			
		||||
 | 
			
		||||
### `apitls/`
 | 
			
		||||
TLS certificate management with automatic renewal support via certman. Provides a CA pool for trusted certificates and interfaces for pluggable certificate providers. Used for secure inter-service communication.
 | 
			
		||||
 | 
			
		||||
### `config/`
 | 
			
		||||
Environment-based configuration system with code-generated accessor methods. Handles deployment mode, hostname configuration, and TLS settings. Provides URL building utilities for web and management interfaces.
 | 
			
		||||
 | 
			
		||||
### `ekko/`
 | 
			
		||||
Enhanced Echo web framework wrapper with pre-configured middleware stack including OpenTelemetry tracing, Prometheus metrics, structured logging, gzip compression, and security headers. Supports HTTP/2 with graceful shutdown.
 | 
			
		||||
 | 
			
		||||
### `health/`
 | 
			
		||||
Standalone HTTP health check server that runs independently from the main application. Exposes `/__health` endpoint with configurable health handlers, timeouts, and graceful shutdown capabilities.
 | 
			
		||||
 | 
			
		||||
### `kafka/`
 | 
			
		||||
Kafka client wrapper with TLS support for secure log streaming. Provides connection management, broker discovery, and reader/writer factories with compression and batching optimizations.
 | 
			
		||||
 | 
			
		||||
### `logger/`
 | 
			
		||||
Structured logging system with OpenTelemetry trace integration. Supports multiple output formats (text, OTLP) with configurable log levels, systemd compatibility, and context-aware logging.
 | 
			
		||||
 | 
			
		||||
### `metricsserver/`
 | 
			
		||||
Dedicated Prometheus metrics HTTP server with custom registry isolation. Exposes `/metrics` endpoint with OpenMetrics support and graceful shutdown handling.
 | 
			
		||||
 | 
			
		||||
### `timeutil/`
 | 
			
		||||
JSON-serializable duration types that support both string parsing ("30s", "5m") and numeric nanosecond values. Compatible with configuration files and REST APIs.
 | 
			
		||||
 | 
			
		||||
### `tracing/`
 | 
			
		||||
OpenTelemetry distributed tracing setup with support for OTLP export via gRPC or HTTP. Handles resource detection, propagation, and automatic instrumentation with configurable TLS.
 | 
			
		||||
 | 
			
		||||
### `types/`
 | 
			
		||||
Shared data structures for the NTP Pool project. Currently contains `LogScoreAttributes` for NTP server scoring with JSON and SQL database compatibility.
 | 
			
		||||
 | 
			
		||||
### `ulid/`
 | 
			
		||||
Thread-safe ULID (Universally Unique Lexicographically Sortable Identifier) generation using cryptographically secure randomness. Optimized for simplicity and performance in high-concurrency environments.
 | 
			
		||||
 | 
			
		||||
### `version/`
 | 
			
		||||
Build metadata and version information system with Git integration. Provides CLI commands for Cobra and Kong frameworks, Prometheus build info metrics, and semantic version validation.
 | 
			
		||||
 | 
			
		||||
### `xff/fastlyxff/`
 | 
			
		||||
Fastly CDN IP range management for trusted proxy handling. Parses Fastly's IP ranges JSON file and generates Echo framework trust options for proper client IP extraction.
 | 
			
		||||
 | 
			
		||||
## Go Development Best Practices
 | 
			
		||||
 | 
			
		||||
### Code Style
 | 
			
		||||
- Follow standard Go formatting (`go fmt ./...`)
 | 
			
		||||
- Use `go vet ./...` for static analysis
 | 
			
		||||
- Run `staticcheck ./...` when available
 | 
			
		||||
- Prefer short, descriptive variable names
 | 
			
		||||
- Use interfaces for testability and flexibility
 | 
			
		||||
 | 
			
		||||
### Error Handling
 | 
			
		||||
- Always handle errors explicitly
 | 
			
		||||
- Use `errors.Join()` for combining multiple errors
 | 
			
		||||
- Wrap errors with context using `fmt.Errorf("context: %w", err)`
 | 
			
		||||
- Return early on errors to reduce nesting
 | 
			
		||||
 | 
			
		||||
### Testing
 | 
			
		||||
- Write table-driven tests when testing multiple scenarios
 | 
			
		||||
- Use `t.Helper()` in test helper functions
 | 
			
		||||
- Test error conditions, not just happy paths
 | 
			
		||||
- Use `testing.Short()` for integration tests that can be skipped
 | 
			
		||||
 | 
			
		||||
### Concurrency
 | 
			
		||||
- Use contexts for cancellation and timeouts
 | 
			
		||||
- Prefer channels for communication over shared memory
 | 
			
		||||
- Use `sync.Once` for one-time initialization
 | 
			
		||||
- Always call `defer cancel()` after `context.WithCancel()`
 | 
			
		||||
 | 
			
		||||
### Performance
 | 
			
		||||
- Use `sync.Pool` for frequently allocated objects
 | 
			
		||||
- Prefer slices over arrays for better performance
 | 
			
		||||
- Use `strings.Builder` for string concatenation in loops
 | 
			
		||||
- Profile before optimizing with `go tool pprof`
 | 
			
		||||
 | 
			
		||||
### Observability
 | 
			
		||||
- Use structured logging with key-value pairs
 | 
			
		||||
- Add OpenTelemetry spans for external calls
 | 
			
		||||
- Include trace IDs in error messages
 | 
			
		||||
- Use metrics for monitoring application health
 | 
			
		||||
 | 
			
		||||
### Dependencies
 | 
			
		||||
- Keep dependencies minimal and well-maintained
 | 
			
		||||
- Use `go mod tidy` to clean up unused dependencies
 | 
			
		||||
- Pin major versions to avoid breaking changes
 | 
			
		||||
- Prefer standard library when possible
 | 
			
		||||
 | 
			
		||||
### Security
 | 
			
		||||
- Never log sensitive information (passwords, tokens)
 | 
			
		||||
- Use `crypto/rand` for cryptographic randomness
 | 
			
		||||
- Validate all inputs at API boundaries
 | 
			
		||||
- Use TLS for all network communication
 | 
			
		||||
							
								
								
									
										20
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,20 @@
 | 
			
		||||
 | 
			
		||||
Common library for the NTP Pool project with shared infrastructure components.
 | 
			
		||||
 | 
			
		||||
## Packages
 | 
			
		||||
 | 
			
		||||
- **apitls** - TLS setup for NTP Pool internal services with embedded CA
 | 
			
		||||
- **config** - NTP Pool project configuration with environment variables
 | 
			
		||||
- **ekko** - Enhanced Echo web framework with observability middleware
 | 
			
		||||
- **health** - Standalone health check HTTP server
 | 
			
		||||
- **kafka** - Kafka client wrapper with TLS support
 | 
			
		||||
- **logger** - Structured logging with OpenTelemetry integration
 | 
			
		||||
- **metricsserver** - Prometheus metrics HTTP server
 | 
			
		||||
- **timeutil** - JSON-serializable duration types
 | 
			
		||||
- **tracing** - OpenTelemetry distributed tracing setup
 | 
			
		||||
- **types** - Shared data structures for NTP Pool
 | 
			
		||||
- **ulid** - Thread-safe ULID generation
 | 
			
		||||
- **version** - Build metadata and version information
 | 
			
		||||
- **xff/fastlyxff** - Fastly CDN IP range management
 | 
			
		||||
 | 
			
		||||
[](https://pkg.go.dev/go.ntppool.org/common)
 | 
			
		||||
@@ -1,3 +1,14 @@
 | 
			
		||||
// Package apitls provides TLS certificate management with automatic renewal support.
 | 
			
		||||
//
 | 
			
		||||
// This package handles TLS certificate provisioning and management for secure
 | 
			
		||||
// inter-service communication within the NTP Pool project infrastructure.
 | 
			
		||||
// It provides both server and client certificate management through the
 | 
			
		||||
// CertificateProvider interface and includes a trusted CA certificate pool
 | 
			
		||||
// for validating certificates.
 | 
			
		||||
//
 | 
			
		||||
// The package integrates with certman for automatic certificate renewal
 | 
			
		||||
// and includes embedded CA certificates for establishing trust relationships
 | 
			
		||||
// between services.
 | 
			
		||||
package apitls
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
@@ -13,11 +24,32 @@ import (
 | 
			
		||||
//go:embed ca.pem
 | 
			
		||||
var caBytes []byte
 | 
			
		||||
 | 
			
		||||
// CertificateProvider defines the interface for providing TLS certificates
 | 
			
		||||
// for both server and client connections. Implementations should handle
 | 
			
		||||
// certificate retrieval, caching, and renewal as needed.
 | 
			
		||||
//
 | 
			
		||||
// This interface supports both server-side certificate provisioning
 | 
			
		||||
// (via GetCertificate) and client-side certificate authentication
 | 
			
		||||
// (via GetClientCertificate).
 | 
			
		||||
type CertificateProvider interface {
 | 
			
		||||
	// GetCertificate retrieves a server certificate based on the client hello information.
 | 
			
		||||
	// This method is typically used in tls.Config.GetCertificate for server-side TLS.
 | 
			
		||||
	GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
 | 
			
		||||
 | 
			
		||||
	// GetClientCertificate retrieves a client certificate for mutual TLS authentication.
 | 
			
		||||
	// This method is used in tls.Config.GetClientCertificate for client-side TLS.
 | 
			
		||||
	GetClientCertificate(certRequestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CAPool returns a certificate pool containing trusted CA certificates
 | 
			
		||||
// for validating TLS connections within the NTP Pool infrastructure.
 | 
			
		||||
//
 | 
			
		||||
// The CA certificates are embedded in the binary and include the trusted
 | 
			
		||||
// certificate authorities used for inter-service communication.
 | 
			
		||||
// This pool should be used in tls.Config.RootCAs for client connections
 | 
			
		||||
// or tls.Config.ClientCAs for server connections requiring client certificates.
 | 
			
		||||
//
 | 
			
		||||
// Returns an error if the embedded CA certificates cannot be parsed or loaded.
 | 
			
		||||
func CAPool() (*x509.CertPool, error) {
 | 
			
		||||
	capool := x509.NewCertPool()
 | 
			
		||||
	if !capool.AppendCertsFromPEM(caBytes) {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,18 @@
 | 
			
		||||
// Package config provides NTP Pool specific
 | 
			
		||||
// configuration tools.
 | 
			
		||||
// Package config provides environment-based configuration management for NTP Pool services.
 | 
			
		||||
//
 | 
			
		||||
// This package handles configuration loading from environment variables and provides
 | 
			
		||||
// utilities for constructing URLs for web and management interfaces. It supports
 | 
			
		||||
// deployment-specific settings including hostname configuration, TLS settings,
 | 
			
		||||
// and deployment modes.
 | 
			
		||||
//
 | 
			
		||||
// Configuration is loaded automatically from environment variables:
 | 
			
		||||
//   - deployment_mode: The deployment environment (devel, production, etc.)
 | 
			
		||||
//   - manage_hostname: Hostname for management interface
 | 
			
		||||
//   - web_hostname: Comma-separated list of web hostnames (first is primary)
 | 
			
		||||
//   - manage_tls: Enable TLS for management interface (yes, no, true, false)
 | 
			
		||||
//   - web_tls: Enable TLS for web interface (yes, no, true, false)
 | 
			
		||||
//
 | 
			
		||||
// The package includes code generation for accessor methods using the accessory tool.
 | 
			
		||||
package config
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
@@ -13,6 +26,9 @@ import (
 | 
			
		||||
 | 
			
		||||
//go:generate go tool github.com/masaushi/accessory -type Config
 | 
			
		||||
 | 
			
		||||
// Config holds environment-based configuration for NTP Pool services.
 | 
			
		||||
// It manages hostnames, TLS settings, and deployment modes loaded from
 | 
			
		||||
// environment variables. The struct includes code-generated accessor methods.
 | 
			
		||||
type Config struct {
 | 
			
		||||
	deploymentMode string `accessor:"getter"`
 | 
			
		||||
 | 
			
		||||
@@ -23,9 +39,22 @@ type Config struct {
 | 
			
		||||
	webHostnames []string
 | 
			
		||||
	webTLS       bool
 | 
			
		||||
 | 
			
		||||
	poolDomain string `accessor:"getter"`
 | 
			
		||||
 | 
			
		||||
	valid bool `accessor:"getter"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// New creates a new Config instance by loading configuration from environment variables.
 | 
			
		||||
// It automatically parses hostnames, TLS settings, and deployment mode from the environment.
 | 
			
		||||
// The configuration is considered valid if at least one web hostname is provided.
 | 
			
		||||
//
 | 
			
		||||
// Environment variables used:
 | 
			
		||||
//   - deployment_mode: Deployment environment identifier
 | 
			
		||||
//   - manage_hostname: Management interface hostname
 | 
			
		||||
//   - web_hostname: Comma-separated web hostnames (first becomes primary)
 | 
			
		||||
//   - manage_tls: Management interface TLS setting
 | 
			
		||||
//   - web_tls: Web interface TLS setting
 | 
			
		||||
//   - pool_domain: NTP pool domain (default: pool.ntp.org)
 | 
			
		||||
func New() *Config {
 | 
			
		||||
	c := Config{}
 | 
			
		||||
	c.deploymentMode = os.Getenv("deployment_mode")
 | 
			
		||||
@@ -43,13 +72,34 @@ func New() *Config {
 | 
			
		||||
	c.manageTLS = parseBool(os.Getenv("manage_tls"))
 | 
			
		||||
	c.webTLS = parseBool(os.Getenv("web_tls"))
 | 
			
		||||
 | 
			
		||||
	c.poolDomain = os.Getenv("pool_domain")
 | 
			
		||||
	if c.poolDomain == "" {
 | 
			
		||||
		c.poolDomain = "pool.ntp.org"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &c
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WebURL constructs a complete URL for the web interface using the primary web hostname.
 | 
			
		||||
// It automatically selects HTTP or HTTPS based on the web_tls configuration setting.
 | 
			
		||||
//
 | 
			
		||||
// Parameters:
 | 
			
		||||
//   - path: URL path component (should start with "/")
 | 
			
		||||
//   - query: Optional URL query parameters (can be nil)
 | 
			
		||||
//
 | 
			
		||||
// Returns a complete URL string suitable for web interface requests.
 | 
			
		||||
func (c *Config) WebURL(path string, query *url.Values) string {
 | 
			
		||||
	return baseURL(c.webHostname, c.webTLS, path, query)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ManageURL constructs a complete URL for the management interface using the management hostname.
 | 
			
		||||
// It automatically selects HTTP or HTTPS based on the manage_tls configuration setting.
 | 
			
		||||
//
 | 
			
		||||
// Parameters:
 | 
			
		||||
//   - path: URL path component (should start with "/")
 | 
			
		||||
//   - query: Optional URL query parameters (can be nil)
 | 
			
		||||
//
 | 
			
		||||
// Returns a complete URL string suitable for management interface requests.
 | 
			
		||||
func (c *Config) ManageURL(path string, query *url.Values) string {
 | 
			
		||||
	return baseURL(c.manageHostname, c.webTLS, path, query)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -23,6 +23,13 @@ func (c *Config) WebHostname() string {
 | 
			
		||||
	return c.webHostname
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Config) PoolDomain() string {
 | 
			
		||||
	if c == nil {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	return c.poolDomain
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Config) Valid() bool {
 | 
			
		||||
	if c == nil {
 | 
			
		||||
		return false
 | 
			
		||||
 
 | 
			
		||||
@@ -1,3 +1,19 @@
 | 
			
		||||
// Package depenv provides deployment environment management for NTP Pool services.
 | 
			
		||||
//
 | 
			
		||||
// This package handles different deployment environments (development, test, production)
 | 
			
		||||
// and provides environment-specific configuration including API endpoints, management URLs,
 | 
			
		||||
// and monitoring domains. It supports string-based environment identification and
 | 
			
		||||
// automatic URL construction for various service endpoints.
 | 
			
		||||
//
 | 
			
		||||
// The package defines three main deployment environments:
 | 
			
		||||
//   - DeployDevel: Development environment with dev-specific endpoints
 | 
			
		||||
//   - DeployTest: Test/beta environment for staging
 | 
			
		||||
//   - DeployProd: Production environment with live endpoints
 | 
			
		||||
//
 | 
			
		||||
// Environment detection supports both short and long forms:
 | 
			
		||||
//   - "dev" or "devel" → DeployDevel
 | 
			
		||||
//   - "test" or "beta" → DeployTest
 | 
			
		||||
//   - "prod" → DeployProd
 | 
			
		||||
package depenv
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
@@ -24,14 +40,27 @@ var apiServers = map[DeploymentEnvironment]string{
 | 
			
		||||
// }
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	// DeployUndefined represents an unrecognized or unset deployment environment.
 | 
			
		||||
	DeployUndefined DeploymentEnvironment = iota
 | 
			
		||||
	// DeployDevel represents the development environment.
 | 
			
		||||
	DeployDevel
 | 
			
		||||
	// DeployTest represents the test/beta environment.
 | 
			
		||||
	DeployTest
 | 
			
		||||
	// DeployProd represents the production environment.
 | 
			
		||||
	DeployProd
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// DeploymentEnvironment represents a deployment environment type.
 | 
			
		||||
// It provides methods for environment-specific URL construction and
 | 
			
		||||
// supports text marshaling/unmarshaling for configuration files.
 | 
			
		||||
type DeploymentEnvironment uint8
 | 
			
		||||
 | 
			
		||||
// DeploymentEnvironmentFromString parses a string into a DeploymentEnvironment.
 | 
			
		||||
// It supports both short and long forms of environment names:
 | 
			
		||||
//   - "dev" or "devel" → DeployDevel
 | 
			
		||||
//   - "test" or "beta" → DeployTest
 | 
			
		||||
//   - "prod" → DeployProd
 | 
			
		||||
//   - any other value → DeployUndefined
 | 
			
		||||
func DeploymentEnvironmentFromString(s string) DeploymentEnvironment {
 | 
			
		||||
	switch s {
 | 
			
		||||
	case "devel", "dev":
 | 
			
		||||
@@ -45,6 +74,8 @@ func DeploymentEnvironmentFromString(s string) DeploymentEnvironment {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// String returns the canonical string representation of the deployment environment.
 | 
			
		||||
// Returns "prod", "test", "devel", or panics for invalid environments.
 | 
			
		||||
func (d DeploymentEnvironment) String() string {
 | 
			
		||||
	switch d {
 | 
			
		||||
	case DeployProd:
 | 
			
		||||
@@ -58,6 +89,9 @@ func (d DeploymentEnvironment) String() string {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// APIHost returns the API server URL for this deployment environment.
 | 
			
		||||
// It first checks the API_HOST environment variable for overrides,
 | 
			
		||||
// then falls back to the environment-specific default API endpoint.
 | 
			
		||||
func (d DeploymentEnvironment) APIHost() string {
 | 
			
		||||
	if apiHost := os.Getenv("API_HOST"); apiHost != "" {
 | 
			
		||||
		return apiHost
 | 
			
		||||
@@ -65,14 +99,26 @@ func (d DeploymentEnvironment) APIHost() string {
 | 
			
		||||
	return apiServers[d]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ManageURL constructs a management interface URL for this deployment environment.
 | 
			
		||||
// It combines the environment-specific management server base URL with the provided path.
 | 
			
		||||
//
 | 
			
		||||
// The path parameter should start with "/" for proper URL construction.
 | 
			
		||||
func (d DeploymentEnvironment) ManageURL(path string) string {
 | 
			
		||||
	return manageServers[d] + path
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MonitorDomain returns the monitoring domain for this deployment environment.
 | 
			
		||||
// The domain follows the pattern: {environment}.mon.ntppool.dev
 | 
			
		||||
// For example: "devel.mon.ntppool.dev" for the development environment.
 | 
			
		||||
func (d DeploymentEnvironment) MonitorDomain() string {
 | 
			
		||||
	return d.String() + ".mon.ntppool.dev"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UnmarshalText implements the encoding.TextUnmarshaler interface.
 | 
			
		||||
// It allows DeploymentEnvironment to be unmarshaled from configuration files
 | 
			
		||||
// and other text-based formats. Empty strings are treated as valid (no-op).
 | 
			
		||||
//
 | 
			
		||||
// Returns an error if the text represents an invalid deployment environment.
 | 
			
		||||
func (d *DeploymentEnvironment) UnmarshalText(text []byte) error {
 | 
			
		||||
	s := string(text)
 | 
			
		||||
	if s == "" {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										125
									
								
								database/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										125
									
								
								database/config.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,125 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/prometheus/client_golang/prometheus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Config represents the database configuration structure
 | 
			
		||||
type Config struct {
 | 
			
		||||
	// MySQL configuration (use this OR Postgres, not both)
 | 
			
		||||
	MySQL *MySQLConfig `yaml:"mysql,omitempty"`
 | 
			
		||||
	// Postgres configuration (use this OR MySQL, not both)
 | 
			
		||||
	Postgres *PostgresConfig `yaml:"postgres,omitempty"`
 | 
			
		||||
 | 
			
		||||
	// Legacy flat PostgreSQL format (deprecated, for backward compatibility only)
 | 
			
		||||
	// If neither MySQL nor Postgres is set, these fields will be used for PostgreSQL
 | 
			
		||||
	User    string `yaml:"user,omitempty"`
 | 
			
		||||
	Pass    string `yaml:"pass,omitempty"`
 | 
			
		||||
	Host    string `yaml:"host,omitempty"`
 | 
			
		||||
	Port    uint16 `yaml:"port,omitempty"`
 | 
			
		||||
	Name    string `yaml:"name,omitempty"`
 | 
			
		||||
	SSLMode string `yaml:"sslmode,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MySQLConfig represents the MySQL database configuration
 | 
			
		||||
type MySQLConfig struct {
 | 
			
		||||
	DSN    string `yaml:"dsn" default:"" flag:"dsn" usage:"Database DSN"`
 | 
			
		||||
	User   string `yaml:"user" default:"" flag:"user"`
 | 
			
		||||
	Pass   string `yaml:"pass" default:"" flag:"pass"`
 | 
			
		||||
	DBName string `yaml:"name,omitempty"` // Optional database name override
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PostgresConfig represents the PostgreSQL database configuration
 | 
			
		||||
type PostgresConfig struct {
 | 
			
		||||
	User    string `yaml:"user"`
 | 
			
		||||
	Pass    string `yaml:"pass"`
 | 
			
		||||
	Host    string `yaml:"host"`
 | 
			
		||||
	Port    uint16 `yaml:"port"`
 | 
			
		||||
	Name    string `yaml:"name"`
 | 
			
		||||
	SSLMode string `yaml:"sslmode"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DBConfig is a legacy alias for MySQLConfig
 | 
			
		||||
type DBConfig = MySQLConfig
 | 
			
		||||
 | 
			
		||||
// Validate ensures the configuration is valid and unambiguous
 | 
			
		||||
func (c *Config) Validate() error {
 | 
			
		||||
	hasMySQL := c.MySQL != nil
 | 
			
		||||
	hasPostgres := c.Postgres != nil
 | 
			
		||||
	hasLegacy := c.User != "" || c.Host != "" || c.Port != 0 || c.Name != ""
 | 
			
		||||
 | 
			
		||||
	count := 0
 | 
			
		||||
	if hasMySQL {
 | 
			
		||||
		count++
 | 
			
		||||
	}
 | 
			
		||||
	if hasPostgres {
 | 
			
		||||
		count++
 | 
			
		||||
	}
 | 
			
		||||
	if hasLegacy {
 | 
			
		||||
		count++
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if count == 0 {
 | 
			
		||||
		return fmt.Errorf("no database configuration provided")
 | 
			
		||||
	}
 | 
			
		||||
	if count > 1 {
 | 
			
		||||
		return fmt.Errorf("multiple database configurations provided (only one allowed)")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConfigOptions allows customization of database opening behavior
 | 
			
		||||
type ConfigOptions struct {
 | 
			
		||||
	// ConfigFiles is a list of config file paths to search for database configuration
 | 
			
		||||
	ConfigFiles []string
 | 
			
		||||
 | 
			
		||||
	// EnablePoolMonitoring enables connection pool metrics collection
 | 
			
		||||
	EnablePoolMonitoring bool
 | 
			
		||||
 | 
			
		||||
	// PrometheusRegisterer for metrics collection. If nil, no metrics are collected.
 | 
			
		||||
	PrometheusRegisterer prometheus.Registerer
 | 
			
		||||
 | 
			
		||||
	// Connection pool settings
 | 
			
		||||
	MaxOpenConns    int
 | 
			
		||||
	MaxIdleConns    int
 | 
			
		||||
	ConnMaxLifetime time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// getConfigFiles returns the list of config files to search for database configuration.
 | 
			
		||||
// If DATABASE_CONFIG_FILE environment variable is set, it returns that single file.
 | 
			
		||||
// Otherwise, it returns the default paths.
 | 
			
		||||
func getConfigFiles() []string {
 | 
			
		||||
	if configFile := os.Getenv("DATABASE_CONFIG_FILE"); configFile != "" {
 | 
			
		||||
		return []string{configFile}
 | 
			
		||||
	}
 | 
			
		||||
	return []string{"database.yaml", "/vault/secrets/database.yaml"}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DefaultConfigOptions returns the standard configuration options used by API package
 | 
			
		||||
func DefaultConfigOptions() ConfigOptions {
 | 
			
		||||
	return ConfigOptions{
 | 
			
		||||
		ConfigFiles:          getConfigFiles(),
 | 
			
		||||
		EnablePoolMonitoring: true,
 | 
			
		||||
		PrometheusRegisterer: prometheus.DefaultRegisterer,
 | 
			
		||||
		MaxOpenConns:         25,
 | 
			
		||||
		MaxIdleConns:         10,
 | 
			
		||||
		ConnMaxLifetime:      3 * time.Minute,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MonitorConfigOptions returns configuration options optimized for Monitor package
 | 
			
		||||
func MonitorConfigOptions() ConfigOptions {
 | 
			
		||||
	return ConfigOptions{
 | 
			
		||||
		ConfigFiles:          getConfigFiles(),
 | 
			
		||||
		EnablePoolMonitoring: false, // Monitor doesn't need metrics
 | 
			
		||||
		PrometheusRegisterer: nil,   // No Prometheus dependency
 | 
			
		||||
		MaxOpenConns:         10,
 | 
			
		||||
		MaxIdleConns:         5,
 | 
			
		||||
		ConnMaxLifetime:      3 * time.Minute,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										196
									
								
								database/config_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										196
									
								
								database/config_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,196 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/prometheus/client_golang/prometheus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestDefaultConfigOptions(t *testing.T) {
 | 
			
		||||
	opts := DefaultConfigOptions()
 | 
			
		||||
 | 
			
		||||
	// Verify expected defaults for API package
 | 
			
		||||
	if opts.MaxOpenConns != 25 {
 | 
			
		||||
		t.Errorf("Expected MaxOpenConns=25, got %d", opts.MaxOpenConns)
 | 
			
		||||
	}
 | 
			
		||||
	if opts.MaxIdleConns != 10 {
 | 
			
		||||
		t.Errorf("Expected MaxIdleConns=10, got %d", opts.MaxIdleConns)
 | 
			
		||||
	}
 | 
			
		||||
	if opts.ConnMaxLifetime != 3*time.Minute {
 | 
			
		||||
		t.Errorf("Expected ConnMaxLifetime=3m, got %v", opts.ConnMaxLifetime)
 | 
			
		||||
	}
 | 
			
		||||
	if !opts.EnablePoolMonitoring {
 | 
			
		||||
		t.Error("Expected EnablePoolMonitoring=true")
 | 
			
		||||
	}
 | 
			
		||||
	if opts.PrometheusRegisterer != prometheus.DefaultRegisterer {
 | 
			
		||||
		t.Error("Expected PrometheusRegisterer to be DefaultRegisterer")
 | 
			
		||||
	}
 | 
			
		||||
	if len(opts.ConfigFiles) == 0 {
 | 
			
		||||
		t.Error("Expected ConfigFiles to be non-empty")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMonitorConfigOptions(t *testing.T) {
 | 
			
		||||
	opts := MonitorConfigOptions()
 | 
			
		||||
 | 
			
		||||
	// Verify expected defaults for Monitor package
 | 
			
		||||
	if opts.MaxOpenConns != 10 {
 | 
			
		||||
		t.Errorf("Expected MaxOpenConns=10, got %d", opts.MaxOpenConns)
 | 
			
		||||
	}
 | 
			
		||||
	if opts.MaxIdleConns != 5 {
 | 
			
		||||
		t.Errorf("Expected MaxIdleConns=5, got %d", opts.MaxIdleConns)
 | 
			
		||||
	}
 | 
			
		||||
	if opts.ConnMaxLifetime != 3*time.Minute {
 | 
			
		||||
		t.Errorf("Expected ConnMaxLifetime=3m, got %v", opts.ConnMaxLifetime)
 | 
			
		||||
	}
 | 
			
		||||
	if opts.EnablePoolMonitoring {
 | 
			
		||||
		t.Error("Expected EnablePoolMonitoring=false")
 | 
			
		||||
	}
 | 
			
		||||
	if opts.PrometheusRegisterer != nil {
 | 
			
		||||
		t.Error("Expected PrometheusRegisterer to be nil")
 | 
			
		||||
	}
 | 
			
		||||
	if len(opts.ConfigFiles) == 0 {
 | 
			
		||||
		t.Error("Expected ConfigFiles to be non-empty")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestConfigStructures(t *testing.T) {
 | 
			
		||||
	// Test that MySQL configuration structures can be created and populated
 | 
			
		||||
	config := Config{
 | 
			
		||||
		MySQL: &MySQLConfig{
 | 
			
		||||
			DSN:    "user:pass@tcp(localhost:3306)/dbname",
 | 
			
		||||
			User:   "testuser",
 | 
			
		||||
			Pass:   "testpass",
 | 
			
		||||
			DBName: "testdb",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if config.MySQL.DSN == "" {
 | 
			
		||||
		t.Error("Expected DSN to be set")
 | 
			
		||||
	}
 | 
			
		||||
	if config.MySQL.User != "testuser" {
 | 
			
		||||
		t.Errorf("Expected User='testuser', got '%s'", config.MySQL.User)
 | 
			
		||||
	}
 | 
			
		||||
	if config.MySQL.Pass != "testpass" {
 | 
			
		||||
		t.Errorf("Expected Pass='testpass', got '%s'", config.MySQL.Pass)
 | 
			
		||||
	}
 | 
			
		||||
	if config.MySQL.DBName != "testdb" {
 | 
			
		||||
		t.Errorf("Expected DBName='testdb', got '%s'", config.MySQL.DBName)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestPostgresConfigStructures(t *testing.T) {
 | 
			
		||||
	// Test that PostgreSQL configuration structures can be created and populated
 | 
			
		||||
	config := Config{
 | 
			
		||||
		Postgres: &PostgresConfig{
 | 
			
		||||
			Host:    "localhost",
 | 
			
		||||
			Port:    5432,
 | 
			
		||||
			User:    "testuser",
 | 
			
		||||
			Pass:    "testpass",
 | 
			
		||||
			Name:    "testdb",
 | 
			
		||||
			SSLMode: "require",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if config.Postgres.Host != "localhost" {
 | 
			
		||||
		t.Errorf("Expected Host='localhost', got '%s'", config.Postgres.Host)
 | 
			
		||||
	}
 | 
			
		||||
	if config.Postgres.Port != 5432 {
 | 
			
		||||
		t.Errorf("Expected Port=5432, got %d", config.Postgres.Port)
 | 
			
		||||
	}
 | 
			
		||||
	if config.Postgres.User != "testuser" {
 | 
			
		||||
		t.Errorf("Expected User='testuser', got '%s'", config.Postgres.User)
 | 
			
		||||
	}
 | 
			
		||||
	if config.Postgres.Pass != "testpass" {
 | 
			
		||||
		t.Errorf("Expected Pass='testpass', got '%s'", config.Postgres.Pass)
 | 
			
		||||
	}
 | 
			
		||||
	if config.Postgres.Name != "testdb" {
 | 
			
		||||
		t.Errorf("Expected Name='testdb', got '%s'", config.Postgres.Name)
 | 
			
		||||
	}
 | 
			
		||||
	if config.Postgres.SSLMode != "require" {
 | 
			
		||||
		t.Errorf("Expected SSLMode='require', got '%s'", config.Postgres.SSLMode)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestLegacyPostgresConfig(t *testing.T) {
 | 
			
		||||
	// Test that legacy flat PostgreSQL format can be created
 | 
			
		||||
	config := Config{
 | 
			
		||||
		User:    "testuser",
 | 
			
		||||
		Pass:    "testpass",
 | 
			
		||||
		Host:    "localhost",
 | 
			
		||||
		Port:    5432,
 | 
			
		||||
		Name:    "testdb",
 | 
			
		||||
		SSLMode: "require",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if config.User != "testuser" {
 | 
			
		||||
		t.Errorf("Expected User='testuser', got '%s'", config.User)
 | 
			
		||||
	}
 | 
			
		||||
	if config.Name != "testdb" {
 | 
			
		||||
		t.Errorf("Expected Name='testdb', got '%s'", config.Name)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestConfigValidation(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name    string
 | 
			
		||||
		config  Config
 | 
			
		||||
		wantErr bool
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name: "valid mysql config",
 | 
			
		||||
			config: Config{
 | 
			
		||||
				MySQL: &MySQLConfig{DSN: "test"},
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "valid postgres config",
 | 
			
		||||
			config: Config{
 | 
			
		||||
				Postgres: &PostgresConfig{User: "test", Host: "localhost", Name: "test"},
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "valid legacy postgres config",
 | 
			
		||||
			config: Config{
 | 
			
		||||
				User: "test",
 | 
			
		||||
				Host: "localhost",
 | 
			
		||||
				Name: "testdb",
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "both mysql and postgres set",
 | 
			
		||||
			config: Config{
 | 
			
		||||
				MySQL:    &MySQLConfig{DSN: "test"},
 | 
			
		||||
				Postgres: &PostgresConfig{User: "test"},
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "mysql and legacy postgres set",
 | 
			
		||||
			config: Config{
 | 
			
		||||
				MySQL: &MySQLConfig{DSN: "test"},
 | 
			
		||||
				User:  "test",
 | 
			
		||||
				Name:  "testdb",
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:    "no config set",
 | 
			
		||||
			config:  Config{},
 | 
			
		||||
			wantErr: true,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			err := tt.config.Validate()
 | 
			
		||||
			if (err != nil) != tt.wantErr {
 | 
			
		||||
				t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										187
									
								
								database/connector.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										187
									
								
								database/connector.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,187 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-sql-driver/mysql"
 | 
			
		||||
	"github.com/jackc/pgx/v5"
 | 
			
		||||
	"github.com/jackc/pgx/v5/stdlib"
 | 
			
		||||
	"gopkg.in/yaml.v3"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// from https://github.com/Boostport/dynamic-database-config
 | 
			
		||||
 | 
			
		||||
// CreateConnectorFunc is a function that creates a database connector
 | 
			
		||||
type CreateConnectorFunc func() (driver.Connector, error)
 | 
			
		||||
 | 
			
		||||
// Driver implements the sql/driver interface with dynamic configuration
 | 
			
		||||
type Driver struct {
 | 
			
		||||
	CreateConnectorFunc CreateConnectorFunc
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Driver returns the driver instance
 | 
			
		||||
func (d Driver) Driver() driver.Driver {
 | 
			
		||||
	return d
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Connect creates a new database connection using the dynamic connector
 | 
			
		||||
func (d Driver) Connect(ctx context.Context) (driver.Conn, error) {
 | 
			
		||||
	connector, err := d.CreateConnectorFunc()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("error creating connector from function: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return connector.Connect(ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Open is not supported for dynamic configuration
 | 
			
		||||
func (d Driver) Open(name string) (driver.Conn, error) {
 | 
			
		||||
	return nil, errors.New("open is not supported")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// createConnector creates a connector function that reads configuration from a file
 | 
			
		||||
func createConnector(configFile string) CreateConnectorFunc {
 | 
			
		||||
	return func() (driver.Connector, error) {
 | 
			
		||||
		dbFile, err := os.Open(configFile)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		defer dbFile.Close()
 | 
			
		||||
 | 
			
		||||
		dec := yaml.NewDecoder(dbFile)
 | 
			
		||||
		cfg := Config{}
 | 
			
		||||
 | 
			
		||||
		err = dec.Decode(&cfg)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Validate configuration
 | 
			
		||||
		if err := cfg.Validate(); err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("invalid configuration: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Determine database type and create appropriate connector
 | 
			
		||||
		if cfg.MySQL != nil {
 | 
			
		||||
			return createMySQLConnector(cfg.MySQL)
 | 
			
		||||
		} else if cfg.Postgres != nil {
 | 
			
		||||
			return createPostgresConnector(cfg.Postgres)
 | 
			
		||||
		} else if cfg.User != "" && cfg.Name != "" {
 | 
			
		||||
			// Legacy flat PostgreSQL format (requires at minimum user and dbname)
 | 
			
		||||
			return createPostgresConnectorFromFlat(&cfg)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil, fmt.Errorf("no valid database configuration found (mysql or postgres section required)")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// createMySQLConnector creates a MySQL connector from configuration
 | 
			
		||||
func createMySQLConnector(cfg *MySQLConfig) (driver.Connector, error) {
 | 
			
		||||
	dsn := cfg.DSN
 | 
			
		||||
	if len(dsn) == 0 {
 | 
			
		||||
		dsn = os.Getenv("DATABASE_DSN")
 | 
			
		||||
		if len(dsn) == 0 {
 | 
			
		||||
			return nil, fmt.Errorf("dsn config in database.yaml or DATABASE_DSN environment variable required")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dbcfg, err := mysql.ParseDSN(dsn)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user := cfg.User; len(user) > 0 {
 | 
			
		||||
		dbcfg.User = user
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if pass := cfg.Pass; len(pass) > 0 {
 | 
			
		||||
		dbcfg.Passwd = pass
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if name := cfg.DBName; len(name) > 0 {
 | 
			
		||||
		dbcfg.DBName = name
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return mysql.NewConnector(dbcfg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// createPostgresConnector creates a PostgreSQL connector from configuration
 | 
			
		||||
func createPostgresConnector(cfg *PostgresConfig) (driver.Connector, error) {
 | 
			
		||||
	// Validate required fields
 | 
			
		||||
	if cfg.Host == "" {
 | 
			
		||||
		return nil, fmt.Errorf("postgres: host is required")
 | 
			
		||||
	}
 | 
			
		||||
	if cfg.User == "" {
 | 
			
		||||
		return nil, fmt.Errorf("postgres: user is required")
 | 
			
		||||
	}
 | 
			
		||||
	if cfg.Name == "" {
 | 
			
		||||
		return nil, fmt.Errorf("postgres: database name is required")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Validate SSLMode
 | 
			
		||||
	validSSLModes := map[string]bool{
 | 
			
		||||
		"disable": true, "allow": true, "prefer": true,
 | 
			
		||||
		"require": true, "verify-ca": true, "verify-full": true,
 | 
			
		||||
	}
 | 
			
		||||
	if cfg.SSLMode != "" && !validSSLModes[cfg.SSLMode] {
 | 
			
		||||
		return nil, fmt.Errorf("postgres: invalid sslmode: %s", cfg.SSLMode)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Build config directly (security: no DSN string with password)
 | 
			
		||||
	connConfig, err := pgx.ParseConfig("")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("postgres: failed to create pgx config: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	connConfig.Host = cfg.Host
 | 
			
		||||
	connConfig.Port = cfg.Port
 | 
			
		||||
	connConfig.User = cfg.User
 | 
			
		||||
	connConfig.Password = cfg.Pass
 | 
			
		||||
	connConfig.Database = cfg.Name
 | 
			
		||||
 | 
			
		||||
	// Map SSLMode to pgx configuration
 | 
			
		||||
	// Note: pgx uses different SSL handling than libpq
 | 
			
		||||
	// For now, we'll construct a minimal DSN with sslmode for ParseConfig
 | 
			
		||||
	if cfg.SSLMode != "" {
 | 
			
		||||
		// Reconstruct with sslmode only (no password in DSN)
 | 
			
		||||
		dsnWithoutPassword := fmt.Sprintf("host=%s port=%d user=%s dbname=%s sslmode=%s",
 | 
			
		||||
			cfg.Host, cfg.Port, cfg.User, cfg.Name, cfg.SSLMode)
 | 
			
		||||
		connConfig, err = pgx.ParseConfig(dsnWithoutPassword)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("postgres: failed to parse config with sslmode: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
		// Set password separately after parsing
 | 
			
		||||
		connConfig.Password = cfg.Pass
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return stdlib.GetConnector(*connConfig), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// createPostgresConnectorFromFlat creates a PostgreSQL connector from flat config format
 | 
			
		||||
func createPostgresConnectorFromFlat(cfg *Config) (driver.Connector, error) {
 | 
			
		||||
	pgCfg := &PostgresConfig{
 | 
			
		||||
		User:    cfg.User,
 | 
			
		||||
		Pass:    cfg.Pass,
 | 
			
		||||
		Host:    cfg.Host,
 | 
			
		||||
		Port:    cfg.Port,
 | 
			
		||||
		Name:    cfg.Name,
 | 
			
		||||
		SSLMode: cfg.SSLMode,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Set defaults for PostgreSQL
 | 
			
		||||
	if pgCfg.Host == "" {
 | 
			
		||||
		pgCfg.Host = "localhost"
 | 
			
		||||
	}
 | 
			
		||||
	if pgCfg.Port == 0 {
 | 
			
		||||
		pgCfg.Port = 5432
 | 
			
		||||
	}
 | 
			
		||||
	if pgCfg.SSLMode == "" {
 | 
			
		||||
		pgCfg.SSLMode = "prefer"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return createPostgresConnector(pgCfg)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										117
									
								
								database/integration_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								database/integration_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,117 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Mock types for testing SQLC integration patterns
 | 
			
		||||
type mockQueries struct {
 | 
			
		||||
	db DBTX
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type mockQueriesTx struct {
 | 
			
		||||
	*mockQueries
 | 
			
		||||
	tx *sql.Tx
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Mock the Begin method pattern that SQLC generates
 | 
			
		||||
func (q *mockQueries) Begin(ctx context.Context) (*mockQueriesTx, error) {
 | 
			
		||||
	// This would normally be: tx, err := q.db.(*sql.DB).BeginTx(ctx, nil)
 | 
			
		||||
	// For our test, we return a mock
 | 
			
		||||
	return &mockQueriesTx{mockQueries: q, tx: nil}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (qtx *mockQueriesTx) Commit(ctx context.Context) error {
 | 
			
		||||
	return nil // Mock implementation
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (qtx *mockQueriesTx) Rollback(ctx context.Context) error {
 | 
			
		||||
	return nil // Mock implementation
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This test verifies that our common database interfaces are compatible with SQLC-generated code
 | 
			
		||||
func TestSQLCIntegration(t *testing.T) {
 | 
			
		||||
	// Test that SQLC's DBTX interface matches our DBTX interface
 | 
			
		||||
	t.Run("DBTX Interface Compatibility", func(t *testing.T) {
 | 
			
		||||
		// Test interface compatibility by assignment without execution
 | 
			
		||||
		var ourDBTX DBTX
 | 
			
		||||
 | 
			
		||||
		// Test with sql.DB (should implement DBTX)
 | 
			
		||||
		var db *sql.DB
 | 
			
		||||
		ourDBTX = db // This will compile only if interfaces are compatible
 | 
			
		||||
		_ = ourDBTX  // Use the variable to avoid "unused" warning
 | 
			
		||||
 | 
			
		||||
		// Test with sql.Tx (should implement DBTX)
 | 
			
		||||
		var tx *sql.Tx
 | 
			
		||||
		ourDBTX = tx // This will compile only if interfaces are compatible
 | 
			
		||||
		_ = ourDBTX  // Use the variable to avoid "unused" warning
 | 
			
		||||
 | 
			
		||||
		// If we reach here, interfaces are compatible
 | 
			
		||||
		t.Log("DBTX interface is compatible with sql.DB and sql.Tx")
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Transaction Interface Compatibility", func(t *testing.T) {
 | 
			
		||||
		// This test verifies our transaction interfaces work with SQLC patterns
 | 
			
		||||
		// We can't define methods inside a function, so we test interface compatibility
 | 
			
		||||
 | 
			
		||||
		// Verify our DB interface is compatible with what SQLC expects
 | 
			
		||||
		var dbInterface DB[*mockQueriesTx]
 | 
			
		||||
		var mockDB *mockQueries = &mockQueries{}
 | 
			
		||||
		dbInterface = mockDB
 | 
			
		||||
 | 
			
		||||
		// Test that our transaction helper can work with this pattern
 | 
			
		||||
		err := WithTransaction(context.Background(), dbInterface, func(ctx context.Context, qtx *mockQueriesTx) error {
 | 
			
		||||
			// This would be where you'd call SQLC-generated query methods
 | 
			
		||||
			return nil
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Errorf("Transaction helper failed: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Test that demonstrates how the common package would be used with real SQLC patterns
 | 
			
		||||
func TestRealWorldUsagePattern(t *testing.T) {
 | 
			
		||||
	// This test shows how a package would typically use our common database code
 | 
			
		||||
 | 
			
		||||
	t.Run("Database Opening Pattern", func(t *testing.T) {
 | 
			
		||||
		// Test that our configuration options work as expected
 | 
			
		||||
		opts := DefaultConfigOptions()
 | 
			
		||||
 | 
			
		||||
		// Modify for test environment (no actual database connection)
 | 
			
		||||
		opts.ConfigFiles = []string{}   // No config files for unit test
 | 
			
		||||
		opts.PrometheusRegisterer = nil // No metrics for unit test
 | 
			
		||||
 | 
			
		||||
		// This would normally open a database: db, err := OpenDB(ctx, opts)
 | 
			
		||||
		// For our unit test, we just verify the options are reasonable
 | 
			
		||||
		if opts.MaxOpenConns <= 0 {
 | 
			
		||||
			t.Error("MaxOpenConns should be positive")
 | 
			
		||||
		}
 | 
			
		||||
		if opts.MaxIdleConns <= 0 {
 | 
			
		||||
			t.Error("MaxIdleConns should be positive")
 | 
			
		||||
		}
 | 
			
		||||
		if opts.ConnMaxLifetime <= 0 {
 | 
			
		||||
			t.Error("ConnMaxLifetime should be positive")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Monitor Package Configuration", func(t *testing.T) {
 | 
			
		||||
		opts := MonitorConfigOptions()
 | 
			
		||||
 | 
			
		||||
		// Verify monitor-specific settings
 | 
			
		||||
		if opts.EnablePoolMonitoring {
 | 
			
		||||
			t.Error("Monitor package should not enable pool monitoring")
 | 
			
		||||
		}
 | 
			
		||||
		if opts.PrometheusRegisterer != nil {
 | 
			
		||||
			t.Error("Monitor package should not have Prometheus registerer")
 | 
			
		||||
		}
 | 
			
		||||
		if opts.MaxOpenConns != 10 {
 | 
			
		||||
			t.Errorf("Expected MaxOpenConns=10 for monitor, got %d", opts.MaxOpenConns)
 | 
			
		||||
		}
 | 
			
		||||
		if opts.MaxIdleConns != 5 {
 | 
			
		||||
			t.Errorf("Expected MaxIdleConns=5 for monitor, got %d", opts.MaxIdleConns)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										34
									
								
								database/interfaces.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								database/interfaces.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,34 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// DBTX matches the interface expected by SQLC-generated code
 | 
			
		||||
// This interface is implemented by both *sql.DB and *sql.Tx
 | 
			
		||||
type DBTX interface {
 | 
			
		||||
	ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
 | 
			
		||||
	PrepareContext(context.Context, string) (*sql.Stmt, error)
 | 
			
		||||
	QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
 | 
			
		||||
	QueryRowContext(context.Context, string, ...interface{}) *sql.Row
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BaseQuerier provides basic query functionality
 | 
			
		||||
// This interface should be implemented by package-specific Queries types
 | 
			
		||||
type BaseQuerier interface {
 | 
			
		||||
	WithTx(tx *sql.Tx) BaseQuerier
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BaseQuerierTx provides transaction functionality
 | 
			
		||||
// This interface should be implemented by package-specific Queries types
 | 
			
		||||
type BaseQuerierTx interface {
 | 
			
		||||
	BaseQuerier
 | 
			
		||||
	Begin(ctx context.Context) (BaseQuerierTx, error)
 | 
			
		||||
	Commit(ctx context.Context) error
 | 
			
		||||
	Rollback(ctx context.Context) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TransactionFunc represents a function that operates within a database transaction
 | 
			
		||||
// This is used by the shared transaction helpers in transaction.go
 | 
			
		||||
type TransactionFunc[Q any] func(ctx context.Context, q Q) error
 | 
			
		||||
							
								
								
									
										93
									
								
								database/metrics.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								database/metrics.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,93 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/prometheus/client_golang/prometheus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// DatabaseMetrics holds the Prometheus metrics for database connection pool monitoring
 | 
			
		||||
type DatabaseMetrics struct {
 | 
			
		||||
	ConnectionsOpen         prometheus.Gauge
 | 
			
		||||
	ConnectionsIdle         prometheus.Gauge
 | 
			
		||||
	ConnectionsInUse        prometheus.Gauge
 | 
			
		||||
	ConnectionsWaitCount    prometheus.Counter
 | 
			
		||||
	ConnectionsWaitDuration prometheus.Histogram
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewDatabaseMetrics creates a new set of database metrics and registers them
 | 
			
		||||
func NewDatabaseMetrics(registerer prometheus.Registerer) *DatabaseMetrics {
 | 
			
		||||
	metrics := &DatabaseMetrics{
 | 
			
		||||
		ConnectionsOpen: prometheus.NewGauge(prometheus.GaugeOpts{
 | 
			
		||||
			Name: "database_connections_open",
 | 
			
		||||
			Help: "Number of open database connections",
 | 
			
		||||
		}),
 | 
			
		||||
		ConnectionsIdle: prometheus.NewGauge(prometheus.GaugeOpts{
 | 
			
		||||
			Name: "database_connections_idle",
 | 
			
		||||
			Help: "Number of idle database connections",
 | 
			
		||||
		}),
 | 
			
		||||
		ConnectionsInUse: prometheus.NewGauge(prometheus.GaugeOpts{
 | 
			
		||||
			Name: "database_connections_in_use",
 | 
			
		||||
			Help: "Number of database connections in use",
 | 
			
		||||
		}),
 | 
			
		||||
		ConnectionsWaitCount: prometheus.NewCounter(prometheus.CounterOpts{
 | 
			
		||||
			Name: "database_connections_wait_count_total",
 | 
			
		||||
			Help: "Total number of times a connection had to wait",
 | 
			
		||||
		}),
 | 
			
		||||
		ConnectionsWaitDuration: prometheus.NewHistogram(prometheus.HistogramOpts{
 | 
			
		||||
			Name:    "database_connections_wait_duration_seconds",
 | 
			
		||||
			Help:    "Time spent waiting for a database connection",
 | 
			
		||||
			Buckets: prometheus.DefBuckets,
 | 
			
		||||
		}),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if registerer != nil {
 | 
			
		||||
		registerer.MustRegister(
 | 
			
		||||
			metrics.ConnectionsOpen,
 | 
			
		||||
			metrics.ConnectionsIdle,
 | 
			
		||||
			metrics.ConnectionsInUse,
 | 
			
		||||
			metrics.ConnectionsWaitCount,
 | 
			
		||||
			metrics.ConnectionsWaitDuration,
 | 
			
		||||
		)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return metrics
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// monitorConnectionPool runs a background goroutine to collect connection pool metrics
 | 
			
		||||
func monitorConnectionPool(ctx context.Context, db *sql.DB, registerer prometheus.Registerer) {
 | 
			
		||||
	if registerer == nil {
 | 
			
		||||
		return // No metrics collection if no registerer provided
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	metrics := NewDatabaseMetrics(registerer)
 | 
			
		||||
	ticker := time.NewTicker(30 * time.Second)
 | 
			
		||||
	defer ticker.Stop()
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-ctx.Done():
 | 
			
		||||
			return
 | 
			
		||||
		case <-ticker.C:
 | 
			
		||||
			stats := db.Stats()
 | 
			
		||||
 | 
			
		||||
			metrics.ConnectionsOpen.Set(float64(stats.OpenConnections))
 | 
			
		||||
			metrics.ConnectionsIdle.Set(float64(stats.Idle))
 | 
			
		||||
			metrics.ConnectionsInUse.Set(float64(stats.InUse))
 | 
			
		||||
			metrics.ConnectionsWaitCount.Add(float64(stats.WaitCount))
 | 
			
		||||
 | 
			
		||||
			if stats.WaitDuration > 0 {
 | 
			
		||||
				metrics.ConnectionsWaitDuration.Observe(stats.WaitDuration.Seconds())
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Log connection pool stats for high usage or waiting
 | 
			
		||||
			if stats.OpenConnections > 20 || stats.WaitCount > 0 {
 | 
			
		||||
				fmt.Printf("Connection pool stats: open=%d idle=%d in_use=%d wait_count=%d wait_duration=%s\n",
 | 
			
		||||
					stats.OpenConnections, stats.Idle, stats.InUse, stats.WaitCount, stats.WaitDuration)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										120
									
								
								database/pgdb/CLAUDE.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								database/pgdb/CLAUDE.md
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,120 @@
 | 
			
		||||
# pgdb - Native PostgreSQL Connection Pool
 | 
			
		||||
 | 
			
		||||
Primary package for PostgreSQL connections using native pgx pool (`*pgxpool.Pool`). Provides better performance and PostgreSQL-specific features compared to `database/sql`.
 | 
			
		||||
 | 
			
		||||
## Usage
 | 
			
		||||
 | 
			
		||||
### Basic Example
 | 
			
		||||
 | 
			
		||||
```go
 | 
			
		||||
import (
 | 
			
		||||
    "context"
 | 
			
		||||
    "go.ntppool.org/common/database/pgdb"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
    ctx := context.Background()
 | 
			
		||||
 | 
			
		||||
    // Open pool with default options
 | 
			
		||||
    pool, err := pgdb.OpenPool(ctx, pgdb.DefaultPoolOptions())
 | 
			
		||||
    if err != nil {
 | 
			
		||||
        log.Fatal(err)
 | 
			
		||||
    }
 | 
			
		||||
    defer pool.Close()
 | 
			
		||||
 | 
			
		||||
    // Use the pool for queries
 | 
			
		||||
    row := pool.QueryRow(ctx, "SELECT version()")
 | 
			
		||||
    var version string
 | 
			
		||||
    row.Scan(&version)
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### With Custom Config File
 | 
			
		||||
 | 
			
		||||
```go
 | 
			
		||||
pool, err := pgdb.OpenPoolWithConfigFile(ctx, "/path/to/database.yaml")
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### With Custom Pool Settings
 | 
			
		||||
 | 
			
		||||
```go
 | 
			
		||||
opts := pgdb.DefaultPoolOptions()
 | 
			
		||||
opts.MaxConns = 50
 | 
			
		||||
opts.MinConns = 5
 | 
			
		||||
opts.MaxConnLifetime = 2 * time.Hour
 | 
			
		||||
 | 
			
		||||
pool, err := pgdb.OpenPool(ctx, opts)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Configuration Format
 | 
			
		||||
 | 
			
		||||
### Recommended: Nested Format (database.yaml)
 | 
			
		||||
 | 
			
		||||
```yaml
 | 
			
		||||
postgres:
 | 
			
		||||
  host: localhost
 | 
			
		||||
  port: 5432
 | 
			
		||||
  user: myuser
 | 
			
		||||
  pass: mypassword
 | 
			
		||||
  name: mydb
 | 
			
		||||
  sslmode: prefer
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Legacy: Flat Format (backward compatible)
 | 
			
		||||
 | 
			
		||||
```yaml
 | 
			
		||||
host: localhost
 | 
			
		||||
port: 5432
 | 
			
		||||
user: myuser
 | 
			
		||||
pass: mypassword
 | 
			
		||||
name: mydb
 | 
			
		||||
sslmode: prefer
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Configuration Options
 | 
			
		||||
 | 
			
		||||
### PoolOptions
 | 
			
		||||
 | 
			
		||||
- `ConfigFiles` - List of config file paths to search (default: `database.yaml`, `/vault/secrets/database.yaml`)
 | 
			
		||||
- `MinConns` - Minimum connections (default: 0)
 | 
			
		||||
- `MaxConns` - Maximum connections (default: 25)
 | 
			
		||||
- `MaxConnLifetime` - Connection lifetime (default: 1 hour)
 | 
			
		||||
- `MaxConnIdleTime` - Idle timeout (default: 30 minutes)
 | 
			
		||||
- `HealthCheckPeriod` - Health check interval (default: 1 minute)
 | 
			
		||||
 | 
			
		||||
### PostgreSQL Config Fields
 | 
			
		||||
 | 
			
		||||
- `host` - Database host (required)
 | 
			
		||||
- `user` - Database user (required)
 | 
			
		||||
- `pass` - Database password
 | 
			
		||||
- `name` - Database name (required)
 | 
			
		||||
- `port` - Port number (default: 5432)
 | 
			
		||||
- `sslmode` - SSL mode: `disable`, `allow`, `prefer`, `require`, `verify-ca`, `verify-full` (default: `prefer`)
 | 
			
		||||
 | 
			
		||||
## Environment Variables
 | 
			
		||||
 | 
			
		||||
- `DATABASE_CONFIG_FILE` - Override config file location
 | 
			
		||||
 | 
			
		||||
## When to Use
 | 
			
		||||
 | 
			
		||||
**Use `pgdb.OpenPool()`** (this package) when:
 | 
			
		||||
- You need native PostgreSQL features (LISTEN/NOTIFY, COPY, etc.)
 | 
			
		||||
- You want better performance
 | 
			
		||||
- You're writing new PostgreSQL code
 | 
			
		||||
 | 
			
		||||
**Use `database.OpenDB()`** (sql.DB) when:
 | 
			
		||||
- You need database-agnostic code
 | 
			
		||||
- You're using SQLC or other tools that expect `database/sql`
 | 
			
		||||
- You need to support both MySQL and PostgreSQL
 | 
			
		||||
 | 
			
		||||
## Security
 | 
			
		||||
 | 
			
		||||
This package avoids password exposure by:
 | 
			
		||||
1. Never constructing DSN strings with passwords
 | 
			
		||||
2. Setting passwords separately in pgx config objects
 | 
			
		||||
3. Validating all configuration before connection
 | 
			
		||||
 | 
			
		||||
## See Also
 | 
			
		||||
 | 
			
		||||
- `database/` - Generic sql.DB support for MySQL and PostgreSQL
 | 
			
		||||
- pgx documentation: https://github.com/jackc/pgx
 | 
			
		||||
							
								
								
									
										64
									
								
								database/pgdb/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								database/pgdb/config.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,64 @@
 | 
			
		||||
package pgdb
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"github.com/jackc/pgx/v5/pgxpool"
 | 
			
		||||
	"go.ntppool.org/common/database"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// CreatePoolConfig converts database.PostgresConfig to pgxpool.Config
 | 
			
		||||
// This is the secure way to create a config without exposing passwords in DSN strings
 | 
			
		||||
func CreatePoolConfig(cfg *database.PostgresConfig) (*pgxpool.Config, error) {
 | 
			
		||||
	// Validate required fields
 | 
			
		||||
	if cfg.Host == "" {
 | 
			
		||||
		return nil, fmt.Errorf("postgres: host is required")
 | 
			
		||||
	}
 | 
			
		||||
	if cfg.User == "" {
 | 
			
		||||
		return nil, fmt.Errorf("postgres: user is required")
 | 
			
		||||
	}
 | 
			
		||||
	if cfg.Name == "" {
 | 
			
		||||
		return nil, fmt.Errorf("postgres: database name is required")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Validate SSLMode
 | 
			
		||||
	validSSLModes := map[string]bool{
 | 
			
		||||
		"disable": true, "allow": true, "prefer": true,
 | 
			
		||||
		"require": true, "verify-ca": true, "verify-full": true,
 | 
			
		||||
	}
 | 
			
		||||
	if cfg.SSLMode != "" && !validSSLModes[cfg.SSLMode] {
 | 
			
		||||
		return nil, fmt.Errorf("postgres: invalid sslmode: %s", cfg.SSLMode)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Set defaults
 | 
			
		||||
	host := cfg.Host
 | 
			
		||||
	if host == "" {
 | 
			
		||||
		host = "localhost"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	port := cfg.Port
 | 
			
		||||
	if port == 0 {
 | 
			
		||||
		port = 5432
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sslmode := cfg.SSLMode
 | 
			
		||||
	if sslmode == "" {
 | 
			
		||||
		sslmode = "prefer"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Build connection string WITHOUT password
 | 
			
		||||
	// We'll set the password separately in the config
 | 
			
		||||
	connString := fmt.Sprintf("host=%s port=%d user=%s dbname=%s sslmode=%s",
 | 
			
		||||
		host, port, cfg.User, cfg.Name, sslmode)
 | 
			
		||||
 | 
			
		||||
	// Parse the connection string
 | 
			
		||||
	poolConfig, err := pgxpool.ParseConfig(connString)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("postgres: failed to parse connection config: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Set password separately (security: never put password in the connection string)
 | 
			
		||||
	poolConfig.ConnConfig.Password = cfg.Pass
 | 
			
		||||
 | 
			
		||||
	return poolConfig, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										173
									
								
								database/pgdb/pool.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								database/pgdb/pool.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,173 @@
 | 
			
		||||
package pgdb
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/jackc/pgx/v5/pgxpool"
 | 
			
		||||
	"go.ntppool.org/common/database"
 | 
			
		||||
	"gopkg.in/yaml.v3"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// PoolOptions configures pgxpool connection behavior
 | 
			
		||||
type PoolOptions struct {
 | 
			
		||||
	// ConfigFiles is a list of config file paths to search for database configuration
 | 
			
		||||
	ConfigFiles []string
 | 
			
		||||
 | 
			
		||||
	// MinConns is the minimum number of connections in the pool
 | 
			
		||||
	// Default: 0 (no minimum)
 | 
			
		||||
	MinConns int32
 | 
			
		||||
 | 
			
		||||
	// MaxConns is the maximum number of connections in the pool
 | 
			
		||||
	// Default: 25
 | 
			
		||||
	MaxConns int32
 | 
			
		||||
 | 
			
		||||
	// MaxConnLifetime is the maximum lifetime of a connection
 | 
			
		||||
	// Default: 1 hour
 | 
			
		||||
	MaxConnLifetime time.Duration
 | 
			
		||||
 | 
			
		||||
	// MaxConnIdleTime is the maximum idle time of a connection
 | 
			
		||||
	// Default: 30 minutes
 | 
			
		||||
	MaxConnIdleTime time.Duration
 | 
			
		||||
 | 
			
		||||
	// HealthCheckPeriod is how often to check connection health
 | 
			
		||||
	// Default: 1 minute
 | 
			
		||||
	HealthCheckPeriod time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DefaultPoolOptions returns sensible defaults for pgxpool
 | 
			
		||||
func DefaultPoolOptions() PoolOptions {
 | 
			
		||||
	return PoolOptions{
 | 
			
		||||
		ConfigFiles:       getConfigFiles(),
 | 
			
		||||
		MinConns:          0,
 | 
			
		||||
		MaxConns:          25,
 | 
			
		||||
		MaxConnLifetime:   time.Hour,
 | 
			
		||||
		MaxConnIdleTime:   30 * time.Minute,
 | 
			
		||||
		HealthCheckPeriod: time.Minute,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// OpenPool opens a native pgx connection pool with the specified configuration
 | 
			
		||||
// This is the primary and recommended way to connect to PostgreSQL
 | 
			
		||||
func OpenPool(ctx context.Context, options PoolOptions) (*pgxpool.Pool, error) {
 | 
			
		||||
	// Find and read config file
 | 
			
		||||
	pgCfg, err := findAndParseConfig(options.ConfigFiles)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Create pool config from PostgreSQL config
 | 
			
		||||
	poolConfig, err := CreatePoolConfig(pgCfg)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Apply pool-specific settings
 | 
			
		||||
	poolConfig.MinConns = options.MinConns
 | 
			
		||||
	poolConfig.MaxConns = options.MaxConns
 | 
			
		||||
	poolConfig.MaxConnLifetime = options.MaxConnLifetime
 | 
			
		||||
	poolConfig.MaxConnIdleTime = options.MaxConnIdleTime
 | 
			
		||||
	poolConfig.HealthCheckPeriod = options.HealthCheckPeriod
 | 
			
		||||
 | 
			
		||||
	// Create the pool
 | 
			
		||||
	pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("failed to create connection pool: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test the connection
 | 
			
		||||
	if err := pool.Ping(ctx); err != nil {
 | 
			
		||||
		pool.Close()
 | 
			
		||||
		return nil, fmt.Errorf("failed to ping database: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return pool, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// OpenPoolWithConfigFile opens a connection pool using an explicit config file path
 | 
			
		||||
// This is a convenience function for when you have a specific config file
 | 
			
		||||
func OpenPoolWithConfigFile(ctx context.Context, configFile string) (*pgxpool.Pool, error) {
 | 
			
		||||
	options := DefaultPoolOptions()
 | 
			
		||||
	options.ConfigFiles = []string{configFile}
 | 
			
		||||
	return OpenPool(ctx, options)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// findAndParseConfig searches for and parses the first existing config file
 | 
			
		||||
func findAndParseConfig(configFiles []string) (*database.PostgresConfig, error) {
 | 
			
		||||
	var firstErr error
 | 
			
		||||
 | 
			
		||||
	for _, configFile := range configFiles {
 | 
			
		||||
		if configFile == "" {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Check if file exists
 | 
			
		||||
		if _, err := os.Stat(configFile); err != nil {
 | 
			
		||||
			if firstErr == nil {
 | 
			
		||||
				firstErr = err
 | 
			
		||||
			}
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Try to read and parse the file
 | 
			
		||||
		pgCfg, err := parseConfigFile(configFile)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if firstErr == nil {
 | 
			
		||||
				firstErr = err
 | 
			
		||||
			}
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return pgCfg, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if firstErr != nil {
 | 
			
		||||
		return nil, fmt.Errorf("no config file found: %w", firstErr)
 | 
			
		||||
	}
 | 
			
		||||
	return nil, fmt.Errorf("no valid config files provided")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// parseConfigFile reads and parses a YAML config file
 | 
			
		||||
func parseConfigFile(configFile string) (*database.PostgresConfig, error) {
 | 
			
		||||
	file, err := os.Open(configFile)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("failed to open config file: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	defer file.Close()
 | 
			
		||||
 | 
			
		||||
	dec := yaml.NewDecoder(file)
 | 
			
		||||
	cfg := database.Config{}
 | 
			
		||||
 | 
			
		||||
	if err := dec.Decode(&cfg); err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("failed to decode config: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Extract PostgreSQL config
 | 
			
		||||
	if cfg.Postgres != nil {
 | 
			
		||||
		return cfg.Postgres, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check for legacy flat format
 | 
			
		||||
	if cfg.User != "" && cfg.Name != "" {
 | 
			
		||||
		return &database.PostgresConfig{
 | 
			
		||||
			User:    cfg.User,
 | 
			
		||||
			Pass:    cfg.Pass,
 | 
			
		||||
			Host:    cfg.Host,
 | 
			
		||||
			Port:    cfg.Port,
 | 
			
		||||
			Name:    cfg.Name,
 | 
			
		||||
			SSLMode: cfg.SSLMode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, fmt.Errorf("no PostgreSQL configuration found in %s", configFile)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// getConfigFiles returns the list of config files to search
 | 
			
		||||
func getConfigFiles() []string {
 | 
			
		||||
	if configFile := os.Getenv("DATABASE_CONFIG_FILE"); configFile != "" {
 | 
			
		||||
		return []string{configFile}
 | 
			
		||||
	}
 | 
			
		||||
	return []string{"database.yaml", "/vault/secrets/database.yaml"}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										151
									
								
								database/pgdb/pool_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										151
									
								
								database/pgdb/pool_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,151 @@
 | 
			
		||||
package pgdb
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"go.ntppool.org/common/database"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestCreatePoolConfig(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name    string
 | 
			
		||||
		cfg     *database.PostgresConfig
 | 
			
		||||
		wantErr bool
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name: "valid config",
 | 
			
		||||
			cfg: &database.PostgresConfig{
 | 
			
		||||
				Host:    "localhost",
 | 
			
		||||
				Port:    5432,
 | 
			
		||||
				User:    "testuser",
 | 
			
		||||
				Pass:    "testpass",
 | 
			
		||||
				Name:    "testdb",
 | 
			
		||||
				SSLMode: "require",
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "valid config with defaults",
 | 
			
		||||
			cfg: &database.PostgresConfig{
 | 
			
		||||
				Host: "localhost",
 | 
			
		||||
				User: "testuser",
 | 
			
		||||
				Pass: "testpass",
 | 
			
		||||
				Name: "testdb",
 | 
			
		||||
				// Port and SSLMode will use defaults
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "missing host",
 | 
			
		||||
			cfg: &database.PostgresConfig{
 | 
			
		||||
				User: "testuser",
 | 
			
		||||
				Pass: "testpass",
 | 
			
		||||
				Name: "testdb",
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "missing user",
 | 
			
		||||
			cfg: &database.PostgresConfig{
 | 
			
		||||
				Host: "localhost",
 | 
			
		||||
				Pass: "testpass",
 | 
			
		||||
				Name: "testdb",
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "missing database name",
 | 
			
		||||
			cfg: &database.PostgresConfig{
 | 
			
		||||
				Host: "localhost",
 | 
			
		||||
				User: "testuser",
 | 
			
		||||
				Pass: "testpass",
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "invalid sslmode",
 | 
			
		||||
			cfg: &database.PostgresConfig{
 | 
			
		||||
				Host:    "localhost",
 | 
			
		||||
				User:    "testuser",
 | 
			
		||||
				Pass:    "testpass",
 | 
			
		||||
				Name:    "testdb",
 | 
			
		||||
				SSLMode: "invalid",
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: true,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			poolCfg, err := CreatePoolConfig(tt.cfg)
 | 
			
		||||
			if (err != nil) != tt.wantErr {
 | 
			
		||||
				t.Errorf("CreatePoolConfig() error = %v, wantErr %v", err, tt.wantErr)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if !tt.wantErr && poolCfg == nil {
 | 
			
		||||
				t.Error("CreatePoolConfig() returned nil config without error")
 | 
			
		||||
			}
 | 
			
		||||
			if !tt.wantErr && poolCfg != nil {
 | 
			
		||||
				// Verify config fields are set correctly
 | 
			
		||||
				if poolCfg.ConnConfig.Host != tt.cfg.Host && tt.cfg.Host != "" {
 | 
			
		||||
					t.Errorf("Expected Host=%s, got %s", tt.cfg.Host, poolCfg.ConnConfig.Host)
 | 
			
		||||
				}
 | 
			
		||||
				if poolCfg.ConnConfig.User != tt.cfg.User {
 | 
			
		||||
					t.Errorf("Expected User=%s, got %s", tt.cfg.User, poolCfg.ConnConfig.User)
 | 
			
		||||
				}
 | 
			
		||||
				if poolCfg.ConnConfig.Password != tt.cfg.Pass {
 | 
			
		||||
					t.Errorf("Expected Password to be set correctly")
 | 
			
		||||
				}
 | 
			
		||||
				if poolCfg.ConnConfig.Database != tt.cfg.Name {
 | 
			
		||||
					t.Errorf("Expected Database=%s, got %s", tt.cfg.Name, poolCfg.ConnConfig.Database)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDefaultPoolOptions(t *testing.T) {
 | 
			
		||||
	opts := DefaultPoolOptions()
 | 
			
		||||
 | 
			
		||||
	// Verify expected defaults
 | 
			
		||||
	if opts.MinConns != 0 {
 | 
			
		||||
		t.Errorf("Expected MinConns=0, got %d", opts.MinConns)
 | 
			
		||||
	}
 | 
			
		||||
	if opts.MaxConns != 25 {
 | 
			
		||||
		t.Errorf("Expected MaxConns=25, got %d", opts.MaxConns)
 | 
			
		||||
	}
 | 
			
		||||
	if opts.MaxConnLifetime != time.Hour {
 | 
			
		||||
		t.Errorf("Expected MaxConnLifetime=1h, got %v", opts.MaxConnLifetime)
 | 
			
		||||
	}
 | 
			
		||||
	if opts.MaxConnIdleTime != 30*time.Minute {
 | 
			
		||||
		t.Errorf("Expected MaxConnIdleTime=30m, got %v", opts.MaxConnIdleTime)
 | 
			
		||||
	}
 | 
			
		||||
	if opts.HealthCheckPeriod != time.Minute {
 | 
			
		||||
		t.Errorf("Expected HealthCheckPeriod=1m, got %v", opts.HealthCheckPeriod)
 | 
			
		||||
	}
 | 
			
		||||
	if len(opts.ConfigFiles) == 0 {
 | 
			
		||||
		t.Error("Expected ConfigFiles to be non-empty")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCreatePoolConfigDefaults(t *testing.T) {
 | 
			
		||||
	// Test that defaults are applied correctly
 | 
			
		||||
	cfg := &database.PostgresConfig{
 | 
			
		||||
		Host: "localhost",
 | 
			
		||||
		User: "testuser",
 | 
			
		||||
		Pass: "testpass",
 | 
			
		||||
		Name: "testdb",
 | 
			
		||||
		// Port and SSLMode not set
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	poolCfg, err := CreatePoolConfig(cfg)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("CreatePoolConfig() failed: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Verify defaults were applied
 | 
			
		||||
	if poolCfg.ConnConfig.Port != 5432 {
 | 
			
		||||
		t.Errorf("Expected default Port=5432, got %d", poolCfg.ConnConfig.Port)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										78
									
								
								database/pool.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								database/pool.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,78 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
 | 
			
		||||
	"go.ntppool.org/common/logger"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// OpenDB opens a database connection with the specified configuration options
 | 
			
		||||
func OpenDB(ctx context.Context, options ConfigOptions) (*sql.DB, error) {
 | 
			
		||||
	log := logger.Setup()
 | 
			
		||||
 | 
			
		||||
	configFile, err := findConfigFile(options.ConfigFiles)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dbconn := sql.OpenDB(Driver{
 | 
			
		||||
		CreateConnectorFunc: createConnector(configFile),
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	// Set connection pool parameters
 | 
			
		||||
	dbconn.SetConnMaxLifetime(options.ConnMaxLifetime)
 | 
			
		||||
	dbconn.SetMaxOpenConns(options.MaxOpenConns)
 | 
			
		||||
	dbconn.SetMaxIdleConns(options.MaxIdleConns)
 | 
			
		||||
 | 
			
		||||
	err = dbconn.Ping()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Error("could not connect to database", "err", err)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Start optional connection pool monitoring
 | 
			
		||||
	if options.EnablePoolMonitoring && options.PrometheusRegisterer != nil {
 | 
			
		||||
		go monitorConnectionPool(ctx, dbconn, options.PrometheusRegisterer)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return dbconn, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// OpenDBWithConfigFile opens a database connection using an explicit config file path
 | 
			
		||||
// This is a convenience function for API package compatibility
 | 
			
		||||
func OpenDBWithConfigFile(ctx context.Context, configFile string) (*sql.DB, error) {
 | 
			
		||||
	options := DefaultConfigOptions()
 | 
			
		||||
	options.ConfigFiles = []string{configFile}
 | 
			
		||||
	return OpenDB(ctx, options)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// OpenDBMonitor opens a database connection with monitor-specific defaults
 | 
			
		||||
// This is a convenience function for Monitor package compatibility
 | 
			
		||||
func OpenDBMonitor() (*sql.DB, error) {
 | 
			
		||||
	options := MonitorConfigOptions()
 | 
			
		||||
	return OpenDB(context.Background(), options)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// findConfigFile searches for the first existing config file from the list
 | 
			
		||||
func findConfigFile(configFiles []string) (string, error) {
 | 
			
		||||
	var firstErr error
 | 
			
		||||
 | 
			
		||||
	for _, configFile := range configFiles {
 | 
			
		||||
		if configFile == "" {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if _, err := os.Stat(configFile); err == nil {
 | 
			
		||||
			return configFile, nil
 | 
			
		||||
		} else if firstErr == nil {
 | 
			
		||||
			firstErr = err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if firstErr != nil {
 | 
			
		||||
		return "", fmt.Errorf("no config file found: %w", firstErr)
 | 
			
		||||
	}
 | 
			
		||||
	return "", fmt.Errorf("no valid config files provided")
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										69
									
								
								database/transaction.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								database/transaction.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,69 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"go.ntppool.org/common/logger"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// DB interface for database operations that can begin transactions
 | 
			
		||||
type DB[Q any] interface {
 | 
			
		||||
	Begin(ctx context.Context) (Q, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TX interface for transaction operations
 | 
			
		||||
type TX interface {
 | 
			
		||||
	Commit(ctx context.Context) error
 | 
			
		||||
	Rollback(ctx context.Context) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithTransaction executes a function within a database transaction
 | 
			
		||||
// Handles proper rollback on error and commit on success
 | 
			
		||||
func WithTransaction[Q TX](ctx context.Context, db DB[Q], fn func(ctx context.Context, q Q) error) error {
 | 
			
		||||
	tx, err := db.Begin(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to begin transaction: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var committed bool
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if !committed {
 | 
			
		||||
			if rbErr := tx.Rollback(ctx); rbErr != nil {
 | 
			
		||||
				// Log rollback error but don't override original error
 | 
			
		||||
				log := logger.FromContext(ctx)
 | 
			
		||||
				log.ErrorContext(ctx, "failed to rollback transaction", "error", rbErr)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if err := fn(ctx, tx); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = tx.Commit(ctx)
 | 
			
		||||
	committed = true // Mark as committed regardless of commit success/failure
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to commit transaction: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithReadOnlyTransaction executes a read-only function within a transaction
 | 
			
		||||
// Always rolls back at the end (for consistent read isolation)
 | 
			
		||||
func WithReadOnlyTransaction[Q TX](ctx context.Context, db DB[Q], fn func(ctx context.Context, q Q) error) error {
 | 
			
		||||
	tx, err := db.Begin(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to begin read-only transaction: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if rbErr := tx.Rollback(ctx); rbErr != nil {
 | 
			
		||||
			log := logger.FromContext(ctx)
 | 
			
		||||
			log.ErrorContext(ctx, "failed to rollback read-only transaction", "error", rbErr)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	return fn(ctx, tx)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										69
									
								
								database/transaction_base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								database/transaction_base.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,69 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"go.ntppool.org/common/logger"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Shared interface definitions that both packages use identically
 | 
			
		||||
type BaseBeginner interface {
 | 
			
		||||
	Begin(context.Context) (sql.Tx, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type BaseTx interface {
 | 
			
		||||
	BaseBeginner
 | 
			
		||||
	Commit(ctx context.Context) error
 | 
			
		||||
	Rollback(ctx context.Context) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BeginTransactionForQuerier contains the shared Begin() logic from both packages
 | 
			
		||||
func BeginTransactionForQuerier(ctx context.Context, db DBTX) (DBTX, error) {
 | 
			
		||||
	if sqlDB, ok := db.(*sql.DB); ok {
 | 
			
		||||
		tx, err := sqlDB.BeginTx(ctx, &sql.TxOptions{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return tx, nil
 | 
			
		||||
	} else {
 | 
			
		||||
		// Handle transaction case
 | 
			
		||||
		if beginner, ok := db.(BaseBeginner); ok {
 | 
			
		||||
			tx, err := beginner.Begin(ctx)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
			return &tx, nil
 | 
			
		||||
		}
 | 
			
		||||
		return nil, fmt.Errorf("database connection does not support transactions")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CommitTransactionForQuerier contains the shared Commit() logic from both packages
 | 
			
		||||
func CommitTransactionForQuerier(ctx context.Context, db DBTX) error {
 | 
			
		||||
	if sqlTx, ok := db.(*sql.Tx); ok {
 | 
			
		||||
		return sqlTx.Commit()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx, ok := db.(BaseTx)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		log := logger.FromContext(ctx)
 | 
			
		||||
		log.ErrorContext(ctx, "could not get a Tx", "type", fmt.Sprintf("%T", db))
 | 
			
		||||
		return sql.ErrTxDone
 | 
			
		||||
	}
 | 
			
		||||
	return tx.Commit(ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RollbackTransactionForQuerier contains the shared Rollback() logic from both packages
 | 
			
		||||
func RollbackTransactionForQuerier(ctx context.Context, db DBTX) error {
 | 
			
		||||
	if sqlTx, ok := db.(*sql.Tx); ok {
 | 
			
		||||
		return sqlTx.Rollback()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx, ok := db.(BaseTx)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return sql.ErrTxDone
 | 
			
		||||
	}
 | 
			
		||||
	return tx.Rollback(ctx)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										157
									
								
								database/transaction_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										157
									
								
								database/transaction_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,157 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Mock implementations for testing
 | 
			
		||||
type mockDB struct {
 | 
			
		||||
	beginError error
 | 
			
		||||
	txMock     *mockTX
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mockDB) Begin(ctx context.Context) (*mockTX, error) {
 | 
			
		||||
	if m.beginError != nil {
 | 
			
		||||
		return nil, m.beginError
 | 
			
		||||
	}
 | 
			
		||||
	return m.txMock, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type mockTX struct {
 | 
			
		||||
	commitError    error
 | 
			
		||||
	rollbackError  error
 | 
			
		||||
	commitCalled   bool
 | 
			
		||||
	rollbackCalled bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mockTX) Commit(ctx context.Context) error {
 | 
			
		||||
	m.commitCalled = true
 | 
			
		||||
	return m.commitError
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mockTX) Rollback(ctx context.Context) error {
 | 
			
		||||
	m.rollbackCalled = true
 | 
			
		||||
	return m.rollbackError
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestWithTransaction_Success(t *testing.T) {
 | 
			
		||||
	tx := &mockTX{}
 | 
			
		||||
	db := &mockDB{txMock: tx}
 | 
			
		||||
 | 
			
		||||
	var functionCalled bool
 | 
			
		||||
	err := WithTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error {
 | 
			
		||||
		functionCalled = true
 | 
			
		||||
		if q != tx {
 | 
			
		||||
			t.Error("Expected transaction to be passed to function")
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Expected no error, got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if !functionCalled {
 | 
			
		||||
		t.Error("Expected function to be called")
 | 
			
		||||
	}
 | 
			
		||||
	if !tx.commitCalled {
 | 
			
		||||
		t.Error("Expected commit to be called")
 | 
			
		||||
	}
 | 
			
		||||
	if tx.rollbackCalled {
 | 
			
		||||
		t.Error("Expected rollback NOT to be called on success")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestWithTransaction_FunctionError(t *testing.T) {
 | 
			
		||||
	tx := &mockTX{}
 | 
			
		||||
	db := &mockDB{txMock: tx}
 | 
			
		||||
 | 
			
		||||
	expectedError := errors.New("function error")
 | 
			
		||||
	err := WithTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error {
 | 
			
		||||
		return expectedError
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if err != expectedError {
 | 
			
		||||
		t.Errorf("Expected error %v, got %v", expectedError, err)
 | 
			
		||||
	}
 | 
			
		||||
	if tx.commitCalled {
 | 
			
		||||
		t.Error("Expected commit NOT to be called on function error")
 | 
			
		||||
	}
 | 
			
		||||
	if !tx.rollbackCalled {
 | 
			
		||||
		t.Error("Expected rollback to be called on function error")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestWithTransaction_BeginError(t *testing.T) {
 | 
			
		||||
	expectedError := errors.New("begin error")
 | 
			
		||||
	db := &mockDB{beginError: expectedError}
 | 
			
		||||
 | 
			
		||||
	err := WithTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error {
 | 
			
		||||
		t.Error("Function should not be called when Begin fails")
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if err == nil || !errors.Is(err, expectedError) {
 | 
			
		||||
		t.Errorf("Expected wrapped begin error, got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestWithTransaction_CommitError(t *testing.T) {
 | 
			
		||||
	commitError := errors.New("commit error")
 | 
			
		||||
	tx := &mockTX{commitError: commitError}
 | 
			
		||||
	db := &mockDB{txMock: tx}
 | 
			
		||||
 | 
			
		||||
	err := WithTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error {
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if err == nil || !errors.Is(err, commitError) {
 | 
			
		||||
		t.Errorf("Expected wrapped commit error, got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if !tx.commitCalled {
 | 
			
		||||
		t.Error("Expected commit to be called")
 | 
			
		||||
	}
 | 
			
		||||
	if tx.rollbackCalled {
 | 
			
		||||
		t.Error("Expected rollback NOT to be called when commit fails")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestWithReadOnlyTransaction_Success(t *testing.T) {
 | 
			
		||||
	tx := &mockTX{}
 | 
			
		||||
	db := &mockDB{txMock: tx}
 | 
			
		||||
 | 
			
		||||
	var functionCalled bool
 | 
			
		||||
	err := WithReadOnlyTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error {
 | 
			
		||||
		functionCalled = true
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Expected no error, got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if !functionCalled {
 | 
			
		||||
		t.Error("Expected function to be called")
 | 
			
		||||
	}
 | 
			
		||||
	if tx.commitCalled {
 | 
			
		||||
		t.Error("Expected commit NOT to be called in read-only transaction")
 | 
			
		||||
	}
 | 
			
		||||
	if !tx.rollbackCalled {
 | 
			
		||||
		t.Error("Expected rollback to be called in read-only transaction")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestWithReadOnlyTransaction_FunctionError(t *testing.T) {
 | 
			
		||||
	tx := &mockTX{}
 | 
			
		||||
	db := &mockDB{txMock: tx}
 | 
			
		||||
 | 
			
		||||
	expectedError := errors.New("function error")
 | 
			
		||||
	err := WithReadOnlyTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error {
 | 
			
		||||
		return expectedError
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if err != expectedError {
 | 
			
		||||
		t.Errorf("Expected error %v, got %v", expectedError, err)
 | 
			
		||||
	}
 | 
			
		||||
	if !tx.rollbackCalled {
 | 
			
		||||
		t.Error("Expected rollback to be called")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										66
									
								
								ekko/ekko.go
									
									
									
									
									
								
							
							
						
						
									
										66
									
								
								ekko/ekko.go
									
									
									
									
									
								
							@@ -1,3 +1,32 @@
 | 
			
		||||
// Package ekko provides an enhanced Echo web framework wrapper with pre-configured middleware.
 | 
			
		||||
//
 | 
			
		||||
// This package wraps the Echo web framework with a comprehensive middleware stack including:
 | 
			
		||||
//   - OpenTelemetry distributed tracing with request context propagation
 | 
			
		||||
//   - Prometheus metrics collection with per-service subsystems
 | 
			
		||||
//   - Structured logging with trace ID correlation
 | 
			
		||||
//   - Security headers (HSTS, content security policy)
 | 
			
		||||
//   - Gzip compression for response optimization
 | 
			
		||||
//   - Recovery middleware with detailed error logging
 | 
			
		||||
//   - HTTP/2 support with H2C (HTTP/2 Cleartext) capability
 | 
			
		||||
//
 | 
			
		||||
// The package uses functional options pattern for flexible configuration
 | 
			
		||||
// and supports graceful shutdown with configurable timeouts. It's designed
 | 
			
		||||
// as the standard web service foundation for NTP Pool project services.
 | 
			
		||||
//
 | 
			
		||||
// Example usage:
 | 
			
		||||
//
 | 
			
		||||
//	ekko, err := ekko.New("myservice",
 | 
			
		||||
//		ekko.WithPort(8080),
 | 
			
		||||
//		ekko.WithPrometheus(prometheus.DefaultRegisterer),
 | 
			
		||||
//		ekko.WithEchoSetup(func(e *echo.Echo) error {
 | 
			
		||||
//			e.GET("/health", healthHandler)
 | 
			
		||||
//			return nil
 | 
			
		||||
//		}),
 | 
			
		||||
//	)
 | 
			
		||||
//	if err != nil {
 | 
			
		||||
//		log.Fatal(err)
 | 
			
		||||
//	}
 | 
			
		||||
//	err = ekko.Start(ctx)
 | 
			
		||||
package ekko
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
@@ -20,6 +49,25 @@ import (
 | 
			
		||||
	"golang.org/x/sync/errgroup"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// New creates a new Ekko instance with the specified service name and functional options.
 | 
			
		||||
// The name parameter is used for OpenTelemetry service identification, Prometheus metrics
 | 
			
		||||
// subsystem naming, and server identification headers.
 | 
			
		||||
//
 | 
			
		||||
// Default configuration includes:
 | 
			
		||||
//   - 60 second write timeout
 | 
			
		||||
//   - 30 second read header timeout
 | 
			
		||||
//   - HTTP/2 support with H2C
 | 
			
		||||
//   - Standard middleware stack (tracing, metrics, logging, security)
 | 
			
		||||
//
 | 
			
		||||
// Use functional options to customize behavior:
 | 
			
		||||
//   - WithPort(): Set server port (required for Start())
 | 
			
		||||
//   - WithPrometheus(): Enable Prometheus metrics
 | 
			
		||||
//   - WithEchoSetup(): Configure routes and handlers
 | 
			
		||||
//   - WithLogFilters(): Filter access logs
 | 
			
		||||
//   - WithOtelMiddleware(): Custom OpenTelemetry middleware
 | 
			
		||||
//   - WithWriteTimeout(): Custom write timeout
 | 
			
		||||
//   - WithReadHeaderTimeout(): Custom read header timeout
 | 
			
		||||
//   - WithGzipConfig(): Custom gzip compression settings
 | 
			
		||||
func New(name string, options ...func(*Ekko)) (*Ekko, error) {
 | 
			
		||||
	ek := &Ekko{
 | 
			
		||||
		writeTimeout:      60 * time.Second,
 | 
			
		||||
@@ -32,13 +80,25 @@ func New(name string, options ...func(*Ekko)) (*Ekko, error) {
 | 
			
		||||
	return ek, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Setup Echo; only intended for testing
 | 
			
		||||
// SetupEcho creates and configures an Echo instance without starting the server.
 | 
			
		||||
// This method is primarily intended for testing scenarios where you need access
 | 
			
		||||
// to the configured Echo instance without starting the HTTP server.
 | 
			
		||||
//
 | 
			
		||||
// The returned Echo instance includes all configured middleware and routes
 | 
			
		||||
// but requires manual server lifecycle management.
 | 
			
		||||
func (ek *Ekko) SetupEcho(ctx context.Context) (*echo.Echo, error) {
 | 
			
		||||
	return ek.setup(ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Setup Echo and start the server. Will return if the http server
 | 
			
		||||
// returns or the context is done.
 | 
			
		||||
// Start creates the Echo instance and starts the HTTP server with graceful shutdown support.
 | 
			
		||||
// The server runs until either an error occurs or the provided context is cancelled.
 | 
			
		||||
//
 | 
			
		||||
// The server supports HTTP/2 with H2C (HTTP/2 Cleartext) and includes a 5-second
 | 
			
		||||
// graceful shutdown timeout when the context is cancelled. Server configuration
 | 
			
		||||
// (port, timeouts, middleware) must be set via functional options during New().
 | 
			
		||||
//
 | 
			
		||||
// Returns an error if server startup fails or if shutdown doesn't complete within
 | 
			
		||||
// the timeout period. Returns nil for clean shutdown via context cancellation.
 | 
			
		||||
func (ek *Ekko) Start(ctx context.Context) error {
 | 
			
		||||
	log := logger.Setup()
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -9,6 +9,9 @@ import (
 | 
			
		||||
	slogecho "github.com/samber/slog-echo"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Ekko represents an enhanced Echo web server with pre-configured middleware stack.
 | 
			
		||||
// It encapsulates server configuration, middleware options, and lifecycle management
 | 
			
		||||
// for NTP Pool web services. Use New() with functional options to configure.
 | 
			
		||||
type Ekko struct {
 | 
			
		||||
	name           string
 | 
			
		||||
	prom           prometheus.Registerer
 | 
			
		||||
@@ -22,50 +25,76 @@ type Ekko struct {
 | 
			
		||||
	readHeaderTimeout time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RouteFn defines a function type for configuring Echo routes and handlers.
 | 
			
		||||
// It receives a configured Echo instance and should register all application
 | 
			
		||||
// routes, middleware, and handlers. Return an error to abort server startup.
 | 
			
		||||
type RouteFn func(e *echo.Echo) error
 | 
			
		||||
 | 
			
		||||
// WithPort sets the HTTP server port. This option is required when using Start().
 | 
			
		||||
// The port should be available and the process should have permission to bind to it.
 | 
			
		||||
func WithPort(port int) func(*Ekko) {
 | 
			
		||||
	return func(ek *Ekko) {
 | 
			
		||||
		ek.port = port
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithPrometheus enables Prometheus metrics collection using the provided registerer.
 | 
			
		||||
// Metrics include HTTP request duration, request count, and response size histograms.
 | 
			
		||||
// The service name is used as the metrics subsystem for namespacing.
 | 
			
		||||
func WithPrometheus(reg prometheus.Registerer) func(*Ekko) {
 | 
			
		||||
	return func(ek *Ekko) {
 | 
			
		||||
		ek.prom = reg
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithEchoSetup configures application routes and handlers via a setup function.
 | 
			
		||||
// The provided function receives the configured Echo instance after all middleware
 | 
			
		||||
// is applied and should register routes, custom middleware, and handlers.
 | 
			
		||||
func WithEchoSetup(rfn RouteFn) func(*Ekko) {
 | 
			
		||||
	return func(ek *Ekko) {
 | 
			
		||||
		ek.routeFn = rfn
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithLogFilters configures access log filtering to reduce log noise.
 | 
			
		||||
// Filters can exclude specific paths, methods, or status codes from access logs.
 | 
			
		||||
// Useful for excluding health checks, metrics endpoints, and other high-frequency requests.
 | 
			
		||||
func WithLogFilters(f []slogecho.Filter) func(*Ekko) {
 | 
			
		||||
	return func(ek *Ekko) {
 | 
			
		||||
		ek.logFilters = f
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithOtelMiddleware replaces the default OpenTelemetry middleware with a custom implementation.
 | 
			
		||||
// The default middleware provides distributed tracing for all requests. Use this option
 | 
			
		||||
// when you need custom trace configuration or want to disable tracing entirely.
 | 
			
		||||
func WithOtelMiddleware(mw echo.MiddlewareFunc) func(*Ekko) {
 | 
			
		||||
	return func(ek *Ekko) {
 | 
			
		||||
		ek.otelmiddleware = mw
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithWriteTimeout configures the HTTP server write timeout.
 | 
			
		||||
// This is the maximum duration before timing out writes of the response.
 | 
			
		||||
// Default is 60 seconds. Should be longer than expected response generation time.
 | 
			
		||||
func WithWriteTimeout(t time.Duration) func(*Ekko) {
 | 
			
		||||
	return func(ek *Ekko) {
 | 
			
		||||
		ek.writeTimeout = t
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithReadHeaderTimeout configures the HTTP server read header timeout.
 | 
			
		||||
// This is the amount of time allowed to read request headers.
 | 
			
		||||
// Default is 30 seconds. Should be sufficient for slow clients and large headers.
 | 
			
		||||
func WithReadHeaderTimeout(t time.Duration) func(*Ekko) {
 | 
			
		||||
	return func(ek *Ekko) {
 | 
			
		||||
		ek.readHeaderTimeout = t
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithGzipConfig provides custom gzip compression configuration.
 | 
			
		||||
// By default, gzip compression is enabled with standard settings.
 | 
			
		||||
// Use this option to customize compression level, skip patterns, or disable compression.
 | 
			
		||||
func WithGzipConfig(gzipConfig *middleware.GzipConfig) func(*Ekko) {
 | 
			
		||||
	return func(ek *Ekko) {
 | 
			
		||||
		ek.gzipConfig = gzipConfig
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										45
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								go.mod
									
									
									
									
									
								
							@@ -1,13 +1,16 @@
 | 
			
		||||
module go.ntppool.org/common
 | 
			
		||||
 | 
			
		||||
go 1.23.5
 | 
			
		||||
go 1.24.0
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/abh/certman v0.4.0
 | 
			
		||||
	github.com/go-sql-driver/mysql v1.9.3
 | 
			
		||||
	github.com/jackc/pgx/v5 v5.7.6
 | 
			
		||||
	github.com/labstack/echo-contrib v0.17.2
 | 
			
		||||
	github.com/labstack/echo/v4 v4.13.3
 | 
			
		||||
	github.com/oklog/ulid/v2 v2.1.0
 | 
			
		||||
	github.com/prometheus/client_golang v1.20.5
 | 
			
		||||
	github.com/prometheus/client_model v0.6.1
 | 
			
		||||
	github.com/remychantenay/slog-otel v1.3.2
 | 
			
		||||
	github.com/samber/slog-echo v1.14.8
 | 
			
		||||
	github.com/samber/slog-multi v1.2.4
 | 
			
		||||
@@ -17,61 +20,75 @@ require (
 | 
			
		||||
	go.opentelemetry.io/contrib/exporters/autoexport v0.58.0
 | 
			
		||||
	go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo/otelecho v0.58.0
 | 
			
		||||
	go.opentelemetry.io/otel v1.33.0
 | 
			
		||||
	go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.9.0
 | 
			
		||||
	go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp v0.9.0
 | 
			
		||||
	go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.33.0
 | 
			
		||||
	go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.33.0
 | 
			
		||||
	go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.33.0
 | 
			
		||||
	go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.33.0
 | 
			
		||||
	go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.33.0
 | 
			
		||||
	go.opentelemetry.io/otel/log v0.9.0
 | 
			
		||||
	go.opentelemetry.io/otel/metric v1.33.0
 | 
			
		||||
	go.opentelemetry.io/otel/sdk v1.33.0
 | 
			
		||||
	go.opentelemetry.io/otel/sdk/log v0.9.0
 | 
			
		||||
	go.opentelemetry.io/otel/sdk/metric v1.33.0
 | 
			
		||||
	go.opentelemetry.io/otel/trace v1.33.0
 | 
			
		||||
	golang.org/x/mod v0.22.0
 | 
			
		||||
	golang.org/x/net v0.33.0
 | 
			
		||||
	golang.org/x/sync v0.10.0
 | 
			
		||||
	golang.org/x/mod v0.28.0
 | 
			
		||||
	golang.org/x/net v0.44.0
 | 
			
		||||
	golang.org/x/sync v0.17.0
 | 
			
		||||
	google.golang.org/grpc v1.69.2
 | 
			
		||||
	gopkg.in/yaml.v3 v3.0.1
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	filippo.io/edwards25519 v1.1.0 // indirect
 | 
			
		||||
	github.com/beorn7/perks v1.0.1 // indirect
 | 
			
		||||
	github.com/cenkalti/backoff/v4 v4.3.0 // indirect
 | 
			
		||||
	github.com/cespare/xxhash/v2 v2.3.0 // indirect
 | 
			
		||||
	github.com/dmarkham/enumer v1.6.1 // indirect
 | 
			
		||||
	github.com/fsnotify/fsnotify v1.8.0 // indirect
 | 
			
		||||
	github.com/go-logr/logr v1.4.2 // indirect
 | 
			
		||||
	github.com/go-logr/stdr v1.2.2 // indirect
 | 
			
		||||
	github.com/google/uuid v1.6.0 // indirect
 | 
			
		||||
	github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 // indirect
 | 
			
		||||
	github.com/inconshreveable/mousetrap v1.1.0 // indirect
 | 
			
		||||
	github.com/jackc/pgpassfile v1.0.0 // indirect
 | 
			
		||||
	github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
 | 
			
		||||
	github.com/jackc/puddle/v2 v2.2.2 // indirect
 | 
			
		||||
	github.com/klauspost/compress v1.17.11 // indirect
 | 
			
		||||
	github.com/labstack/gommon v0.4.2 // indirect
 | 
			
		||||
	github.com/masaushi/accessory v0.6.0 // indirect
 | 
			
		||||
	github.com/mattn/go-colorable v0.1.13 // indirect
 | 
			
		||||
	github.com/mattn/go-isatty v0.0.20 // indirect
 | 
			
		||||
	github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
 | 
			
		||||
	github.com/pascaldekloe/name v1.0.0 // indirect
 | 
			
		||||
	github.com/pierrec/lz4/v4 v4.1.22 // indirect
 | 
			
		||||
	github.com/pkg/errors v0.9.1 // indirect
 | 
			
		||||
	github.com/prometheus/client_model v0.6.1 // indirect
 | 
			
		||||
	github.com/prometheus/common v0.61.0 // indirect
 | 
			
		||||
	github.com/prometheus/procfs v0.15.1 // indirect
 | 
			
		||||
	github.com/samber/lo v1.47.0 // indirect
 | 
			
		||||
	github.com/spf13/afero v1.15.0 // indirect
 | 
			
		||||
	github.com/spf13/pflag v1.0.5 // indirect
 | 
			
		||||
	github.com/valyala/bytebufferpool v1.0.0 // indirect
 | 
			
		||||
	github.com/valyala/fasttemplate v1.2.2 // indirect
 | 
			
		||||
	go.opentelemetry.io/auto/sdk v1.1.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/contrib/bridges/prometheus v0.58.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.9.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp v0.9.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.33.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.33.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/otel/exporters/prometheus v0.55.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/otel/exporters/stdout/stdoutlog v0.9.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.33.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.33.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/otel/metric v1.33.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/otel/sdk/metric v1.33.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/proto/otlp v1.4.0 // indirect
 | 
			
		||||
	golang.org/x/crypto v0.31.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.28.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.21.0 // indirect
 | 
			
		||||
	golang.org/x/crypto v0.42.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.36.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.29.0 // indirect
 | 
			
		||||
	golang.org/x/time v0.8.0 // indirect
 | 
			
		||||
	golang.org/x/tools v0.37.0 // indirect
 | 
			
		||||
	google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8 // indirect
 | 
			
		||||
	google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 // indirect
 | 
			
		||||
	google.golang.org/protobuf v1.36.1 // indirect
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
tool (
 | 
			
		||||
	github.com/dmarkham/enumer
 | 
			
		||||
	github.com/masaushi/accessory
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										58
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										58
									
								
								go.sum
									
									
									
									
									
								
							@@ -1,7 +1,11 @@
 | 
			
		||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
 | 
			
		||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
 | 
			
		||||
github.com/abh/certman v0.4.0 h1:XHoDtb0YyRQPclaHMrBDlKTVZpNjTK6vhB0S3Bd/Sbs=
 | 
			
		||||
github.com/abh/certman v0.4.0/go.mod h1:x8QhpKVZifmV1Hdiwdg9gLo2GMPAxezz1s3zrVnPs+I=
 | 
			
		||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
 | 
			
		||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
 | 
			
		||||
github.com/bradleyjkemp/cupaloy/v2 v2.8.0 h1:any4BmKE+jGIaMpnU8YgH/I2LPiLBufr6oMMlVBbn9M=
 | 
			
		||||
github.com/bradleyjkemp/cupaloy/v2 v2.8.0/go.mod h1:bm7JXdkRd4BHJk9HpwqAI8BoAY1lps46Enkdqw6aRX0=
 | 
			
		||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
 | 
			
		||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
 | 
			
		||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
 | 
			
		||||
@@ -10,6 +14,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t
 | 
			
		||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 | 
			
		||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
 | 
			
		||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 | 
			
		||||
github.com/dmarkham/enumer v1.6.1 h1:aSc9awYtZL07TUueWs40QcHtxTvHTAwG0EqrNsK45w4=
 | 
			
		||||
github.com/dmarkham/enumer v1.6.1/go.mod h1:yixql+kDDQRYqcuBM2n9Vlt7NoT9ixgXhaXry8vmRg8=
 | 
			
		||||
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
 | 
			
		||||
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
 | 
			
		||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
 | 
			
		||||
@@ -17,6 +23,8 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
 | 
			
		||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
 | 
			
		||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
 | 
			
		||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
 | 
			
		||||
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
 | 
			
		||||
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
 | 
			
		||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
 | 
			
		||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
 | 
			
		||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
 | 
			
		||||
@@ -27,9 +35,21 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 h1:VNqngBF40hVlDloBruUehVYC3Ar
 | 
			
		||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1/go.mod h1:RBRO7fro65R6tjKzYgLAFo0t1QEXY1Dp+i/bvpRiqiQ=
 | 
			
		||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
 | 
			
		||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
 | 
			
		||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
 | 
			
		||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
 | 
			
		||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
 | 
			
		||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
 | 
			
		||||
github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk=
 | 
			
		||||
github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
 | 
			
		||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
 | 
			
		||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
 | 
			
		||||
github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU=
 | 
			
		||||
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
 | 
			
		||||
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
 | 
			
		||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
 | 
			
		||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
 | 
			
		||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
 | 
			
		||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
 | 
			
		||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
 | 
			
		||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
 | 
			
		||||
github.com/labstack/echo-contrib v0.17.2 h1:K1zivqmtcC70X9VdBFdLomjPDEVHlrcAObqmuFj1c6w=
 | 
			
		||||
@@ -38,6 +58,8 @@ github.com/labstack/echo/v4 v4.13.3 h1:pwhpCPrTl5qry5HRdM5FwdXnhXSLSY+WE+YQSeCaa
 | 
			
		||||
github.com/labstack/echo/v4 v4.13.3/go.mod h1:o90YNEeQWjDozo584l7AwhJMHN0bOC4tAfg+Xox9q5g=
 | 
			
		||||
github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0=
 | 
			
		||||
github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU=
 | 
			
		||||
github.com/masaushi/accessory v0.6.0 h1:HYAzxkuhfvlbaQwinxXTxsSPbFabAnwHt8K6I/DvNBU=
 | 
			
		||||
github.com/masaushi/accessory v0.6.0/go.mod h1:8GZMgq3wcIapVZWt7VVQCh5+onPc/8gJeHb8WRXezvQ=
 | 
			
		||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
 | 
			
		||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
 | 
			
		||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
 | 
			
		||||
@@ -47,6 +69,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
 | 
			
		||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
 | 
			
		||||
github.com/oklog/ulid/v2 v2.1.0 h1:+9lhoxAP56we25tyYETBBY1YLA2SaoLvUFgrP2miPJU=
 | 
			
		||||
github.com/oklog/ulid/v2 v2.1.0/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ=
 | 
			
		||||
github.com/pascaldekloe/name v1.0.0 h1:n7LKFgHixETzxpRv2R77YgPUFo85QHGZKrdaYm7eY5U=
 | 
			
		||||
github.com/pascaldekloe/name v1.0.0/go.mod h1:Z//MfYJnH4jVpQ9wkclwu2I2MkHmXTlT9wR5UZScttM=
 | 
			
		||||
github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o=
 | 
			
		||||
github.com/pierrec/lz4/v4 v4.1.15/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
 | 
			
		||||
github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU=
 | 
			
		||||
@@ -65,6 +89,8 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg
 | 
			
		||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
 | 
			
		||||
github.com/remychantenay/slog-otel v1.3.2 h1:ZBx8qnwfLJ6e18Vba4e9Xp9B7khTmpIwFsU1sAmActw=
 | 
			
		||||
github.com/remychantenay/slog-otel v1.3.2/go.mod h1:gKW4tQ8cGOKoA+bi7wtYba/tcJ6Tc9XyQ/EW8gHA/2E=
 | 
			
		||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
 | 
			
		||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
 | 
			
		||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
 | 
			
		||||
github.com/samber/lo v1.47.0 h1:z7RynLwP5nbyRscyvcD043DWYoOcYRv3mV8lBeqOCLc=
 | 
			
		||||
github.com/samber/lo v1.47.0/go.mod h1:RmDH9Ct32Qy3gduHQuKJ3gW1fMHAnE/fAzQuf6He5cU=
 | 
			
		||||
@@ -74,12 +100,16 @@ github.com/samber/slog-multi v1.2.4 h1:k9x3JAWKJFPKffx+oXZ8TasaNuorIW4tG+TXxkt6R
 | 
			
		||||
github.com/samber/slog-multi v1.2.4/go.mod h1:ACuZ5B6heK57TfMVkVknN2UZHoFfjCwRxR0Q2OXKHlo=
 | 
			
		||||
github.com/segmentio/kafka-go v0.4.47 h1:IqziR4pA3vrZq7YdRxaT3w1/5fvIH5qpCwstUanQQB0=
 | 
			
		||||
github.com/segmentio/kafka-go v0.4.47/go.mod h1:HjF6XbOKh0Pjlkr5GVZxt6CsjjwnmhVOfURM5KMd8qg=
 | 
			
		||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
 | 
			
		||||
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
 | 
			
		||||
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
 | 
			
		||||
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
 | 
			
		||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
 | 
			
		||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
 | 
			
		||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 | 
			
		||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
 | 
			
		||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
 | 
			
		||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 | 
			
		||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 | 
			
		||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
 | 
			
		||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
 | 
			
		||||
@@ -150,25 +180,25 @@ go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
 | 
			
		||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
 | 
			
		||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
 | 
			
		||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
 | 
			
		||||
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
 | 
			
		||||
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
 | 
			
		||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
 | 
			
		||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
 | 
			
		||||
golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4=
 | 
			
		||||
golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
 | 
			
		||||
golang.org/x/mod v0.28.0 h1:gQBtGhjxykdjY9YhZpSlZIsbnaE2+PgjfLWUQTnoZ1U=
 | 
			
		||||
golang.org/x/mod v0.28.0/go.mod h1:yfB/L0NOf/kmEbXjzCPOx1iK1fRutOydrCMsqRhEBxI=
 | 
			
		||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
 | 
			
		||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
 | 
			
		||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
 | 
			
		||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
 | 
			
		||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
 | 
			
		||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
 | 
			
		||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
 | 
			
		||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
 | 
			
		||||
golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I=
 | 
			
		||||
golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
 | 
			
		||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 | 
			
		||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 | 
			
		||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 | 
			
		||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
 | 
			
		||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
 | 
			
		||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
 | 
			
		||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
 | 
			
		||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 | 
			
		||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
			
		||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
@@ -179,8 +209,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
 | 
			
		||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 | 
			
		||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
 | 
			
		||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
 | 
			
		||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 | 
			
		||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
 | 
			
		||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
 | 
			
		||||
@@ -193,14 +223,16 @@ golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
 | 
			
		||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
 | 
			
		||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
 | 
			
		||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
 | 
			
		||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
 | 
			
		||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
 | 
			
		||||
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
 | 
			
		||||
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
 | 
			
		||||
golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg=
 | 
			
		||||
golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
 | 
			
		||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 | 
			
		||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
 | 
			
		||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
 | 
			
		||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
 | 
			
		||||
golang.org/x/tools v0.37.0 h1:DVSRzp7FwePZW356yEAChSdNcQo6Nsp+fex1SUW09lE=
 | 
			
		||||
golang.org/x/tools v0.37.0/go.mod h1:MBN5QPQtLMHVdvsbtarmTNukZDdgwdwlO5qGacAzF0w=
 | 
			
		||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 | 
			
		||||
google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8 h1:st3LcW/BPi75W4q1jJTEor/QWwbNlPlDG0JTn6XhZu0=
 | 
			
		||||
google.golang.org/genproto/googleapis/api v0.0.0-20241223144023-3abc09e42ca8/go.mod h1:klhJGKFyG8Tn50enBn7gizg4nXGXJ+jqEREdCWaPcV4=
 | 
			
		||||
@@ -211,6 +243,8 @@ google.golang.org/grpc v1.69.2/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7
 | 
			
		||||
google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk=
 | 
			
		||||
google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
 | 
			
		||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 | 
			
		||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
 | 
			
		||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
 | 
			
		||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 | 
			
		||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
 | 
			
		||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,71 @@
 | 
			
		||||
// Package health provides a standalone HTTP server for health checks.
 | 
			
		||||
//
 | 
			
		||||
// This package implements a flexible health check server that supports
 | 
			
		||||
// different handlers for Kubernetes probe types (liveness, readiness, startup).
 | 
			
		||||
// It provides structured logging, graceful shutdown, and standard HTTP endpoints
 | 
			
		||||
// for monitoring and load balancing.
 | 
			
		||||
//
 | 
			
		||||
// # Kubernetes Probe Types
 | 
			
		||||
//
 | 
			
		||||
// Liveness Probe: Detects when a container is "dead" and needs restarting.
 | 
			
		||||
// Should be a lightweight check that verifies the process is still running
 | 
			
		||||
// and not in an unrecoverable state.
 | 
			
		||||
//
 | 
			
		||||
// Readiness Probe: Determines when a container is ready to accept traffic.
 | 
			
		||||
// Controls which Pods are used as backends for Services. Should verify
 | 
			
		||||
// the application can handle requests properly.
 | 
			
		||||
//
 | 
			
		||||
// Startup Probe: Verifies when a container application has successfully started.
 | 
			
		||||
// Delays liveness and readiness probes until startup succeeds. Useful for
 | 
			
		||||
// slow-starting applications.
 | 
			
		||||
//
 | 
			
		||||
// # Usage Examples
 | 
			
		||||
//
 | 
			
		||||
// Basic usage with a single handler for all probes:
 | 
			
		||||
//
 | 
			
		||||
//	srv := health.NewServer(myHealthHandler)
 | 
			
		||||
//	srv.Listen(ctx, 9091)
 | 
			
		||||
//
 | 
			
		||||
// Advanced usage with separate handlers for each probe type:
 | 
			
		||||
//
 | 
			
		||||
//	srv := health.NewServer(nil,
 | 
			
		||||
//		health.WithLivenessHandler(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
//			// Simple alive check
 | 
			
		||||
//			w.WriteHeader(http.StatusOK)
 | 
			
		||||
//		}),
 | 
			
		||||
//		health.WithReadinessHandler(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
//			// Check if ready to serve traffic
 | 
			
		||||
//			if err := checkDatabase(); err != nil {
 | 
			
		||||
//				w.WriteHeader(http.StatusServiceUnavailable)
 | 
			
		||||
//				return
 | 
			
		||||
//			}
 | 
			
		||||
//			w.WriteHeader(http.StatusOK)
 | 
			
		||||
//		}),
 | 
			
		||||
//		health.WithStartupHandler(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
//			// Check if startup is complete
 | 
			
		||||
//			if !applicationReady() {
 | 
			
		||||
//				w.WriteHeader(http.StatusServiceUnavailable)
 | 
			
		||||
//				return
 | 
			
		||||
//			}
 | 
			
		||||
//			w.WriteHeader(http.StatusOK)
 | 
			
		||||
//		}),
 | 
			
		||||
//		health.WithServiceName("my-service"),
 | 
			
		||||
//	)
 | 
			
		||||
//	srv.Listen(ctx, 9091)
 | 
			
		||||
//
 | 
			
		||||
// # Standard Endpoints
 | 
			
		||||
//
 | 
			
		||||
// The server exposes these endpoints:
 | 
			
		||||
//   - /healthz - liveness probe (or general health if no specific handler)
 | 
			
		||||
//   - /readyz - readiness probe (or general health if no specific handler)
 | 
			
		||||
//   - /startupz - startup probe (or general health if no specific handler)
 | 
			
		||||
//   - /__health - general health endpoint (backward compatibility)
 | 
			
		||||
//   - / - general health endpoint (root path)
 | 
			
		||||
package health
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"log/slog"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
@@ -11,32 +75,108 @@ import (
 | 
			
		||||
	"golang.org/x/sync/errgroup"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Server is a standalone HTTP server dedicated to health checks.
 | 
			
		||||
// It runs separately from the main application server to ensure health
 | 
			
		||||
// checks remain available even if the main server is experiencing issues.
 | 
			
		||||
//
 | 
			
		||||
// The server supports separate handlers for different Kubernetes probe types
 | 
			
		||||
// (liveness, readiness, startup) and includes built-in timeouts, graceful
 | 
			
		||||
// shutdown, and structured logging.
 | 
			
		||||
type Server struct {
 | 
			
		||||
	log              *slog.Logger
 | 
			
		||||
	healthFn http.HandlerFunc
 | 
			
		||||
	livenessHandler  http.HandlerFunc
 | 
			
		||||
	readinessHandler http.HandlerFunc
 | 
			
		||||
	startupHandler   http.HandlerFunc
 | 
			
		||||
	generalHandler   http.HandlerFunc // fallback for /__health and / paths
 | 
			
		||||
	serviceName      string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewServer(healthFn http.HandlerFunc) *Server {
 | 
			
		||||
// Option represents a configuration option for the health server.
 | 
			
		||||
type Option func(*Server)
 | 
			
		||||
 | 
			
		||||
// WithLivenessHandler sets a specific handler for the /healthz endpoint.
 | 
			
		||||
// Liveness probes determine if a container should be restarted.
 | 
			
		||||
func WithLivenessHandler(handler http.HandlerFunc) Option {
 | 
			
		||||
	return func(s *Server) {
 | 
			
		||||
		s.livenessHandler = handler
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithReadinessHandler sets a specific handler for the /readyz endpoint.
 | 
			
		||||
// Readiness probes determine if a container can receive traffic.
 | 
			
		||||
func WithReadinessHandler(handler http.HandlerFunc) Option {
 | 
			
		||||
	return func(s *Server) {
 | 
			
		||||
		s.readinessHandler = handler
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithStartupHandler sets a specific handler for the /startupz endpoint.
 | 
			
		||||
// Startup probes determine if a container has finished initializing.
 | 
			
		||||
func WithStartupHandler(handler http.HandlerFunc) Option {
 | 
			
		||||
	return func(s *Server) {
 | 
			
		||||
		s.startupHandler = handler
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithServiceName sets the service name for JSON responses and logging.
 | 
			
		||||
func WithServiceName(serviceName string) Option {
 | 
			
		||||
	return func(s *Server) {
 | 
			
		||||
		s.serviceName = serviceName
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// NewServer creates a new health check server with optional probe-specific handlers.
 | 
			
		||||
//
 | 
			
		||||
// If healthFn is provided, it will be used as a fallback for any probe endpoints
 | 
			
		||||
// that don't have specific handlers configured. If healthFn is nil, a default
 | 
			
		||||
// handler that returns HTTP 200 "ok" is used as the fallback.
 | 
			
		||||
//
 | 
			
		||||
// Use the With* option functions to configure specific handlers for different
 | 
			
		||||
// probe types (liveness, readiness, startup).
 | 
			
		||||
func NewServer(healthFn http.HandlerFunc, opts ...Option) *Server {
 | 
			
		||||
	if healthFn == nil {
 | 
			
		||||
		healthFn = basicHealth
 | 
			
		||||
	}
 | 
			
		||||
	srv := &Server{
 | 
			
		||||
		log:            logger.Setup(),
 | 
			
		||||
		healthFn: healthFn,
 | 
			
		||||
		generalHandler: healthFn,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, opt := range opts {
 | 
			
		||||
		opt(srv)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return srv
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetLogger replaces the default logger with a custom one.
 | 
			
		||||
func (srv *Server) SetLogger(log *slog.Logger) {
 | 
			
		||||
	srv.log = log
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Listen starts the health server on the specified port and blocks until ctx is cancelled.
 | 
			
		||||
// The server exposes health check endpoints with graceful shutdown support.
 | 
			
		||||
//
 | 
			
		||||
// Standard endpoints exposed:
 | 
			
		||||
//   - /healthz - liveness probe (uses livenessHandler or falls back to generalHandler)
 | 
			
		||||
//   - /readyz - readiness probe (uses readinessHandler or falls back to generalHandler)
 | 
			
		||||
//   - /startupz - startup probe (uses startupHandler or falls back to generalHandler)
 | 
			
		||||
//   - /__health - general health endpoint (uses generalHandler)
 | 
			
		||||
//   - / - root health endpoint (uses generalHandler)
 | 
			
		||||
func (srv *Server) Listen(ctx context.Context, port int) error {
 | 
			
		||||
	srv.log.Info("starting health listener", "port", port)
 | 
			
		||||
 | 
			
		||||
	serveMux := http.NewServeMux()
 | 
			
		||||
 | 
			
		||||
	serveMux.HandleFunc("/__health", srv.healthFn)
 | 
			
		||||
	// Register probe-specific handlers
 | 
			
		||||
	serveMux.HandleFunc("/healthz", srv.createProbeHandler("liveness"))
 | 
			
		||||
	serveMux.HandleFunc("/readyz", srv.createProbeHandler("readiness"))
 | 
			
		||||
	serveMux.HandleFunc("/startupz", srv.createProbeHandler("startup"))
 | 
			
		||||
 | 
			
		||||
	// Register general health endpoints for backward compatibility
 | 
			
		||||
	serveMux.HandleFunc("/__health", srv.createGeneralHandler())
 | 
			
		||||
	serveMux.HandleFunc("/", srv.createGeneralHandler())
 | 
			
		||||
 | 
			
		||||
	hsrv := &http.Server{
 | 
			
		||||
		Addr:         ":" + strconv.Itoa(port),
 | 
			
		||||
@@ -72,8 +212,122 @@ func (srv *Server) Listen(ctx context.Context, port int) error {
 | 
			
		||||
	return g.Wait()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// HealthCheckListener runs simple http server on the specified port for
 | 
			
		||||
// health check probes
 | 
			
		||||
// createProbeHandler creates a handler for a specific probe type that provides
 | 
			
		||||
// appropriate JSON responses and falls back to the general handler if no specific
 | 
			
		||||
// handler is configured.
 | 
			
		||||
func (srv *Server) createProbeHandler(probeType string) http.HandlerFunc {
 | 
			
		||||
	return func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		var handler http.HandlerFunc
 | 
			
		||||
 | 
			
		||||
		// Select the appropriate handler
 | 
			
		||||
		switch probeType {
 | 
			
		||||
		case "liveness":
 | 
			
		||||
			handler = srv.livenessHandler
 | 
			
		||||
		case "readiness":
 | 
			
		||||
			handler = srv.readinessHandler
 | 
			
		||||
		case "startup":
 | 
			
		||||
			handler = srv.startupHandler
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Fall back to general handler if no specific handler is configured
 | 
			
		||||
		if handler == nil {
 | 
			
		||||
			handler = srv.generalHandler
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Create a response recorder to capture the handler's status code
 | 
			
		||||
		recorder := &statusRecorder{ResponseWriter: w, statusCode: 200}
 | 
			
		||||
		handler(recorder, r)
 | 
			
		||||
 | 
			
		||||
		// If the handler already wrote a response, we're done
 | 
			
		||||
		if recorder.written {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Otherwise, provide a standard JSON response based on the status code
 | 
			
		||||
		w.Header().Set("Content-Type", "application/json")
 | 
			
		||||
 | 
			
		||||
		if recorder.statusCode >= 400 {
 | 
			
		||||
			// Handler indicated unhealthy
 | 
			
		||||
			switch probeType {
 | 
			
		||||
			case "liveness":
 | 
			
		||||
				json.NewEncoder(w).Encode(map[string]string{"status": "unhealthy"})
 | 
			
		||||
			case "readiness":
 | 
			
		||||
				json.NewEncoder(w).Encode(map[string]bool{"ready": false})
 | 
			
		||||
			case "startup":
 | 
			
		||||
				json.NewEncoder(w).Encode(map[string]bool{"started": false})
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			// Handler indicated healthy
 | 
			
		||||
			switch probeType {
 | 
			
		||||
			case "liveness":
 | 
			
		||||
				json.NewEncoder(w).Encode(map[string]string{"status": "alive"})
 | 
			
		||||
			case "readiness":
 | 
			
		||||
				json.NewEncoder(w).Encode(map[string]bool{"ready": true})
 | 
			
		||||
			case "startup":
 | 
			
		||||
				json.NewEncoder(w).Encode(map[string]bool{"started": true})
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// createGeneralHandler creates a handler for general health endpoints that provides
 | 
			
		||||
// comprehensive health information.
 | 
			
		||||
func (srv *Server) createGeneralHandler() http.HandlerFunc {
 | 
			
		||||
	return func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		// Create a response recorder to capture the handler's status code
 | 
			
		||||
		// Use a buffer to prevent the handler from writing to the actual response
 | 
			
		||||
		recorder := &statusRecorder{ResponseWriter: &discardWriter{}, statusCode: 200}
 | 
			
		||||
		srv.generalHandler(recorder, r)
 | 
			
		||||
 | 
			
		||||
		// Always provide a comprehensive JSON response for general endpoints
 | 
			
		||||
		w.Header().Set("Content-Type", "application/json")
 | 
			
		||||
		w.WriteHeader(recorder.statusCode)
 | 
			
		||||
 | 
			
		||||
		response := map[string]interface{}{
 | 
			
		||||
			"status": map[bool]string{true: "healthy", false: "unhealthy"}[recorder.statusCode < 400],
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if srv.serviceName != "" {
 | 
			
		||||
			response["service"] = srv.serviceName
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		json.NewEncoder(w).Encode(response)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// statusRecorder captures the response status code from handlers while allowing
 | 
			
		||||
// them to write their own response content if needed.
 | 
			
		||||
type statusRecorder struct {
 | 
			
		||||
	http.ResponseWriter
 | 
			
		||||
	statusCode int
 | 
			
		||||
	written    bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *statusRecorder) WriteHeader(code int) {
 | 
			
		||||
	r.statusCode = code
 | 
			
		||||
	r.ResponseWriter.WriteHeader(code)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *statusRecorder) Write(data []byte) (int, error) {
 | 
			
		||||
	r.written = true
 | 
			
		||||
	return r.ResponseWriter.Write(data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// discardWriter implements http.ResponseWriter but discards all writes.
 | 
			
		||||
// Used to capture status codes without writing response content.
 | 
			
		||||
type discardWriter struct{}
 | 
			
		||||
 | 
			
		||||
func (d *discardWriter) Header() http.Header {
 | 
			
		||||
	return make(http.Header)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *discardWriter) Write([]byte) (int, error) {
 | 
			
		||||
	return 0, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *discardWriter) WriteHeader(int) {}
 | 
			
		||||
 | 
			
		||||
// HealthCheckListener runs a simple HTTP server on the specified port for health check probes.
 | 
			
		||||
func HealthCheckListener(ctx context.Context, port int, log *slog.Logger) error {
 | 
			
		||||
	srv := NewServer(nil)
 | 
			
		||||
	srv.SetLogger(log)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,13 +1,14 @@
 | 
			
		||||
package health
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/http/httptest"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestHealthHandler(t *testing.T) {
 | 
			
		||||
func TestBasicHealthHandler(t *testing.T) {
 | 
			
		||||
	req := httptest.NewRequest(http.MethodGet, "/__health", nil)
 | 
			
		||||
	w := httptest.NewRecorder()
 | 
			
		||||
 | 
			
		||||
@@ -24,3 +25,129 @@ func TestHealthHandler(t *testing.T) {
 | 
			
		||||
		t.Errorf("expected ok got %q", string(data))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestProbeHandlers(t *testing.T) {
 | 
			
		||||
	// Test with separate handlers for each probe type
 | 
			
		||||
	srv := NewServer(nil,
 | 
			
		||||
		WithLivenessHandler(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
			w.WriteHeader(http.StatusOK)
 | 
			
		||||
		}),
 | 
			
		||||
		WithReadinessHandler(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
			w.WriteHeader(http.StatusOK)
 | 
			
		||||
		}),
 | 
			
		||||
		WithStartupHandler(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
			w.WriteHeader(http.StatusOK)
 | 
			
		||||
		}),
 | 
			
		||||
		WithServiceName("test-service"),
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		handler        http.HandlerFunc
 | 
			
		||||
		expectedStatus int
 | 
			
		||||
		expectedBody   string
 | 
			
		||||
	}{
 | 
			
		||||
		{srv.createProbeHandler("liveness"), 200, `{"status":"alive"}`},
 | 
			
		||||
		{srv.createProbeHandler("readiness"), 200, `{"ready":true}`},
 | 
			
		||||
		{srv.createProbeHandler("startup"), 200, `{"started":true}`},
 | 
			
		||||
		{srv.createGeneralHandler(), 200, `{"service":"test-service","status":"healthy"}`},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for i, tt := range tests {
 | 
			
		||||
		t.Run(fmt.Sprintf("test_%d", i), func(t *testing.T) {
 | 
			
		||||
			req := httptest.NewRequest(http.MethodGet, "/", nil)
 | 
			
		||||
			w := httptest.NewRecorder()
 | 
			
		||||
 | 
			
		||||
			tt.handler(w, req)
 | 
			
		||||
 | 
			
		||||
			if w.Code != tt.expectedStatus {
 | 
			
		||||
				t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			body := w.Body.String()
 | 
			
		||||
			if body != tt.expectedBody+"\n" { // json.Encoder adds newline
 | 
			
		||||
				t.Errorf("expected body %q, got %q", tt.expectedBody, body)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestProbeHandlerFallback(t *testing.T) {
 | 
			
		||||
	// Test fallback to general handler when no specific handler is configured
 | 
			
		||||
	generalHandler := func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		w.WriteHeader(http.StatusOK)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	srv := NewServer(generalHandler, WithServiceName("test-service"))
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		handler        http.HandlerFunc
 | 
			
		||||
		expectedStatus int
 | 
			
		||||
		expectedBody   string
 | 
			
		||||
	}{
 | 
			
		||||
		{srv.createProbeHandler("liveness"), 200, `{"status":"alive"}`},
 | 
			
		||||
		{srv.createProbeHandler("readiness"), 200, `{"ready":true}`},
 | 
			
		||||
		{srv.createProbeHandler("startup"), 200, `{"started":true}`},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for i, tt := range tests {
 | 
			
		||||
		t.Run(fmt.Sprintf("fallback_%d", i), func(t *testing.T) {
 | 
			
		||||
			req := httptest.NewRequest(http.MethodGet, "/", nil)
 | 
			
		||||
			w := httptest.NewRecorder()
 | 
			
		||||
 | 
			
		||||
			tt.handler(w, req)
 | 
			
		||||
 | 
			
		||||
			if w.Code != tt.expectedStatus {
 | 
			
		||||
				t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			body := w.Body.String()
 | 
			
		||||
			if body != tt.expectedBody+"\n" { // json.Encoder adds newline
 | 
			
		||||
				t.Errorf("expected body %q, got %q", tt.expectedBody, body)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestUnhealthyProbeHandlers(t *testing.T) {
 | 
			
		||||
	// Test with handlers that return unhealthy status
 | 
			
		||||
	srv := NewServer(nil,
 | 
			
		||||
		WithLivenessHandler(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
			w.WriteHeader(http.StatusServiceUnavailable)
 | 
			
		||||
		}),
 | 
			
		||||
		WithReadinessHandler(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
			w.WriteHeader(http.StatusServiceUnavailable)
 | 
			
		||||
		}),
 | 
			
		||||
		WithStartupHandler(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
			w.WriteHeader(http.StatusServiceUnavailable)
 | 
			
		||||
		}),
 | 
			
		||||
		WithServiceName("test-service"),
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		handler        http.HandlerFunc
 | 
			
		||||
		expectedStatus int
 | 
			
		||||
		expectedBody   string
 | 
			
		||||
	}{
 | 
			
		||||
		{srv.createProbeHandler("liveness"), 503, `{"status":"unhealthy"}`},
 | 
			
		||||
		{srv.createProbeHandler("readiness"), 503, `{"ready":false}`},
 | 
			
		||||
		{srv.createProbeHandler("startup"), 503, `{"started":false}`},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for i, tt := range tests {
 | 
			
		||||
		t.Run(fmt.Sprintf("unhealthy_%d", i), func(t *testing.T) {
 | 
			
		||||
			req := httptest.NewRequest(http.MethodGet, "/", nil)
 | 
			
		||||
			w := httptest.NewRecorder()
 | 
			
		||||
 | 
			
		||||
			tt.handler(w, req)
 | 
			
		||||
 | 
			
		||||
			if w.Code != tt.expectedStatus {
 | 
			
		||||
				t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			body := w.Body.String()
 | 
			
		||||
			if body != tt.expectedBody+"\n" { // json.Encoder adds newline
 | 
			
		||||
				t.Errorf("expected body %q, got %q", tt.expectedBody, body)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										378
									
								
								internal/tracerconfig/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										378
									
								
								internal/tracerconfig/config.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,378 @@
 | 
			
		||||
// Package tracerconfig provides a bridge to eliminate circular dependencies between
 | 
			
		||||
// the logger and tracing packages. It stores tracer configuration and provides
 | 
			
		||||
// factory functions that can be used by the logger package without importing tracing.
 | 
			
		||||
package tracerconfig
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"crypto/x509"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc"
 | 
			
		||||
	"go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp"
 | 
			
		||||
	"go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc"
 | 
			
		||||
	"go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp"
 | 
			
		||||
	"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
 | 
			
		||||
	"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
 | 
			
		||||
	"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
 | 
			
		||||
	sdklog "go.opentelemetry.io/otel/sdk/log"
 | 
			
		||||
	sdkmetric "go.opentelemetry.io/otel/sdk/metric"
 | 
			
		||||
	sdktrace "go.opentelemetry.io/otel/sdk/trace"
 | 
			
		||||
	"google.golang.org/grpc/credentials"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	otelExporterOTLPProtoEnvKey        = "OTEL_EXPORTER_OTLP_PROTOCOL"
 | 
			
		||||
	otelExporterOTLPTracesProtoEnvKey  = "OTEL_EXPORTER_OTLP_TRACES_PROTOCOL"
 | 
			
		||||
	otelExporterOTLPLogsProtoEnvKey    = "OTEL_EXPORTER_OTLP_LOGS_PROTOCOL"
 | 
			
		||||
	otelExporterOTLPMetricsProtoEnvKey = "OTEL_EXPORTER_OTLP_METRICS_PROTOCOL"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var errInvalidOTLPProtocol = errors.New("invalid OTLP protocol - should be one of ['grpc', 'http/protobuf']")
 | 
			
		||||
 | 
			
		||||
// newInvalidProtocolError creates a specific error message for invalid protocols
 | 
			
		||||
func newInvalidProtocolError(protocol, signalType string) error {
 | 
			
		||||
	return fmt.Errorf("invalid OTLP protocol '%s' for %s - should be one of ['grpc', 'http/protobuf', 'http/json']", protocol, signalType)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Validate checks the configuration for common errors and inconsistencies
 | 
			
		||||
func (c *Config) Validate() error {
 | 
			
		||||
	var errs []error
 | 
			
		||||
 | 
			
		||||
	// Check that both Endpoint and EndpointURL are not specified
 | 
			
		||||
	if c.Endpoint != "" && c.EndpointURL != "" {
 | 
			
		||||
		errs = append(errs, errors.New("cannot specify both Endpoint and EndpointURL - use one or the other"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Validate EndpointURL format if specified
 | 
			
		||||
	if c.EndpointURL != "" {
 | 
			
		||||
		if _, err := url.Parse(c.EndpointURL); err != nil {
 | 
			
		||||
			errs = append(errs, fmt.Errorf("invalid EndpointURL format: %w", err))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Validate Endpoint format if specified
 | 
			
		||||
	if c.Endpoint != "" {
 | 
			
		||||
		// Basic validation - should not contain protocol scheme
 | 
			
		||||
		if strings.Contains(c.Endpoint, "://") {
 | 
			
		||||
			errs = append(errs, errors.New("Endpoint should not include protocol scheme (use EndpointURL for full URLs)"))
 | 
			
		||||
		}
 | 
			
		||||
		// Should not be empty after trimming whitespace
 | 
			
		||||
		if strings.TrimSpace(c.Endpoint) == "" {
 | 
			
		||||
			errs = append(errs, errors.New("Endpoint cannot be empty or whitespace"))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Validate TLS configuration consistency
 | 
			
		||||
	if c.CertificateProvider != nil && c.RootCAs == nil {
 | 
			
		||||
		// This is just a warning - client cert without custom CAs is valid
 | 
			
		||||
		// but might indicate a configuration issue
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Validate service name if specified
 | 
			
		||||
	if c.ServiceName != "" && strings.TrimSpace(c.ServiceName) == "" {
 | 
			
		||||
		errs = append(errs, errors.New("ServiceName cannot be empty or whitespace"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Combine all errors
 | 
			
		||||
	if len(errs) > 0 {
 | 
			
		||||
		var errMsgs []string
 | 
			
		||||
		for _, err := range errs {
 | 
			
		||||
			errMsgs = append(errMsgs, err.Error())
 | 
			
		||||
		}
 | 
			
		||||
		return fmt.Errorf("configuration validation failed: %s", strings.Join(errMsgs, "; "))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ValidateAndStore validates the configuration before storing it
 | 
			
		||||
func ValidateAndStore(ctx context.Context, cfg *Config, logFactory LogExporterFactory, metricFactory MetricExporterFactory, traceFactory TraceExporterFactory) error {
 | 
			
		||||
	if cfg != nil {
 | 
			
		||||
		if err := cfg.Validate(); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	Store(ctx, cfg, logFactory, metricFactory, traceFactory)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetClientCertificate defines a function type for providing client certificates for mutual TLS.
 | 
			
		||||
// This is used when exporting telemetry data to secured OTLP endpoints that require
 | 
			
		||||
// client certificate authentication.
 | 
			
		||||
type GetClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
 | 
			
		||||
 | 
			
		||||
// Config provides configuration options for OpenTelemetry tracing setup.
 | 
			
		||||
// It supplements standard OpenTelemetry environment variables with additional
 | 
			
		||||
// NTP Pool-specific configuration including TLS settings for secure OTLP export.
 | 
			
		||||
type Config struct {
 | 
			
		||||
	ServiceName         string               // Service name for resource identification (overrides OTEL_SERVICE_NAME)
 | 
			
		||||
	Environment         string               // Deployment environment (development, staging, production)
 | 
			
		||||
	Endpoint            string               // OTLP endpoint hostname/port (e.g., "otlp.example.com:4317")
 | 
			
		||||
	EndpointURL         string               // Complete OTLP endpoint URL (e.g., "https://otlp.example.com:4317/v1/traces")
 | 
			
		||||
	CertificateProvider GetClientCertificate // Client certificate provider for mutual TLS
 | 
			
		||||
	RootCAs             *x509.CertPool       // CA certificate pool for server verification
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// LogExporterFactory creates an OTLP log exporter using the provided configuration.
 | 
			
		||||
// This allows the logger package to create exporters without importing the tracing package.
 | 
			
		||||
type LogExporterFactory func(context.Context, *Config) (sdklog.Exporter, error)
 | 
			
		||||
 | 
			
		||||
// MetricExporterFactory creates an OTLP metric exporter using the provided configuration.
 | 
			
		||||
// This allows the metrics package to create exporters without importing the tracing package.
 | 
			
		||||
type MetricExporterFactory func(context.Context, *Config) (sdkmetric.Exporter, error)
 | 
			
		||||
 | 
			
		||||
// TraceExporterFactory creates an OTLP trace exporter using the provided configuration.
 | 
			
		||||
// This allows for consistent trace exporter creation across packages.
 | 
			
		||||
type TraceExporterFactory func(context.Context, *Config) (sdktrace.SpanExporter, error)
 | 
			
		||||
 | 
			
		||||
// Global state for sharing configuration between packages
 | 
			
		||||
var (
 | 
			
		||||
	globalConfig          *Config
 | 
			
		||||
	globalContext         context.Context
 | 
			
		||||
	logExporterFactory    LogExporterFactory
 | 
			
		||||
	metricExporterFactory MetricExporterFactory
 | 
			
		||||
	traceExporterFactory  TraceExporterFactory
 | 
			
		||||
	configMu              sync.RWMutex
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Store saves the tracer configuration and exporter factories for use by other packages.
 | 
			
		||||
// This should be called by the tracing package during initialization.
 | 
			
		||||
func Store(ctx context.Context, cfg *Config, logFactory LogExporterFactory, metricFactory MetricExporterFactory, traceFactory TraceExporterFactory) {
 | 
			
		||||
	configMu.Lock()
 | 
			
		||||
	defer configMu.Unlock()
 | 
			
		||||
	globalConfig = cfg
 | 
			
		||||
	globalContext = ctx
 | 
			
		||||
	logExporterFactory = logFactory
 | 
			
		||||
	metricExporterFactory = metricFactory
 | 
			
		||||
	traceExporterFactory = traceFactory
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetLogExporter returns the stored configuration and log exporter factory.
 | 
			
		||||
// Returns nil values if no configuration has been stored yet.
 | 
			
		||||
func GetLogExporter() (*Config, context.Context, LogExporterFactory) {
 | 
			
		||||
	configMu.RLock()
 | 
			
		||||
	defer configMu.RUnlock()
 | 
			
		||||
	return globalConfig, globalContext, logExporterFactory
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetMetricExporter returns the stored configuration and metric exporter factory.
 | 
			
		||||
// Returns nil values if no configuration has been stored yet.
 | 
			
		||||
func GetMetricExporter() (*Config, context.Context, MetricExporterFactory) {
 | 
			
		||||
	configMu.RLock()
 | 
			
		||||
	defer configMu.RUnlock()
 | 
			
		||||
	return globalConfig, globalContext, metricExporterFactory
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetTraceExporter returns the stored configuration and trace exporter factory.
 | 
			
		||||
// Returns nil values if no configuration has been stored yet.
 | 
			
		||||
func GetTraceExporter() (*Config, context.Context, TraceExporterFactory) {
 | 
			
		||||
	configMu.RLock()
 | 
			
		||||
	defer configMu.RUnlock()
 | 
			
		||||
	return globalConfig, globalContext, traceExporterFactory
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Get returns the stored tracer configuration, context, and log exporter factory.
 | 
			
		||||
// This maintains backward compatibility for the logger package.
 | 
			
		||||
// Returns nil values if no configuration has been stored yet.
 | 
			
		||||
func Get() (*Config, context.Context, LogExporterFactory) {
 | 
			
		||||
	return GetLogExporter()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsConfigured returns true if tracer configuration has been stored.
 | 
			
		||||
func IsConfigured() bool {
 | 
			
		||||
	configMu.RLock()
 | 
			
		||||
	defer configMu.RUnlock()
 | 
			
		||||
	return globalConfig != nil && globalContext != nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Clear removes the stored configuration. This is primarily useful for testing.
 | 
			
		||||
func Clear() {
 | 
			
		||||
	configMu.Lock()
 | 
			
		||||
	defer configMu.Unlock()
 | 
			
		||||
	globalConfig = nil
 | 
			
		||||
	globalContext = nil
 | 
			
		||||
	logExporterFactory = nil
 | 
			
		||||
	metricExporterFactory = nil
 | 
			
		||||
	traceExporterFactory = nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// getTLSConfig creates a TLS configuration from the provided Config.
 | 
			
		||||
func getTLSConfig(cfg *Config) *tls.Config {
 | 
			
		||||
	if cfg.CertificateProvider == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return &tls.Config{
 | 
			
		||||
		GetClientCertificate: cfg.CertificateProvider,
 | 
			
		||||
		RootCAs:              cfg.RootCAs,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// getProtocol determines the OTLP protocol to use for the given signal type.
 | 
			
		||||
// It follows OpenTelemetry environment variable precedence.
 | 
			
		||||
func getProtocol(signalSpecificEnv string) string {
 | 
			
		||||
	proto := os.Getenv(signalSpecificEnv)
 | 
			
		||||
	if proto == "" {
 | 
			
		||||
		proto = os.Getenv(otelExporterOTLPProtoEnvKey)
 | 
			
		||||
	}
 | 
			
		||||
	// Fallback to default, http/protobuf.
 | 
			
		||||
	if proto == "" {
 | 
			
		||||
		proto = "http/protobuf"
 | 
			
		||||
	}
 | 
			
		||||
	return proto
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CreateOTLPLogExporter creates an OTLP log exporter using the provided configuration.
 | 
			
		||||
func CreateOTLPLogExporter(ctx context.Context, cfg *Config) (sdklog.Exporter, error) {
 | 
			
		||||
	tlsConfig := getTLSConfig(cfg)
 | 
			
		||||
	proto := getProtocol(otelExporterOTLPLogsProtoEnvKey)
 | 
			
		||||
 | 
			
		||||
	switch proto {
 | 
			
		||||
	case "grpc":
 | 
			
		||||
		opts := []otlploggrpc.Option{
 | 
			
		||||
			otlploggrpc.WithCompressor("gzip"),
 | 
			
		||||
		}
 | 
			
		||||
		if tlsConfig != nil {
 | 
			
		||||
			opts = append(opts, otlploggrpc.WithTLSCredentials(credentials.NewTLS(tlsConfig)))
 | 
			
		||||
		}
 | 
			
		||||
		if len(cfg.Endpoint) > 0 {
 | 
			
		||||
			opts = append(opts, otlploggrpc.WithEndpoint(cfg.Endpoint))
 | 
			
		||||
		}
 | 
			
		||||
		if len(cfg.EndpointURL) > 0 {
 | 
			
		||||
			opts = append(opts, otlploggrpc.WithEndpointURL(cfg.EndpointURL))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return otlploggrpc.New(ctx, opts...)
 | 
			
		||||
	case "http/protobuf", "http/json":
 | 
			
		||||
		opts := []otlploghttp.Option{
 | 
			
		||||
			otlploghttp.WithCompression(otlploghttp.GzipCompression),
 | 
			
		||||
		}
 | 
			
		||||
		if tlsConfig != nil {
 | 
			
		||||
			opts = append(opts, otlploghttp.WithTLSClientConfig(tlsConfig))
 | 
			
		||||
		}
 | 
			
		||||
		if len(cfg.Endpoint) > 0 {
 | 
			
		||||
			opts = append(opts, otlploghttp.WithEndpoint(cfg.Endpoint))
 | 
			
		||||
		}
 | 
			
		||||
		if len(cfg.EndpointURL) > 0 {
 | 
			
		||||
			opts = append(opts, otlploghttp.WithEndpointURL(cfg.EndpointURL))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		opts = append(opts, otlploghttp.WithRetry(otlploghttp.RetryConfig{
 | 
			
		||||
			Enabled:         true,
 | 
			
		||||
			InitialInterval: 3 * time.Second,
 | 
			
		||||
			MaxInterval:     60 * time.Second,
 | 
			
		||||
			MaxElapsedTime:  5 * time.Minute,
 | 
			
		||||
		}))
 | 
			
		||||
 | 
			
		||||
		return otlploghttp.New(ctx, opts...)
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, newInvalidProtocolError(proto, "logs")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CreateOTLPMetricExporter creates an OTLP metric exporter using the provided configuration.
 | 
			
		||||
func CreateOTLPMetricExporter(ctx context.Context, cfg *Config) (sdkmetric.Exporter, error) {
 | 
			
		||||
	tlsConfig := getTLSConfig(cfg)
 | 
			
		||||
	proto := getProtocol(otelExporterOTLPMetricsProtoEnvKey)
 | 
			
		||||
 | 
			
		||||
	switch proto {
 | 
			
		||||
	case "grpc":
 | 
			
		||||
		opts := []otlpmetricgrpc.Option{
 | 
			
		||||
			otlpmetricgrpc.WithCompressor("gzip"),
 | 
			
		||||
		}
 | 
			
		||||
		if tlsConfig != nil {
 | 
			
		||||
			opts = append(opts, otlpmetricgrpc.WithTLSCredentials(credentials.NewTLS(tlsConfig)))
 | 
			
		||||
		}
 | 
			
		||||
		if len(cfg.Endpoint) > 0 {
 | 
			
		||||
			opts = append(opts, otlpmetricgrpc.WithEndpoint(cfg.Endpoint))
 | 
			
		||||
		}
 | 
			
		||||
		if len(cfg.EndpointURL) > 0 {
 | 
			
		||||
			opts = append(opts, otlpmetricgrpc.WithEndpointURL(cfg.EndpointURL))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return otlpmetricgrpc.New(ctx, opts...)
 | 
			
		||||
	case "http/protobuf", "http/json":
 | 
			
		||||
		opts := []otlpmetrichttp.Option{
 | 
			
		||||
			otlpmetrichttp.WithCompression(otlpmetrichttp.GzipCompression),
 | 
			
		||||
		}
 | 
			
		||||
		if tlsConfig != nil {
 | 
			
		||||
			opts = append(opts, otlpmetrichttp.WithTLSClientConfig(tlsConfig))
 | 
			
		||||
		}
 | 
			
		||||
		if len(cfg.Endpoint) > 0 {
 | 
			
		||||
			opts = append(opts, otlpmetrichttp.WithEndpoint(cfg.Endpoint))
 | 
			
		||||
		}
 | 
			
		||||
		if len(cfg.EndpointURL) > 0 {
 | 
			
		||||
			opts = append(opts, otlpmetrichttp.WithEndpointURL(cfg.EndpointURL))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		opts = append(opts, otlpmetrichttp.WithRetry(otlpmetrichttp.RetryConfig{
 | 
			
		||||
			Enabled:         true,
 | 
			
		||||
			InitialInterval: 3 * time.Second,
 | 
			
		||||
			MaxInterval:     60 * time.Second,
 | 
			
		||||
			MaxElapsedTime:  5 * time.Minute,
 | 
			
		||||
		}))
 | 
			
		||||
 | 
			
		||||
		return otlpmetrichttp.New(ctx, opts...)
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, newInvalidProtocolError(proto, "metrics")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CreateOTLPTraceExporter creates an OTLP trace exporter using the provided configuration.
 | 
			
		||||
func CreateOTLPTraceExporter(ctx context.Context, cfg *Config) (sdktrace.SpanExporter, error) {
 | 
			
		||||
	tlsConfig := getTLSConfig(cfg)
 | 
			
		||||
	proto := getProtocol(otelExporterOTLPTracesProtoEnvKey)
 | 
			
		||||
 | 
			
		||||
	var client otlptrace.Client
 | 
			
		||||
 | 
			
		||||
	switch proto {
 | 
			
		||||
	case "grpc":
 | 
			
		||||
		opts := []otlptracegrpc.Option{
 | 
			
		||||
			otlptracegrpc.WithCompressor("gzip"),
 | 
			
		||||
		}
 | 
			
		||||
		if tlsConfig != nil {
 | 
			
		||||
			opts = append(opts, otlptracegrpc.WithTLSCredentials(credentials.NewTLS(tlsConfig)))
 | 
			
		||||
		}
 | 
			
		||||
		if len(cfg.Endpoint) > 0 {
 | 
			
		||||
			opts = append(opts, otlptracegrpc.WithEndpoint(cfg.Endpoint))
 | 
			
		||||
		}
 | 
			
		||||
		if len(cfg.EndpointURL) > 0 {
 | 
			
		||||
			opts = append(opts, otlptracegrpc.WithEndpointURL(cfg.EndpointURL))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		client = otlptracegrpc.NewClient(opts...)
 | 
			
		||||
	case "http/protobuf", "http/json":
 | 
			
		||||
		opts := []otlptracehttp.Option{
 | 
			
		||||
			otlptracehttp.WithCompression(otlptracehttp.GzipCompression),
 | 
			
		||||
		}
 | 
			
		||||
		if tlsConfig != nil {
 | 
			
		||||
			opts = append(opts, otlptracehttp.WithTLSClientConfig(tlsConfig))
 | 
			
		||||
		}
 | 
			
		||||
		if len(cfg.Endpoint) > 0 {
 | 
			
		||||
			opts = append(opts, otlptracehttp.WithEndpoint(cfg.Endpoint))
 | 
			
		||||
		}
 | 
			
		||||
		if len(cfg.EndpointURL) > 0 {
 | 
			
		||||
			opts = append(opts, otlptracehttp.WithEndpointURL(cfg.EndpointURL))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		opts = append(opts, otlptracehttp.WithRetry(otlptracehttp.RetryConfig{
 | 
			
		||||
			Enabled:         true,
 | 
			
		||||
			InitialInterval: 3 * time.Second,
 | 
			
		||||
			MaxInterval:     60 * time.Second,
 | 
			
		||||
			MaxElapsedTime:  5 * time.Minute,
 | 
			
		||||
		}))
 | 
			
		||||
 | 
			
		||||
		client = otlptracehttp.NewClient(opts...)
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, newInvalidProtocolError(proto, "traces")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return otlptrace.New(ctx, client)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										474
									
								
								internal/tracerconfig/config_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										474
									
								
								internal/tracerconfig/config_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,474 @@
 | 
			
		||||
package tracerconfig
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"crypto/x509"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	sdklog "go.opentelemetry.io/otel/sdk/log"
 | 
			
		||||
	sdkmetric "go.opentelemetry.io/otel/sdk/metric"
 | 
			
		||||
	sdktrace "go.opentelemetry.io/otel/sdk/trace"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestStore_And_Retrieve(t *testing.T) {
 | 
			
		||||
	// Clear any existing configuration
 | 
			
		||||
	Clear()
 | 
			
		||||
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	config := &Config{
 | 
			
		||||
		ServiceName: "test-service",
 | 
			
		||||
		Environment: "test",
 | 
			
		||||
		Endpoint:    "localhost:4317",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Create mock factories
 | 
			
		||||
	logFactory := func(context.Context, *Config) (sdklog.Exporter, error) { return nil, nil }
 | 
			
		||||
	metricFactory := func(context.Context, *Config) (sdkmetric.Exporter, error) { return nil, nil }
 | 
			
		||||
	traceFactory := func(context.Context, *Config) (sdktrace.SpanExporter, error) { return nil, nil }
 | 
			
		||||
 | 
			
		||||
	// Store configuration
 | 
			
		||||
	Store(ctx, config, logFactory, metricFactory, traceFactory)
 | 
			
		||||
 | 
			
		||||
	// Test IsConfigured
 | 
			
		||||
	if !IsConfigured() {
 | 
			
		||||
		t.Error("IsConfigured() should return true after Store()")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test GetLogExporter
 | 
			
		||||
	cfg, ctx2, factory := GetLogExporter()
 | 
			
		||||
	if cfg == nil || ctx2 == nil || factory == nil {
 | 
			
		||||
		t.Error("GetLogExporter() should return non-nil values")
 | 
			
		||||
	}
 | 
			
		||||
	if cfg.ServiceName != "test-service" {
 | 
			
		||||
		t.Errorf("Expected ServiceName 'test-service', got '%s'", cfg.ServiceName)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test GetMetricExporter
 | 
			
		||||
	cfg, ctx3, metricFact := GetMetricExporter()
 | 
			
		||||
	if cfg == nil || ctx3 == nil || metricFact == nil {
 | 
			
		||||
		t.Error("GetMetricExporter() should return non-nil values")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test GetTraceExporter
 | 
			
		||||
	cfg, ctx4, traceFact := GetTraceExporter()
 | 
			
		||||
	if cfg == nil || ctx4 == nil || traceFact == nil {
 | 
			
		||||
		t.Error("GetTraceExporter() should return non-nil values")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test backward compatibility Get()
 | 
			
		||||
	cfg, ctx5, logFact := Get()
 | 
			
		||||
	if cfg == nil || ctx5 == nil || logFact == nil {
 | 
			
		||||
		t.Error("Get() should return non-nil values for backward compatibility")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestClear(t *testing.T) {
 | 
			
		||||
	// Store some configuration first
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	config := &Config{ServiceName: "test"}
 | 
			
		||||
	Store(ctx, config, nil, nil, nil)
 | 
			
		||||
 | 
			
		||||
	if !IsConfigured() {
 | 
			
		||||
		t.Error("Should be configured before Clear()")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Clear configuration
 | 
			
		||||
	Clear()
 | 
			
		||||
 | 
			
		||||
	if IsConfigured() {
 | 
			
		||||
		t.Error("Should not be configured after Clear()")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// All getters should return nil
 | 
			
		||||
	cfg, ctx2, factory := GetLogExporter()
 | 
			
		||||
	if cfg != nil || ctx2 != nil || factory != nil {
 | 
			
		||||
		t.Error("GetLogExporter() should return nil values after Clear()")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestConcurrentAccess(t *testing.T) {
 | 
			
		||||
	Clear()
 | 
			
		||||
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	config := &Config{ServiceName: "concurrent-test"}
 | 
			
		||||
 | 
			
		||||
	var wg sync.WaitGroup
 | 
			
		||||
	const numGoroutines = 10
 | 
			
		||||
 | 
			
		||||
	// Test concurrent Store and Get operations
 | 
			
		||||
	wg.Add(numGoroutines * 2)
 | 
			
		||||
 | 
			
		||||
	// Concurrent Store operations
 | 
			
		||||
	for i := 0; i < numGoroutines; i++ {
 | 
			
		||||
		go func() {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
			Store(ctx, config, nil, nil, nil)
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Concurrent Get operations
 | 
			
		||||
	for i := 0; i < numGoroutines; i++ {
 | 
			
		||||
		go func() {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
			IsConfigured()
 | 
			
		||||
			GetLogExporter()
 | 
			
		||||
			GetMetricExporter()
 | 
			
		||||
			GetTraceExporter()
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
 | 
			
		||||
	// Should be configured after all operations
 | 
			
		||||
	if !IsConfigured() {
 | 
			
		||||
		t.Error("Should be configured after concurrent operations")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGetTLSConfig(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name     string
 | 
			
		||||
		config   *Config
 | 
			
		||||
		expected bool // whether TLS config should be nil
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:     "nil certificate provider",
 | 
			
		||||
			config:   &Config{},
 | 
			
		||||
			expected: true, // should be nil
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "with certificate provider",
 | 
			
		||||
			config: &Config{
 | 
			
		||||
				CertificateProvider: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
 | 
			
		||||
					return &tls.Certificate{}, nil
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			expected: false, // should not be nil
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "with certificate provider and RootCAs",
 | 
			
		||||
			config: &Config{
 | 
			
		||||
				CertificateProvider: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
 | 
			
		||||
					return &tls.Certificate{}, nil
 | 
			
		||||
				},
 | 
			
		||||
				RootCAs: x509.NewCertPool(),
 | 
			
		||||
			},
 | 
			
		||||
			expected: false, // should not be nil
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			tlsConfig := getTLSConfig(tt.config)
 | 
			
		||||
			if tt.expected && tlsConfig != nil {
 | 
			
		||||
				t.Errorf("Expected nil TLS config, got %v", tlsConfig)
 | 
			
		||||
			}
 | 
			
		||||
			if !tt.expected && tlsConfig == nil {
 | 
			
		||||
				t.Error("Expected non-nil TLS config, got nil")
 | 
			
		||||
			}
 | 
			
		||||
			if !tt.expected && tlsConfig != nil {
 | 
			
		||||
				if tlsConfig.GetClientCertificate == nil {
 | 
			
		||||
					t.Error("Expected GetClientCertificate to be set")
 | 
			
		||||
				}
 | 
			
		||||
				if tt.config.RootCAs != nil && tlsConfig.RootCAs != tt.config.RootCAs {
 | 
			
		||||
					t.Error("Expected RootCAs to be set correctly")
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGetProtocol(t *testing.T) {
 | 
			
		||||
	// Save original env vars
 | 
			
		||||
	originalGeneral := os.Getenv(otelExporterOTLPProtoEnvKey)
 | 
			
		||||
	originalLogs := os.Getenv(otelExporterOTLPLogsProtoEnvKey)
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		// Restore original env vars
 | 
			
		||||
		if originalGeneral != "" {
 | 
			
		||||
			os.Setenv(otelExporterOTLPProtoEnvKey, originalGeneral)
 | 
			
		||||
		} else {
 | 
			
		||||
			os.Unsetenv(otelExporterOTLPProtoEnvKey)
 | 
			
		||||
		}
 | 
			
		||||
		if originalLogs != "" {
 | 
			
		||||
			os.Setenv(otelExporterOTLPLogsProtoEnvKey, originalLogs)
 | 
			
		||||
		} else {
 | 
			
		||||
			os.Unsetenv(otelExporterOTLPLogsProtoEnvKey)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name           string
 | 
			
		||||
		signalSpecific string
 | 
			
		||||
		generalProto   string
 | 
			
		||||
		specificProto  string
 | 
			
		||||
		expectedResult string
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:           "no env vars set - default",
 | 
			
		||||
			signalSpecific: otelExporterOTLPLogsProtoEnvKey,
 | 
			
		||||
			expectedResult: "http/protobuf",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:           "general env var set",
 | 
			
		||||
			signalSpecific: otelExporterOTLPLogsProtoEnvKey,
 | 
			
		||||
			generalProto:   "grpc",
 | 
			
		||||
			expectedResult: "grpc",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:           "specific env var overrides general",
 | 
			
		||||
			signalSpecific: otelExporterOTLPLogsProtoEnvKey,
 | 
			
		||||
			generalProto:   "grpc",
 | 
			
		||||
			specificProto:  "http/protobuf",
 | 
			
		||||
			expectedResult: "http/protobuf",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			// Clear env vars
 | 
			
		||||
			os.Unsetenv(otelExporterOTLPProtoEnvKey)
 | 
			
		||||
			os.Unsetenv(otelExporterOTLPLogsProtoEnvKey)
 | 
			
		||||
 | 
			
		||||
			// Set test env vars
 | 
			
		||||
			if tt.generalProto != "" {
 | 
			
		||||
				os.Setenv(otelExporterOTLPProtoEnvKey, tt.generalProto)
 | 
			
		||||
			}
 | 
			
		||||
			if tt.specificProto != "" {
 | 
			
		||||
				os.Setenv(tt.signalSpecific, tt.specificProto)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			result := getProtocol(tt.signalSpecific)
 | 
			
		||||
			if result != tt.expectedResult {
 | 
			
		||||
				t.Errorf("Expected protocol '%s', got '%s'", tt.expectedResult, result)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCreateExporterErrors(t *testing.T) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	config := &Config{
 | 
			
		||||
		ServiceName: "test-service",
 | 
			
		||||
		Endpoint:    "invalid-endpoint",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test with invalid protocol for logs
 | 
			
		||||
	os.Setenv(otelExporterOTLPLogsProtoEnvKey, "invalid-protocol")
 | 
			
		||||
	defer os.Unsetenv(otelExporterOTLPLogsProtoEnvKey)
 | 
			
		||||
 | 
			
		||||
	_, err := CreateOTLPLogExporter(ctx, config)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Error("Expected error for invalid protocol")
 | 
			
		||||
	}
 | 
			
		||||
	// Check that it's a protocol error (the specific message will be different now)
 | 
			
		||||
	if !strings.Contains(err.Error(), "invalid OTLP protocol") {
 | 
			
		||||
		t.Errorf("Expected protocol error, got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test with invalid protocol for metrics
 | 
			
		||||
	os.Setenv(otelExporterOTLPMetricsProtoEnvKey, "invalid-protocol")
 | 
			
		||||
	defer os.Unsetenv(otelExporterOTLPMetricsProtoEnvKey)
 | 
			
		||||
 | 
			
		||||
	_, err = CreateOTLPMetricExporter(ctx, config)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Error("Expected error for invalid protocol")
 | 
			
		||||
	}
 | 
			
		||||
	if !strings.Contains(err.Error(), "invalid OTLP protocol") {
 | 
			
		||||
		t.Errorf("Expected protocol error, got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test with invalid protocol for traces
 | 
			
		||||
	os.Setenv(otelExporterOTLPTracesProtoEnvKey, "invalid-protocol")
 | 
			
		||||
	defer os.Unsetenv(otelExporterOTLPTracesProtoEnvKey)
 | 
			
		||||
 | 
			
		||||
	_, err = CreateOTLPTraceExporter(ctx, config)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Error("Expected error for invalid protocol")
 | 
			
		||||
	}
 | 
			
		||||
	if !strings.Contains(err.Error(), "invalid OTLP protocol") {
 | 
			
		||||
		t.Errorf("Expected protocol error, got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCreateExporterValidProtocols(t *testing.T) {
 | 
			
		||||
	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
	config := &Config{
 | 
			
		||||
		ServiceName: "test-service",
 | 
			
		||||
		Endpoint:    "localhost:4317", // This will likely fail to connect, but should create exporter
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	protocols := []string{"grpc", "http/protobuf", "http/json"}
 | 
			
		||||
 | 
			
		||||
	for _, proto := range protocols {
 | 
			
		||||
		t.Run("logs_"+proto, func(t *testing.T) {
 | 
			
		||||
			os.Setenv(otelExporterOTLPLogsProtoEnvKey, proto)
 | 
			
		||||
			defer os.Unsetenv(otelExporterOTLPLogsProtoEnvKey)
 | 
			
		||||
 | 
			
		||||
			exporter, err := CreateOTLPLogExporter(ctx, config)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				// Connection errors are expected since we're not running a real OTLP server
 | 
			
		||||
				// but the exporter should be created successfully
 | 
			
		||||
				t.Logf("Connection error expected: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
			if exporter != nil {
 | 
			
		||||
				exporter.Shutdown(ctx)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		t.Run("metrics_"+proto, func(t *testing.T) {
 | 
			
		||||
			os.Setenv(otelExporterOTLPMetricsProtoEnvKey, proto)
 | 
			
		||||
			defer os.Unsetenv(otelExporterOTLPMetricsProtoEnvKey)
 | 
			
		||||
 | 
			
		||||
			exporter, err := CreateOTLPMetricExporter(ctx, config)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Logf("Connection error expected: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
			if exporter != nil {
 | 
			
		||||
				exporter.Shutdown(ctx)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		t.Run("traces_"+proto, func(t *testing.T) {
 | 
			
		||||
			os.Setenv(otelExporterOTLPTracesProtoEnvKey, proto)
 | 
			
		||||
			defer os.Unsetenv(otelExporterOTLPTracesProtoEnvKey)
 | 
			
		||||
 | 
			
		||||
			exporter, err := CreateOTLPTraceExporter(ctx, config)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Logf("Connection error expected: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
			if exporter != nil {
 | 
			
		||||
				exporter.Shutdown(ctx)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestConfigValidation(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name      string
 | 
			
		||||
		config    *Config
 | 
			
		||||
		shouldErr bool
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:      "valid empty config",
 | 
			
		||||
			config:    &Config{},
 | 
			
		||||
			shouldErr: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "valid config with endpoint",
 | 
			
		||||
			config: &Config{
 | 
			
		||||
				ServiceName: "test-service",
 | 
			
		||||
				Endpoint:    "localhost:4317",
 | 
			
		||||
			},
 | 
			
		||||
			shouldErr: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "valid config with endpoint URL",
 | 
			
		||||
			config: &Config{
 | 
			
		||||
				ServiceName: "test-service",
 | 
			
		||||
				EndpointURL: "https://otlp.example.com:4317/v1/traces",
 | 
			
		||||
			},
 | 
			
		||||
			shouldErr: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "invalid - both endpoint and endpoint URL",
 | 
			
		||||
			config: &Config{
 | 
			
		||||
				ServiceName: "test-service",
 | 
			
		||||
				Endpoint:    "localhost:4317",
 | 
			
		||||
				EndpointURL: "https://otlp.example.com:4317/v1/traces",
 | 
			
		||||
			},
 | 
			
		||||
			shouldErr: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "invalid - endpoint with protocol",
 | 
			
		||||
			config: &Config{
 | 
			
		||||
				ServiceName: "test-service",
 | 
			
		||||
				Endpoint:    "https://localhost:4317",
 | 
			
		||||
			},
 | 
			
		||||
			shouldErr: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "invalid - empty endpoint",
 | 
			
		||||
			config: &Config{
 | 
			
		||||
				ServiceName: "test-service",
 | 
			
		||||
				Endpoint:    "   ",
 | 
			
		||||
			},
 | 
			
		||||
			shouldErr: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "invalid - malformed endpoint URL",
 | 
			
		||||
			config: &Config{
 | 
			
		||||
				ServiceName: "test-service",
 | 
			
		||||
				EndpointURL: "://invalid-url-missing-scheme",
 | 
			
		||||
			},
 | 
			
		||||
			shouldErr: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "invalid - empty service name",
 | 
			
		||||
			config: &Config{
 | 
			
		||||
				ServiceName: "   ",
 | 
			
		||||
				Endpoint:    "localhost:4317",
 | 
			
		||||
			},
 | 
			
		||||
			shouldErr: true,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			err := tt.config.Validate()
 | 
			
		||||
			if tt.shouldErr && err == nil {
 | 
			
		||||
				t.Error("Expected validation error, got nil")
 | 
			
		||||
			}
 | 
			
		||||
			if !tt.shouldErr && err != nil {
 | 
			
		||||
				t.Errorf("Expected no validation error, got: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestValidateAndStore(t *testing.T) {
 | 
			
		||||
	Clear()
 | 
			
		||||
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
 | 
			
		||||
	// Test with valid config
 | 
			
		||||
	validConfig := &Config{
 | 
			
		||||
		ServiceName: "test-service",
 | 
			
		||||
		Endpoint:    "localhost:4317",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := ValidateAndStore(ctx, validConfig, nil, nil, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("ValidateAndStore with valid config should not error: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !IsConfigured() {
 | 
			
		||||
		t.Error("Should be configured after ValidateAndStore")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	Clear()
 | 
			
		||||
 | 
			
		||||
	// Test with invalid config
 | 
			
		||||
	invalidConfig := &Config{
 | 
			
		||||
		ServiceName: "test-service",
 | 
			
		||||
		Endpoint:    "localhost:4317",
 | 
			
		||||
		EndpointURL: "https://example.com:4317", // both specified - invalid
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = ValidateAndStore(ctx, invalidConfig, nil, nil, nil)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Error("ValidateAndStore with invalid config should return error")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if IsConfigured() {
 | 
			
		||||
		t.Error("Should not be configured after failed ValidateAndStore")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -1,3 +1,32 @@
 | 
			
		||||
// Package kafconn provides a Kafka client wrapper with TLS support for secure log streaming.
 | 
			
		||||
//
 | 
			
		||||
// This package handles Kafka connections with mutual TLS authentication for the NTP Pool
 | 
			
		||||
// project's log streaming infrastructure. It provides factories for creating Kafka readers
 | 
			
		||||
// and writers with automatic broker discovery, TLS configuration, and connection management.
 | 
			
		||||
//
 | 
			
		||||
// The package is designed specifically for the NTP Pool pipeline infrastructure and includes
 | 
			
		||||
// hardcoded bootstrap servers and group configurations. It uses certman for automatic
 | 
			
		||||
// certificate renewal and provides compression and batching optimizations.
 | 
			
		||||
//
 | 
			
		||||
// Key features:
 | 
			
		||||
//   - Mutual TLS authentication with automatic certificate renewal
 | 
			
		||||
//   - Broker discovery and connection pooling
 | 
			
		||||
//   - Reader and writer factory methods with optimized configurations
 | 
			
		||||
//   - LZ4 compression for efficient data transfer
 | 
			
		||||
//   - Configurable batch sizes and load balancing
 | 
			
		||||
//
 | 
			
		||||
// Example usage:
 | 
			
		||||
//
 | 
			
		||||
//	tlsSetup := kafconn.TLSSetup{
 | 
			
		||||
//		CA:   "/path/to/ca.pem",
 | 
			
		||||
//		Cert: "/path/to/client.pem",
 | 
			
		||||
//		Key:  "/path/to/client.key",
 | 
			
		||||
//	}
 | 
			
		||||
//	kafka, err := kafconn.NewKafka(ctx, tlsSetup)
 | 
			
		||||
//	if err != nil {
 | 
			
		||||
//		log.Fatal(err)
 | 
			
		||||
//	}
 | 
			
		||||
//	writer, err := kafka.NewWriter("logs")
 | 
			
		||||
package kafconn
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
@@ -24,12 +53,17 @@ const (
 | 
			
		||||
	// kafkaMinBatchSize = 1000
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// TLSSetup contains file paths for TLS certificate configuration.
 | 
			
		||||
// All fields are required for establishing secure Kafka connections.
 | 
			
		||||
type TLSSetup struct {
 | 
			
		||||
	CA   string
 | 
			
		||||
	Key  string
 | 
			
		||||
	Cert string
 | 
			
		||||
	CA   string // Path to CA certificate file for server verification
 | 
			
		||||
	Key  string // Path to client private key file
 | 
			
		||||
	Cert string // Path to client certificate file
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Kafka represents a configured Kafka client with TLS support.
 | 
			
		||||
// It manages connections, brokers, and provides factory methods for readers and writers.
 | 
			
		||||
// The client handles broker discovery, connection pooling, and TLS configuration automatically.
 | 
			
		||||
type Kafka struct {
 | 
			
		||||
	tls TLSSetup
 | 
			
		||||
 | 
			
		||||
@@ -116,6 +150,19 @@ func (k *Kafka) kafkaTransport(ctx context.Context) (*kafka.Transport, error) {
 | 
			
		||||
	return transport, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewKafka creates a new Kafka client with TLS configuration and establishes initial connections.
 | 
			
		||||
// It performs broker discovery, validates TLS certificates, and prepares the client for creating
 | 
			
		||||
// readers and writers.
 | 
			
		||||
//
 | 
			
		||||
// The function validates TLS configuration, establishes a connection to the bootstrap server,
 | 
			
		||||
// discovers all available brokers, and configures transport layers for optimal performance.
 | 
			
		||||
//
 | 
			
		||||
// Parameters:
 | 
			
		||||
//   - ctx: Context for connection establishment and timeouts
 | 
			
		||||
//   - tls: TLS configuration with paths to CA, certificate, and key files
 | 
			
		||||
//
 | 
			
		||||
// Returns a configured Kafka client ready for creating readers and writers, or an error
 | 
			
		||||
// if TLS setup fails, connection cannot be established, or broker discovery fails.
 | 
			
		||||
func NewKafka(ctx context.Context, tls TLSSetup) (*Kafka, error) {
 | 
			
		||||
	l := log.New(os.Stdout, "kafka: ", log.Ldate|log.Ltime|log.LUTC|log.Lmsgprefix|log.Lmicroseconds)
 | 
			
		||||
 | 
			
		||||
@@ -171,6 +218,12 @@ func NewKafka(ctx context.Context, tls TLSSetup) (*Kafka, error) {
 | 
			
		||||
	return k, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewReader creates a new Kafka reader with the client's broker list and TLS configuration.
 | 
			
		||||
// The provided config is enhanced with the discovered brokers and configured dialer.
 | 
			
		||||
// The reader supports automatic offset management, consumer group coordination, and reconnection.
 | 
			
		||||
//
 | 
			
		||||
// The caller should configure the reader's Topic, GroupID, and other consumer-specific settings
 | 
			
		||||
// in the provided config. The client automatically sets Brokers and Dialer fields.
 | 
			
		||||
func (k *Kafka) NewReader(config kafka.ReaderConfig) *kafka.Reader {
 | 
			
		||||
	config.Brokers = k.brokerAddrs()
 | 
			
		||||
	config.Dialer = k.dialer
 | 
			
		||||
@@ -186,6 +239,16 @@ func (k *Kafka) brokerAddrs() []string {
 | 
			
		||||
	return addrs
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewWriter creates a new Kafka writer for the specified topic with optimized configuration.
 | 
			
		||||
// The writer uses LZ4 compression, least-bytes load balancing, and batching for performance.
 | 
			
		||||
//
 | 
			
		||||
// Configuration includes:
 | 
			
		||||
//   - Batch size: 2000 messages for efficient throughput
 | 
			
		||||
//   - Compression: LZ4 for fast compression with good ratios
 | 
			
		||||
//   - Balancer: LeastBytes for optimal partition distribution
 | 
			
		||||
//   - Transport: TLS-configured transport with connection pooling
 | 
			
		||||
//
 | 
			
		||||
// The writer is ready for immediate use and handles connection management automatically.
 | 
			
		||||
func (k *Kafka) NewWriter(topic string) (*kafka.Writer, error) {
 | 
			
		||||
	// https://pkg.go.dev/github.com/segmentio/kafka-go#Writer
 | 
			
		||||
	w := &kafka.Writer{
 | 
			
		||||
@@ -202,6 +265,12 @@ func (k *Kafka) NewWriter(topic string) (*kafka.Writer, error) {
 | 
			
		||||
	return w, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CheckPartitions verifies that the Kafka connection can read partition metadata.
 | 
			
		||||
// This method is useful for health checks and connection validation.
 | 
			
		||||
//
 | 
			
		||||
// Returns an error if partition metadata cannot be retrieved, which typically
 | 
			
		||||
// indicates connection problems, authentication failures, or broker unavailability.
 | 
			
		||||
// Logs a warning if no partitions are available but does not return an error.
 | 
			
		||||
func (k *Kafka) CheckPartitions() error {
 | 
			
		||||
	partitions, err := k.conn.ReadPartitions()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										204
									
								
								logger/buffering_exporter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										204
									
								
								logger/buffering_exporter.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,204 @@
 | 
			
		||||
package logger
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"go.ntppool.org/common/internal/tracerconfig"
 | 
			
		||||
	otellog "go.opentelemetry.io/otel/sdk/log"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// bufferingExporter wraps an OTLP exporter and buffers logs until tracing is configured
 | 
			
		||||
type bufferingExporter struct {
 | 
			
		||||
	mu sync.RWMutex
 | 
			
		||||
 | 
			
		||||
	// Buffered records while waiting for tracing config
 | 
			
		||||
	buffer      [][]otellog.Record
 | 
			
		||||
	bufferSize  int
 | 
			
		||||
	maxBuffSize int
 | 
			
		||||
 | 
			
		||||
	// Real exporter (created when tracing is configured)
 | 
			
		||||
	exporter otellog.Exporter
 | 
			
		||||
 | 
			
		||||
	// Thread-safe initialization state (managed only by checkReadiness)
 | 
			
		||||
	initErr error
 | 
			
		||||
 | 
			
		||||
	// Background checker
 | 
			
		||||
	stopChecker chan struct{}
 | 
			
		||||
	checkerDone chan struct{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// newBufferingExporter creates a new exporter that buffers logs until tracing is configured
 | 
			
		||||
func newBufferingExporter() *bufferingExporter {
 | 
			
		||||
	e := &bufferingExporter{
 | 
			
		||||
		maxBuffSize: 1000, // Max number of batches to buffer
 | 
			
		||||
		stopChecker: make(chan struct{}),
 | 
			
		||||
		checkerDone: make(chan struct{}),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Start background readiness checker
 | 
			
		||||
	go e.checkReadiness()
 | 
			
		||||
 | 
			
		||||
	return e
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Export implements otellog.Exporter
 | 
			
		||||
func (e *bufferingExporter) Export(ctx context.Context, records []otellog.Record) error {
 | 
			
		||||
	// Check if exporter is ready (initialization handled by checkReadiness goroutine)
 | 
			
		||||
	e.mu.RLock()
 | 
			
		||||
	exporter := e.exporter
 | 
			
		||||
	e.mu.RUnlock()
 | 
			
		||||
 | 
			
		||||
	if exporter != nil {
 | 
			
		||||
		return exporter.Export(ctx, records)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Not ready yet, buffer the records
 | 
			
		||||
	return e.bufferRecords(records)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// initialize attempts to create the real OTLP exporter using tracing config
 | 
			
		||||
func (e *bufferingExporter) initialize() error {
 | 
			
		||||
	cfg, ctx, factory := tracerconfig.Get()
 | 
			
		||||
	if cfg == nil || ctx == nil || factory == nil {
 | 
			
		||||
		return errors.New("tracer not configured yet")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Add timeout for initialization
 | 
			
		||||
	initCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
	exporter, err := factory(initCtx, cfg)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to create OTLP exporter: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	e.mu.Lock()
 | 
			
		||||
	e.exporter = exporter
 | 
			
		||||
	flushErr := e.flushBuffer(initCtx)
 | 
			
		||||
	e.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	if flushErr != nil {
 | 
			
		||||
		// Log but don't fail initialization
 | 
			
		||||
		Setup().Warn("buffer flush failed during initialization", "error", flushErr)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// bufferRecords adds records to the buffer for later processing
 | 
			
		||||
func (e *bufferingExporter) bufferRecords(records []otellog.Record) error {
 | 
			
		||||
	e.mu.Lock()
 | 
			
		||||
	defer e.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	// Buffer the batch if we have space
 | 
			
		||||
	if e.bufferSize < e.maxBuffSize {
 | 
			
		||||
		// Clone records to avoid retention issues
 | 
			
		||||
		cloned := make([]otellog.Record, len(records))
 | 
			
		||||
		for i, r := range records {
 | 
			
		||||
			cloned[i] = r.Clone()
 | 
			
		||||
		}
 | 
			
		||||
		e.buffer = append(e.buffer, cloned)
 | 
			
		||||
		e.bufferSize++
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Always return success to BatchProcessor
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// checkReadiness periodically attempts initialization until successful
 | 
			
		||||
func (e *bufferingExporter) checkReadiness() {
 | 
			
		||||
	defer close(e.checkerDone)
 | 
			
		||||
 | 
			
		||||
	ticker := time.NewTicker(1 * time.Second)
 | 
			
		||||
	defer ticker.Stop()
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-ticker.C:
 | 
			
		||||
			// Check if we already have a working exporter
 | 
			
		||||
			e.mu.RLock()
 | 
			
		||||
			hasExporter := e.exporter != nil
 | 
			
		||||
			e.mu.RUnlock()
 | 
			
		||||
 | 
			
		||||
			if hasExporter {
 | 
			
		||||
				return // Exporter ready, checker no longer needed
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Try to initialize
 | 
			
		||||
			err := e.initialize()
 | 
			
		||||
			e.mu.Lock()
 | 
			
		||||
			e.initErr = err
 | 
			
		||||
			e.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
		case <-e.stopChecker:
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// flushBuffer sends all buffered batches through the real exporter
 | 
			
		||||
func (e *bufferingExporter) flushBuffer(ctx context.Context) error {
 | 
			
		||||
	if len(e.buffer) == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	flushCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
	var lastErr error
 | 
			
		||||
	for _, batch := range e.buffer {
 | 
			
		||||
		if err := e.exporter.Export(flushCtx, batch); err != nil {
 | 
			
		||||
			lastErr = err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Clear buffer after flush attempt
 | 
			
		||||
	e.buffer = nil
 | 
			
		||||
	e.bufferSize = 0
 | 
			
		||||
 | 
			
		||||
	return lastErr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ForceFlush implements otellog.Exporter
 | 
			
		||||
func (e *bufferingExporter) ForceFlush(ctx context.Context) error {
 | 
			
		||||
	e.mu.RLock()
 | 
			
		||||
	defer e.mu.RUnlock()
 | 
			
		||||
 | 
			
		||||
	if e.exporter != nil {
 | 
			
		||||
		return e.exporter.ForceFlush(ctx)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Shutdown implements otellog.Exporter
 | 
			
		||||
func (e *bufferingExporter) Shutdown(ctx context.Context) error {
 | 
			
		||||
	// Stop the readiness checker from continuing
 | 
			
		||||
	close(e.stopChecker)
 | 
			
		||||
 | 
			
		||||
	// Wait for readiness checker goroutine to complete
 | 
			
		||||
	<-e.checkerDone
 | 
			
		||||
 | 
			
		||||
	// Give one final chance for TLS/tracing to become ready for buffer flushing
 | 
			
		||||
	e.mu.RLock()
 | 
			
		||||
	hasExporter := e.exporter != nil
 | 
			
		||||
	e.mu.RUnlock()
 | 
			
		||||
 | 
			
		||||
	if !hasExporter {
 | 
			
		||||
		err := e.initialize()
 | 
			
		||||
		e.mu.Lock()
 | 
			
		||||
		e.initErr = err
 | 
			
		||||
		e.mu.Unlock()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	e.mu.Lock()
 | 
			
		||||
	defer e.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	if e.exporter != nil {
 | 
			
		||||
		return e.exporter.Shutdown(ctx)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										235
									
								
								logger/level_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										235
									
								
								logger/level_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,235 @@
 | 
			
		||||
package logger
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"log/slog"
 | 
			
		||||
	"os"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestParseLevel(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name        string
 | 
			
		||||
		input       string
 | 
			
		||||
		expected    slog.Level
 | 
			
		||||
		expectError bool
 | 
			
		||||
	}{
 | 
			
		||||
		{"empty string", "", slog.LevelInfo, false},
 | 
			
		||||
		{"DEBUG upper", "DEBUG", slog.LevelDebug, false},
 | 
			
		||||
		{"debug lower", "debug", slog.LevelDebug, false},
 | 
			
		||||
		{"INFO upper", "INFO", slog.LevelInfo, false},
 | 
			
		||||
		{"info lower", "info", slog.LevelInfo, false},
 | 
			
		||||
		{"WARN upper", "WARN", slog.LevelWarn, false},
 | 
			
		||||
		{"warn lower", "warn", slog.LevelWarn, false},
 | 
			
		||||
		{"ERROR upper", "ERROR", slog.LevelError, false},
 | 
			
		||||
		{"error lower", "error", slog.LevelError, false},
 | 
			
		||||
		{"invalid level", "invalid", slog.LevelInfo, true},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			level, err := ParseLevel(tt.input)
 | 
			
		||||
			if tt.expectError {
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					t.Errorf("expected error for input %q, got nil", tt.input)
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					t.Errorf("unexpected error for input %q: %v", tt.input, err)
 | 
			
		||||
				}
 | 
			
		||||
				if level != tt.expected {
 | 
			
		||||
					t.Errorf("expected level %v for input %q, got %v", tt.expected, tt.input, level)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSetLevel(t *testing.T) {
 | 
			
		||||
	// Store original level to restore later
 | 
			
		||||
	originalLevel := Level.Level()
 | 
			
		||||
	defer Level.Set(originalLevel)
 | 
			
		||||
 | 
			
		||||
	SetLevel(slog.LevelDebug)
 | 
			
		||||
	if Level.Level() != slog.LevelDebug {
 | 
			
		||||
		t.Errorf("expected Level to be Debug, got %v", Level.Level())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	SetLevel(slog.LevelError)
 | 
			
		||||
	if Level.Level() != slog.LevelError {
 | 
			
		||||
		t.Errorf("expected Level to be Error, got %v", Level.Level())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSetOTLPLevel(t *testing.T) {
 | 
			
		||||
	// Store original level to restore later
 | 
			
		||||
	originalLevel := OTLPLevel.Level()
 | 
			
		||||
	defer OTLPLevel.Set(originalLevel)
 | 
			
		||||
 | 
			
		||||
	SetOTLPLevel(slog.LevelWarn)
 | 
			
		||||
	if OTLPLevel.Level() != slog.LevelWarn {
 | 
			
		||||
		t.Errorf("expected OTLPLevel to be Warn, got %v", OTLPLevel.Level())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	SetOTLPLevel(slog.LevelDebug)
 | 
			
		||||
	if OTLPLevel.Level() != slog.LevelDebug {
 | 
			
		||||
		t.Errorf("expected OTLPLevel to be Debug, got %v", OTLPLevel.Level())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOTLPLevelHandler(t *testing.T) {
 | 
			
		||||
	// Create a mock handler that counts calls
 | 
			
		||||
	callCount := 0
 | 
			
		||||
	mockHandler := &mockHandler{
 | 
			
		||||
		handleFunc: func(ctx context.Context, r slog.Record) error {
 | 
			
		||||
			callCount++
 | 
			
		||||
			return nil
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Set OTLP level to Warn
 | 
			
		||||
	originalLevel := OTLPLevel.Level()
 | 
			
		||||
	defer OTLPLevel.Set(originalLevel)
 | 
			
		||||
	OTLPLevel.Set(slog.LevelWarn)
 | 
			
		||||
 | 
			
		||||
	// Create OTLP level handler
 | 
			
		||||
	handler := newOTLPLevelHandler(mockHandler)
 | 
			
		||||
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
 | 
			
		||||
	// Test that Debug and Info are filtered out
 | 
			
		||||
	if handler.Enabled(ctx, slog.LevelDebug) {
 | 
			
		||||
		t.Error("Debug level should be disabled when OTLP level is Warn")
 | 
			
		||||
	}
 | 
			
		||||
	if handler.Enabled(ctx, slog.LevelInfo) {
 | 
			
		||||
		t.Error("Info level should be disabled when OTLP level is Warn")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test that Warn and Error are enabled
 | 
			
		||||
	if !handler.Enabled(ctx, slog.LevelWarn) {
 | 
			
		||||
		t.Error("Warn level should be enabled when OTLP level is Warn")
 | 
			
		||||
	}
 | 
			
		||||
	if !handler.Enabled(ctx, slog.LevelError) {
 | 
			
		||||
		t.Error("Error level should be enabled when OTLP level is Warn")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test that Handle respects level filtering
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	debugRecord := slog.NewRecord(now, slog.LevelDebug, "debug message", 0)
 | 
			
		||||
	warnRecord := slog.NewRecord(now, slog.LevelWarn, "warn message", 0)
 | 
			
		||||
 | 
			
		||||
	handler.Handle(ctx, debugRecord)
 | 
			
		||||
	if callCount != 0 {
 | 
			
		||||
		t.Error("Debug record should not be passed to underlying handler")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	handler.Handle(ctx, warnRecord)
 | 
			
		||||
	if callCount != 1 {
 | 
			
		||||
		t.Error("Warn record should be passed to underlying handler")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestEnvironmentVariables(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name         string
 | 
			
		||||
		envVar       string
 | 
			
		||||
		envValue     string
 | 
			
		||||
		configPrefix string
 | 
			
		||||
		testFunc     func(t *testing.T)
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:     "LOG_LEVEL sets stderr level",
 | 
			
		||||
			envVar:   "LOG_LEVEL",
 | 
			
		||||
			envValue: "ERROR",
 | 
			
		||||
			testFunc: func(t *testing.T) {
 | 
			
		||||
				// Reset the setup state
 | 
			
		||||
				resetLoggerSetup()
 | 
			
		||||
 | 
			
		||||
				// Call setupStdErrHandler which should read the env var
 | 
			
		||||
				handler := setupStdErrHandler()
 | 
			
		||||
				if handler == nil {
 | 
			
		||||
					t.Fatal("setupStdErrHandler returned nil")
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if Level.Level() != slog.LevelError {
 | 
			
		||||
					t.Errorf("expected Level to be Error after setting LOG_LEVEL=ERROR, got %v", Level.Level())
 | 
			
		||||
				}
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:         "Prefixed LOG_LEVEL",
 | 
			
		||||
			envVar:       "TEST_LOG_LEVEL",
 | 
			
		||||
			envValue:     "DEBUG",
 | 
			
		||||
			configPrefix: "TEST",
 | 
			
		||||
			testFunc: func(t *testing.T) {
 | 
			
		||||
				ConfigPrefix = "TEST"
 | 
			
		||||
				defer func() { ConfigPrefix = "" }()
 | 
			
		||||
 | 
			
		||||
				resetLoggerSetup()
 | 
			
		||||
				handler := setupStdErrHandler()
 | 
			
		||||
				if handler == nil {
 | 
			
		||||
					t.Fatal("setupStdErrHandler returned nil")
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if Level.Level() != slog.LevelDebug {
 | 
			
		||||
					t.Errorf("expected Level to be Debug after setting TEST_LOG_LEVEL=DEBUG, got %v", Level.Level())
 | 
			
		||||
				}
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			// Store original env value and level
 | 
			
		||||
			originalEnv := os.Getenv(tt.envVar)
 | 
			
		||||
			originalLevel := Level.Level()
 | 
			
		||||
			defer func() {
 | 
			
		||||
				os.Setenv(tt.envVar, originalEnv)
 | 
			
		||||
				Level.Set(originalLevel)
 | 
			
		||||
			}()
 | 
			
		||||
 | 
			
		||||
			// Set test environment variable
 | 
			
		||||
			os.Setenv(tt.envVar, tt.envValue)
 | 
			
		||||
 | 
			
		||||
			// Run the test
 | 
			
		||||
			tt.testFunc(t)
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// mockHandler is a simple mock implementation of slog.Handler for testing
 | 
			
		||||
type mockHandler struct {
 | 
			
		||||
	handleFunc func(ctx context.Context, r slog.Record) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mockHandler) Enabled(ctx context.Context, level slog.Level) bool {
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mockHandler) Handle(ctx context.Context, r slog.Record) error {
 | 
			
		||||
	if m.handleFunc != nil {
 | 
			
		||||
		return m.handleFunc(ctx, r)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mockHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
 | 
			
		||||
	return m
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mockHandler) WithGroup(name string) slog.Handler {
 | 
			
		||||
	return m
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// resetLoggerSetup resets the sync.Once instances for testing
 | 
			
		||||
func resetLoggerSetup() {
 | 
			
		||||
	// Reset package-level variables
 | 
			
		||||
	textLogger = nil
 | 
			
		||||
	otlpLogger = nil
 | 
			
		||||
	multiLogger = nil
 | 
			
		||||
 | 
			
		||||
	// Note: We can't easily reset sync.Once instances in tests,
 | 
			
		||||
	// but for the specific test we're doing (environment variable parsing)
 | 
			
		||||
	// we can test the setupStdErrHandler function directly
 | 
			
		||||
}
 | 
			
		||||
@@ -16,13 +16,9 @@ type logfmt struct {
 | 
			
		||||
	mu   sync.Mutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newLogFmtHandler(next slog.Handler) slog.Handler {
 | 
			
		||||
	buf := bytes.NewBuffer([]byte{})
 | 
			
		||||
 | 
			
		||||
	h := &logfmt{
 | 
			
		||||
		buf:  buf,
 | 
			
		||||
		next: next,
 | 
			
		||||
		txt: slog.NewTextHandler(buf, &slog.HandlerOptions{
 | 
			
		||||
// createTextHandlerOptions creates the common slog.HandlerOptions used by all logfmt handlers
 | 
			
		||||
func createTextHandlerOptions() *slog.HandlerOptions {
 | 
			
		||||
	return &slog.HandlerOptions{
 | 
			
		||||
		ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
 | 
			
		||||
			if a.Key == slog.TimeKey && len(groups) == 0 {
 | 
			
		||||
				return slog.Attr{}
 | 
			
		||||
@@ -32,7 +28,16 @@ func newLogFmtHandler(next slog.Handler) slog.Handler {
 | 
			
		||||
			}
 | 
			
		||||
			return a
 | 
			
		||||
		},
 | 
			
		||||
		}),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newLogFmtHandler(next slog.Handler) slog.Handler {
 | 
			
		||||
	buf := bytes.NewBuffer([]byte{})
 | 
			
		||||
 | 
			
		||||
	h := &logfmt{
 | 
			
		||||
		buf:  buf,
 | 
			
		||||
		next: next,
 | 
			
		||||
		txt:  slog.NewTextHandler(buf, createTextHandlerOptions()),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return h
 | 
			
		||||
@@ -43,10 +48,11 @@ func (h *logfmt) Enabled(ctx context.Context, lvl slog.Level) bool {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *logfmt) WithAttrs(attrs []slog.Attr) slog.Handler {
 | 
			
		||||
	buf := bytes.NewBuffer([]byte{})
 | 
			
		||||
	return &logfmt{
 | 
			
		||||
		buf:  bytes.NewBuffer([]byte{}),
 | 
			
		||||
		buf:  buf,
 | 
			
		||||
		next: h.next.WithAttrs(slices.Clone(attrs)),
 | 
			
		||||
		txt:  h.txt.WithAttrs(slices.Clone(attrs)),
 | 
			
		||||
		txt:  slog.NewTextHandler(buf, createTextHandlerOptions()).WithAttrs(slices.Clone(attrs)),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -54,10 +60,11 @@ func (h *logfmt) WithGroup(g string) slog.Handler {
 | 
			
		||||
	if g == "" {
 | 
			
		||||
		return h
 | 
			
		||||
	}
 | 
			
		||||
	buf := bytes.NewBuffer([]byte{})
 | 
			
		||||
	return &logfmt{
 | 
			
		||||
		buf:  bytes.NewBuffer([]byte{}),
 | 
			
		||||
		buf:  buf,
 | 
			
		||||
		next: h.next.WithGroup(g),
 | 
			
		||||
		txt:  h.txt.WithGroup(g),
 | 
			
		||||
		txt:  slog.NewTextHandler(buf, createTextHandlerOptions()).WithGroup(g),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -69,10 +76,22 @@ func (h *logfmt) Handle(ctx context.Context, r slog.Record) error {
 | 
			
		||||
		panic("buffer wasn't empty")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.txt.Handle(ctx, r)
 | 
			
		||||
	r.Message = h.buf.String()
 | 
			
		||||
	r.Message = strings.TrimSuffix(r.Message, "\n")
 | 
			
		||||
	// Format using text handler to get the formatted message
 | 
			
		||||
	err := h.txt.Handle(ctx, r)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	formattedMessage := h.buf.String()
 | 
			
		||||
	formattedMessage = strings.TrimSuffix(formattedMessage, "\n")
 | 
			
		||||
	h.buf.Reset()
 | 
			
		||||
 | 
			
		||||
	return h.next.Handle(ctx, r)
 | 
			
		||||
	// Create a new record with the formatted message
 | 
			
		||||
	newRecord := slog.NewRecord(r.Time, r.Level, formattedMessage, r.PC)
 | 
			
		||||
	r.Attrs(func(a slog.Attr) bool {
 | 
			
		||||
		newRecord.AddAttrs(a)
 | 
			
		||||
		return true
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return h.next.Handle(ctx, newRecord)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										217
									
								
								logger/logger.go
									
									
									
									
									
								
							
							
						
						
									
										217
									
								
								logger/logger.go
									
									
									
									
									
								
							@@ -1,20 +1,61 @@
 | 
			
		||||
// Package logger provides structured logging with OpenTelemetry trace integration.
 | 
			
		||||
//
 | 
			
		||||
// This package offers multiple logging configurations for different deployment scenarios:
 | 
			
		||||
//   - Text logging to stderr with optional timestamp removal for systemd
 | 
			
		||||
//   - OTLP (OpenTelemetry Protocol) logging for observability pipelines
 | 
			
		||||
//   - Multi-logger setup that outputs to both text and OTLP simultaneously
 | 
			
		||||
//   - Context-aware logging with trace ID correlation
 | 
			
		||||
//
 | 
			
		||||
// The package automatically detects systemd environments and adjusts timestamp handling
 | 
			
		||||
// accordingly. It supports debug level configuration via environment variables and
 | 
			
		||||
// provides compatibility bridges for legacy logging interfaces.
 | 
			
		||||
//
 | 
			
		||||
// Key features:
 | 
			
		||||
//   - Automatic OpenTelemetry trace and span ID inclusion in log entries
 | 
			
		||||
//   - Configurable log levels via DEBUG environment variable (with optional prefix)
 | 
			
		||||
//   - Systemd-compatible output (no timestamps when INVOCATION_ID is present)
 | 
			
		||||
//   - Thread-safe logger setup with sync.Once protection
 | 
			
		||||
//   - Context propagation for request-scoped logging
 | 
			
		||||
//
 | 
			
		||||
// Environment variables:
 | 
			
		||||
//   - LOG_LEVEL: Set stderr log level (DEBUG, INFO, WARN, ERROR) (configurable prefix via ConfigPrefix)
 | 
			
		||||
//   - OTLP_LOG_LEVEL: Set OTLP log level independently (configurable prefix via ConfigPrefix)
 | 
			
		||||
//   - DEBUG: Enable debug level logging for backward compatibility (configurable prefix via ConfigPrefix)
 | 
			
		||||
//   - INVOCATION_ID: Systemd detection for timestamp handling
 | 
			
		||||
package logger
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"log/slog"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	slogtraceid "github.com/remychantenay/slog-otel"
 | 
			
		||||
	slogmulti "github.com/samber/slog-multi"
 | 
			
		||||
	"go.opentelemetry.io/contrib/bridges/otelslog"
 | 
			
		||||
	"go.opentelemetry.io/otel/log/global"
 | 
			
		||||
	otellog "go.opentelemetry.io/otel/sdk/log"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ConfigPrefix allows customizing the environment variable prefix for configuration.
 | 
			
		||||
// When set, environment variables like DEBUG become {ConfigPrefix}_DEBUG.
 | 
			
		||||
// This enables multiple services to have independent logging configuration.
 | 
			
		||||
var ConfigPrefix = ""
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	// Level controls the log level for the default stderr logger.
 | 
			
		||||
	// Can be changed at runtime to adjust logging verbosity.
 | 
			
		||||
	Level = new(slog.LevelVar) // Info by default
 | 
			
		||||
 | 
			
		||||
	// OTLPLevel controls the log level for OTLP output.
 | 
			
		||||
	// Can be changed independently from the stderr logger level.
 | 
			
		||||
	OTLPLevel = new(slog.LevelVar) // Info by default
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	textLogger  *slog.Logger
 | 
			
		||||
	otlpLogger  *slog.Logger
 | 
			
		||||
@@ -28,21 +69,64 @@ var (
 | 
			
		||||
	mu         sync.Mutex
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// SetLevel sets the log level for the default stderr logger.
 | 
			
		||||
// This affects the primary application logger returned by Setup().
 | 
			
		||||
func SetLevel(level slog.Level) {
 | 
			
		||||
	Level.Set(level)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetOTLPLevel sets the log level for OTLP output.
 | 
			
		||||
// This affects the logger returned by SetupOLTP() and the OTLP portion of SetupMultiLogger().
 | 
			
		||||
func SetOTLPLevel(level slog.Level) {
 | 
			
		||||
	OTLPLevel.Set(level)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ParseLevel converts a string log level to slog.Level.
 | 
			
		||||
// Supported levels: "DEBUG", "INFO", "WARN", "ERROR" (case insensitive).
 | 
			
		||||
// Returns an error for unrecognized level strings.
 | 
			
		||||
func ParseLevel(level string) (slog.Level, error) {
 | 
			
		||||
	switch {
 | 
			
		||||
	case level == "":
 | 
			
		||||
		return slog.LevelInfo, nil
 | 
			
		||||
	case level == "DEBUG" || level == "debug":
 | 
			
		||||
		return slog.LevelDebug, nil
 | 
			
		||||
	case level == "INFO" || level == "info":
 | 
			
		||||
		return slog.LevelInfo, nil
 | 
			
		||||
	case level == "WARN" || level == "warn":
 | 
			
		||||
		return slog.LevelWarn, nil
 | 
			
		||||
	case level == "ERROR" || level == "error":
 | 
			
		||||
		return slog.LevelError, nil
 | 
			
		||||
	default:
 | 
			
		||||
		return slog.LevelInfo, fmt.Errorf("unknown log level: %s", level)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func setupStdErrHandler() slog.Handler {
 | 
			
		||||
	programLevel := new(slog.LevelVar) // Info by default
 | 
			
		||||
 | 
			
		||||
	envVar := "DEBUG"
 | 
			
		||||
	// Parse LOG_LEVEL environment variable
 | 
			
		||||
	logLevelVar := "LOG_LEVEL"
 | 
			
		||||
	if len(ConfigPrefix) > 0 {
 | 
			
		||||
		envVar = ConfigPrefix + "_" + envVar
 | 
			
		||||
		logLevelVar = ConfigPrefix + "_" + logLevelVar
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if opt := os.Getenv(envVar); len(opt) > 0 {
 | 
			
		||||
	if levelStr := os.Getenv(logLevelVar); levelStr != "" {
 | 
			
		||||
		if level, err := ParseLevel(levelStr); err == nil {
 | 
			
		||||
			Level.Set(level)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Maintain backward compatibility with DEBUG environment variable
 | 
			
		||||
	debugVar := "DEBUG"
 | 
			
		||||
	if len(ConfigPrefix) > 0 {
 | 
			
		||||
		debugVar = ConfigPrefix + "_" + debugVar
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if opt := os.Getenv(debugVar); len(opt) > 0 {
 | 
			
		||||
		if debug, _ := strconv.ParseBool(opt); debug {
 | 
			
		||||
			programLevel.Set(slog.LevelDebug)
 | 
			
		||||
			Level.Set(slog.LevelDebug)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logOptions := &slog.HandlerOptions{Level: programLevel}
 | 
			
		||||
	logOptions := &slog.HandlerOptions{Level: Level}
 | 
			
		||||
 | 
			
		||||
	if len(os.Getenv("INVOCATION_ID")) > 0 {
 | 
			
		||||
		// don't add timestamps when running under systemd
 | 
			
		||||
@@ -60,17 +144,54 @@ func setupStdErrHandler() slog.Handler {
 | 
			
		||||
 | 
			
		||||
func setupOtlpLogger() *slog.Logger {
 | 
			
		||||
	setupOtlp.Do(func() {
 | 
			
		||||
		otlpLogger = slog.New(
 | 
			
		||||
			newLogFmtHandler(otelslog.NewHandler("common")),
 | 
			
		||||
		// Parse OTLP_LOG_LEVEL environment variable
 | 
			
		||||
		otlpLevelVar := "OTLP_LOG_LEVEL"
 | 
			
		||||
		if len(ConfigPrefix) > 0 {
 | 
			
		||||
			otlpLevelVar = ConfigPrefix + "_" + otlpLevelVar
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if levelStr := os.Getenv(otlpLevelVar); levelStr != "" {
 | 
			
		||||
			if level, err := ParseLevel(levelStr); err == nil {
 | 
			
		||||
				OTLPLevel.Set(level)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Create our buffering exporter
 | 
			
		||||
		// It will buffer until tracing is configured
 | 
			
		||||
		bufferingExp := newBufferingExporter()
 | 
			
		||||
 | 
			
		||||
		// Use BatchProcessor with our custom exporter
 | 
			
		||||
		processor := otellog.NewBatchProcessor(bufferingExp,
 | 
			
		||||
			otellog.WithExportInterval(10*time.Second),
 | 
			
		||||
			otellog.WithMaxQueueSize(2048),
 | 
			
		||||
			otellog.WithExportMaxBatchSize(512),
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		// Create logger provider
 | 
			
		||||
		provider := otellog.NewLoggerProvider(
 | 
			
		||||
			otellog.WithProcessor(processor),
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		// Set global provider
 | 
			
		||||
		global.SetLoggerProvider(provider)
 | 
			
		||||
 | 
			
		||||
		// Create slog handler with level control
 | 
			
		||||
		baseHandler := newLogFmtHandler(otelslog.NewHandler("common"))
 | 
			
		||||
		handler := newOTLPLevelHandler(baseHandler)
 | 
			
		||||
		otlpLogger = slog.New(handler)
 | 
			
		||||
	})
 | 
			
		||||
	return otlpLogger
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetupMultiLogger will setup and make default a logger that
 | 
			
		||||
// logs as described in Setup() as well as an OLTP logger.
 | 
			
		||||
// The "multi logger" is made the default the first time
 | 
			
		||||
// this function is called
 | 
			
		||||
// SetupMultiLogger creates a logger that outputs to both text (stderr) and OTLP simultaneously.
 | 
			
		||||
// This is useful for services that need both human-readable logs and structured observability data.
 | 
			
		||||
//
 | 
			
		||||
// The multi-logger combines:
 | 
			
		||||
//   - Text handler: Stderr output with OpenTelemetry trace integration
 | 
			
		||||
//   - OTLP handler: Structured logs sent via OpenTelemetry Protocol
 | 
			
		||||
//
 | 
			
		||||
// On first call, this logger becomes the default logger returned by Setup().
 | 
			
		||||
// The function is thread-safe and uses sync.Once to ensure single initialization.
 | 
			
		||||
func SetupMultiLogger() *slog.Logger {
 | 
			
		||||
	setupMulti.Do(func() {
 | 
			
		||||
		textHandler := Setup().Handler()
 | 
			
		||||
@@ -89,28 +210,38 @@ func SetupMultiLogger() *slog.Logger {
 | 
			
		||||
	return multiLogger
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetupOLTP configures and returns a logger sending logs
 | 
			
		||||
// via OpenTelemetry (configured via the tracing package).
 | 
			
		||||
// SetupOLTP creates a logger that sends structured logs via OpenTelemetry Protocol.
 | 
			
		||||
// This logger is designed for observability pipelines and log aggregation systems.
 | 
			
		||||
//
 | 
			
		||||
// This was made to work with Loki + Grafana that makes it
 | 
			
		||||
// hard to view the log attributes in the UI, so the log
 | 
			
		||||
// message is formatted similarly to the text logger. The
 | 
			
		||||
// attributes are duplicated as OLTP attributes in the
 | 
			
		||||
// log messages. https://github.com/grafana/loki/issues/14788
 | 
			
		||||
// The OTLP logger formats log messages similarly to the text logger for better
 | 
			
		||||
// compatibility with Loki + Grafana, while still providing structured attributes.
 | 
			
		||||
// Log attributes are available both in the message format and as OTLP attributes.
 | 
			
		||||
//
 | 
			
		||||
// This logger does not become the default logger and must be used explicitly.
 | 
			
		||||
// It requires OpenTelemetry tracing configuration to be set up via the tracing package.
 | 
			
		||||
//
 | 
			
		||||
// See: https://github.com/grafana/loki/issues/14788 for formatting rationale.
 | 
			
		||||
func SetupOLTP() *slog.Logger {
 | 
			
		||||
	return setupOtlpLogger()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Setup returns an slog.Logger configured for text formatting
 | 
			
		||||
// to stderr.
 | 
			
		||||
// OpenTelemetry trace_id and span_id's are logged as attributes
 | 
			
		||||
// when available.
 | 
			
		||||
// When the application is running under systemd timestamps are
 | 
			
		||||
// omitted. On first call the slog default logger is set to this
 | 
			
		||||
// logger as well.
 | 
			
		||||
// Setup creates and returns the standard text logger for the application.
 | 
			
		||||
// This is the primary logging function that most applications should use.
 | 
			
		||||
//
 | 
			
		||||
// If SetupMultiLogger has been called Setup() will return
 | 
			
		||||
// the "multi logger"
 | 
			
		||||
// Features:
 | 
			
		||||
//   - Text formatting to stderr with human-readable output
 | 
			
		||||
//   - Automatic OpenTelemetry trace_id and span_id inclusion when available
 | 
			
		||||
//   - Systemd compatibility: omits timestamps when INVOCATION_ID environment variable is present
 | 
			
		||||
//   - Debug level support via DEBUG environment variable (respects ConfigPrefix)
 | 
			
		||||
//   - Thread-safe initialization with sync.Once
 | 
			
		||||
//
 | 
			
		||||
// On first call, this logger becomes the slog default logger. If SetupMultiLogger()
 | 
			
		||||
// has been called previously, Setup() returns the multi-logger instead of the text logger.
 | 
			
		||||
//
 | 
			
		||||
// The logger automatically detects execution context:
 | 
			
		||||
//   - Systemd: Removes timestamps (systemd adds its own)
 | 
			
		||||
//   - Debug mode: Enables debug level logging based on environment variables
 | 
			
		||||
//   - OpenTelemetry: Includes trace correlation when tracing is active
 | 
			
		||||
func Setup() *slog.Logger {
 | 
			
		||||
	setupText.Do(func() {
 | 
			
		||||
		h := setupStdErrHandler()
 | 
			
		||||
@@ -129,15 +260,33 @@ func Setup() *slog.Logger {
 | 
			
		||||
 | 
			
		||||
type loggerKey struct{}
 | 
			
		||||
 | 
			
		||||
// NewContext adds the logger to the context. Use this
 | 
			
		||||
// to for example make a request specific logger available
 | 
			
		||||
// to other functions through the context
 | 
			
		||||
// NewContext stores a logger in the context for request-scoped logging.
 | 
			
		||||
// This enables passing request-specific loggers (e.g., with request IDs,
 | 
			
		||||
// user context, or other correlation data) through the call stack.
 | 
			
		||||
//
 | 
			
		||||
// Use this to create context-aware logging where different parts of the
 | 
			
		||||
// application can access the same enriched logger instance.
 | 
			
		||||
//
 | 
			
		||||
// Example:
 | 
			
		||||
//
 | 
			
		||||
//	logger := slog.With("request_id", requestID)
 | 
			
		||||
//	ctx := logger.NewContext(ctx, logger)
 | 
			
		||||
//	// Pass ctx to downstream functions
 | 
			
		||||
func NewContext(ctx context.Context, l *slog.Logger) context.Context {
 | 
			
		||||
	return context.WithValue(ctx, loggerKey{}, l)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FromContext retrieves a logger from the context. If there is none,
 | 
			
		||||
// it returns the default logger
 | 
			
		||||
// FromContext retrieves a logger from the context.
 | 
			
		||||
// If no logger is stored in the context, it returns the default logger from Setup().
 | 
			
		||||
//
 | 
			
		||||
// This function provides a safe way to access context-scoped loggers without
 | 
			
		||||
// needing to check for nil values. It ensures that logging is always available,
 | 
			
		||||
// falling back to the application's standard logger configuration.
 | 
			
		||||
//
 | 
			
		||||
// Example:
 | 
			
		||||
//
 | 
			
		||||
//	log := logger.FromContext(ctx)
 | 
			
		||||
//	log.Info("processing request") // Uses context logger or default
 | 
			
		||||
func FromContext(ctx context.Context) *slog.Logger {
 | 
			
		||||
	if l, ok := ctx.Value(loggerKey{}).(*slog.Logger); ok {
 | 
			
		||||
		return l
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										48
									
								
								logger/otlp_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								logger/otlp_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,48 @@
 | 
			
		||||
package logger
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"log/slog"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// otlpLevelHandler is a wrapper that enforces level checking for OTLP handlers.
 | 
			
		||||
// This allows independent level control for OTLP output separate from stderr logging.
 | 
			
		||||
type otlpLevelHandler struct {
 | 
			
		||||
	next slog.Handler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// newOTLPLevelHandler creates a new OTLP level wrapper handler.
 | 
			
		||||
func newOTLPLevelHandler(next slog.Handler) slog.Handler {
 | 
			
		||||
	return &otlpLevelHandler{
 | 
			
		||||
		next: next,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Enabled checks if the log level should be processed by the OTLP handler.
 | 
			
		||||
// It uses the OTLPLevel variable to determine if the record should be processed.
 | 
			
		||||
func (h *otlpLevelHandler) Enabled(ctx context.Context, level slog.Level) bool {
 | 
			
		||||
	return level >= OTLPLevel.Level()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Handle processes the log record if the level is enabled.
 | 
			
		||||
// If disabled by level checking, the record is silently dropped.
 | 
			
		||||
func (h *otlpLevelHandler) Handle(ctx context.Context, r slog.Record) error {
 | 
			
		||||
	if !h.Enabled(ctx, r.Level) {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return h.next.Handle(ctx, r)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithAttrs returns a new handler with the specified attributes added.
 | 
			
		||||
func (h *otlpLevelHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
 | 
			
		||||
	return &otlpLevelHandler{
 | 
			
		||||
		next: h.next.WithAttrs(attrs),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithGroup returns a new handler with the specified group name.
 | 
			
		||||
func (h *otlpLevelHandler) WithGroup(name string) slog.Handler {
 | 
			
		||||
	return &otlpLevelHandler{
 | 
			
		||||
		next: h.next.WithGroup(name),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -5,12 +5,24 @@ import (
 | 
			
		||||
	"log/slog"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// stdLoggerish provides a bridge between legacy log interfaces and slog.
 | 
			
		||||
// It implements common logging methods (Println, Printf, Fatalf) that
 | 
			
		||||
// delegate to structured logging with a consistent key prefix.
 | 
			
		||||
type stdLoggerish struct {
 | 
			
		||||
	key string
 | 
			
		||||
	log *slog.Logger
 | 
			
		||||
	f   func(string, ...any)
 | 
			
		||||
	key string               // Prefix key for all log messages
 | 
			
		||||
	log *slog.Logger         // Underlying structured logger
 | 
			
		||||
	f   func(string, ...any) // Log function (Info or Debug level)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewStdLog creates a legacy-compatible logger that bridges to structured logging.
 | 
			
		||||
// This is useful for third-party libraries that expect a standard log.Logger interface.
 | 
			
		||||
//
 | 
			
		||||
// Parameters:
 | 
			
		||||
//   - key: Prefix added to all log messages for identification
 | 
			
		||||
//   - debug: If true, logs at debug level; otherwise logs at info level
 | 
			
		||||
//   - log: Underlying slog.Logger (uses Setup() if nil)
 | 
			
		||||
//
 | 
			
		||||
// The returned logger implements Println, Printf, and Fatalf methods.
 | 
			
		||||
func NewStdLog(key string, debug bool, log *slog.Logger) *stdLoggerish {
 | 
			
		||||
	if log == nil {
 | 
			
		||||
		log = Setup()
 | 
			
		||||
@@ -27,14 +39,18 @@ func NewStdLog(key string, debug bool, log *slog.Logger) *stdLoggerish {
 | 
			
		||||
	return sl
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Println logs the arguments using the configured log level with the instance key.
 | 
			
		||||
func (l stdLoggerish) Println(msg ...any) {
 | 
			
		||||
	l.f(l.key, "msg", msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Printf logs a formatted message using the configured log level with the instance key.
 | 
			
		||||
func (l stdLoggerish) Printf(msg string, args ...any) {
 | 
			
		||||
	l.f(l.key, "msg", fmt.Sprintf(msg, args...))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Fatalf logs a formatted error message and panics.
 | 
			
		||||
// Note: This implementation panics instead of calling os.Exit for testability.
 | 
			
		||||
func (l stdLoggerish) Fatalf(msg string, args ...any) {
 | 
			
		||||
	l.log.Error(l.key, "msg", fmt.Sprintf(msg, args...))
 | 
			
		||||
	panic("fatal error") // todo: does this make sense at all?
 | 
			
		||||
 
 | 
			
		||||
@@ -1,17 +0,0 @@
 | 
			
		||||
package logger
 | 
			
		||||
 | 
			
		||||
type Error struct {
 | 
			
		||||
	Msg  string
 | 
			
		||||
	Data []any
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewError(msg string, data ...any) *Error {
 | 
			
		||||
	return &Error{
 | 
			
		||||
		Msg:  msg,
 | 
			
		||||
		Data: data,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *Error) Error() string {
 | 
			
		||||
	return "not implemented"
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										122
									
								
								metrics/metrics.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										122
									
								
								metrics/metrics.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,122 @@
 | 
			
		||||
// Package metrics provides OpenTelemetry-native metrics with OTLP export support.
 | 
			
		||||
//
 | 
			
		||||
// This package implements a metrics system using the OpenTelemetry metrics data model
 | 
			
		||||
// with OTLP export capabilities. It's designed for new applications that want to use
 | 
			
		||||
// structured metrics export to observability backends.
 | 
			
		||||
//
 | 
			
		||||
// Key features:
 | 
			
		||||
//   - OpenTelemetry native metric types (Counter, Histogram, Gauge, etc.)
 | 
			
		||||
//   - OTLP export for sending metrics to observability backends
 | 
			
		||||
//   - Resource detection and correlation with traces/logs
 | 
			
		||||
//   - Graceful handling when OTLP configuration is not available
 | 
			
		||||
//
 | 
			
		||||
// Example usage:
 | 
			
		||||
//
 | 
			
		||||
//	// Initialize metrics along with tracing
 | 
			
		||||
//	shutdown, err := tracing.InitTracer(ctx, cfg)
 | 
			
		||||
//	if err != nil {
 | 
			
		||||
//		log.Fatal(err)
 | 
			
		||||
//	}
 | 
			
		||||
//	defer shutdown(ctx)
 | 
			
		||||
//
 | 
			
		||||
//	// Get a meter and create instruments
 | 
			
		||||
//	meter := metrics.GetMeter("my-service")
 | 
			
		||||
//	counter, _ := meter.Int64Counter("requests_total")
 | 
			
		||||
//	counter.Add(ctx, 1, metric.WithAttributes(attribute.String("method", "GET")))
 | 
			
		||||
package metrics
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"log/slog"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"go.ntppool.org/common/internal/tracerconfig"
 | 
			
		||||
	"go.opentelemetry.io/otel"
 | 
			
		||||
	"go.opentelemetry.io/otel/metric"
 | 
			
		||||
	sdkmetric "go.opentelemetry.io/otel/sdk/metric"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	meterProvider metric.MeterProvider
 | 
			
		||||
	setupOnce     sync.Once
 | 
			
		||||
	setupErr      error
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Setup initializes the OpenTelemetry metrics provider with OTLP export.
 | 
			
		||||
// This function uses the configuration stored by the tracing package and
 | 
			
		||||
// creates a metrics provider that exports to the same OTLP endpoint.
 | 
			
		||||
//
 | 
			
		||||
// The function is safe to call multiple times - it will only initialize once.
 | 
			
		||||
// If tracing configuration is not available, it returns a no-op provider that
 | 
			
		||||
// doesn't export metrics.
 | 
			
		||||
//
 | 
			
		||||
// Returns an error only if there's a configuration problem. Missing tracing
 | 
			
		||||
// configuration is handled gracefully with a warning log.
 | 
			
		||||
func Setup(ctx context.Context) error {
 | 
			
		||||
	setupOnce.Do(func() {
 | 
			
		||||
		setupErr = initializeMetrics(ctx)
 | 
			
		||||
	})
 | 
			
		||||
	return setupErr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetMeter returns a named meter for creating metric instruments.
 | 
			
		||||
// The meter uses the configured metrics provider, or the global provider
 | 
			
		||||
// if metrics haven't been set up yet.
 | 
			
		||||
//
 | 
			
		||||
// This is the primary entry point for creating metric instruments in your application.
 | 
			
		||||
func GetMeter(name string, opts ...metric.MeterOption) metric.Meter {
 | 
			
		||||
	if meterProvider == nil {
 | 
			
		||||
		// Return the global provider as fallback (no-op if not configured)
 | 
			
		||||
		return otel.GetMeterProvider().Meter(name, opts...)
 | 
			
		||||
	}
 | 
			
		||||
	return meterProvider.Meter(name, opts...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// initializeMetrics sets up the OpenTelemetry metrics provider with OTLP export.
 | 
			
		||||
func initializeMetrics(ctx context.Context) error {
 | 
			
		||||
	log := slog.Default()
 | 
			
		||||
 | 
			
		||||
	// Check if tracing configuration is available
 | 
			
		||||
	cfg, configCtx, factory := tracerconfig.GetMetricExporter()
 | 
			
		||||
	if cfg == nil || configCtx == nil || factory == nil {
 | 
			
		||||
		log.Warn("metrics setup: tracing configuration not available, using no-op provider")
 | 
			
		||||
		// Set the global provider as fallback - metrics just won't be exported
 | 
			
		||||
		meterProvider = otel.GetMeterProvider()
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Create OTLP metrics exporter
 | 
			
		||||
	exporter, err := factory(ctx, cfg)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Error("metrics setup: failed to create OTLP exporter", "error", err)
 | 
			
		||||
		// Fall back to global provider
 | 
			
		||||
		meterProvider = otel.GetMeterProvider()
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Create metrics provider with the exporter
 | 
			
		||||
	provider := sdkmetric.NewMeterProvider(
 | 
			
		||||
		sdkmetric.WithReader(sdkmetric.NewPeriodicReader(
 | 
			
		||||
			exporter,
 | 
			
		||||
			sdkmetric.WithInterval(15*time.Second),
 | 
			
		||||
		)),
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	// Set the global provider
 | 
			
		||||
	otel.SetMeterProvider(provider)
 | 
			
		||||
	meterProvider = provider
 | 
			
		||||
 | 
			
		||||
	log.Info("metrics setup: OTLP metrics provider initialized")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Shutdown gracefully shuts down the metrics provider.
 | 
			
		||||
// This should be called during application shutdown to ensure all metrics
 | 
			
		||||
// are properly flushed and exported.
 | 
			
		||||
func Shutdown(ctx context.Context) error {
 | 
			
		||||
	if provider, ok := meterProvider.(*sdkmetric.MeterProvider); ok {
 | 
			
		||||
		return provider.Shutdown(ctx)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										296
									
								
								metrics/metrics_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										296
									
								
								metrics/metrics_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,296 @@
 | 
			
		||||
package metrics
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"os"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"go.ntppool.org/common/internal/tracerconfig"
 | 
			
		||||
	"go.opentelemetry.io/otel/attribute"
 | 
			
		||||
	"go.opentelemetry.io/otel/metric"
 | 
			
		||||
	sdkmetric "go.opentelemetry.io/otel/sdk/metric"
 | 
			
		||||
	"go.opentelemetry.io/otel/sdk/metric/metricdata"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestSetup_NoConfiguration(t *testing.T) {
 | 
			
		||||
	// Clear any existing configuration
 | 
			
		||||
	tracerconfig.Clear()
 | 
			
		||||
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	err := Setup(ctx)
 | 
			
		||||
	// Should not return an error even when no configuration is available
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Setup() returned unexpected error: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Should be able to get a meter (even if it's a no-op)
 | 
			
		||||
	meter := GetMeter("test-meter")
 | 
			
		||||
	if meter == nil {
 | 
			
		||||
		t.Error("GetMeter() returned nil")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGetMeter(t *testing.T) {
 | 
			
		||||
	// Clear any existing configuration
 | 
			
		||||
	tracerconfig.Clear()
 | 
			
		||||
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	_ = Setup(ctx)
 | 
			
		||||
 | 
			
		||||
	meter := GetMeter("test-service")
 | 
			
		||||
	if meter == nil {
 | 
			
		||||
		t.Fatal("GetMeter() returned nil")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test creating a counter instrument
 | 
			
		||||
	counter, err := meter.Int64Counter("test_counter")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Failed to create counter: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test using the counter (should not error even with no-op provider)
 | 
			
		||||
	counter.Add(ctx, 1, metric.WithAttributes(attribute.String("test", "value")))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSetup_MultipleCallsSafe(t *testing.T) {
 | 
			
		||||
	// Clear any existing configuration
 | 
			
		||||
	tracerconfig.Clear()
 | 
			
		||||
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
 | 
			
		||||
	// Call Setup multiple times
 | 
			
		||||
	err1 := Setup(ctx)
 | 
			
		||||
	err2 := Setup(ctx)
 | 
			
		||||
	err3 := Setup(ctx)
 | 
			
		||||
 | 
			
		||||
	if err1 != nil {
 | 
			
		||||
		t.Errorf("First Setup() call returned error: %v", err1)
 | 
			
		||||
	}
 | 
			
		||||
	if err2 != nil {
 | 
			
		||||
		t.Errorf("Second Setup() call returned error: %v", err2)
 | 
			
		||||
	}
 | 
			
		||||
	if err3 != nil {
 | 
			
		||||
		t.Errorf("Third Setup() call returned error: %v", err3)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Should still be able to get meters
 | 
			
		||||
	meter := GetMeter("test-meter")
 | 
			
		||||
	if meter == nil {
 | 
			
		||||
		t.Error("GetMeter() returned nil after multiple Setup() calls")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSetup_WithConfiguration(t *testing.T) {
 | 
			
		||||
	// Clear any existing configuration
 | 
			
		||||
	tracerconfig.Clear()
 | 
			
		||||
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	config := &tracerconfig.Config{
 | 
			
		||||
		ServiceName: "test-metrics-service",
 | 
			
		||||
		Environment: "test",
 | 
			
		||||
		Endpoint:    "localhost:4317", // Will likely fail to connect, but should set up provider
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Create a mock exporter factory that returns a working exporter
 | 
			
		||||
	mockFactory := func(ctx context.Context, cfg *tracerconfig.Config) (sdkmetric.Exporter, error) {
 | 
			
		||||
		// Create a simple in-memory exporter for testing
 | 
			
		||||
		return &mockMetricExporter{}, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Store configuration with mock factory
 | 
			
		||||
	tracerconfig.Store(ctx, config, nil, mockFactory, nil)
 | 
			
		||||
 | 
			
		||||
	// Setup metrics
 | 
			
		||||
	err := Setup(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Setup() returned error: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Should be able to get a meter
 | 
			
		||||
	meter := GetMeter("test-service")
 | 
			
		||||
	if meter == nil {
 | 
			
		||||
		t.Fatal("GetMeter() returned nil")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test creating and using instruments
 | 
			
		||||
	counter, err := meter.Int64Counter("test_counter")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Failed to create counter: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	histogram, err := meter.Float64Histogram("test_histogram")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Failed to create histogram: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	gauge, err := meter.Int64UpDownCounter("test_gauge")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Failed to create gauge: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Use the instruments
 | 
			
		||||
	counter.Add(ctx, 1, metric.WithAttributes(attribute.String("test", "value")))
 | 
			
		||||
	histogram.Record(ctx, 1.5, metric.WithAttributes(attribute.String("test", "value")))
 | 
			
		||||
	gauge.Add(ctx, 10, metric.WithAttributes(attribute.String("test", "value")))
 | 
			
		||||
 | 
			
		||||
	// Test shutdown
 | 
			
		||||
	err = Shutdown(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Shutdown() returned error: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSetup_WithRealOTLPConfig(t *testing.T) {
 | 
			
		||||
	// Skip this test in short mode since it may try to make network connections
 | 
			
		||||
	if testing.Short() {
 | 
			
		||||
		t.Skip("Skipping integration test in short mode")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Clear any existing configuration
 | 
			
		||||
	tracerconfig.Clear()
 | 
			
		||||
 | 
			
		||||
	// Set environment variables for OTLP configuration
 | 
			
		||||
	originalEndpoint := os.Getenv("OTEL_EXPORTER_OTLP_ENDPOINT")
 | 
			
		||||
	originalProtocol := os.Getenv("OTEL_EXPORTER_OTLP_PROTOCOL")
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if originalEndpoint != "" {
 | 
			
		||||
			os.Setenv("OTEL_EXPORTER_OTLP_ENDPOINT", originalEndpoint)
 | 
			
		||||
		} else {
 | 
			
		||||
			os.Unsetenv("OTEL_EXPORTER_OTLP_ENDPOINT")
 | 
			
		||||
		}
 | 
			
		||||
		if originalProtocol != "" {
 | 
			
		||||
			os.Setenv("OTEL_EXPORTER_OTLP_PROTOCOL", originalProtocol)
 | 
			
		||||
		} else {
 | 
			
		||||
			os.Unsetenv("OTEL_EXPORTER_OTLP_PROTOCOL")
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	os.Setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4318") // HTTP endpoint
 | 
			
		||||
	os.Setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf")
 | 
			
		||||
 | 
			
		||||
	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
	config := &tracerconfig.Config{
 | 
			
		||||
		ServiceName: "test-metrics-e2e",
 | 
			
		||||
		Environment: "test",
 | 
			
		||||
		Endpoint:    "localhost:4318",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Store configuration with real factory
 | 
			
		||||
	tracerconfig.Store(ctx, config, nil, tracerconfig.CreateOTLPMetricExporter, nil)
 | 
			
		||||
 | 
			
		||||
	// Setup metrics - this may fail if no OTLP collector is running, which is okay
 | 
			
		||||
	err := Setup(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Logf("Setup() returned error (expected if no OTLP collector): %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Should still be able to get a meter
 | 
			
		||||
	meter := GetMeter("test-service-e2e")
 | 
			
		||||
	if meter == nil {
 | 
			
		||||
		t.Fatal("GetMeter() returned nil")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Create and use instruments
 | 
			
		||||
	counter, err := meter.Int64Counter("e2e_test_counter")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Failed to create counter: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Add some metrics
 | 
			
		||||
	for i := 0; i < 5; i++ {
 | 
			
		||||
		counter.Add(ctx, 1, metric.WithAttributes(
 | 
			
		||||
			attribute.String("iteration", string(rune('0'+i))),
 | 
			
		||||
			attribute.String("test_type", "e2e"),
 | 
			
		||||
		))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Give some time for export (if collector is running)
 | 
			
		||||
	time.Sleep(100 * time.Millisecond)
 | 
			
		||||
 | 
			
		||||
	// Test shutdown
 | 
			
		||||
	err = Shutdown(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Logf("Shutdown() returned error (may be expected): %v", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestConcurrentMetricUsage(t *testing.T) {
 | 
			
		||||
	// Clear any existing configuration
 | 
			
		||||
	tracerconfig.Clear()
 | 
			
		||||
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	config := &tracerconfig.Config{
 | 
			
		||||
		ServiceName: "concurrent-test",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Use mock factory
 | 
			
		||||
	mockFactory := func(ctx context.Context, cfg *tracerconfig.Config) (sdkmetric.Exporter, error) {
 | 
			
		||||
		return &mockMetricExporter{}, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tracerconfig.Store(ctx, config, nil, mockFactory, nil)
 | 
			
		||||
	Setup(ctx)
 | 
			
		||||
 | 
			
		||||
	meter := GetMeter("concurrent-test")
 | 
			
		||||
	counter, err := meter.Int64Counter("concurrent_counter")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Failed to create counter: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test concurrent metric usage
 | 
			
		||||
	const numGoroutines = 10
 | 
			
		||||
	const metricsPerGoroutine = 100
 | 
			
		||||
 | 
			
		||||
	done := make(chan bool, numGoroutines)
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < numGoroutines; i++ {
 | 
			
		||||
		go func(goroutineID int) {
 | 
			
		||||
			for j := 0; j < metricsPerGoroutine; j++ {
 | 
			
		||||
				counter.Add(ctx, 1, metric.WithAttributes(
 | 
			
		||||
					attribute.Int("goroutine", goroutineID),
 | 
			
		||||
					attribute.Int("iteration", j),
 | 
			
		||||
				))
 | 
			
		||||
			}
 | 
			
		||||
			done <- true
 | 
			
		||||
		}(i)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Wait for all goroutines to complete
 | 
			
		||||
	for i := 0; i < numGoroutines; i++ {
 | 
			
		||||
		<-done
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Shutdown
 | 
			
		||||
	err = Shutdown(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Shutdown() returned error: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// mockMetricExporter is a simple mock exporter for testing
 | 
			
		||||
type mockMetricExporter struct{}
 | 
			
		||||
 | 
			
		||||
func (m *mockMetricExporter) Export(ctx context.Context, rm *metricdata.ResourceMetrics) error {
 | 
			
		||||
	// Just pretend to export
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mockMetricExporter) ForceFlush(ctx context.Context) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mockMetricExporter) Shutdown(ctx context.Context) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mockMetricExporter) Temporality(kind sdkmetric.InstrumentKind) metricdata.Temporality {
 | 
			
		||||
	return metricdata.CumulativeTemporality
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mockMetricExporter) Aggregation(kind sdkmetric.InstrumentKind) sdkmetric.Aggregation {
 | 
			
		||||
	return sdkmetric.DefaultAggregationSelector(kind)
 | 
			
		||||
}
 | 
			
		||||
@@ -1,3 +1,26 @@
 | 
			
		||||
// Package metricsserver provides a standalone HTTP server for exposing Prometheus metrics.
 | 
			
		||||
//
 | 
			
		||||
// This package implements a dedicated metrics server that exposes application metrics
 | 
			
		||||
// via HTTP. It uses a custom Prometheus registry to avoid conflicts with other metric
 | 
			
		||||
// collectors and provides graceful shutdown capabilities.
 | 
			
		||||
//
 | 
			
		||||
// # Usage
 | 
			
		||||
//
 | 
			
		||||
// Create a new metrics server and register your metrics with its Registry():
 | 
			
		||||
//
 | 
			
		||||
//	m := metricsserver.New()
 | 
			
		||||
//	myCounter := prometheus.NewCounter(...)
 | 
			
		||||
//	m.Registry().MustRegister(myCounter)
 | 
			
		||||
//
 | 
			
		||||
// When you need a Gatherer (for example, to pass to other libraries), use the Gatherer() method
 | 
			
		||||
// instead of prometheus.DefaultGatherer to ensure your custom metrics are collected:
 | 
			
		||||
//
 | 
			
		||||
//	gatherer := m.Gatherer()  // Returns the custom registry as a Gatherer
 | 
			
		||||
//
 | 
			
		||||
// To use the default Prometheus gatherer alongside your custom registry:
 | 
			
		||||
//
 | 
			
		||||
//	m := metricsserver.NewWithDefaultGatherer()
 | 
			
		||||
//	m.Gatherer()  // Returns prometheus.DefaultGatherer
 | 
			
		||||
package metricsserver
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
@@ -13,24 +36,64 @@ import (
 | 
			
		||||
	"go.ntppool.org/common/logger"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Metrics provides a custom Prometheus registry and HTTP handlers for metrics exposure.
 | 
			
		||||
// It isolates application metrics from the default global registry.
 | 
			
		||||
type Metrics struct {
 | 
			
		||||
	r                  *prometheus.Registry
 | 
			
		||||
	useDefaultGatherer bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// New creates a new Metrics instance with a custom Prometheus registry.
 | 
			
		||||
// Use this when you want isolated metrics that don't interfere with the default registry.
 | 
			
		||||
func New() *Metrics {
 | 
			
		||||
	r := prometheus.NewRegistry()
 | 
			
		||||
 | 
			
		||||
	m := &Metrics{
 | 
			
		||||
		r:                  r,
 | 
			
		||||
		useDefaultGatherer: false,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return m
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewWithDefaultGatherer creates a new Metrics instance that uses the default Prometheus gatherer.
 | 
			
		||||
// This is useful when you want to expose metrics from the default registry alongside your custom metrics.
 | 
			
		||||
// The custom registry will still be available via Registry() but Gatherer() will return prometheus.DefaultGatherer.
 | 
			
		||||
func NewWithDefaultGatherer() *Metrics {
 | 
			
		||||
	r := prometheus.NewRegistry()
 | 
			
		||||
 | 
			
		||||
	m := &Metrics{
 | 
			
		||||
		r:                  r,
 | 
			
		||||
		useDefaultGatherer: true,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return m
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Registry returns the custom Prometheus registry.
 | 
			
		||||
// Use this to register your application's metrics collectors.
 | 
			
		||||
func (m *Metrics) Registry() *prometheus.Registry {
 | 
			
		||||
	return m.r
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Gatherer returns the Prometheus gatherer to use for collecting metrics.
 | 
			
		||||
// This returns the custom registry's Gatherer by default, ensuring your registered
 | 
			
		||||
// metrics are collected. If the instance was created with NewWithDefaultGatherer(),
 | 
			
		||||
// this returns prometheus.DefaultGatherer instead.
 | 
			
		||||
//
 | 
			
		||||
// Use this method when you need a prometheus.Gatherer interface, for example when
 | 
			
		||||
// configuring other libraries that need to collect metrics.
 | 
			
		||||
//
 | 
			
		||||
// IMPORTANT: Do not use prometheus.DefaultGatherer directly if you want to collect
 | 
			
		||||
// metrics registered with this instance's Registry(). Always use this Gatherer() method.
 | 
			
		||||
func (m *Metrics) Gatherer() prometheus.Gatherer {
 | 
			
		||||
	if m.useDefaultGatherer {
 | 
			
		||||
		return prometheus.DefaultGatherer
 | 
			
		||||
	}
 | 
			
		||||
	return m.r
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Handler returns an HTTP handler for the /metrics endpoint with OpenMetrics support.
 | 
			
		||||
func (m *Metrics) Handler() http.Handler {
 | 
			
		||||
	log := logger.NewStdLog("prom http", false, nil)
 | 
			
		||||
 | 
			
		||||
@@ -41,9 +104,8 @@ func (m *Metrics) Handler() http.Handler {
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ListenAndServe starts a goroutine with a server running on
 | 
			
		||||
// the specified port. The server will shutdown and return when
 | 
			
		||||
// the provided context is done
 | 
			
		||||
// ListenAndServe starts a metrics server on the specified port and blocks until ctx is done.
 | 
			
		||||
// The server exposes the metrics handler and shuts down gracefully when the context is cancelled.
 | 
			
		||||
func (m *Metrics) ListenAndServe(ctx context.Context, port int) error {
 | 
			
		||||
	log := logger.Setup()
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										322
									
								
								metricsserver/metrics_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										322
									
								
								metricsserver/metrics_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,322 @@
 | 
			
		||||
package metricsserver
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/http/httptest"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/prometheus/client_golang/prometheus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestNew(t *testing.T) {
 | 
			
		||||
	metrics := New()
 | 
			
		||||
 | 
			
		||||
	if metrics == nil {
 | 
			
		||||
		t.Fatal("New returned nil")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if metrics.r == nil {
 | 
			
		||||
		t.Error("metrics registry is nil")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRegistry(t *testing.T) {
 | 
			
		||||
	metrics := New()
 | 
			
		||||
	registry := metrics.Registry()
 | 
			
		||||
 | 
			
		||||
	if registry == nil {
 | 
			
		||||
		t.Fatal("Registry() returned nil")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if registry != metrics.r {
 | 
			
		||||
		t.Error("Registry() did not return the metrics registry")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test that we can register a metric
 | 
			
		||||
	counter := prometheus.NewCounter(prometheus.CounterOpts{
 | 
			
		||||
		Name: "test_counter",
 | 
			
		||||
		Help: "A test counter",
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	err := registry.Register(counter)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("failed to register metric: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test that the metric is registered
 | 
			
		||||
	metricFamilies, err := registry.Gather()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("failed to gather metrics: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	found := false
 | 
			
		||||
	for _, mf := range metricFamilies {
 | 
			
		||||
		if mf.GetName() == "test_counter" {
 | 
			
		||||
			found = true
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !found {
 | 
			
		||||
		t.Error("registered metric not found in registry")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGatherer(t *testing.T) {
 | 
			
		||||
	metrics := New()
 | 
			
		||||
 | 
			
		||||
	gatherer := metrics.Gatherer()
 | 
			
		||||
	if gatherer == nil {
 | 
			
		||||
		t.Fatal("Gatherer() returned nil")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Register a test metric
 | 
			
		||||
	counter := prometheus.NewCounter(prometheus.CounterOpts{
 | 
			
		||||
		Name: "test_gatherer_counter",
 | 
			
		||||
		Help: "A test counter for gatherer",
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	metrics.Registry().MustRegister(counter)
 | 
			
		||||
	counter.Inc()
 | 
			
		||||
 | 
			
		||||
	// Test that the gatherer collects our custom metric
 | 
			
		||||
	metricFamilies, err := gatherer.Gather()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("failed to gather metrics: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	found := false
 | 
			
		||||
	for _, mf := range metricFamilies {
 | 
			
		||||
		if mf.GetName() == "test_gatherer_counter" {
 | 
			
		||||
			found = true
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !found {
 | 
			
		||||
		t.Error("registered metric not found via Gatherer()")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Verify gatherer is the same as registry
 | 
			
		||||
	if gatherer != metrics.r {
 | 
			
		||||
		t.Error("Gatherer() should return the same object as the registry for custom registry mode")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestHandler(t *testing.T) {
 | 
			
		||||
	metrics := New()
 | 
			
		||||
 | 
			
		||||
	// Register a test metric
 | 
			
		||||
	counter := prometheus.NewCounterVec(
 | 
			
		||||
		prometheus.CounterOpts{
 | 
			
		||||
			Name: "test_requests_total",
 | 
			
		||||
			Help: "Total number of test requests",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{"method"},
 | 
			
		||||
	)
 | 
			
		||||
	metrics.Registry().MustRegister(counter)
 | 
			
		||||
	counter.WithLabelValues("GET").Inc()
 | 
			
		||||
 | 
			
		||||
	// Test the handler
 | 
			
		||||
	handler := metrics.Handler()
 | 
			
		||||
	if handler == nil {
 | 
			
		||||
		t.Fatal("Handler() returned nil")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Create a test request
 | 
			
		||||
	req := httptest.NewRequest("GET", "/metrics", nil)
 | 
			
		||||
	recorder := httptest.NewRecorder()
 | 
			
		||||
 | 
			
		||||
	// Call the handler
 | 
			
		||||
	handler.ServeHTTP(recorder, req)
 | 
			
		||||
 | 
			
		||||
	// Check response
 | 
			
		||||
	resp := recorder.Result()
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		t.Errorf("expected status 200, got %d", resp.StatusCode)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("failed to read response body: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	bodyStr := string(body)
 | 
			
		||||
 | 
			
		||||
	// Check for our test metric
 | 
			
		||||
	if !strings.Contains(bodyStr, "test_requests_total") {
 | 
			
		||||
		t.Error("test metric not found in metrics output")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check for OpenMetrics format indicators
 | 
			
		||||
	if !strings.Contains(bodyStr, "# TYPE") {
 | 
			
		||||
		t.Error("metrics output missing TYPE comments")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestListenAndServe(t *testing.T) {
 | 
			
		||||
	metrics := New()
 | 
			
		||||
 | 
			
		||||
	// Register a test metric
 | 
			
		||||
	counter := prometheus.NewCounterVec(
 | 
			
		||||
		prometheus.CounterOpts{
 | 
			
		||||
			Name: "test_requests_total",
 | 
			
		||||
			Help: "Total number of test requests",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{"method"},
 | 
			
		||||
	)
 | 
			
		||||
	metrics.Registry().MustRegister(counter)
 | 
			
		||||
	counter.WithLabelValues("GET").Inc()
 | 
			
		||||
 | 
			
		||||
	ctx, cancel := context.WithCancel(context.Background())
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
	// Start server in a goroutine
 | 
			
		||||
	errCh := make(chan error, 1)
 | 
			
		||||
	go func() {
 | 
			
		||||
		// Use a high port number to avoid conflicts
 | 
			
		||||
		errCh <- metrics.ListenAndServe(ctx, 9999)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// Give the server a moment to start
 | 
			
		||||
	time.Sleep(100 * time.Millisecond)
 | 
			
		||||
 | 
			
		||||
	// Test metrics endpoint
 | 
			
		||||
	resp, err := http.Get("http://localhost:9999/metrics")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("failed to GET /metrics: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		t.Errorf("expected status 200, got %d", resp.StatusCode)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("failed to read response body: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	bodyStr := string(body)
 | 
			
		||||
 | 
			
		||||
	// Check for our test metric
 | 
			
		||||
	if !strings.Contains(bodyStr, "test_requests_total") {
 | 
			
		||||
		t.Error("test metric not found in metrics output")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Cancel context to stop server
 | 
			
		||||
	cancel()
 | 
			
		||||
 | 
			
		||||
	// Wait for server to stop
 | 
			
		||||
	select {
 | 
			
		||||
	case err := <-errCh:
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Errorf("server returned error: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	case <-time.After(5 * time.Second):
 | 
			
		||||
		t.Error("server did not stop within timeout")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestListenAndServeContextCancellation(t *testing.T) {
 | 
			
		||||
	metrics := New()
 | 
			
		||||
 | 
			
		||||
	ctx, cancel := context.WithCancel(context.Background())
 | 
			
		||||
 | 
			
		||||
	// Start server
 | 
			
		||||
	errCh := make(chan error, 1)
 | 
			
		||||
	go func() {
 | 
			
		||||
		errCh <- metrics.ListenAndServe(ctx, 9998)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// Give server time to start
 | 
			
		||||
	time.Sleep(100 * time.Millisecond)
 | 
			
		||||
 | 
			
		||||
	// Cancel context
 | 
			
		||||
	cancel()
 | 
			
		||||
 | 
			
		||||
	// Server should stop gracefully
 | 
			
		||||
	select {
 | 
			
		||||
	case err := <-errCh:
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Errorf("server returned error on graceful shutdown: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	case <-time.After(5 * time.Second):
 | 
			
		||||
		t.Error("server did not stop within timeout after context cancellation")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNewWithDefaultGatherer(t *testing.T) {
 | 
			
		||||
	metrics := NewWithDefaultGatherer()
 | 
			
		||||
 | 
			
		||||
	if metrics == nil {
 | 
			
		||||
		t.Fatal("NewWithDefaultGatherer returned nil")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !metrics.useDefaultGatherer {
 | 
			
		||||
		t.Error("useDefaultGatherer should be true")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	gatherer := metrics.Gatherer()
 | 
			
		||||
	if gatherer == nil {
 | 
			
		||||
		t.Fatal("Gatherer() returned nil")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Verify it returns the default gatherer
 | 
			
		||||
	if gatherer != prometheus.DefaultGatherer {
 | 
			
		||||
		t.Error("Gatherer() should return prometheus.DefaultGatherer when useDefaultGatherer is true")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Verify the custom registry is still available and separate
 | 
			
		||||
	if metrics.Registry() == nil {
 | 
			
		||||
		t.Error("Registry() should still return a custom registry")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test that registering in custom registry doesn't affect default gatherer check
 | 
			
		||||
	counter := prometheus.NewCounter(prometheus.CounterOpts{
 | 
			
		||||
		Name: "test_default_gatherer_counter",
 | 
			
		||||
		Help: "A test counter",
 | 
			
		||||
	})
 | 
			
		||||
	metrics.Registry().MustRegister(counter)
 | 
			
		||||
 | 
			
		||||
	// The gatherer should still be the default one, not our custom registry
 | 
			
		||||
	if metrics.Gatherer() != prometheus.DefaultGatherer {
 | 
			
		||||
		t.Error("Gatherer() should continue to return prometheus.DefaultGatherer")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Benchmark the metrics handler response time
 | 
			
		||||
func BenchmarkMetricsHandler(b *testing.B) {
 | 
			
		||||
	metrics := New()
 | 
			
		||||
 | 
			
		||||
	// Register some test metrics
 | 
			
		||||
	for i := 0; i < 10; i++ {
 | 
			
		||||
		counter := prometheus.NewCounter(prometheus.CounterOpts{
 | 
			
		||||
			Name: fmt.Sprintf("bench_counter_%d", i),
 | 
			
		||||
			Help: "A benchmark counter",
 | 
			
		||||
		})
 | 
			
		||||
		metrics.Registry().MustRegister(counter)
 | 
			
		||||
		counter.Add(float64(i * 100))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	handler := metrics.Handler()
 | 
			
		||||
 | 
			
		||||
	b.ResetTimer()
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		req := httptest.NewRequest("GET", "/metrics", nil)
 | 
			
		||||
		recorder := httptest.NewRecorder()
 | 
			
		||||
		handler.ServeHTTP(recorder, req)
 | 
			
		||||
 | 
			
		||||
		if recorder.Code != http.StatusOK {
 | 
			
		||||
			b.Fatalf("unexpected status code: %d", recorder.Code)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -15,7 +15,7 @@ mkdir -p $DIR
 | 
			
		||||
 | 
			
		||||
BASE=https://geodns.bitnames.com/${BASE}/builds/${BUILD}
 | 
			
		||||
 | 
			
		||||
files=`curl -sSf ${BASE}/checksums.txt | awk '{print $2}'`
 | 
			
		||||
files=`curl -sSf ${BASE}/checksums.txt | sed 's/^[a-f0-9]*[[:space:]]*//'`
 | 
			
		||||
metafiles="checksums.txt metadata.json CHANGELOG.md artifacts.json"
 | 
			
		||||
 | 
			
		||||
for f in $metafiles; do
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,7 @@
 | 
			
		||||
 | 
			
		||||
set -euo pipefail
 | 
			
		||||
 | 
			
		||||
go install github.com/goreleaser/goreleaser/v2@v2.8.2
 | 
			
		||||
go install github.com/goreleaser/goreleaser/v2@v2.12.3
 | 
			
		||||
 | 
			
		||||
if [ ! -z "${harbor_username:-}" ]; then
 | 
			
		||||
  DOCKER_FILE=~/.docker/config.json
 | 
			
		||||
 
 | 
			
		||||
@@ -1,3 +1,4 @@
 | 
			
		||||
// Package timeutil provides JSON-serializable time utilities.
 | 
			
		||||
package timeutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
@@ -6,14 +7,37 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Duration is a wrapper around time.Duration that supports JSON marshaling/unmarshaling.
 | 
			
		||||
//
 | 
			
		||||
// When marshaling to JSON, it outputs the duration as a string using time.Duration.String().
 | 
			
		||||
// When unmarshaling from JSON, it accepts both:
 | 
			
		||||
//   - String values that can be parsed by time.ParseDuration (e.g., "30s", "5m", "1h30m")
 | 
			
		||||
//   - Numeric values that represent nanoseconds as a float64
 | 
			
		||||
//
 | 
			
		||||
// This makes it compatible with configuration files and APIs that need to represent
 | 
			
		||||
// durations in a human-readable format.
 | 
			
		||||
//
 | 
			
		||||
// Example usage:
 | 
			
		||||
//
 | 
			
		||||
//	type Config struct {
 | 
			
		||||
//		Timeout timeutil.Duration `json:"timeout"`
 | 
			
		||||
//	}
 | 
			
		||||
//
 | 
			
		||||
//	// JSON: {"timeout": "30s"}
 | 
			
		||||
//	// or:   {"timeout": 30000000000}
 | 
			
		||||
type Duration struct {
 | 
			
		||||
	time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MarshalJSON implements json.Marshaler.
 | 
			
		||||
// It marshals the duration as a string using time.Duration.String().
 | 
			
		||||
func (d Duration) MarshalJSON() ([]byte, error) {
 | 
			
		||||
	return json.Marshal(time.Duration(d.Duration).String())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UnmarshalJSON implements json.Unmarshaler.
 | 
			
		||||
// It accepts both string values (parsed via time.ParseDuration) and
 | 
			
		||||
// numeric values (interpreted as nanoseconds).
 | 
			
		||||
func (d *Duration) UnmarshalJSON(b []byte) error {
 | 
			
		||||
	var v any
 | 
			
		||||
	if err := json.Unmarshal(b, &v); err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,3 +1,36 @@
 | 
			
		||||
// Package tracing provides OpenTelemetry distributed tracing setup with OTLP export support.
 | 
			
		||||
//
 | 
			
		||||
// This package handles the complete OpenTelemetry SDK initialization including:
 | 
			
		||||
//   - Trace provider configuration with batching and resource detection
 | 
			
		||||
//   - Log provider setup for structured log export via OTLP
 | 
			
		||||
//   - Automatic resource discovery (service name, version, host, container, process info)
 | 
			
		||||
//   - Support for both gRPC and HTTP OTLP exporters with TLS configuration
 | 
			
		||||
//   - Propagation context setup for distributed tracing across services
 | 
			
		||||
//   - Graceful shutdown handling for all telemetry components
 | 
			
		||||
//
 | 
			
		||||
// The package supports various deployment scenarios:
 | 
			
		||||
//   - Development: Local OTLP collectors or observability backends
 | 
			
		||||
//   - Production: Secure OTLP export with mutual TLS authentication
 | 
			
		||||
//   - Container environments: Automatic container and Kubernetes resource detection
 | 
			
		||||
//
 | 
			
		||||
// Configuration is primarily handled via standard OpenTelemetry environment variables:
 | 
			
		||||
//   - OTEL_SERVICE_NAME: Service identification
 | 
			
		||||
//   - OTEL_EXPORTER_OTLP_PROTOCOL: Protocol selection (grpc, http/protobuf)
 | 
			
		||||
//   - OTEL_TRACES_EXPORTER: Exporter type (otlp, autoexport)
 | 
			
		||||
//   - OTEL_RESOURCE_ATTRIBUTES: Additional resource attributes
 | 
			
		||||
//
 | 
			
		||||
// Example usage:
 | 
			
		||||
//
 | 
			
		||||
//	cfg := &tracing.TracerConfig{
 | 
			
		||||
//		ServiceName: "my-service",
 | 
			
		||||
//		Environment: "production",
 | 
			
		||||
//		Endpoint:    "https://otlp.example.com:4317",
 | 
			
		||||
//	}
 | 
			
		||||
//	shutdown, err := tracing.InitTracer(ctx, cfg)
 | 
			
		||||
//	if err != nil {
 | 
			
		||||
//		log.Fatal(err)
 | 
			
		||||
//	}
 | 
			
		||||
//	defer shutdown(ctx)
 | 
			
		||||
package tracing
 | 
			
		||||
 | 
			
		||||
// todo, review:
 | 
			
		||||
@@ -5,26 +38,23 @@ package tracing
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"crypto/x509"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"log/slog"
 | 
			
		||||
	"os"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"go.ntppool.org/common/logger"
 | 
			
		||||
	"go.ntppool.org/common/internal/tracerconfig"
 | 
			
		||||
	"go.ntppool.org/common/version"
 | 
			
		||||
	"google.golang.org/grpc/credentials"
 | 
			
		||||
 | 
			
		||||
	"go.opentelemetry.io/contrib/exporters/autoexport"
 | 
			
		||||
	"go.opentelemetry.io/otel"
 | 
			
		||||
	"go.opentelemetry.io/otel/attribute"
 | 
			
		||||
	"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
 | 
			
		||||
	"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
 | 
			
		||||
	"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
 | 
			
		||||
	logglobal "go.opentelemetry.io/otel/log/global"
 | 
			
		||||
	"go.opentelemetry.io/otel/log/global"
 | 
			
		||||
	"go.opentelemetry.io/otel/propagation"
 | 
			
		||||
	sdklog "go.opentelemetry.io/otel/sdk/log"
 | 
			
		||||
	sdkmetric "go.opentelemetry.io/otel/sdk/metric"
 | 
			
		||||
	"go.opentelemetry.io/otel/sdk/resource"
 | 
			
		||||
	sdktrace "go.opentelemetry.io/otel/sdk/trace"
 | 
			
		||||
	semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
 | 
			
		||||
@@ -34,49 +64,106 @@ import (
 | 
			
		||||
const (
 | 
			
		||||
	// svcNameKey is the environment variable name that Service Name information will be read from.
 | 
			
		||||
	svcNameKey = "OTEL_SERVICE_NAME"
 | 
			
		||||
 | 
			
		||||
	otelExporterOTLPProtoEnvKey       = "OTEL_EXPORTER_OTLP_PROTOCOL"
 | 
			
		||||
	otelExporterOTLPTracesProtoEnvKey = "OTEL_EXPORTER_OTLP_TRACES_PROTOCOL"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var errInvalidOTLPProtocol = errors.New("invalid OTLP protocol - should be one of ['grpc', 'http/protobuf']")
 | 
			
		||||
// createOTLPLogExporter creates an OTLP log exporter using the provided configuration.
 | 
			
		||||
// This function is used as the LogExporterFactory for the tracerconfig bridge.
 | 
			
		||||
func createOTLPLogExporter(ctx context.Context, cfg *tracerconfig.Config) (sdklog.Exporter, error) {
 | 
			
		||||
	return tracerconfig.CreateOTLPLogExporter(ctx, cfg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// createOTLPMetricExporter creates an OTLP metric exporter using the provided configuration.
 | 
			
		||||
// This function is used as the MetricExporterFactory for the tracerconfig bridge.
 | 
			
		||||
func createOTLPMetricExporter(ctx context.Context, cfg *tracerconfig.Config) (sdkmetric.Exporter, error) {
 | 
			
		||||
	return tracerconfig.CreateOTLPMetricExporter(ctx, cfg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// createOTLPTraceExporter creates an OTLP trace exporter using the provided configuration.
 | 
			
		||||
// This function is used as the TraceExporterFactory for the tracerconfig bridge.
 | 
			
		||||
func createOTLPTraceExporter(ctx context.Context, cfg *tracerconfig.Config) (sdktrace.SpanExporter, error) {
 | 
			
		||||
	return tracerconfig.CreateOTLPTraceExporter(ctx, cfg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// https://github.com/open-telemetry/opentelemetry-go/blob/main/exporters/otlp/otlptrace/otlptracehttp/example_test.go
 | 
			
		||||
 | 
			
		||||
// TpShutdownFunc represents a function that gracefully shuts down telemetry providers.
 | 
			
		||||
// It should be called during application shutdown to ensure all telemetry data is flushed
 | 
			
		||||
// and exporters are properly closed. The context can be used to set shutdown timeouts.
 | 
			
		||||
type TpShutdownFunc func(ctx context.Context) error
 | 
			
		||||
 | 
			
		||||
// Tracer returns the configured OpenTelemetry tracer for the NTP Pool project.
 | 
			
		||||
// This tracer should be used for creating spans and distributed tracing throughout
 | 
			
		||||
// the application. It uses the global tracer provider set up by InitTracer/SetupSDK.
 | 
			
		||||
func Tracer() trace.Tracer {
 | 
			
		||||
	traceProvider := otel.GetTracerProvider()
 | 
			
		||||
	return traceProvider.Tracer("ntppool-tracer")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Start creates a new span with the given name and options using the configured tracer.
 | 
			
		||||
// This is a convenience function that wraps the standard OpenTelemetry span creation.
 | 
			
		||||
// It returns a new context containing the span and the span itself for further configuration.
 | 
			
		||||
//
 | 
			
		||||
// The returned context should be used for downstream operations to maintain trace correlation.
 | 
			
		||||
func Start(ctx context.Context, spanName string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
 | 
			
		||||
	return Tracer().Start(ctx, spanName, opts...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type GetClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
 | 
			
		||||
// GetClientCertificate is an alias for the type defined in tracerconfig.
 | 
			
		||||
// This maintains backward compatibility for existing code.
 | 
			
		||||
type GetClientCertificate = tracerconfig.GetClientCertificate
 | 
			
		||||
 | 
			
		||||
// TracerConfig provides configuration options for OpenTelemetry tracing setup.
 | 
			
		||||
// It supplements standard OpenTelemetry environment variables with additional
 | 
			
		||||
// NTP Pool-specific configuration including TLS settings for secure OTLP export.
 | 
			
		||||
type TracerConfig struct {
 | 
			
		||||
	ServiceName string
 | 
			
		||||
	Environment string
 | 
			
		||||
	Endpoint    string
 | 
			
		||||
	EndpointURL string
 | 
			
		||||
	ServiceName string // Service name for resource identification (overrides OTEL_SERVICE_NAME)
 | 
			
		||||
	Environment string // Deployment environment (development, staging, production)
 | 
			
		||||
	Endpoint    string // OTLP endpoint hostname/port (e.g., "otlp.example.com:4317")
 | 
			
		||||
	EndpointURL string // Complete OTLP endpoint URL (e.g., "https://otlp.example.com:4317/v1/traces")
 | 
			
		||||
 | 
			
		||||
	CertificateProvider GetClientCertificate
 | 
			
		||||
	RootCAs             *x509.CertPool
 | 
			
		||||
	CertificateProvider GetClientCertificate // Client certificate provider for mutual TLS
 | 
			
		||||
	RootCAs             *x509.CertPool       // CA certificate pool for server verification
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// InitTracer initializes the OpenTelemetry SDK with the provided configuration.
 | 
			
		||||
// This is the main entry point for setting up distributed tracing in applications.
 | 
			
		||||
//
 | 
			
		||||
// The function configures trace and log providers, sets up OTLP exporters,
 | 
			
		||||
// and returns a shutdown function that must be called during application termination.
 | 
			
		||||
//
 | 
			
		||||
// Returns a shutdown function and an error. The shutdown function should be called
 | 
			
		||||
// with a context that has an appropriate timeout for graceful shutdown.
 | 
			
		||||
func InitTracer(ctx context.Context, cfg *TracerConfig) (TpShutdownFunc, error) {
 | 
			
		||||
	// todo: setup environment from cfg
 | 
			
		||||
	return SetupSDK(ctx, cfg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetupSDK performs the complete OpenTelemetry SDK initialization including resource
 | 
			
		||||
// discovery, exporter configuration, provider setup, and shutdown function creation.
 | 
			
		||||
//
 | 
			
		||||
// The function automatically discovers system resources (service info, host, container,
 | 
			
		||||
// process details) and configures both trace and log exporters. It supports multiple
 | 
			
		||||
// OTLP protocols (gRPC, HTTP) and handles TLS configuration for secure deployments.
 | 
			
		||||
//
 | 
			
		||||
// The returned shutdown function coordinates graceful shutdown of all telemetry
 | 
			
		||||
// components in the reverse order of their initialization.
 | 
			
		||||
func SetupSDK(ctx context.Context, cfg *TracerConfig) (shutdown TpShutdownFunc, err error) {
 | 
			
		||||
	if cfg == nil {
 | 
			
		||||
		cfg = &TracerConfig{}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log := logger.Setup()
 | 
			
		||||
	// Store configuration for use by logger and metrics packages via bridge
 | 
			
		||||
	bridgeConfig := &tracerconfig.Config{
 | 
			
		||||
		ServiceName:         cfg.ServiceName,
 | 
			
		||||
		Environment:         cfg.Environment,
 | 
			
		||||
		Endpoint:            cfg.Endpoint,
 | 
			
		||||
		EndpointURL:         cfg.EndpointURL,
 | 
			
		||||
		CertificateProvider: cfg.CertificateProvider,
 | 
			
		||||
		RootCAs:             cfg.RootCAs,
 | 
			
		||||
	}
 | 
			
		||||
	tracerconfig.Store(ctx, bridgeConfig, createOTLPLogExporter, createOTLPMetricExporter, createOTLPTraceExporter)
 | 
			
		||||
 | 
			
		||||
	log := slog.Default()
 | 
			
		||||
 | 
			
		||||
	if serviceName := os.Getenv(svcNameKey); len(serviceName) == 0 {
 | 
			
		||||
		if len(cfg.ServiceName) > 0 {
 | 
			
		||||
@@ -117,13 +204,21 @@ func SetupSDK(ctx context.Context, cfg *TracerConfig) (shutdown TpShutdownFunc,
 | 
			
		||||
 | 
			
		||||
	var shutdownFuncs []func(context.Context) error
 | 
			
		||||
	shutdown = func(ctx context.Context) error {
 | 
			
		||||
		// Force flush the global logger provider before shutting down anything else
 | 
			
		||||
		if loggerProvider := global.GetLoggerProvider(); loggerProvider != nil {
 | 
			
		||||
			if sdkProvider, ok := loggerProvider.(*sdklog.LoggerProvider); ok {
 | 
			
		||||
				if flushErr := sdkProvider.ForceFlush(ctx); flushErr != nil {
 | 
			
		||||
					log.Warn("logger provider force flush failed", "err", flushErr)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var err error
 | 
			
		||||
		// need to shutdown the providers first,
 | 
			
		||||
		// exporters after which is the opposite
 | 
			
		||||
		// order they are setup.
 | 
			
		||||
		slices.Reverse(shutdownFuncs)
 | 
			
		||||
		for _, fn := range shutdownFuncs {
 | 
			
		||||
			// log.Warn("shutting down", "fn", fn)
 | 
			
		||||
			err = errors.Join(err, fn(ctx))
 | 
			
		||||
		}
 | 
			
		||||
		shutdownFuncs = nil
 | 
			
		||||
@@ -145,9 +240,9 @@ func SetupSDK(ctx context.Context, cfg *TracerConfig) (shutdown TpShutdownFunc,
 | 
			
		||||
 | 
			
		||||
	switch os.Getenv("OTEL_TRACES_EXPORTER") {
 | 
			
		||||
	case "":
 | 
			
		||||
		spanExporter, err = newOLTPExporter(ctx, cfg)
 | 
			
		||||
		spanExporter, err = createOTLPTraceExporter(ctx, bridgeConfig)
 | 
			
		||||
	case "otlp":
 | 
			
		||||
		spanExporter, err = newOLTPExporter(ctx, cfg)
 | 
			
		||||
		spanExporter, err = createOTLPTraceExporter(ctx, bridgeConfig)
 | 
			
		||||
	default:
 | 
			
		||||
		// log.Debug("OTEL_TRACES_EXPORTER", "fallback", os.Getenv("OTEL_TRACES_EXPORTER"))
 | 
			
		||||
		spanExporter, err = autoexport.NewSpanExporter(ctx)
 | 
			
		||||
@@ -158,13 +253,6 @@ func SetupSDK(ctx context.Context, cfg *TracerConfig) (shutdown TpShutdownFunc,
 | 
			
		||||
	}
 | 
			
		||||
	shutdownFuncs = append(shutdownFuncs, spanExporter.Shutdown)
 | 
			
		||||
 | 
			
		||||
	logExporter, err := autoexport.NewLogExporter(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		handleErr(err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	shutdownFuncs = append(shutdownFuncs, logExporter.Shutdown)
 | 
			
		||||
 | 
			
		||||
	// Set up trace provider.
 | 
			
		||||
	tracerProvider, err := newTraceProvider(spanExporter, res)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -174,19 +262,6 @@ func SetupSDK(ctx context.Context, cfg *TracerConfig) (shutdown TpShutdownFunc,
 | 
			
		||||
	shutdownFuncs = append(shutdownFuncs, tracerProvider.Shutdown)
 | 
			
		||||
	otel.SetTracerProvider(tracerProvider)
 | 
			
		||||
 | 
			
		||||
	logProvider := sdklog.NewLoggerProvider(sdklog.WithResource(res),
 | 
			
		||||
		sdklog.WithProcessor(
 | 
			
		||||
			sdklog.NewBatchProcessor(logExporter, sdklog.WithExportBufferSize(10)),
 | 
			
		||||
		),
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	logglobal.SetLoggerProvider(logProvider)
 | 
			
		||||
	shutdownFuncs = append(shutdownFuncs, func(ctx context.Context) error {
 | 
			
		||||
		logProvider.ForceFlush(ctx)
 | 
			
		||||
		return logProvider.Shutdown(ctx)
 | 
			
		||||
	},
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		handleErr(err)
 | 
			
		||||
		return
 | 
			
		||||
@@ -195,74 +270,6 @@ func SetupSDK(ctx context.Context, cfg *TracerConfig) (shutdown TpShutdownFunc,
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newOLTPExporter(ctx context.Context, cfg *TracerConfig) (sdktrace.SpanExporter, error) {
 | 
			
		||||
	log := logger.Setup()
 | 
			
		||||
 | 
			
		||||
	var tlsConfig *tls.Config
 | 
			
		||||
 | 
			
		||||
	if cfg.CertificateProvider != nil {
 | 
			
		||||
		tlsConfig = &tls.Config{
 | 
			
		||||
			GetClientCertificate: cfg.CertificateProvider,
 | 
			
		||||
			RootCAs:              cfg.RootCAs,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	proto := os.Getenv(otelExporterOTLPTracesProtoEnvKey)
 | 
			
		||||
	if proto == "" {
 | 
			
		||||
		proto = os.Getenv(otelExporterOTLPProtoEnvKey)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Fallback to default, http/protobuf.
 | 
			
		||||
	if proto == "" {
 | 
			
		||||
		proto = "http/protobuf"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var client otlptrace.Client
 | 
			
		||||
 | 
			
		||||
	switch proto {
 | 
			
		||||
	case "grpc":
 | 
			
		||||
		opts := []otlptracegrpc.Option{
 | 
			
		||||
			otlptracegrpc.WithCompressor("gzip"),
 | 
			
		||||
		}
 | 
			
		||||
		if tlsConfig != nil {
 | 
			
		||||
			opts = append(opts, otlptracegrpc.WithTLSCredentials(credentials.NewTLS(tlsConfig)))
 | 
			
		||||
		}
 | 
			
		||||
		if len(cfg.Endpoint) > 0 {
 | 
			
		||||
			log.Info("adding option", "Endpoint", cfg.Endpoint)
 | 
			
		||||
			opts = append(opts, otlptracegrpc.WithEndpoint(cfg.Endpoint))
 | 
			
		||||
		}
 | 
			
		||||
		if len(cfg.EndpointURL) > 0 {
 | 
			
		||||
			log.Info("adding option", "EndpointURL", cfg.EndpointURL)
 | 
			
		||||
			opts = append(opts, otlptracegrpc.WithEndpointURL(cfg.EndpointURL))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		client = otlptracegrpc.NewClient(opts...)
 | 
			
		||||
	case "http/protobuf", "http/json":
 | 
			
		||||
		opts := []otlptracehttp.Option{
 | 
			
		||||
			otlptracehttp.WithCompression(otlptracehttp.GzipCompression),
 | 
			
		||||
		}
 | 
			
		||||
		if tlsConfig != nil {
 | 
			
		||||
			opts = append(opts, otlptracehttp.WithTLSClientConfig(tlsConfig))
 | 
			
		||||
		}
 | 
			
		||||
		if len(cfg.Endpoint) > 0 {
 | 
			
		||||
			opts = append(opts, otlptracehttp.WithEndpoint(cfg.Endpoint))
 | 
			
		||||
		}
 | 
			
		||||
		if len(cfg.EndpointURL) > 0 {
 | 
			
		||||
			opts = append(opts, otlptracehttp.WithEndpointURL(cfg.EndpointURL))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		client = otlptracehttp.NewClient(opts...)
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, errInvalidOTLPProtocol
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	exporter, err := otlptrace.New(ctx, client)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.ErrorContext(ctx, "creating OTLP trace exporter", "err", err)
 | 
			
		||||
	}
 | 
			
		||||
	return exporter, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newTraceProvider(traceExporter sdktrace.SpanExporter, res *resource.Resource) (*sdktrace.TracerProvider, error) {
 | 
			
		||||
	traceProvider := sdktrace.NewTracerProvider(
 | 
			
		||||
		sdktrace.WithResource(res),
 | 
			
		||||
 
 | 
			
		||||
@@ -1,3 +1,17 @@
 | 
			
		||||
// Package types provides shared data structures for the NTP Pool project.
 | 
			
		||||
//
 | 
			
		||||
// This package contains common types used across different NTP Pool services
 | 
			
		||||
// for data exchange, logging, and database operations. The types are designed
 | 
			
		||||
// to support JSON serialization for API responses and SQL database storage
 | 
			
		||||
// with automatic marshaling/unmarshaling.
 | 
			
		||||
//
 | 
			
		||||
// Current types include:
 | 
			
		||||
//   - LogScoreAttributes: NTP server scoring metadata for monitoring and analysis
 | 
			
		||||
//
 | 
			
		||||
// All types implement appropriate interfaces for:
 | 
			
		||||
//   - JSON serialization (json.Marshaler/json.Unmarshaler)
 | 
			
		||||
//   - SQL database storage (database/sql/driver.Valuer/sql.Scanner)
 | 
			
		||||
//   - String representation for logging and debugging
 | 
			
		||||
package types
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
@@ -6,17 +20,26 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// LogScoreAttributes contains metadata about NTP server scoring and monitoring results.
 | 
			
		||||
// This structure captures both NTP protocol-specific information (leap, stratum) and
 | 
			
		||||
// operational data (errors, warnings, response status) for analysis and alerting.
 | 
			
		||||
//
 | 
			
		||||
// The type supports JSON serialization for API responses and database storage
 | 
			
		||||
// via the database/sql/driver interfaces. Fields use omitempty tags to minimize
 | 
			
		||||
// JSON payload size when values are at their zero state.
 | 
			
		||||
type LogScoreAttributes struct {
 | 
			
		||||
	Leap       int8   `json:"leap,omitempty"`
 | 
			
		||||
	Stratum    int8   `json:"stratum,omitempty"`
 | 
			
		||||
	NoResponse bool   `json:"no_response,omitempty"`
 | 
			
		||||
	Error      string `json:"error,omitempty"`
 | 
			
		||||
	Warning    string `json:"warning,omitempty"`
 | 
			
		||||
	Leap       int8   `json:"leap,omitempty"`        // NTP leap indicator (0=no warning, 1=+1s, 2=-1s, 3=unsynchronized)
 | 
			
		||||
	Stratum    int8   `json:"stratum,omitempty"`     // NTP stratum level (1=primary, 2-15=secondary, 16=unsynchronized)
 | 
			
		||||
	NoResponse bool   `json:"no_response,omitempty"` // True if server failed to respond to NTP queries
 | 
			
		||||
	Error      string `json:"error,omitempty"`       // Error message if scoring failed
 | 
			
		||||
	Warning    string `json:"warning,omitempty"`     // Warning message for non-fatal issues
 | 
			
		||||
 | 
			
		||||
	FromLSID int `json:"from_ls_id,omitempty"`
 | 
			
		||||
	FromSSID int `json:"from_ss_id,omitempty"`
 | 
			
		||||
	FromLSID int `json:"from_ls_id,omitempty"` // Source log server ID for traceability
 | 
			
		||||
	FromSSID int `json:"from_ss_id,omitempty"` // Source scoring system ID for traceability
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// String returns a JSON representation of the LogScoreAttributes for logging and debugging.
 | 
			
		||||
// Returns an empty string if JSON marshaling fails.
 | 
			
		||||
func (lsa *LogScoreAttributes) String() string {
 | 
			
		||||
	b, err := json.Marshal(lsa)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -25,10 +48,17 @@ func (lsa *LogScoreAttributes) String() string {
 | 
			
		||||
	return string(b)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Value implements the database/sql/driver.Valuer interface for database storage.
 | 
			
		||||
// It serializes the LogScoreAttributes to JSON for storage in SQL databases.
 | 
			
		||||
// Returns the JSON bytes or an error if marshaling fails.
 | 
			
		||||
func (lsa *LogScoreAttributes) Value() (driver.Value, error) {
 | 
			
		||||
	return json.Marshal(lsa)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Scan implements the database/sql.Scanner interface for reading from SQL databases.
 | 
			
		||||
// It deserializes JSON data from the database back into LogScoreAttributes.
 | 
			
		||||
// Supports both []byte and string input types, with nil values treated as no-op.
 | 
			
		||||
// Returns an error if the input type is unsupported or JSON unmarshaling fails.
 | 
			
		||||
func (lsa *LogScoreAttributes) Scan(value any) error {
 | 
			
		||||
	var source []byte
 | 
			
		||||
	_t := LogScoreAttributes{}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										64
									
								
								ulid/ulid.go
									
									
									
									
									
								
							
							
						
						
									
										64
									
								
								ulid/ulid.go
									
									
									
									
									
								
							@@ -1,46 +1,44 @@
 | 
			
		||||
// Package ulid provides thread-safe ULID (Universally Unique Lexicographically Sortable Identifier) generation.
 | 
			
		||||
//
 | 
			
		||||
// ULIDs are 128-bit identifiers that are lexicographically sortable and contain
 | 
			
		||||
// a timestamp component. This package uses cryptographically secure random
 | 
			
		||||
// generation optimized for simplicity and performance in concurrent environments.
 | 
			
		||||
package ulid
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	cryptorand "crypto/rand"
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"io"
 | 
			
		||||
	mathrand "math/rand"
 | 
			
		||||
	"os"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	oklid "github.com/oklog/ulid/v2"
 | 
			
		||||
	"go.ntppool.org/common/logger"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var monotonicPool = sync.Pool{
 | 
			
		||||
	New: func() any {
 | 
			
		||||
		log := logger.Setup()
 | 
			
		||||
 | 
			
		||||
		var seed int64
 | 
			
		||||
		err := binary.Read(cryptorand.Reader, binary.BigEndian, &seed)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Error("crypto/rand error", "err", err)
 | 
			
		||||
			os.Exit(10)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		rand := mathrand.New(mathrand.NewSource(seed))
 | 
			
		||||
 | 
			
		||||
		inc := uint64(mathrand.Int63())
 | 
			
		||||
 | 
			
		||||
		// log.Printf("seed: %d", seed)
 | 
			
		||||
		// log.Printf("inc:  %d", inc)
 | 
			
		||||
 | 
			
		||||
		// inc = inc & ^uint64(1<<63) // only want 63 bits
 | 
			
		||||
		mono := oklid.Monotonic(rand, inc)
 | 
			
		||||
		return mono
 | 
			
		||||
	},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MakeULID generates a new ULID with the specified timestamp using cryptographically secure randomness.
 | 
			
		||||
// The function is thread-safe and optimized for high-concurrency environments.
 | 
			
		||||
//
 | 
			
		||||
// This implementation prioritizes simplicity and performance over strict monotonicity within
 | 
			
		||||
// the same millisecond. Each ULID is guaranteed to be unique and lexicographically sortable
 | 
			
		||||
// across different timestamps.
 | 
			
		||||
//
 | 
			
		||||
// Returns a pointer to the generated ULID or an error if generation fails.
 | 
			
		||||
// Generation should only fail under extreme circumstances (entropy exhaustion).
 | 
			
		||||
func MakeULID(t time.Time) (*oklid.ULID, error) {
 | 
			
		||||
	mono := monotonicPool.Get().(io.Reader)
 | 
			
		||||
 | 
			
		||||
	id, err := oklid.New(oklid.Timestamp(t), mono)
 | 
			
		||||
	id, err := oklid.New(oklid.Timestamp(t), cryptorand.Reader)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &id, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Make generates a new ULID with the current timestamp using cryptographically secure randomness.
 | 
			
		||||
// This is a convenience function equivalent to MakeULID(time.Now()).
 | 
			
		||||
//
 | 
			
		||||
// The function is thread-safe and optimized for high-concurrency environments.
 | 
			
		||||
//
 | 
			
		||||
// Returns a pointer to the generated ULID or an error if generation fails.
 | 
			
		||||
// Generation should only fail under extreme circumstances (entropy exhaustion).
 | 
			
		||||
func Make() (*oklid.ULID, error) {
 | 
			
		||||
	id, err := oklid.New(oklid.Now(), cryptorand.Reader)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,25 +1,336 @@
 | 
			
		||||
package ulid
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	cryptorand "crypto/rand"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	oklid "github.com/oklog/ulid/v2"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestULID(t *testing.T) {
 | 
			
		||||
func TestMakeULID(t *testing.T) {
 | 
			
		||||
	tm := time.Now()
 | 
			
		||||
	ul1, err := MakeULID(tm)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Logf("makeULID failed: %s", err)
 | 
			
		||||
		t.Fail()
 | 
			
		||||
		t.Fatalf("MakeULID failed: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	ul2, err := MakeULID(tm)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Logf("MakeULID failed: %s", err)
 | 
			
		||||
		t.Fail()
 | 
			
		||||
		t.Fatalf("MakeULID failed: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ul1 == nil || ul2 == nil {
 | 
			
		||||
		t.Fatal("MakeULID returned nil ULID")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ul1.String() == ul2.String() {
 | 
			
		||||
		t.Logf("ul1 and ul2 got the same string: %s", ul1.String())
 | 
			
		||||
		t.Fail()
 | 
			
		||||
		t.Errorf("ul1 and ul2 should be different: %s", ul1.String())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Verify they have the same timestamp
 | 
			
		||||
	if ul1.Time() != ul2.Time() {
 | 
			
		||||
		t.Errorf("ULIDs with same input time should have same timestamp: %d != %d", ul1.Time(), ul2.Time())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.Logf("ulid string 1 and 2: %s | %s", ul1.String(), ul2.String())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMake(t *testing.T) {
 | 
			
		||||
	// Test Make() function (uses current time)
 | 
			
		||||
	ul1, err := Make()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Make failed: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ul1 == nil {
 | 
			
		||||
		t.Fatal("Make returned nil ULID")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Sleep a bit and generate another
 | 
			
		||||
	time.Sleep(2 * time.Millisecond)
 | 
			
		||||
 | 
			
		||||
	ul2, err := Make()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Make failed: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Should be different ULIDs
 | 
			
		||||
	if ul1.String() == ul2.String() {
 | 
			
		||||
		t.Errorf("ULIDs from Make() should be different: %s", ul1.String())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Second should be later (or at least not earlier)
 | 
			
		||||
	if ul1.Time() > ul2.Time() {
 | 
			
		||||
		t.Errorf("second ULID should not have earlier timestamp: %d > %d", ul1.Time(), ul2.Time())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.Logf("Make() ULIDs: %s | %s", ul1.String(), ul2.String())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMakeULIDUniqueness(t *testing.T) {
 | 
			
		||||
	tm := time.Now()
 | 
			
		||||
	seen := make(map[string]bool)
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < 1000; i++ {
 | 
			
		||||
		ul, err := MakeULID(tm)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("MakeULID failed on iteration %d: %s", i, err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		str := ul.String()
 | 
			
		||||
		if seen[str] {
 | 
			
		||||
			t.Errorf("duplicate ULID generated: %s", str)
 | 
			
		||||
		}
 | 
			
		||||
		seen[str] = true
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMakeUniqueness(t *testing.T) {
 | 
			
		||||
	seen := make(map[string]bool)
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < 1000; i++ {
 | 
			
		||||
		ul, err := Make()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("Make failed on iteration %d: %s", i, err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		str := ul.String()
 | 
			
		||||
		if seen[str] {
 | 
			
		||||
			t.Errorf("duplicate ULID generated: %s", str)
 | 
			
		||||
		}
 | 
			
		||||
		seen[str] = true
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMakeULIDTimestampProgression(t *testing.T) {
 | 
			
		||||
	t1 := time.Now()
 | 
			
		||||
	ul1, err := MakeULID(t1)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("MakeULID failed: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Wait to ensure different timestamp
 | 
			
		||||
	time.Sleep(2 * time.Millisecond)
 | 
			
		||||
 | 
			
		||||
	t2 := time.Now()
 | 
			
		||||
	ul2, err := MakeULID(t2)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("MakeULID failed: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ul1.Time() >= ul2.Time() {
 | 
			
		||||
		t.Errorf("second ULID should have later timestamp: %d >= %d", ul1.Time(), ul2.Time())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ul1.Compare(*ul2) >= 0 {
 | 
			
		||||
		t.Errorf("second ULID should be greater: %s >= %s", ul1.String(), ul2.String())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMakeULIDConcurrency(t *testing.T) {
 | 
			
		||||
	const numGoroutines = 10
 | 
			
		||||
	const numULIDsPerGoroutine = 100
 | 
			
		||||
 | 
			
		||||
	var wg sync.WaitGroup
 | 
			
		||||
	ulidChan := make(chan *oklid.ULID, numGoroutines*numULIDsPerGoroutine)
 | 
			
		||||
	tm := time.Now()
 | 
			
		||||
 | 
			
		||||
	// Start multiple goroutines generating ULIDs concurrently
 | 
			
		||||
	for i := 0; i < numGoroutines; i++ {
 | 
			
		||||
		wg.Add(1)
 | 
			
		||||
		go func() {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
			for j := 0; j < numULIDsPerGoroutine; j++ {
 | 
			
		||||
				ul, err := MakeULID(tm)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					t.Errorf("MakeULID failed: %s", err)
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
				ulidChan <- ul
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
	close(ulidChan)
 | 
			
		||||
 | 
			
		||||
	// Collect all ULIDs and check uniqueness
 | 
			
		||||
	seen := make(map[string]bool)
 | 
			
		||||
	count := 0
 | 
			
		||||
 | 
			
		||||
	for ul := range ulidChan {
 | 
			
		||||
		str := ul.String()
 | 
			
		||||
		if seen[str] {
 | 
			
		||||
			t.Errorf("duplicate ULID generated in concurrent test: %s", str)
 | 
			
		||||
		}
 | 
			
		||||
		seen[str] = true
 | 
			
		||||
		count++
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if count != numGoroutines*numULIDsPerGoroutine {
 | 
			
		||||
		t.Errorf("expected %d ULIDs, got %d", numGoroutines*numULIDsPerGoroutine, count)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMakeConcurrency(t *testing.T) {
 | 
			
		||||
	const numGoroutines = 10
 | 
			
		||||
	const numULIDsPerGoroutine = 100
 | 
			
		||||
 | 
			
		||||
	var wg sync.WaitGroup
 | 
			
		||||
	ulidChan := make(chan *oklid.ULID, numGoroutines*numULIDsPerGoroutine)
 | 
			
		||||
 | 
			
		||||
	// Start multiple goroutines generating ULIDs concurrently
 | 
			
		||||
	for i := 0; i < numGoroutines; i++ {
 | 
			
		||||
		wg.Add(1)
 | 
			
		||||
		go func() {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
			for j := 0; j < numULIDsPerGoroutine; j++ {
 | 
			
		||||
				ul, err := Make()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					t.Errorf("Make failed: %s", err)
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
				ulidChan <- ul
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
	close(ulidChan)
 | 
			
		||||
 | 
			
		||||
	// Collect all ULIDs and check uniqueness
 | 
			
		||||
	seen := make(map[string]bool)
 | 
			
		||||
	count := 0
 | 
			
		||||
 | 
			
		||||
	for ul := range ulidChan {
 | 
			
		||||
		str := ul.String()
 | 
			
		||||
		if seen[str] {
 | 
			
		||||
			t.Errorf("duplicate ULID generated in concurrent test: %s", str)
 | 
			
		||||
		}
 | 
			
		||||
		seen[str] = true
 | 
			
		||||
		count++
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if count != numGoroutines*numULIDsPerGoroutine {
 | 
			
		||||
		t.Errorf("expected %d ULIDs, got %d", numGoroutines*numULIDsPerGoroutine, count)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMakeULIDErrorHandling(t *testing.T) {
 | 
			
		||||
	// Test with various timestamps
 | 
			
		||||
	timestamps := []time.Time{
 | 
			
		||||
		time.Unix(0, 0),           // Unix epoch
 | 
			
		||||
		time.Now(),                // Current time
 | 
			
		||||
		time.Now().Add(time.Hour), // Future time
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for i, tm := range timestamps {
 | 
			
		||||
		ul, err := MakeULID(tm)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Errorf("MakeULID failed with timestamp %d: %s", i, err)
 | 
			
		||||
		}
 | 
			
		||||
		if ul == nil {
 | 
			
		||||
			t.Errorf("MakeULID returned nil ULID with timestamp %d", i)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMakeULIDLexicographicOrdering(t *testing.T) {
 | 
			
		||||
	var ulids []*oklid.ULID
 | 
			
		||||
	var timestamps []time.Time
 | 
			
		||||
 | 
			
		||||
	// Generate ULIDs with increasing timestamps
 | 
			
		||||
	for i := 0; i < 10; i++ {
 | 
			
		||||
		tm := time.Now().Add(time.Duration(i) * time.Millisecond)
 | 
			
		||||
		timestamps = append(timestamps, tm)
 | 
			
		||||
 | 
			
		||||
		ul, err := MakeULID(tm)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("MakeULID failed: %s", err)
 | 
			
		||||
		}
 | 
			
		||||
		ulids = append(ulids, ul)
 | 
			
		||||
 | 
			
		||||
		// Small delay to ensure different timestamps
 | 
			
		||||
		time.Sleep(time.Millisecond)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Sort ULID strings lexicographically
 | 
			
		||||
	ulidStrings := make([]string, len(ulids))
 | 
			
		||||
	for i, ul := range ulids {
 | 
			
		||||
		ulidStrings[i] = ul.String()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	originalOrder := make([]string, len(ulidStrings))
 | 
			
		||||
	copy(originalOrder, ulidStrings)
 | 
			
		||||
 | 
			
		||||
	sort.Strings(ulidStrings)
 | 
			
		||||
 | 
			
		||||
	// Verify lexicographic order matches chronological order
 | 
			
		||||
	for i := 0; i < len(originalOrder); i++ {
 | 
			
		||||
		if originalOrder[i] != ulidStrings[i] {
 | 
			
		||||
			t.Errorf("lexicographic order doesn't match chronological order at index %d: %s != %s",
 | 
			
		||||
				i, originalOrder[i], ulidStrings[i])
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Benchmark ULID generation performance
 | 
			
		||||
func BenchmarkMakeULID(b *testing.B) {
 | 
			
		||||
	tm := time.Now()
 | 
			
		||||
 | 
			
		||||
	b.ResetTimer()
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		_, err := MakeULID(tm)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			b.Fatalf("MakeULID failed: %s", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Benchmark Make function
 | 
			
		||||
func BenchmarkMake(b *testing.B) {
 | 
			
		||||
	b.ResetTimer()
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		_, err := Make()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			b.Fatalf("Make failed: %s", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Benchmark concurrent ULID generation
 | 
			
		||||
func BenchmarkMakeULIDConcurrent(b *testing.B) {
 | 
			
		||||
	tm := time.Now()
 | 
			
		||||
 | 
			
		||||
	b.RunParallel(func(pb *testing.PB) {
 | 
			
		||||
		for pb.Next() {
 | 
			
		||||
			_, err := MakeULID(tm)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				b.Fatalf("MakeULID failed: %s", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Benchmark concurrent Make function
 | 
			
		||||
func BenchmarkMakeConcurrent(b *testing.B) {
 | 
			
		||||
	b.RunParallel(func(pb *testing.PB) {
 | 
			
		||||
		for pb.Next() {
 | 
			
		||||
			_, err := Make()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				b.Fatalf("Make failed: %s", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Benchmark random number generation
 | 
			
		||||
func BenchmarkCryptoRand(b *testing.B) {
 | 
			
		||||
	buf := make([]byte, 10) // ULID entropy size
 | 
			
		||||
	b.ResetTimer()
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		cryptorand.Read(buf)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,3 +1,27 @@
 | 
			
		||||
// Package version provides build metadata and version information management.
 | 
			
		||||
//
 | 
			
		||||
// This package manages application version information including semantic version,
 | 
			
		||||
// Git revision, build time, and provides integration with CLI frameworks (Cobra, Kong)
 | 
			
		||||
// and Prometheus metrics for operational visibility.
 | 
			
		||||
//
 | 
			
		||||
// Version information can be injected at build time using ldflags:
 | 
			
		||||
//
 | 
			
		||||
//	go build -ldflags "-X go.ntppool.org/common/version.VERSION=v1.0.0 \
 | 
			
		||||
//	  -X go.ntppool.org/common/version.buildTime=2023-01-01T00:00:00Z \
 | 
			
		||||
//	  -X go.ntppool.org/common/version.gitVersion=abc123"
 | 
			
		||||
//
 | 
			
		||||
// Build time supports both Unix epoch timestamps and RFC3339 format:
 | 
			
		||||
//
 | 
			
		||||
//	# Unix epoch (simpler, recommended)
 | 
			
		||||
//	go build -ldflags "-X go.ntppool.org/common/version.buildTime=$(date +%s)"
 | 
			
		||||
//
 | 
			
		||||
//	# RFC3339 format
 | 
			
		||||
//	go build -ldflags "-X go.ntppool.org/common/version.buildTime=$(date -u +%Y-%m-%dT%H:%M:%SZ)"
 | 
			
		||||
//
 | 
			
		||||
// Both formats are automatically converted to RFC3339 for consistent output. The buildTime
 | 
			
		||||
// parameter takes priority over Git commit time. If buildTime is not specified, the package
 | 
			
		||||
// automatically extracts build information from Go's debug.BuildInfo when available,
 | 
			
		||||
// providing fallback values for VCS time and revision.
 | 
			
		||||
package version
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
@@ -5,31 +29,60 @@ import (
 | 
			
		||||
	"log/slog"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"runtime/debug"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/prometheus/client_golang/prometheus"
 | 
			
		||||
	"github.com/spf13/cobra"
 | 
			
		||||
	"golang.org/x/mod/semver"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// VERSION has the current software version (set in the build process)
 | 
			
		||||
// VERSION contains the current software version (typically set during the build process via ldflags).
 | 
			
		||||
// If not set, defaults to "dev-snapshot". The version should follow semantic versioning.
 | 
			
		||||
var (
 | 
			
		||||
	VERSION     string
 | 
			
		||||
	buildTime   string
 | 
			
		||||
	gitVersion  string
 | 
			
		||||
	gitModified bool
 | 
			
		||||
	VERSION     string // Semantic version (e.g., "1.0.0" or "v1.0.0")
 | 
			
		||||
	buildTime   string // Build timestamp (Unix epoch or RFC3339, normalized to RFC3339)
 | 
			
		||||
	gitVersion  string // Git commit hash
 | 
			
		||||
	gitModified bool   // Whether the working tree was modified during build
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// info holds the consolidated version information extracted from build variables and debug.BuildInfo.
 | 
			
		||||
var info Info
 | 
			
		||||
 | 
			
		||||
// parseBuildTime converts a build time string to RFC3339 format.
 | 
			
		||||
// Supports both Unix epoch timestamps (numeric strings) and RFC3339 format.
 | 
			
		||||
// Returns the input unchanged if it cannot be parsed as either format.
 | 
			
		||||
func parseBuildTime(s string) string {
 | 
			
		||||
	if s == "" {
 | 
			
		||||
		return s
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Try parsing as Unix epoch timestamp (numeric string)
 | 
			
		||||
	if epoch, err := strconv.ParseInt(s, 10, 64); err == nil {
 | 
			
		||||
		return time.Unix(epoch, 0).UTC().Format(time.RFC3339)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Try parsing as RFC3339 to validate format
 | 
			
		||||
	if _, err := time.Parse(time.RFC3339, s); err == nil {
 | 
			
		||||
		return s // Already in RFC3339 format
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Return original string if neither format works (graceful fallback)
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Info represents structured version and build information.
 | 
			
		||||
// This struct is used for JSON serialization and programmatic access to build metadata.
 | 
			
		||||
type Info struct {
 | 
			
		||||
	Version     string `json:",omitempty"`
 | 
			
		||||
	GitRev      string `json:",omitempty"`
 | 
			
		||||
	GitRevShort string `json:",omitempty"`
 | 
			
		||||
	BuildTime   string `json:",omitempty"`
 | 
			
		||||
	Version     string `json:",omitempty"` // Semantic version with "v" prefix
 | 
			
		||||
	GitRev      string `json:",omitempty"` // Full Git commit hash
 | 
			
		||||
	GitRevShort string `json:",omitempty"` // Shortened Git commit hash (7 characters)
 | 
			
		||||
	BuildTime   string `json:",omitempty"` // Build timestamp
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	buildTime = parseBuildTime(buildTime)
 | 
			
		||||
	info.BuildTime = buildTime
 | 
			
		||||
	info.GitRev = gitVersion
 | 
			
		||||
 | 
			
		||||
@@ -49,9 +102,9 @@ func init() {
 | 
			
		||||
			switch h.Key {
 | 
			
		||||
			case "vcs.time":
 | 
			
		||||
				if len(buildTime) == 0 {
 | 
			
		||||
					buildTime = h.Value
 | 
			
		||||
					buildTime = parseBuildTime(h.Value)
 | 
			
		||||
					info.BuildTime = buildTime
 | 
			
		||||
				}
 | 
			
		||||
				info.BuildTime = h.Value
 | 
			
		||||
			case "vcs.revision":
 | 
			
		||||
				// https://blog.carlmjohnson.net/post/2023/golang-git-hash-how-to/
 | 
			
		||||
				// todo: use BuildInfo.Main.Version if revision is empty
 | 
			
		||||
@@ -79,10 +132,16 @@ func init() {
 | 
			
		||||
	Version()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// VersionCmd creates a Cobra command for displaying version information.
 | 
			
		||||
// The name parameter is used as a prefix in the output (e.g., "myapp v1.0.0").
 | 
			
		||||
// Returns a configured cobra.Command that can be added to any CLI application.
 | 
			
		||||
func VersionCmd(name string) *cobra.Command {
 | 
			
		||||
	versionCmd := &cobra.Command{
 | 
			
		||||
		Use:   "version",
 | 
			
		||||
		Short: "Print version and build information",
 | 
			
		||||
		Long: `Print detailed version information including semantic version,
 | 
			
		||||
Git revision, build time, and Go version. Build information is automatically
 | 
			
		||||
extracted from Go's debug.BuildInfo when available.`,
 | 
			
		||||
		Run: func(cmd *cobra.Command, args []string) {
 | 
			
		||||
			ver := Version()
 | 
			
		||||
			fmt.Printf("%s %s\n", name, ver)
 | 
			
		||||
@@ -91,15 +150,23 @@ func VersionCmd(name string) *cobra.Command {
 | 
			
		||||
	return versionCmd
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// KongVersionCmd provides a Kong CLI framework compatible version command.
 | 
			
		||||
// The Name field should be set to the application name for proper output formatting.
 | 
			
		||||
type KongVersionCmd struct {
 | 
			
		||||
	Name string `kong:"-"`
 | 
			
		||||
	Name string `kong:"-"` // Application name, excluded from Kong parsing
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Run executes the version command for Kong CLI framework.
 | 
			
		||||
// Prints the application name and version information to stdout.
 | 
			
		||||
func (cmd *KongVersionCmd) Run() error {
 | 
			
		||||
	fmt.Printf("%s %s\n", cmd.Name, Version())
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RegisterMetric registers a Prometheus gauge metric with build information.
 | 
			
		||||
// If name is provided, it creates a metric named "{name}_build_info", otherwise "build_info".
 | 
			
		||||
// The metric includes labels for version, build time, Git time, and Git revision.
 | 
			
		||||
// This is useful for exposing build information in monitoring systems.
 | 
			
		||||
func RegisterMetric(name string, registry prometheus.Registerer) {
 | 
			
		||||
	if len(name) > 0 {
 | 
			
		||||
		name = strings.ReplaceAll(name, "-", "_")
 | 
			
		||||
@@ -110,13 +177,13 @@ func RegisterMetric(name string, registry prometheus.Registerer) {
 | 
			
		||||
	buildInfo := prometheus.NewGaugeVec(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: name,
 | 
			
		||||
			Help: "Build information",
 | 
			
		||||
			Help: "Build information including version, build time, and git revision",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{
 | 
			
		||||
			"version",
 | 
			
		||||
			"buildtime",
 | 
			
		||||
			"gittime",
 | 
			
		||||
			"git",
 | 
			
		||||
			"version",   // Combined version/git format (e.g., "v1.0.0/abc123")
 | 
			
		||||
			"buildtime", // Build timestamp from ldflags
 | 
			
		||||
			"gittime",   // Git commit timestamp from VCS info
 | 
			
		||||
			"git",       // Full Git commit hash
 | 
			
		||||
		},
 | 
			
		||||
	)
 | 
			
		||||
	registry.MustRegister(buildInfo)
 | 
			
		||||
@@ -131,12 +198,20 @@ func RegisterMetric(name string, registry prometheus.Registerer) {
 | 
			
		||||
	).Set(1)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// v caches the formatted version string to avoid repeated computation.
 | 
			
		||||
var v string
 | 
			
		||||
 | 
			
		||||
// VersionInfo returns the structured version information.
 | 
			
		||||
// This provides programmatic access to version details for JSON serialization
 | 
			
		||||
// or other structured uses.
 | 
			
		||||
func VersionInfo() Info {
 | 
			
		||||
	return info
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Version returns a human-readable version string suitable for display.
 | 
			
		||||
// The format includes semantic version, Git revision, build time, and Go version.
 | 
			
		||||
// Example: "v1.0.0/abc123f-M (2023-01-01T00:00:00Z, go1.21.0)"
 | 
			
		||||
// The "-M" suffix indicates the working tree was modified during build.
 | 
			
		||||
func Version() string {
 | 
			
		||||
	if len(v) > 0 {
 | 
			
		||||
		return v
 | 
			
		||||
@@ -164,10 +239,20 @@ func Version() string {
 | 
			
		||||
	return v
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CheckVersion compares a version against a minimum required version.
 | 
			
		||||
// Returns true if the version meets or exceeds the minimum requirement.
 | 
			
		||||
//
 | 
			
		||||
// Special handling:
 | 
			
		||||
//   - "dev-snapshot" is always considered valid (returns true)
 | 
			
		||||
//   - Git hash suffixes (e.g., "v1.0.0/abc123") are stripped before comparison
 | 
			
		||||
//   - Uses semantic version comparison rules
 | 
			
		||||
//
 | 
			
		||||
// Both version and minimumVersion should follow semantic versioning with "v" prefix.
 | 
			
		||||
func CheckVersion(version, minimumVersion string) bool {
 | 
			
		||||
	if version == "dev-snapshot" {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	// Strip Git hash suffix if present (e.g., "v1.0.0/abc123" -> "v1.0.0")
 | 
			
		||||
	if idx := strings.Index(version, "/"); idx >= 0 {
 | 
			
		||||
		version = version[0:idx]
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										414
									
								
								version/version_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										414
									
								
								version/version_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,414 @@
 | 
			
		||||
package version
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/prometheus/client_golang/prometheus"
 | 
			
		||||
	dto "github.com/prometheus/client_model/go"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestCheckVersion(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		In       string
 | 
			
		||||
		Min      string
 | 
			
		||||
		Expected bool
 | 
			
		||||
	}{
 | 
			
		||||
		// Basic version comparisons
 | 
			
		||||
		{"v3.8.4", "v3.8.5", false},
 | 
			
		||||
		{"v3.9.3", "v3.8.5", true},
 | 
			
		||||
		{"v3.8.5", "v3.8.5", true},
 | 
			
		||||
		// Dev snapshot should always pass
 | 
			
		||||
		{"dev-snapshot", "v3.8.5", true},
 | 
			
		||||
		{"dev-snapshot", "v99.99.99", true},
 | 
			
		||||
		// Versions with Git hashes should be stripped
 | 
			
		||||
		{"v3.8.5/abc123", "v3.8.5", true},
 | 
			
		||||
		{"v3.8.4/abc123", "v3.8.5", false},
 | 
			
		||||
		{"v3.9.0/def456", "v3.8.5", true},
 | 
			
		||||
		// Pre-release versions
 | 
			
		||||
		{"v3.8.5-alpha", "v3.8.5", false},
 | 
			
		||||
		{"v3.8.5", "v3.8.5-alpha", true},
 | 
			
		||||
		{"v3.8.5-beta", "v3.8.5-alpha", true},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, d := range tests {
 | 
			
		||||
		r := CheckVersion(d.In, d.Min)
 | 
			
		||||
		if r != d.Expected {
 | 
			
		||||
			t.Errorf("CheckVersion(%q, %q) = %t, expected %t", d.In, d.Min, r, d.Expected)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestVersionInfo(t *testing.T) {
 | 
			
		||||
	info := VersionInfo()
 | 
			
		||||
 | 
			
		||||
	// Check that we get a valid Info struct
 | 
			
		||||
	if info.Version == "" {
 | 
			
		||||
		t.Error("VersionInfo().Version should not be empty")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Version should start with "v" or be "dev-snapshot"
 | 
			
		||||
	if !strings.HasPrefix(info.Version, "v") && info.Version != "dev-snapshot" {
 | 
			
		||||
		t.Errorf("Version should start with 'v' or be 'dev-snapshot', got: %s", info.Version)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// GitRevShort should be <= 7 characters if set
 | 
			
		||||
	if info.GitRevShort != "" && len(info.GitRevShort) > 7 {
 | 
			
		||||
		t.Errorf("GitRevShort should be <= 7 characters, got: %s", info.GitRevShort)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// GitRevShort should be prefix of GitRev if both are set
 | 
			
		||||
	if info.GitRev != "" && info.GitRevShort != "" {
 | 
			
		||||
		if !strings.HasPrefix(info.GitRev, info.GitRevShort) {
 | 
			
		||||
			t.Errorf("GitRevShort should be prefix of GitRev: %s not prefix of %s",
 | 
			
		||||
				info.GitRevShort, info.GitRev)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestVersion(t *testing.T) {
 | 
			
		||||
	version := Version()
 | 
			
		||||
 | 
			
		||||
	if version == "" {
 | 
			
		||||
		t.Error("Version() should not return empty string")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Should contain Go version
 | 
			
		||||
	if !strings.Contains(version, runtime.Version()) {
 | 
			
		||||
		t.Errorf("Version should contain Go version %s, got: %s", runtime.Version(), version)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Should contain the VERSION variable (or dev-snapshot)
 | 
			
		||||
	info := VersionInfo()
 | 
			
		||||
	if !strings.Contains(version, info.Version) {
 | 
			
		||||
		t.Errorf("Version should contain %s, got: %s", info.Version, version)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Should be in expected format: "version (extras)"
 | 
			
		||||
	if !strings.Contains(version, "(") || !strings.Contains(version, ")") {
 | 
			
		||||
		t.Errorf("Version should be in format 'version (extras)', got: %s", version)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestVersionCmd(t *testing.T) {
 | 
			
		||||
	appName := "testapp"
 | 
			
		||||
	cmd := VersionCmd(appName)
 | 
			
		||||
 | 
			
		||||
	// Test basic command properties
 | 
			
		||||
	if cmd.Use != "version" {
 | 
			
		||||
		t.Errorf("Expected command use to be 'version', got: %s", cmd.Use)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if cmd.Short == "" {
 | 
			
		||||
		t.Error("Command should have a short description")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if cmd.Long == "" {
 | 
			
		||||
		t.Error("Command should have a long description")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if cmd.Run == nil {
 | 
			
		||||
		t.Error("Command should have a Run function")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test that the command can be executed without error
 | 
			
		||||
	cmd.SetArgs([]string{})
 | 
			
		||||
	err := cmd.Execute()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("VersionCmd execution should not return error, got: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestKongVersionCmd(t *testing.T) {
 | 
			
		||||
	cmd := &KongVersionCmd{Name: "testapp"}
 | 
			
		||||
 | 
			
		||||
	// Test that Run() doesn't return an error
 | 
			
		||||
	err := cmd.Run()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("KongVersionCmd.Run() should not return error, got: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRegisterMetric(t *testing.T) {
 | 
			
		||||
	// Create a test registry
 | 
			
		||||
	registry := prometheus.NewRegistry()
 | 
			
		||||
 | 
			
		||||
	// Test registering metric without name
 | 
			
		||||
	RegisterMetric("", registry)
 | 
			
		||||
 | 
			
		||||
	// Gather metrics
 | 
			
		||||
	metricFamilies, err := registry.Gather()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Failed to gather metrics: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Find the build_info metric
 | 
			
		||||
	var buildInfoFamily *dto.MetricFamily
 | 
			
		||||
	for _, family := range metricFamilies {
 | 
			
		||||
		if family.GetName() == "build_info" {
 | 
			
		||||
			buildInfoFamily = family
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if buildInfoFamily == nil {
 | 
			
		||||
		t.Fatal("build_info metric not found")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if buildInfoFamily.GetHelp() == "" {
 | 
			
		||||
		t.Error("build_info metric should have help text")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	metrics := buildInfoFamily.GetMetric()
 | 
			
		||||
	if len(metrics) == 0 {
 | 
			
		||||
		t.Fatal("build_info metric should have at least one sample")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check that the metric has the expected labels
 | 
			
		||||
	metric := metrics[0]
 | 
			
		||||
	labels := metric.GetLabel()
 | 
			
		||||
 | 
			
		||||
	expectedLabels := []string{"version", "buildtime", "gittime", "git"}
 | 
			
		||||
	labelMap := make(map[string]string)
 | 
			
		||||
 | 
			
		||||
	for _, label := range labels {
 | 
			
		||||
		labelMap[label.GetName()] = label.GetValue()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, expectedLabel := range expectedLabels {
 | 
			
		||||
		if _, exists := labelMap[expectedLabel]; !exists {
 | 
			
		||||
			t.Errorf("Expected label %s not found in metric", expectedLabel)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check that the metric value is 1
 | 
			
		||||
	if metric.GetGauge().GetValue() != 1 {
 | 
			
		||||
		t.Errorf("Expected build_info metric value to be 1, got %f", metric.GetGauge().GetValue())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRegisterMetricWithName(t *testing.T) {
 | 
			
		||||
	// Create a test registry
 | 
			
		||||
	registry := prometheus.NewRegistry()
 | 
			
		||||
 | 
			
		||||
	// Test registering metric with custom name
 | 
			
		||||
	appName := "my-test-app"
 | 
			
		||||
	RegisterMetric(appName, registry)
 | 
			
		||||
 | 
			
		||||
	// Gather metrics
 | 
			
		||||
	metricFamilies, err := registry.Gather()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Failed to gather metrics: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Find the my_test_app_build_info metric
 | 
			
		||||
	expectedName := "my_test_app_build_info"
 | 
			
		||||
	var buildInfoFamily *dto.MetricFamily
 | 
			
		||||
	for _, family := range metricFamilies {
 | 
			
		||||
		if family.GetName() == expectedName {
 | 
			
		||||
			buildInfoFamily = family
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if buildInfoFamily == nil {
 | 
			
		||||
		t.Fatalf("%s metric not found", expectedName)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestVersionConsistency(t *testing.T) {
 | 
			
		||||
	// Call Version() multiple times and ensure it returns the same result
 | 
			
		||||
	v1 := Version()
 | 
			
		||||
	v2 := Version()
 | 
			
		||||
 | 
			
		||||
	if v1 != v2 {
 | 
			
		||||
		t.Errorf("Version() should return consistent results: %s != %s", v1, v2)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestVersionInfoConsistency(t *testing.T) {
 | 
			
		||||
	// Ensure VersionInfo() is consistent with Version()
 | 
			
		||||
	info := VersionInfo()
 | 
			
		||||
	version := Version()
 | 
			
		||||
 | 
			
		||||
	// Version string should contain the semantic version
 | 
			
		||||
	if !strings.Contains(version, info.Version) {
 | 
			
		||||
		t.Errorf("Version() should contain VersionInfo().Version: %s not in %s",
 | 
			
		||||
			info.Version, version)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// If GitRevShort is set, version should contain it
 | 
			
		||||
	if info.GitRevShort != "" {
 | 
			
		||||
		if !strings.Contains(version, info.GitRevShort) {
 | 
			
		||||
			t.Errorf("Version() should contain GitRevShort: %s not in %s",
 | 
			
		||||
				info.GitRevShort, version)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Test edge cases
 | 
			
		||||
func TestCheckVersionEdgeCases(t *testing.T) {
 | 
			
		||||
	// Test with empty strings
 | 
			
		||||
	if CheckVersion("", "v1.0.0") {
 | 
			
		||||
		t.Error("Empty version should not be >= v1.0.0")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test with malformed versions (should be handled gracefully)
 | 
			
		||||
	// Note: semver.Compare might panic or return unexpected results for invalid versions
 | 
			
		||||
	// but our function should handle the common cases
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		version string
 | 
			
		||||
		minimum string
 | 
			
		||||
		desc    string
 | 
			
		||||
	}{
 | 
			
		||||
		{"v1.0.0/", "v1.0.0", "version with trailing slash"},
 | 
			
		||||
		{"v1.0.0/abc/def", "v1.0.0", "version with multiple slashes"},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		// This should not panic
 | 
			
		||||
		result := CheckVersion(test.version, test.minimum)
 | 
			
		||||
		t.Logf("%s: CheckVersion(%q, %q) = %t", test.desc, test.version, test.minimum, result)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Benchmark version operations
 | 
			
		||||
func BenchmarkVersion(b *testing.B) {
 | 
			
		||||
	// Reset the cached version to test actual computation
 | 
			
		||||
	v = ""
 | 
			
		||||
 | 
			
		||||
	b.ResetTimer()
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		_ = Version()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkVersionInfo(b *testing.B) {
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		_ = VersionInfo()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkCheckVersion(b *testing.B) {
 | 
			
		||||
	version := "v1.2.3/abc123"
 | 
			
		||||
	minimum := "v1.2.0"
 | 
			
		||||
 | 
			
		||||
	b.ResetTimer()
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		_ = CheckVersion(version, minimum)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkCheckVersionDevSnapshot(b *testing.B) {
 | 
			
		||||
	version := "dev-snapshot"
 | 
			
		||||
	minimum := "v1.2.0"
 | 
			
		||||
 | 
			
		||||
	b.ResetTimer()
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		_ = CheckVersion(version, minimum)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestParseBuildTime(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name     string
 | 
			
		||||
		input    string
 | 
			
		||||
		expected string
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:     "Unix epoch timestamp",
 | 
			
		||||
			input:    "1672531200", // 2023-01-01T00:00:00Z
 | 
			
		||||
			expected: "2023-01-01T00:00:00Z",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "Unix epoch zero",
 | 
			
		||||
			input:    "0",
 | 
			
		||||
			expected: "1970-01-01T00:00:00Z",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "Valid RFC3339 format",
 | 
			
		||||
			input:    "2023-12-25T15:30:45Z",
 | 
			
		||||
			expected: "2023-12-25T15:30:45Z",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "RFC3339 with timezone",
 | 
			
		||||
			input:    "2023-12-25T10:30:45-05:00",
 | 
			
		||||
			expected: "2023-12-25T10:30:45-05:00",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "Empty string",
 | 
			
		||||
			input:    "",
 | 
			
		||||
			expected: "",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "Invalid format - return unchanged",
 | 
			
		||||
			input:    "not-a-date",
 | 
			
		||||
			expected: "not-a-date",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "Invalid timestamp - return unchanged",
 | 
			
		||||
			input:    "invalid-timestamp",
 | 
			
		||||
			expected: "invalid-timestamp",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "Partial date - return unchanged",
 | 
			
		||||
			input:    "2023-01-01",
 | 
			
		||||
			expected: "2023-01-01",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "Negative epoch - should work",
 | 
			
		||||
			input:    "-1",
 | 
			
		||||
			expected: "1969-12-31T23:59:59Z",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "Large epoch timestamp",
 | 
			
		||||
			input:    "4102444800", // 2100-01-01T00:00:00Z
 | 
			
		||||
			expected: "2100-01-01T00:00:00Z",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			result := parseBuildTime(tt.input)
 | 
			
		||||
			if result != tt.expected {
 | 
			
		||||
				t.Errorf("parseBuildTime(%q) = %q, expected %q", tt.input, result, tt.expected)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestParseBuildTimeConsistency(t *testing.T) {
 | 
			
		||||
	// Test that calling parseBuildTime multiple times with the same input returns the same result
 | 
			
		||||
	testInputs := []string{
 | 
			
		||||
		"1672531200",
 | 
			
		||||
		"2023-01-01T00:00:00Z",
 | 
			
		||||
		"invalid-date",
 | 
			
		||||
		"",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, input := range testInputs {
 | 
			
		||||
		result1 := parseBuildTime(input)
 | 
			
		||||
		result2 := parseBuildTime(input)
 | 
			
		||||
		if result1 != result2 {
 | 
			
		||||
			t.Errorf("parseBuildTime(%q) not consistent: %q != %q", input, result1, result2)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkParseBuildTime(b *testing.B) {
 | 
			
		||||
	inputs := []string{
 | 
			
		||||
		"1672531200",                 // Unix epoch
 | 
			
		||||
		"2023-01-01T00:00:00Z",      // RFC3339
 | 
			
		||||
		"invalid-timestamp",          // Invalid
 | 
			
		||||
		"",                          // Empty
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, input := range inputs {
 | 
			
		||||
		b.Run(input, func(b *testing.B) {
 | 
			
		||||
			for i := 0; i < b.N; i++ {
 | 
			
		||||
				_ = parseBuildTime(input)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -1,23 +1,116 @@
 | 
			
		||||
// Package fastlyxff provides Fastly CDN IP range management for trusted proxy handling.
 | 
			
		||||
//
 | 
			
		||||
// This package parses Fastly's public IP ranges JSON file and provides middleware
 | 
			
		||||
// for both Echo framework and standard net/http for proper client IP extraction
 | 
			
		||||
// from X-Forwarded-For headers. It's designed specifically for services deployed
 | 
			
		||||
// behind Fastly's CDN that need to identify real client IPs for logging, rate
 | 
			
		||||
// limiting, and security purposes.
 | 
			
		||||
//
 | 
			
		||||
// Fastly publishes their edge server IP ranges in a JSON format that this package
 | 
			
		||||
// consumes to automatically configure trusted proxy ranges. This ensures that
 | 
			
		||||
// X-Forwarded-For headers are only trusted when they originate from legitimate
 | 
			
		||||
// Fastly edge servers.
 | 
			
		||||
//
 | 
			
		||||
// Key features:
 | 
			
		||||
//   - Automatic parsing of Fastly's IP ranges JSON format
 | 
			
		||||
//   - Support for both IPv4 and IPv6 address ranges
 | 
			
		||||
//   - Echo framework integration via TrustOption generation
 | 
			
		||||
//   - Standard net/http middleware support
 | 
			
		||||
//   - CIDR notation parsing and validation
 | 
			
		||||
//
 | 
			
		||||
// # Echo Framework Usage
 | 
			
		||||
//
 | 
			
		||||
//	fastlyRanges, err := fastlyxff.New("fastly.json")
 | 
			
		||||
//	if err != nil {
 | 
			
		||||
//		return err
 | 
			
		||||
//	}
 | 
			
		||||
//	options, err := fastlyRanges.EchoTrustOption()
 | 
			
		||||
//	if err != nil {
 | 
			
		||||
//		return err
 | 
			
		||||
//	}
 | 
			
		||||
//	e.IPExtractor = echo.ExtractIPFromXFFHeader(options...)
 | 
			
		||||
//
 | 
			
		||||
// # Net/HTTP Usage
 | 
			
		||||
//
 | 
			
		||||
//	fastlyRanges, err := fastlyxff.New("fastly.json")
 | 
			
		||||
//	if err != nil {
 | 
			
		||||
//		return err
 | 
			
		||||
//	}
 | 
			
		||||
//	middleware := fastlyRanges.HTTPMiddleware()
 | 
			
		||||
//
 | 
			
		||||
//	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
//		// Both methods work - middleware updates r.RemoteAddr (with port 0) and stores in context
 | 
			
		||||
//		realIP := fastlyxff.GetRealIP(r)  // Preferred method
 | 
			
		||||
//		// OR: host, _, _ := net.SplitHostPort(r.RemoteAddr)  // Direct access (port will be "0")
 | 
			
		||||
//		fmt.Fprintf(w, "Real IP: %s\n", realIP)
 | 
			
		||||
//	})
 | 
			
		||||
//
 | 
			
		||||
//	http.ListenAndServe(":8080", middleware(handler))
 | 
			
		||||
//
 | 
			
		||||
// # Net/HTTP with Additional Trusted Ranges
 | 
			
		||||
//
 | 
			
		||||
//	fastlyRanges, err := fastlyxff.New("fastly.json")
 | 
			
		||||
//	if err != nil {
 | 
			
		||||
//		return err
 | 
			
		||||
//	}
 | 
			
		||||
//
 | 
			
		||||
//	// Add custom trusted CIDRs (e.g., internal load balancers)
 | 
			
		||||
//	// Note: For Echo framework, use the ekko package for additional ranges
 | 
			
		||||
//	err = fastlyRanges.AddTrustedCIDR("10.0.0.0/8")
 | 
			
		||||
//	if err != nil {
 | 
			
		||||
//		return err
 | 
			
		||||
//	}
 | 
			
		||||
//
 | 
			
		||||
//	middleware := fastlyRanges.HTTPMiddleware()
 | 
			
		||||
//	handler := middleware(yourHandler)
 | 
			
		||||
//
 | 
			
		||||
// The JSON file typically contains IP ranges in this format:
 | 
			
		||||
//
 | 
			
		||||
//	{
 | 
			
		||||
//	  "addresses": ["23.235.32.0/20", "43.249.72.0/22", ...],
 | 
			
		||||
//	  "ipv6_addresses": ["2a04:4e40::/32", "2a04:4e42::/32", ...]
 | 
			
		||||
//	}
 | 
			
		||||
package fastlyxff
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/labstack/echo/v4"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// FastlyXFF represents Fastly's published IP ranges for their CDN edge servers.
 | 
			
		||||
// This structure matches the JSON format provided by Fastly for their public IP ranges.
 | 
			
		||||
// It contains separate lists for IPv4 and IPv6 CIDR ranges, plus additional trusted CIDRs.
 | 
			
		||||
type FastlyXFF struct {
 | 
			
		||||
	IPv4 []string `json:"addresses"`
 | 
			
		||||
	IPv6 []string `json:"ipv6_addresses"`
 | 
			
		||||
	IPv4       []string `json:"addresses"`      // IPv4 CIDR ranges (e.g., "23.235.32.0/20")
 | 
			
		||||
	IPv6       []string `json:"ipv6_addresses"` // IPv6 CIDR ranges (e.g., "2a04:4e40::/32")
 | 
			
		||||
	extraCIDRs []string // Additional trusted CIDRs added via AddTrustedCIDR
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TrustedNets holds parsed network prefixes for efficient IP range checking.
 | 
			
		||||
type TrustedNets struct {
 | 
			
		||||
	prefixes []netip.Prefix
 | 
			
		||||
	prefixes []netip.Prefix // Parsed network prefixes for efficient lookups
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// contextKey is used for storing the real client IP in request context
 | 
			
		||||
type contextKey string
 | 
			
		||||
 | 
			
		||||
const realIPKey contextKey = "fastly-real-ip"
 | 
			
		||||
 | 
			
		||||
// New loads and parses Fastly IP ranges from a JSON file.
 | 
			
		||||
// The file should contain Fastly's published IP ranges in their standard JSON format.
 | 
			
		||||
//
 | 
			
		||||
// Parameters:
 | 
			
		||||
//   - fileName: Path to the Fastly IP ranges JSON file
 | 
			
		||||
//
 | 
			
		||||
// Returns the parsed FastlyXFF structure or an error if the file cannot be
 | 
			
		||||
// read or the JSON format is invalid.
 | 
			
		||||
func New(fileName string) (*FastlyXFF, error) {
 | 
			
		||||
	b, err := os.ReadFile(fileName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -34,6 +127,19 @@ func New(fileName string) (*FastlyXFF, error) {
 | 
			
		||||
	return &d, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// EchoTrustOption converts Fastly IP ranges into Echo framework trust options.
 | 
			
		||||
// This method generates trust configurations that tell Echo to accept X-Forwarded-For
 | 
			
		||||
// headers only from Fastly's edge servers, ensuring accurate client IP extraction.
 | 
			
		||||
//
 | 
			
		||||
// The generated trust options should be used with Echo's IP extractor:
 | 
			
		||||
//
 | 
			
		||||
//	options, err := fastlyRanges.EchoTrustOption()
 | 
			
		||||
//	if err != nil {
 | 
			
		||||
//		return err
 | 
			
		||||
//	}
 | 
			
		||||
//	e.IPExtractor = echo.ExtractIPFromXFFHeader(options...)
 | 
			
		||||
//
 | 
			
		||||
// Returns a slice of Echo trust options or an error if any CIDR range cannot be parsed.
 | 
			
		||||
func (xff *FastlyXFF) EchoTrustOption() ([]echo.TrustOption, error) {
 | 
			
		||||
	ranges := []echo.TrustOption{}
 | 
			
		||||
 | 
			
		||||
@@ -49,3 +155,116 @@ func (xff *FastlyXFF) EchoTrustOption() ([]echo.TrustOption, error) {
 | 
			
		||||
 | 
			
		||||
	return ranges, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AddTrustedCIDR adds an additional CIDR to the list of trusted proxies.
 | 
			
		||||
// This allows trusting proxies beyond Fastly's published ranges.
 | 
			
		||||
// The cidr parameter must be a valid CIDR notation (e.g., "10.0.0.0/8", "192.168.1.0/24").
 | 
			
		||||
// Returns an error if the CIDR format is invalid.
 | 
			
		||||
func (xff *FastlyXFF) AddTrustedCIDR(cidr string) error {
 | 
			
		||||
	// Validate CIDR format
 | 
			
		||||
	_, _, err := net.ParseCIDR(cidr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Add to extra CIDRs
 | 
			
		||||
	xff.extraCIDRs = append(xff.extraCIDRs, cidr)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// isTrustedProxy checks if the given IP address belongs to Fastly's trusted IP ranges
 | 
			
		||||
// or any additional CIDRs added via AddTrustedCIDR.
 | 
			
		||||
func (xff *FastlyXFF) isTrustedProxy(ip string) bool {
 | 
			
		||||
	addr, err := netip.ParseAddr(ip)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check all IPv4 and IPv6 ranges (Fastly + additional)
 | 
			
		||||
	allRanges := append(append(xff.IPv4, xff.IPv6...), xff.extraCIDRs...)
 | 
			
		||||
	for _, s := range allRanges {
 | 
			
		||||
		_, cidr, err := net.ParseCIDR(s)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if cidr.Contains(net.IP(addr.AsSlice())) {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// extractRealIP extracts the real client IP from X-Forwarded-For header.
 | 
			
		||||
// It returns the rightmost IP that is not from a trusted Fastly proxy.
 | 
			
		||||
func (xff *FastlyXFF) extractRealIP(r *http.Request) string {
 | 
			
		||||
	// Get the immediate peer IP
 | 
			
		||||
	host, _, err := net.SplitHostPort(r.RemoteAddr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		host = r.RemoteAddr
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// If the immediate peer is not a trusted Fastly proxy, return it
 | 
			
		||||
	if !xff.isTrustedProxy(host) {
 | 
			
		||||
		return host
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check X-Forwarded-For header
 | 
			
		||||
	xff_header := r.Header.Get("X-Forwarded-For")
 | 
			
		||||
	if xff_header == "" {
 | 
			
		||||
		return host
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Parse comma-separated IP list
 | 
			
		||||
	ips := strings.Split(xff_header, ",")
 | 
			
		||||
	if len(ips) == 0 {
 | 
			
		||||
		return host
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Find the leftmost IP that is not from a trusted proxy
 | 
			
		||||
	// This represents the original client IP
 | 
			
		||||
	for i := 0; i < len(ips); i++ {
 | 
			
		||||
		ip := strings.TrimSpace(ips[i])
 | 
			
		||||
		if ip != "" && !xff.isTrustedProxy(ip) {
 | 
			
		||||
			return ip
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Fallback to the immediate peer
 | 
			
		||||
	return host
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// HTTPMiddleware returns a net/http middleware that extracts real client IP
 | 
			
		||||
// from X-Forwarded-For headers when the request comes from trusted Fastly proxies.
 | 
			
		||||
// The real IP is stored in the request context and also updates r.RemoteAddr
 | 
			
		||||
// with port 0 (since the original port is from the proxy, not the real client).
 | 
			
		||||
func (xff *FastlyXFF) HTTPMiddleware() func(http.Handler) http.Handler {
 | 
			
		||||
	return func(next http.Handler) http.Handler {
 | 
			
		||||
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
			realIP := xff.extractRealIP(r)
 | 
			
		||||
 | 
			
		||||
			// Store in context for GetRealIP function
 | 
			
		||||
			ctx := context.WithValue(r.Context(), realIPKey, realIP)
 | 
			
		||||
			r = r.WithContext(ctx)
 | 
			
		||||
 | 
			
		||||
			// Update RemoteAddr to be consistent with extracted IP
 | 
			
		||||
			// Use port 0 since the original port is from the proxy, not the real client
 | 
			
		||||
			r.RemoteAddr = net.JoinHostPort(realIP, "0")
 | 
			
		||||
 | 
			
		||||
			next.ServeHTTP(w, r)
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetRealIP retrieves the real client IP from the request context.
 | 
			
		||||
// This should be used after the HTTPMiddleware has processed the request.
 | 
			
		||||
// Returns the remote address if no real IP was extracted.
 | 
			
		||||
func GetRealIP(r *http.Request) string {
 | 
			
		||||
	if ip, ok := r.Context().Value(realIPKey).(string); ok {
 | 
			
		||||
		return ip
 | 
			
		||||
	}
 | 
			
		||||
	host, _, err := net.SplitHostPort(r.RemoteAddr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return r.RemoteAddr
 | 
			
		||||
	}
 | 
			
		||||
	return host
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,11 @@
 | 
			
		||||
package fastlyxff
 | 
			
		||||
 | 
			
		||||
import "testing"
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/http/httptest"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestFastlyIPRanges(t *testing.T) {
 | 
			
		||||
	fastlyxff, err := New("fastly.json")
 | 
			
		||||
@@ -18,3 +23,334 @@ func TestFastlyIPRanges(t *testing.T) {
 | 
			
		||||
		t.Fail()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestHTTPMiddleware(t *testing.T) {
 | 
			
		||||
	// Create a test FastlyXFF instance with known IP ranges
 | 
			
		||||
	xff := &FastlyXFF{
 | 
			
		||||
		IPv4: []string{"192.0.2.0/24", "203.0.113.0/24"},
 | 
			
		||||
		IPv6: []string{"2001:db8::/32"},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	middleware := xff.HTTPMiddleware()
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name           string
 | 
			
		||||
		remoteAddr     string
 | 
			
		||||
		xForwardedFor  string
 | 
			
		||||
		expectedRealIP string
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:           "direct connection",
 | 
			
		||||
			remoteAddr:     "198.51.100.1:12345",
 | 
			
		||||
			xForwardedFor:  "",
 | 
			
		||||
			expectedRealIP: "198.51.100.1",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:           "trusted proxy with XFF",
 | 
			
		||||
			remoteAddr:     "192.0.2.1:80",
 | 
			
		||||
			xForwardedFor:  "198.51.100.1",
 | 
			
		||||
			expectedRealIP: "198.51.100.1",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:           "trusted proxy with multiple XFF",
 | 
			
		||||
			remoteAddr:     "192.0.2.1:80",
 | 
			
		||||
			xForwardedFor:  "198.51.100.1, 203.0.113.1",
 | 
			
		||||
			expectedRealIP: "198.51.100.1",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:           "untrusted proxy ignored",
 | 
			
		||||
			remoteAddr:     "198.51.100.2:80",
 | 
			
		||||
			xForwardedFor:  "10.0.0.1",
 | 
			
		||||
			expectedRealIP: "198.51.100.2",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:           "IPv6 trusted proxy",
 | 
			
		||||
			remoteAddr:     "[2001:db8::1]:80",
 | 
			
		||||
			xForwardedFor:  "198.51.100.1",
 | 
			
		||||
			expectedRealIP: "198.51.100.1",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			// Create test handler that captures both GetRealIP and r.RemoteAddr
 | 
			
		||||
			var capturedRealIP, capturedRemoteAddr string
 | 
			
		||||
			handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
				capturedRealIP = GetRealIP(r)
 | 
			
		||||
				capturedRemoteAddr = r.RemoteAddr
 | 
			
		||||
				w.WriteHeader(http.StatusOK)
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			// Create request with middleware
 | 
			
		||||
			req := httptest.NewRequest("GET", "/", nil)
 | 
			
		||||
			req.RemoteAddr = tt.remoteAddr
 | 
			
		||||
			if tt.xForwardedFor != "" {
 | 
			
		||||
				req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			rr := httptest.NewRecorder()
 | 
			
		||||
			middleware(handler).ServeHTTP(rr, req)
 | 
			
		||||
 | 
			
		||||
			// Test GetRealIP function
 | 
			
		||||
			if capturedRealIP != tt.expectedRealIP {
 | 
			
		||||
				t.Errorf("GetRealIP: expected %s, got %s", tt.expectedRealIP, capturedRealIP)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Test that r.RemoteAddr is updated with real IP and port 0
 | 
			
		||||
			// (since the original port is from the proxy, not the real client)
 | 
			
		||||
			expectedRemoteAddr := net.JoinHostPort(tt.expectedRealIP, "0")
 | 
			
		||||
			if capturedRemoteAddr != expectedRemoteAddr {
 | 
			
		||||
				t.Errorf("RemoteAddr: expected %s, got %s", expectedRemoteAddr, capturedRemoteAddr)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIsTrustedProxy(t *testing.T) {
 | 
			
		||||
	xff := &FastlyXFF{
 | 
			
		||||
		IPv4: []string{"192.0.2.0/24", "203.0.113.0/24"},
 | 
			
		||||
		IPv6: []string{"2001:db8::/32"},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		ip       string
 | 
			
		||||
		expected bool
 | 
			
		||||
	}{
 | 
			
		||||
		{"192.0.2.1", true},
 | 
			
		||||
		{"192.0.2.255", true},
 | 
			
		||||
		{"203.0.113.1", true},
 | 
			
		||||
		{"192.0.3.1", false},
 | 
			
		||||
		{"198.51.100.1", false},
 | 
			
		||||
		{"2001:db8::1", true},
 | 
			
		||||
		{"2001:db8:ffff::1", true},
 | 
			
		||||
		{"2001:db9::1", false},
 | 
			
		||||
		{"invalid-ip", false},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.ip, func(t *testing.T) {
 | 
			
		||||
			result := xff.isTrustedProxy(tt.ip)
 | 
			
		||||
			if result != tt.expected {
 | 
			
		||||
				t.Errorf("isTrustedProxy(%s) = %v, expected %v", tt.ip, result, tt.expected)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestExtractRealIP(t *testing.T) {
 | 
			
		||||
	xff := &FastlyXFF{
 | 
			
		||||
		IPv4: []string{"192.0.2.0/24"},
 | 
			
		||||
		IPv6: []string{"2001:db8::/32"},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name          string
 | 
			
		||||
		remoteAddr    string
 | 
			
		||||
		xForwardedFor string
 | 
			
		||||
		expected      string
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:          "no XFF header",
 | 
			
		||||
			remoteAddr:    "198.51.100.1:12345",
 | 
			
		||||
			xForwardedFor: "",
 | 
			
		||||
			expected:      "198.51.100.1",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "trusted proxy with single IP",
 | 
			
		||||
			remoteAddr:    "192.0.2.1:80",
 | 
			
		||||
			xForwardedFor: "198.51.100.1",
 | 
			
		||||
			expected:      "198.51.100.1",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "trusted proxy with multiple IPs",
 | 
			
		||||
			remoteAddr:    "192.0.2.1:80",
 | 
			
		||||
			xForwardedFor: "198.51.100.1, 203.0.113.5",
 | 
			
		||||
			expected:      "198.51.100.1",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "untrusted proxy",
 | 
			
		||||
			remoteAddr:    "198.51.100.1:80",
 | 
			
		||||
			xForwardedFor: "10.0.0.1",
 | 
			
		||||
			expected:      "198.51.100.1",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "empty XFF",
 | 
			
		||||
			remoteAddr:    "192.0.2.1:80",
 | 
			
		||||
			xForwardedFor: "",
 | 
			
		||||
			expected:      "192.0.2.1",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "malformed remote addr",
 | 
			
		||||
			remoteAddr:    "192.0.2.1",
 | 
			
		||||
			xForwardedFor: "198.51.100.1",
 | 
			
		||||
			expected:      "198.51.100.1",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			req := httptest.NewRequest("GET", "/", nil)
 | 
			
		||||
			req.RemoteAddr = tt.remoteAddr
 | 
			
		||||
			if tt.xForwardedFor != "" {
 | 
			
		||||
				req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			result := xff.extractRealIP(req)
 | 
			
		||||
			if result != tt.expected {
 | 
			
		||||
				t.Errorf("extractRealIP() = %s, expected %s", result, tt.expected)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGetRealIPWithoutMiddleware(t *testing.T) {
 | 
			
		||||
	req := httptest.NewRequest("GET", "/", nil)
 | 
			
		||||
	req.RemoteAddr = "198.51.100.1:12345"
 | 
			
		||||
 | 
			
		||||
	realIP := GetRealIP(req)
 | 
			
		||||
	expected := "198.51.100.1"
 | 
			
		||||
	if realIP != expected {
 | 
			
		||||
		t.Errorf("GetRealIP() = %s, expected %s", realIP, expected)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestAddTrustedCIDR(t *testing.T) {
 | 
			
		||||
	xff := &FastlyXFF{
 | 
			
		||||
		IPv4: []string{"192.0.2.0/24"},
 | 
			
		||||
		IPv6: []string{"2001:db8::/32"},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name    string
 | 
			
		||||
		cidr    string
 | 
			
		||||
		wantErr bool
 | 
			
		||||
	}{
 | 
			
		||||
		{"valid IPv4 range", "10.0.0.0/8", false},
 | 
			
		||||
		{"valid IPv6 range", "fc00::/7", false},
 | 
			
		||||
		{"valid single IP", "203.0.113.1/32", false},
 | 
			
		||||
		{"invalid CIDR", "not-a-cidr", true},
 | 
			
		||||
		{"invalid format", "10.0.0.0/99", true},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			err := xff.AddTrustedCIDR(tt.cidr)
 | 
			
		||||
			if (err != nil) != tt.wantErr {
 | 
			
		||||
				t.Errorf("AddTrustedCIDR(%s) error = %v, wantErr %v", tt.cidr, err, tt.wantErr)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCustomTrustedCIDRs(t *testing.T) {
 | 
			
		||||
	xff := &FastlyXFF{
 | 
			
		||||
		IPv4: []string{"192.0.2.0/24"},
 | 
			
		||||
		IPv6: []string{"2001:db8::/32"},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Add custom trusted CIDRs
 | 
			
		||||
	err := xff.AddTrustedCIDR("10.0.0.0/8")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Failed to add trusted CIDR: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = xff.AddTrustedCIDR("172.16.0.0/12")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Failed to add trusted CIDR: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		ip       string
 | 
			
		||||
		expected bool
 | 
			
		||||
	}{
 | 
			
		||||
		// Original Fastly ranges
 | 
			
		||||
		{"192.0.2.1", true},
 | 
			
		||||
		{"2001:db8::1", true},
 | 
			
		||||
		// Custom CIDRs
 | 
			
		||||
		{"10.1.2.3", true},
 | 
			
		||||
		{"172.16.1.1", true},
 | 
			
		||||
		// Not trusted
 | 
			
		||||
		{"198.51.100.1", false},
 | 
			
		||||
		{"172.15.1.1", false},
 | 
			
		||||
		{"10.0.0.0", true}, // Network address should still match
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.ip, func(t *testing.T) {
 | 
			
		||||
			result := xff.isTrustedProxy(tt.ip)
 | 
			
		||||
			if result != tt.expected {
 | 
			
		||||
				t.Errorf("isTrustedProxy(%s) = %v, expected %v", tt.ip, result, tt.expected)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestHTTPMiddlewareWithCustomCIDRs(t *testing.T) {
 | 
			
		||||
	xff := &FastlyXFF{
 | 
			
		||||
		IPv4: []string{"192.0.2.0/24"},
 | 
			
		||||
		IPv6: []string{"2001:db8::/32"},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Add custom trusted CIDR for internal proxies
 | 
			
		||||
	err := xff.AddTrustedCIDR("10.0.0.0/8")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Failed to add trusted CIDR: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	middleware := xff.HTTPMiddleware()
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name           string
 | 
			
		||||
		remoteAddr     string
 | 
			
		||||
		xForwardedFor  string
 | 
			
		||||
		expectedRealIP string
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:           "custom trusted proxy with XFF",
 | 
			
		||||
			remoteAddr:     "10.1.2.3:80",
 | 
			
		||||
			xForwardedFor:  "198.51.100.1",
 | 
			
		||||
			expectedRealIP: "198.51.100.1",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:           "fastly proxy with XFF",
 | 
			
		||||
			remoteAddr:     "192.0.2.1:80",
 | 
			
		||||
			xForwardedFor:  "198.51.100.1",
 | 
			
		||||
			expectedRealIP: "198.51.100.1",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:           "untrusted proxy ignored",
 | 
			
		||||
			remoteAddr:     "172.16.1.1:80",
 | 
			
		||||
			xForwardedFor:  "198.51.100.1",
 | 
			
		||||
			expectedRealIP: "172.16.1.1",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:           "chain through custom and fastly",
 | 
			
		||||
			remoteAddr:     "192.0.2.1:80",
 | 
			
		||||
			xForwardedFor:  "198.51.100.1, 10.1.2.3",
 | 
			
		||||
			expectedRealIP: "198.51.100.1",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			var capturedIP string
 | 
			
		||||
			handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
				capturedIP = GetRealIP(r)
 | 
			
		||||
				w.WriteHeader(http.StatusOK)
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			req := httptest.NewRequest("GET", "/", nil)
 | 
			
		||||
			req.RemoteAddr = tt.remoteAddr
 | 
			
		||||
			if tt.xForwardedFor != "" {
 | 
			
		||||
				req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			rr := httptest.NewRecorder()
 | 
			
		||||
			middleware(handler).ServeHTTP(rr, req)
 | 
			
		||||
 | 
			
		||||
			if capturedIP != tt.expectedRealIP {
 | 
			
		||||
				t.Errorf("expected real IP %s, got %s", tt.expectedRealIP, capturedIP)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user