package database import ( "context" "errors" "testing" ) // Mock implementations for testing type mockDB struct { beginError error txMock *mockTX } func (m *mockDB) Begin(ctx context.Context) (*mockTX, error) { if m.beginError != nil { return nil, m.beginError } return m.txMock, nil } type mockTX struct { commitError error rollbackError error commitCalled bool rollbackCalled bool } func (m *mockTX) Commit(ctx context.Context) error { m.commitCalled = true return m.commitError } func (m *mockTX) Rollback(ctx context.Context) error { m.rollbackCalled = true return m.rollbackError } func TestWithTransaction_Success(t *testing.T) { tx := &mockTX{} db := &mockDB{txMock: tx} var functionCalled bool err := WithTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error { functionCalled = true if q != tx { t.Error("Expected transaction to be passed to function") } return nil }) if err != nil { t.Errorf("Expected no error, got %v", err) } if !functionCalled { t.Error("Expected function to be called") } if !tx.commitCalled { t.Error("Expected commit to be called") } if tx.rollbackCalled { t.Error("Expected rollback NOT to be called on success") } } func TestWithTransaction_FunctionError(t *testing.T) { tx := &mockTX{} db := &mockDB{txMock: tx} expectedError := errors.New("function error") err := WithTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error { return expectedError }) if err != expectedError { t.Errorf("Expected error %v, got %v", expectedError, err) } if tx.commitCalled { t.Error("Expected commit NOT to be called on function error") } if !tx.rollbackCalled { t.Error("Expected rollback to be called on function error") } } func TestWithTransaction_BeginError(t *testing.T) { expectedError := errors.New("begin error") db := &mockDB{beginError: expectedError} err := WithTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error { t.Error("Function should not be called when Begin fails") return nil }) if err == nil || !errors.Is(err, expectedError) { t.Errorf("Expected wrapped begin error, got %v", err) } } func TestWithTransaction_CommitError(t *testing.T) { commitError := errors.New("commit error") tx := &mockTX{commitError: commitError} db := &mockDB{txMock: tx} err := WithTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error { return nil }) if err == nil || !errors.Is(err, commitError) { t.Errorf("Expected wrapped commit error, got %v", err) } if !tx.commitCalled { t.Error("Expected commit to be called") } if tx.rollbackCalled { t.Error("Expected rollback NOT to be called when commit fails") } } func TestWithReadOnlyTransaction_Success(t *testing.T) { tx := &mockTX{} db := &mockDB{txMock: tx} var functionCalled bool err := WithReadOnlyTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error { functionCalled = true return nil }) if err != nil { t.Errorf("Expected no error, got %v", err) } if !functionCalled { t.Error("Expected function to be called") } if tx.commitCalled { t.Error("Expected commit NOT to be called in read-only transaction") } if !tx.rollbackCalled { t.Error("Expected rollback to be called in read-only transaction") } } func TestWithReadOnlyTransaction_FunctionError(t *testing.T) { tx := &mockTX{} db := &mockDB{txMock: tx} expectedError := errors.New("function error") err := WithReadOnlyTransaction(context.Background(), db, func(ctx context.Context, q *mockTX) error { return expectedError }) if err != expectedError { t.Errorf("Expected error %v, got %v", expectedError, err) } if !tx.rollbackCalled { t.Error("Expected rollback to be called") } }