Add transaction base utilities with Begin, Commit, and Rollback functions supporting both sql.DB and sql.Tx interfaces.
70 lines
1.6 KiB
Go
70 lines
1.6 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
|
|
"go.ntppool.org/common/logger"
|
|
)
|
|
|
|
// Shared interface definitions that both packages use identically
|
|
type BaseBeginner interface {
|
|
Begin(context.Context) (sql.Tx, error)
|
|
}
|
|
|
|
type BaseTx interface {
|
|
BaseBeginner
|
|
Commit(ctx context.Context) error
|
|
Rollback(ctx context.Context) error
|
|
}
|
|
|
|
// BeginTransactionForQuerier contains the shared Begin() logic from both packages
|
|
func BeginTransactionForQuerier(ctx context.Context, db DBTX) (DBTX, error) {
|
|
if sqlDB, ok := db.(*sql.DB); ok {
|
|
tx, err := sqlDB.BeginTx(ctx, &sql.TxOptions{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return tx, nil
|
|
} else {
|
|
// Handle transaction case
|
|
if beginner, ok := db.(BaseBeginner); ok {
|
|
tx, err := beginner.Begin(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &tx, nil
|
|
}
|
|
return nil, fmt.Errorf("database connection does not support transactions")
|
|
}
|
|
}
|
|
|
|
// CommitTransactionForQuerier contains the shared Commit() logic from both packages
|
|
func CommitTransactionForQuerier(ctx context.Context, db DBTX) error {
|
|
if sqlTx, ok := db.(*sql.Tx); ok {
|
|
return sqlTx.Commit()
|
|
}
|
|
|
|
tx, ok := db.(BaseTx)
|
|
if !ok {
|
|
log := logger.FromContext(ctx)
|
|
log.ErrorContext(ctx, "could not get a Tx", "type", fmt.Sprintf("%T", db))
|
|
return sql.ErrTxDone
|
|
}
|
|
return tx.Commit(ctx)
|
|
}
|
|
|
|
// RollbackTransactionForQuerier contains the shared Rollback() logic from both packages
|
|
func RollbackTransactionForQuerier(ctx context.Context, db DBTX) error {
|
|
if sqlTx, ok := db.(*sql.Tx); ok {
|
|
return sqlTx.Rollback()
|
|
}
|
|
|
|
tx, ok := db.(BaseTx)
|
|
if !ok {
|
|
return sql.ErrTxDone
|
|
}
|
|
return tx.Rollback(ctx)
|
|
}
|