common/database/transaction_test.go
Ask Bjørn Hansen a1a5a6b8be database: create shared database package
Extract common database functionality from api/ntpdb and monitor/ntpdb
into shared common/database package:

- Dynamic connector pattern with configuration loading
- Configurable connection pool management (API: 25/10, Monitor: 10/5)
- Optional Prometheus metrics integration
- Generic transaction helpers with proper error handling
- Unified interfaces compatible with SQLC-generated code

Foundation for migration to eliminate ~200 lines of duplicate code.
2025-07-12 17:59:28 -07:00

158 lines
3.7 KiB
Go

package database
import (
"context"
"errors"
"testing"
)
// Mock implementations for testing
type mockDB struct {
beginError error
txMock *mockTX
}
func (m *mockDB) Begin(ctx context.Context) (*mockTX, error) {
if m.beginError != nil {
return nil, m.beginError
}
return m.txMock, nil
}
type mockTX struct {
commitError error
rollbackError error
commitCalled bool
rollbackCalled bool
}
func (m *mockTX) Commit(ctx context.Context) error {
m.commitCalled = true
return m.commitError
}
func (m *mockTX) Rollback(ctx context.Context) error {
m.rollbackCalled = true
return m.rollbackError
}
func TestWithTransaction_Success(t *testing.T) {
tx := &mockTX{}
db := &mockDB{txMock: tx}
var functionCalled bool
err := WithTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error {
functionCalled = true
if q != tx {
t.Error("Expected transaction to be passed to function")
}
return nil
})
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if !functionCalled {
t.Error("Expected function to be called")
}
if !tx.commitCalled {
t.Error("Expected commit to be called")
}
if tx.rollbackCalled {
t.Error("Expected rollback NOT to be called on success")
}
}
func TestWithTransaction_FunctionError(t *testing.T) {
tx := &mockTX{}
db := &mockDB{txMock: tx}
expectedError := errors.New("function error")
err := WithTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error {
return expectedError
})
if err != expectedError {
t.Errorf("Expected error %v, got %v", expectedError, err)
}
if tx.commitCalled {
t.Error("Expected commit NOT to be called on function error")
}
if !tx.rollbackCalled {
t.Error("Expected rollback to be called on function error")
}
}
func TestWithTransaction_BeginError(t *testing.T) {
expectedError := errors.New("begin error")
db := &mockDB{beginError: expectedError}
err := WithTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error {
t.Error("Function should not be called when Begin fails")
return nil
})
if err == nil || !errors.Is(err, expectedError) {
t.Errorf("Expected wrapped begin error, got %v", err)
}
}
func TestWithTransaction_CommitError(t *testing.T) {
commitError := errors.New("commit error")
tx := &mockTX{commitError: commitError}
db := &mockDB{txMock: tx}
err := WithTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error {
return nil
})
if err == nil || !errors.Is(err, commitError) {
t.Errorf("Expected wrapped commit error, got %v", err)
}
if !tx.commitCalled {
t.Error("Expected commit to be called")
}
if tx.rollbackCalled {
t.Error("Expected rollback NOT to be called when commit fails")
}
}
func TestWithReadOnlyTransaction_Success(t *testing.T) {
tx := &mockTX{}
db := &mockDB{txMock: tx}
var functionCalled bool
err := WithReadOnlyTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error {
functionCalled = true
return nil
})
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if !functionCalled {
t.Error("Expected function to be called")
}
if tx.commitCalled {
t.Error("Expected commit NOT to be called in read-only transaction")
}
if !tx.rollbackCalled {
t.Error("Expected rollback to be called in read-only transaction")
}
}
func TestWithReadOnlyTransaction_FunctionError(t *testing.T) {
tx := &mockTX{}
db := &mockDB{txMock: tx}
expectedError := errors.New("function error")
err := WithReadOnlyTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error {
return expectedError
})
if err != expectedError {
t.Errorf("Expected error %v, got %v", expectedError, err)
}
if !tx.rollbackCalled {
t.Error("Expected rollback to be called")
}
}