65 lines
1.2 KiB
Go
65 lines
1.2 KiB
Go
package ntpdb
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
)
|
|
|
|
type QuerierTx interface {
|
|
Querier
|
|
|
|
Begin(ctx context.Context) (QuerierTx, error)
|
|
Commit(ctx context.Context) error
|
|
Rollback(ctx context.Context) error
|
|
}
|
|
|
|
type Beginner interface {
|
|
Begin(context.Context) (sql.Tx, error)
|
|
}
|
|
|
|
type Tx interface {
|
|
Begin(context.Context) (sql.Tx, error)
|
|
Commit(ctx context.Context) error
|
|
Rollback(ctx context.Context) error
|
|
}
|
|
|
|
func (q *Queries) Begin(ctx context.Context) (QuerierTx, error) {
|
|
tx, err := q.db.(Beginner).Begin(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Queries{db: &tx}, nil
|
|
}
|
|
|
|
func (q *Queries) Commit(ctx context.Context) error {
|
|
tx, ok := q.db.(Tx)
|
|
if !ok {
|
|
return sql.ErrTxDone
|
|
}
|
|
return tx.Commit(ctx)
|
|
}
|
|
|
|
func (q *Queries) Rollback(ctx context.Context) error {
|
|
tx, ok := q.db.(Tx)
|
|
if !ok {
|
|
return sql.ErrTxDone
|
|
}
|
|
return tx.Rollback(ctx)
|
|
}
|
|
|
|
type WrappedQuerier struct {
|
|
QuerierTxWithTracing
|
|
}
|
|
|
|
func NewWrappedQuerier(q QuerierTx) QuerierTx {
|
|
return &WrappedQuerier{NewQuerierTxWithTracing(q, "")}
|
|
}
|
|
|
|
func (wq *WrappedQuerier) Begin(ctx context.Context) (QuerierTx, error) {
|
|
q, err := wq.QuerierTxWithTracing.Begin(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return NewWrappedQuerier(q), nil
|
|
}
|