package database import ( "context" "database/sql" "fmt" "go.ntppool.org/common/logger" ) // Shared interface definitions that both packages use identically type BaseBeginner interface { Begin(context.Context) (sql.Tx, error) } type BaseTx interface { BaseBeginner Commit(ctx context.Context) error Rollback(ctx context.Context) error } // BeginTransactionForQuerier contains the shared Begin() logic from both packages func BeginTransactionForQuerier(ctx context.Context, db DBTX) (DBTX, error) { if sqlDB, ok := db.(*sql.DB); ok { tx, err := sqlDB.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return nil, err } return tx, nil } else { // Handle transaction case if beginner, ok := db.(BaseBeginner); ok { tx, err := beginner.Begin(ctx) if err != nil { return nil, err } return &tx, nil } return nil, fmt.Errorf("database connection does not support transactions") } } // CommitTransactionForQuerier contains the shared Commit() logic from both packages func CommitTransactionForQuerier(ctx context.Context, db DBTX) error { if sqlTx, ok := db.(*sql.Tx); ok { return sqlTx.Commit() } tx, ok := db.(BaseTx) if !ok { log := logger.FromContext(ctx) log.ErrorContext(ctx, "could not get a Tx", "type", fmt.Sprintf("%T", db)) return sql.ErrTxDone } return tx.Commit(ctx) } // RollbackTransactionForQuerier contains the shared Rollback() logic from both packages func RollbackTransactionForQuerier(ctx context.Context, db DBTX) error { if sqlTx, ok := db.(*sql.Tx); ok { return sqlTx.Rollback() } tx, ok := db.(BaseTx) if !ok { return sql.ErrTxDone } return tx.Rollback(ctx) }