Skip to content

Commit

Permalink
pkg/sqlutil: switch TimeoutHook argument to func for dynamic reconfig…
Browse files Browse the repository at this point in the history
…uration; add TransactDataSource
  • Loading branch information
jmank88 committed Mar 23, 2024
1 parent c150979 commit 627eed0
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 12 deletions.
16 changes: 16 additions & 0 deletions pkg/sqlutil/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ func (w *wrappedDataSource) ExecContext(ctx context.Context, query string, args
return
}

func (w *wrappedDataSource) NamedExecContext(ctx context.Context, query string, arg interface{}) (res sql.Result, err error) {
err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) {
res, err = w.db.NamedExecContext(ctx, query, arg)
return
}, query, arg)
return
}

func (w *wrappedDataSource) PrepareContext(ctx context.Context, query string) (stmt *sql.Stmt, err error) {
err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) {
stmt, err = w.db.PrepareContext(ctx, query) //nolint:sqlclosecheck
Expand All @@ -121,6 +129,14 @@ func (w *wrappedDataSource) PrepareContext(ctx context.Context, query string) (s
return
}

func (w *wrappedDataSource) PrepareNamedContext(ctx context.Context, query string) (stmt *sqlx.NamedStmt, err error) {
err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) {
stmt, err = w.db.PrepareNamedContext(ctx, query) //nolint:sqlclosecheck
return
}, query, nil)
return
}

func (w *wrappedDataSource) GetContext(ctx context.Context, dest interface{}, query string, args ...any) error {
return w.hook(ctx, w.lggr, func(ctx context.Context) error {
return w.db.GetContext(ctx, dest, query, args...)
Expand Down
30 changes: 22 additions & 8 deletions pkg/sqlutil/hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,20 @@ func TestWrapDataSource(t *testing.T) {
var ds DataSource = &dataSource{}
var sentinelErr = errors.New("intercepted error")
const fakeError = "fake warning"
ds = WrapDataSource(ds, lggr, TimeoutHook(selDur/2), noopHook, MonitorHook(func() bool { return true }), noopHook, func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error {
err := do(ctx)
if err != nil {
return err
}
lggr.Error(fakeError)
return sentinelErr
})
ds = WrapDataSource(ds, lggr,
TimeoutHook(func() time.Duration { return selDur / 2 }),
noopHook,
MonitorHook(func() bool { return true }),
noopHook,
func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error {
err := do(ctx)
if err != nil {
return err
}
lggr.Error(fakeError)
return sentinelErr
},
)
ctx := tests.Context(t)

// Error intercepted
Expand Down Expand Up @@ -125,10 +131,18 @@ func (q *dataSource) ExecContext(ctx context.Context, query string, args ...inte
return nil, nil
}

func (q *dataSource) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) {
return nil, nil
}

func (q *dataSource) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
return nil, nil
}

func (q *dataSource) PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) {
return nil, nil
}

func (q *dataSource) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
select {
case <-ctx.Done():
Expand Down
20 changes: 18 additions & 2 deletions pkg/sqlutil/sqlutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,30 @@ type DataSource interface {
sqlx.PreparerContext
GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error)
NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error)
}

type TxOptions struct {
sql.TxOptions
OnPanic func(recovered any, rollbackErr error)
}

// TransactDataSource is a helper for executing transactions.
// This useful for executing raw SQL queries, or when using more than one type which is not supported by Transact.
func TransactDataSource(ctx context.Context, ds DataSource, opts *TxOptions, fn func(tx DataSource) error) error {
return Transact(ctx, func(tx DataSource) DataSource { return tx }, ds, opts, fn)
}

// Transact is a helper for executing transactions with a domain specific type.
// A typical use looks like:
//
// func (d *MyD) Transaction(ctx context.Context, fn func(*MyD) error) (err error) {
// func (d *MyD) Transact(ctx context.Context, fn func(tx *MyD) error) (err error) {
// return sqlutil.Transact(ctx, d.new, d.db, nil, fn)
// }
func Transact[D any](ctx context.Context, newD func(DataSource) D, ds DataSource, opts *TxOptions, fn func(D) error) (err error) {
//
// If you need to combine multiple types in one transaction, you can declare a new type, or use TransactDataSource.
func Transact[D any](ctx context.Context, newD func(DataSource) D, ds DataSource, opts *TxOptions, fn func(tx D) error) (err error) {
txds, ok := ds.(transactional)
if !ok {
// Unsupported or already inside another transaction.
Expand Down Expand Up @@ -81,15 +91,21 @@ func Transact[D any](ctx context.Context, newD func(DataSource) D, ds DataSource
return
}

var _ transactional = (*sqlx.DB)(nil)

type transactional interface {
// BeginTxx is implemented by *sqlx.DB but not *sqlx.Tx.
BeginTxx(context.Context, *sql.TxOptions) (*sqlx.Tx, error)
}

var _ wrappedTransactional = (*wrappedTransactionalDataSource)(nil)

type wrappedTransactional interface {
BeginWrappedTxx(context.Context, *sql.TxOptions) (transaction, error)
}

var _ transaction = (*wrappedTx)(nil)

type transaction interface {
DataSource
Commit() error
Expand Down
4 changes: 2 additions & 2 deletions pkg/sqlutil/timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import (

// TimeoutHook returns a [QueryHook] which adds the defaultTimeout to each context.Context,
// unless [WithoutDefaultTimeout] has been applied to bypass intentionally.
func TimeoutHook(defaultTimeout time.Duration) QueryHook {
func TimeoutHook(defaultTimeout func() time.Duration) QueryHook {
return func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error {
if wo := ctx.Value(ctxKeyWithoutDefaultTimeout{}); wo == nil {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, defaultTimeout)
ctx, cancel = context.WithTimeout(ctx, defaultTimeout())
defer cancel()
}

Expand Down

0 comments on commit 627eed0

Please sign in to comment.