From a1a5a6b8be397c5e590ecafcec55d470f493336a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ask=20Bj=C3=B8rn=20Hansen?= Date: Sat, 12 Jul 2025 17:59:28 -0700 Subject: [PATCH] database: create shared database package Extract common database functionality from api/ntpdb and monitor/ntpdb into shared common/database package: - Dynamic connector pattern with configuration loading - Configurable connection pool management (API: 25/10, Monitor: 10/5) - Optional Prometheus metrics integration - Generic transaction helpers with proper error handling - Unified interfaces compatible with SQLC-generated code Foundation for migration to eliminate ~200 lines of duplicate code. --- database/config_test.go | 81 ++++++++++++++++++ database/integration_test.go | 117 ++++++++++++++++++++++++++ database/transaction.go | 5 +- database/transaction_test.go | 157 +++++++++++++++++++++++++++++++++++ 4 files changed, 358 insertions(+), 2 deletions(-) create mode 100644 database/config_test.go create mode 100644 database/integration_test.go create mode 100644 database/transaction_test.go diff --git a/database/config_test.go b/database/config_test.go new file mode 100644 index 0000000..2eb9851 --- /dev/null +++ b/database/config_test.go @@ -0,0 +1,81 @@ +package database + +import ( + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" +) + +func TestDefaultConfigOptions(t *testing.T) { + opts := DefaultConfigOptions() + + // Verify expected defaults for API package + if opts.MaxOpenConns != 25 { + t.Errorf("Expected MaxOpenConns=25, got %d", opts.MaxOpenConns) + } + if opts.MaxIdleConns != 10 { + t.Errorf("Expected MaxIdleConns=10, got %d", opts.MaxIdleConns) + } + if opts.ConnMaxLifetime != 3*time.Minute { + t.Errorf("Expected ConnMaxLifetime=3m, got %v", opts.ConnMaxLifetime) + } + if !opts.EnablePoolMonitoring { + t.Error("Expected EnablePoolMonitoring=true") + } + if opts.PrometheusRegisterer != prometheus.DefaultRegisterer { + t.Error("Expected PrometheusRegisterer to be DefaultRegisterer") + } + if len(opts.ConfigFiles) == 0 { + t.Error("Expected ConfigFiles to be non-empty") + } +} + +func TestMonitorConfigOptions(t *testing.T) { + opts := MonitorConfigOptions() + + // Verify expected defaults for Monitor package + if opts.MaxOpenConns != 10 { + t.Errorf("Expected MaxOpenConns=10, got %d", opts.MaxOpenConns) + } + if opts.MaxIdleConns != 5 { + t.Errorf("Expected MaxIdleConns=5, got %d", opts.MaxIdleConns) + } + if opts.ConnMaxLifetime != 3*time.Minute { + t.Errorf("Expected ConnMaxLifetime=3m, got %v", opts.ConnMaxLifetime) + } + if opts.EnablePoolMonitoring { + t.Error("Expected EnablePoolMonitoring=false") + } + if opts.PrometheusRegisterer != nil { + t.Error("Expected PrometheusRegisterer to be nil") + } + if len(opts.ConfigFiles) == 0 { + t.Error("Expected ConfigFiles to be non-empty") + } +} + +func TestConfigStructures(t *testing.T) { + // Test that configuration structures can be created and populated + config := Config{ + MySQL: DBConfig{ + DSN: "user:pass@tcp(localhost:3306)/dbname", + User: "testuser", + Pass: "testpass", + DBName: "testdb", + }, + } + + if config.MySQL.DSN == "" { + t.Error("Expected DSN to be set") + } + if config.MySQL.User != "testuser" { + t.Errorf("Expected User='testuser', got '%s'", config.MySQL.User) + } + if config.MySQL.Pass != "testpass" { + t.Errorf("Expected Pass='testpass', got '%s'", config.MySQL.Pass) + } + if config.MySQL.DBName != "testdb" { + t.Errorf("Expected DBName='testdb', got '%s'", config.MySQL.DBName) + } +} diff --git a/database/integration_test.go b/database/integration_test.go new file mode 100644 index 0000000..1c68228 --- /dev/null +++ b/database/integration_test.go @@ -0,0 +1,117 @@ +package database + +import ( + "context" + "database/sql" + "testing" +) + +// Mock types for testing SQLC integration patterns +type mockQueries struct { + db DBTX +} + +type mockQueriesTx struct { + *mockQueries + tx *sql.Tx +} + +// Mock the Begin method pattern that SQLC generates +func (q *mockQueries) Begin(ctx context.Context) (*mockQueriesTx, error) { + // This would normally be: tx, err := q.db.(*sql.DB).BeginTx(ctx, nil) + // For our test, we return a mock + return &mockQueriesTx{mockQueries: q, tx: nil}, nil +} + +func (qtx *mockQueriesTx) Commit(ctx context.Context) error { + return nil // Mock implementation +} + +func (qtx *mockQueriesTx) Rollback(ctx context.Context) error { + return nil // Mock implementation +} + +// This test verifies that our common database interfaces are compatible with SQLC-generated code +func TestSQLCIntegration(t *testing.T) { + // Test that SQLC's DBTX interface matches our DBTX interface + t.Run("DBTX Interface Compatibility", func(t *testing.T) { + // Test interface compatibility by assignment without execution + var ourDBTX DBTX + + // Test with sql.DB (should implement DBTX) + var db *sql.DB + ourDBTX = db // This will compile only if interfaces are compatible + _ = ourDBTX // Use the variable to avoid "unused" warning + + // Test with sql.Tx (should implement DBTX) + var tx *sql.Tx + ourDBTX = tx // This will compile only if interfaces are compatible + _ = ourDBTX // Use the variable to avoid "unused" warning + + // If we reach here, interfaces are compatible + t.Log("DBTX interface is compatible with sql.DB and sql.Tx") + }) + + t.Run("Transaction Interface Compatibility", func(t *testing.T) { + // This test verifies our transaction interfaces work with SQLC patterns + // We can't define methods inside a function, so we test interface compatibility + + // Verify our DB interface is compatible with what SQLC expects + var dbInterface DB[*mockQueriesTx] + var mockDB *mockQueries = &mockQueries{} + dbInterface = mockDB + + // Test that our transaction helper can work with this pattern + err := WithTransaction(context.Background(), dbInterface, func(ctx context.Context, qtx *mockQueriesTx) error { + // This would be where you'd call SQLC-generated query methods + return nil + }) + if err != nil { + t.Errorf("Transaction helper failed: %v", err) + } + }) +} + +// Test that demonstrates how the common package would be used with real SQLC patterns +func TestRealWorldUsagePattern(t *testing.T) { + // This test shows how a package would typically use our common database code + + t.Run("Database Opening Pattern", func(t *testing.T) { + // Test that our configuration options work as expected + opts := DefaultConfigOptions() + + // Modify for test environment (no actual database connection) + opts.ConfigFiles = []string{} // No config files for unit test + opts.PrometheusRegisterer = nil // No metrics for unit test + + // This would normally open a database: db, err := OpenDB(ctx, opts) + // For our unit test, we just verify the options are reasonable + if opts.MaxOpenConns <= 0 { + t.Error("MaxOpenConns should be positive") + } + if opts.MaxIdleConns <= 0 { + t.Error("MaxIdleConns should be positive") + } + if opts.ConnMaxLifetime <= 0 { + t.Error("ConnMaxLifetime should be positive") + } + }) + + t.Run("Monitor Package Configuration", func(t *testing.T) { + opts := MonitorConfigOptions() + + // Verify monitor-specific settings + if opts.EnablePoolMonitoring { + t.Error("Monitor package should not enable pool monitoring") + } + if opts.PrometheusRegisterer != nil { + t.Error("Monitor package should not have Prometheus registerer") + } + if opts.MaxOpenConns != 10 { + t.Errorf("Expected MaxOpenConns=10 for monitor, got %d", opts.MaxOpenConns) + } + if opts.MaxIdleConns != 5 { + t.Errorf("Expected MaxIdleConns=5 for monitor, got %d", opts.MaxIdleConns) + } + }) +} diff --git a/database/transaction.go b/database/transaction.go index 45743a1..10150e0 100644 --- a/database/transaction.go +++ b/database/transaction.go @@ -41,11 +41,12 @@ func WithTransaction[Q TX](ctx context.Context, db DB[Q], fn func(ctx context.Co return err } - if err := tx.Commit(ctx); err != nil { + err = tx.Commit(ctx) + committed = true // Mark as committed regardless of commit success/failure + if err != nil { return fmt.Errorf("failed to commit transaction: %w", err) } - committed = true return nil } diff --git a/database/transaction_test.go b/database/transaction_test.go new file mode 100644 index 0000000..0561f82 --- /dev/null +++ b/database/transaction_test.go @@ -0,0 +1,157 @@ +package database + +import ( + "context" + "errors" + "testing" +) + +// Mock implementations for testing +type mockDB struct { + beginError error + txMock *mockTX +} + +func (m *mockDB) Begin(ctx context.Context) (*mockTX, error) { + if m.beginError != nil { + return nil, m.beginError + } + return m.txMock, nil +} + +type mockTX struct { + commitError error + rollbackError error + commitCalled bool + rollbackCalled bool +} + +func (m *mockTX) Commit(ctx context.Context) error { + m.commitCalled = true + return m.commitError +} + +func (m *mockTX) Rollback(ctx context.Context) error { + m.rollbackCalled = true + return m.rollbackError +} + +func TestWithTransaction_Success(t *testing.T) { + tx := &mockTX{} + db := &mockDB{txMock: tx} + + var functionCalled bool + err := WithTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error { + functionCalled = true + if q != tx { + t.Error("Expected transaction to be passed to function") + } + return nil + }) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if !functionCalled { + t.Error("Expected function to be called") + } + if !tx.commitCalled { + t.Error("Expected commit to be called") + } + if tx.rollbackCalled { + t.Error("Expected rollback NOT to be called on success") + } +} + +func TestWithTransaction_FunctionError(t *testing.T) { + tx := &mockTX{} + db := &mockDB{txMock: tx} + + expectedError := errors.New("function error") + err := WithTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error { + return expectedError + }) + + if err != expectedError { + t.Errorf("Expected error %v, got %v", expectedError, err) + } + if tx.commitCalled { + t.Error("Expected commit NOT to be called on function error") + } + if !tx.rollbackCalled { + t.Error("Expected rollback to be called on function error") + } +} + +func TestWithTransaction_BeginError(t *testing.T) { + expectedError := errors.New("begin error") + db := &mockDB{beginError: expectedError} + + err := WithTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error { + t.Error("Function should not be called when Begin fails") + return nil + }) + + if err == nil || !errors.Is(err, expectedError) { + t.Errorf("Expected wrapped begin error, got %v", err) + } +} + +func TestWithTransaction_CommitError(t *testing.T) { + commitError := errors.New("commit error") + tx := &mockTX{commitError: commitError} + db := &mockDB{txMock: tx} + + err := WithTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error { + return nil + }) + + if err == nil || !errors.Is(err, commitError) { + t.Errorf("Expected wrapped commit error, got %v", err) + } + if !tx.commitCalled { + t.Error("Expected commit to be called") + } + if tx.rollbackCalled { + t.Error("Expected rollback NOT to be called when commit fails") + } +} + +func TestWithReadOnlyTransaction_Success(t *testing.T) { + tx := &mockTX{} + db := &mockDB{txMock: tx} + + var functionCalled bool + err := WithReadOnlyTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error { + functionCalled = true + return nil + }) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if !functionCalled { + t.Error("Expected function to be called") + } + if tx.commitCalled { + t.Error("Expected commit NOT to be called in read-only transaction") + } + if !tx.rollbackCalled { + t.Error("Expected rollback to be called in read-only transaction") + } +} + +func TestWithReadOnlyTransaction_FunctionError(t *testing.T) { + tx := &mockTX{} + db := &mockDB{txMock: tx} + + expectedError := errors.New("function error") + err := WithReadOnlyTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error { + return expectedError + }) + + if err != expectedError { + t.Errorf("Expected error %v, got %v", expectedError, err) + } + if !tx.rollbackCalled { + t.Error("Expected rollback to be called") + } +}