diff --git a/database/transaction_base.go b/database/transaction_base.go new file mode 100644 index 0000000..e8ee875 --- /dev/null +++ b/database/transaction_base.go @@ -0,0 +1,69 @@ +package database + +import ( + "context" + "database/sql" + "fmt" + + "go.ntppool.org/common/logger" +) + +// Shared interface definitions that both packages use identically +type BaseBeginner interface { + Begin(context.Context) (sql.Tx, error) +} + +type BaseTx interface { + BaseBeginner + Commit(ctx context.Context) error + Rollback(ctx context.Context) error +} + +// BeginTransactionForQuerier contains the shared Begin() logic from both packages +func BeginTransactionForQuerier(ctx context.Context, db DBTX) (DBTX, error) { + if sqlDB, ok := db.(*sql.DB); ok { + tx, err := sqlDB.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + return nil, err + } + return tx, nil + } else { + // Handle transaction case + if beginner, ok := db.(BaseBeginner); ok { + tx, err := beginner.Begin(ctx) + if err != nil { + return nil, err + } + return &tx, nil + } + return nil, fmt.Errorf("database connection does not support transactions") + } +} + +// CommitTransactionForQuerier contains the shared Commit() logic from both packages +func CommitTransactionForQuerier(ctx context.Context, db DBTX) error { + if sqlTx, ok := db.(*sql.Tx); ok { + return sqlTx.Commit() + } + + tx, ok := db.(BaseTx) + if !ok { + log := logger.FromContext(ctx) + log.ErrorContext(ctx, "could not get a Tx", "type", fmt.Sprintf("%T", db)) + return sql.ErrTxDone + } + return tx.Commit(ctx) +} + +// RollbackTransactionForQuerier contains the shared Rollback() logic from both packages +func RollbackTransactionForQuerier(ctx context.Context, db DBTX) error { + if sqlTx, ok := db.(*sql.Tx); ok { + return sqlTx.Rollback() + } + + tx, ok := db.(BaseTx) + if !ok { + return sql.ErrTxDone + } + return tx.Rollback(ctx) +}