From da13a371b48ac59f17132eef62199d81e8ebd49d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ask=20Bj=C3=B8rn=20Hansen?= Date: Sat, 12 Jul 2025 23:52:48 -0700 Subject: [PATCH] feat(database): add shared transaction helpers Add transaction base utilities with Begin, Commit, and Rollback functions supporting both sql.DB and sql.Tx interfaces. --- database/transaction_base.go | 69 ++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 database/transaction_base.go diff --git a/database/transaction_base.go b/database/transaction_base.go new file mode 100644 index 0000000..e8ee875 --- /dev/null +++ b/database/transaction_base.go @@ -0,0 +1,69 @@ +package database + +import ( + "context" + "database/sql" + "fmt" + + "go.ntppool.org/common/logger" +) + +// Shared interface definitions that both packages use identically +type BaseBeginner interface { + Begin(context.Context) (sql.Tx, error) +} + +type BaseTx interface { + BaseBeginner + Commit(ctx context.Context) error + Rollback(ctx context.Context) error +} + +// BeginTransactionForQuerier contains the shared Begin() logic from both packages +func BeginTransactionForQuerier(ctx context.Context, db DBTX) (DBTX, error) { + if sqlDB, ok := db.(*sql.DB); ok { + tx, err := sqlDB.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + return nil, err + } + return tx, nil + } else { + // Handle transaction case + if beginner, ok := db.(BaseBeginner); ok { + tx, err := beginner.Begin(ctx) + if err != nil { + return nil, err + } + return &tx, nil + } + return nil, fmt.Errorf("database connection does not support transactions") + } +} + +// CommitTransactionForQuerier contains the shared Commit() logic from both packages +func CommitTransactionForQuerier(ctx context.Context, db DBTX) error { + if sqlTx, ok := db.(*sql.Tx); ok { + return sqlTx.Commit() + } + + tx, ok := db.(BaseTx) + if !ok { + log := logger.FromContext(ctx) + log.ErrorContext(ctx, "could not get a Tx", "type", fmt.Sprintf("%T", db)) + return sql.ErrTxDone + } + return tx.Commit(ctx) +} + +// RollbackTransactionForQuerier contains the shared Rollback() logic from both packages +func RollbackTransactionForQuerier(ctx context.Context, db DBTX) error { + if sqlTx, ok := db.(*sql.Tx); ok { + return sqlTx.Rollback() + } + + tx, ok := db.(BaseTx) + if !ok { + return sql.ErrTxDone + } + return tx.Rollback(ctx) +}