diff --git a/pkg/sqlutil/hook.go b/pkg/sqlutil/hook.go index fce0bd92e..1387487e8 100644 --- a/pkg/sqlutil/hook.go +++ b/pkg/sqlutil/hook.go @@ -113,6 +113,14 @@ func (w *wrappedDataSource) ExecContext(ctx context.Context, query string, args return } +func (w *wrappedDataSource) NamedExecContext(ctx context.Context, query string, arg interface{}) (res sql.Result, err error) { + err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) { + res, err = w.db.NamedExecContext(ctx, query, arg) + return + }, query, arg) + return +} + func (w *wrappedDataSource) PrepareContext(ctx context.Context, query string) (stmt *sql.Stmt, err error) { err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) { stmt, err = w.db.PrepareContext(ctx, query) //nolint:sqlclosecheck @@ -121,6 +129,14 @@ func (w *wrappedDataSource) PrepareContext(ctx context.Context, query string) (s return } +func (w *wrappedDataSource) PrepareNamedContext(ctx context.Context, query string) (stmt *sqlx.NamedStmt, err error) { + err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) { + stmt, err = w.db.PrepareNamedContext(ctx, query) //nolint:sqlclosecheck + return + }, query, nil) + return +} + func (w *wrappedDataSource) GetContext(ctx context.Context, dest interface{}, query string, args ...any) error { return w.hook(ctx, w.lggr, func(ctx context.Context) error { return w.db.GetContext(ctx, dest, query, args...) diff --git a/pkg/sqlutil/hook_test.go b/pkg/sqlutil/hook_test.go index 5bdf9541b..378067282 100644 --- a/pkg/sqlutil/hook_test.go +++ b/pkg/sqlutil/hook_test.go @@ -28,14 +28,20 @@ func TestWrapDataSource(t *testing.T) { var ds DataSource = &dataSource{} var sentinelErr = errors.New("intercepted error") const fakeError = "fake warning" - ds = WrapDataSource(ds, lggr, TimeoutHook(selDur/2), noopHook, MonitorHook(func() bool { return true }), noopHook, func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error { - err := do(ctx) - if err != nil { - return err - } - lggr.Error(fakeError) - return sentinelErr - }) + ds = WrapDataSource(ds, lggr, + TimeoutHook(func() time.Duration { return selDur / 2 }), + noopHook, + MonitorHook(func() bool { return true }), + noopHook, + func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error { + err := do(ctx) + if err != nil { + return err + } + lggr.Error(fakeError) + return sentinelErr + }, + ) ctx := tests.Context(t) // Error intercepted @@ -125,10 +131,18 @@ func (q *dataSource) ExecContext(ctx context.Context, query string, args ...inte return nil, nil } +func (q *dataSource) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { + return nil, nil +} + func (q *dataSource) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { return nil, nil } +func (q *dataSource) PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) { + return nil, nil +} + func (q *dataSource) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { select { case <-ctx.Done(): diff --git a/pkg/sqlutil/sqlutil.go b/pkg/sqlutil/sqlutil.go index 2120324d7..0b1ec7a79 100644 --- a/pkg/sqlutil/sqlutil.go +++ b/pkg/sqlutil/sqlutil.go @@ -19,6 +19,8 @@ type DataSource interface { sqlx.PreparerContext GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error + PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) + NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) } type TxOptions struct { @@ -26,13 +28,21 @@ type TxOptions struct { OnPanic func(recovered any, rollbackErr error) } +// TransactDataSource is a helper for executing transactions. +// This useful for executing raw SQL queries, or when using more than one type which is not supported by Transact. +func TransactDataSource(ctx context.Context, ds DataSource, opts *TxOptions, fn func(tx DataSource) error) error { + return Transact(ctx, func(tx DataSource) DataSource { return tx }, ds, opts, fn) +} + // Transact is a helper for executing transactions with a domain specific type. // A typical use looks like: // -// func (d *MyD) Transaction(ctx context.Context, fn func(*MyD) error) (err error) { +// func (d *MyD) Transact(ctx context.Context, fn func(tx *MyD) error) (err error) { // return sqlutil.Transact(ctx, d.new, d.db, nil, fn) // } -func Transact[D any](ctx context.Context, newD func(DataSource) D, ds DataSource, opts *TxOptions, fn func(D) error) (err error) { +// +// If you need to combine multiple types in one transaction, you can declare a new type, or use TransactDataSource. +func Transact[D any](ctx context.Context, newD func(DataSource) D, ds DataSource, opts *TxOptions, fn func(tx D) error) (err error) { txds, ok := ds.(transactional) if !ok { // Unsupported or already inside another transaction. @@ -81,15 +91,21 @@ func Transact[D any](ctx context.Context, newD func(DataSource) D, ds DataSource return } +var _ transactional = (*sqlx.DB)(nil) + type transactional interface { // BeginTxx is implemented by *sqlx.DB but not *sqlx.Tx. BeginTxx(context.Context, *sql.TxOptions) (*sqlx.Tx, error) } +var _ wrappedTransactional = (*wrappedTransactionalDataSource)(nil) + type wrappedTransactional interface { BeginWrappedTxx(context.Context, *sql.TxOptions) (transaction, error) } +var _ transaction = (*wrappedTx)(nil) + type transaction interface { DataSource Commit() error diff --git a/pkg/sqlutil/timeout.go b/pkg/sqlutil/timeout.go index 4b38aece9..800b3147b 100644 --- a/pkg/sqlutil/timeout.go +++ b/pkg/sqlutil/timeout.go @@ -9,11 +9,11 @@ import ( // TimeoutHook returns a [QueryHook] which adds the defaultTimeout to each context.Context, // unless [WithoutDefaultTimeout] has been applied to bypass intentionally. -func TimeoutHook(defaultTimeout time.Duration) QueryHook { +func TimeoutHook(defaultTimeout func() time.Duration) QueryHook { return func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error { if wo := ctx.Value(ctxKeyWithoutDefaultTimeout{}); wo == nil { var cancel func() - ctx, cancel = context.WithTimeout(ctx, defaultTimeout) + ctx, cancel = context.WithTimeout(ctx, defaultTimeout()) defer cancel() }