Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pkg/sqlutil: switch TimeoutHook argument to func for dynamic reconfiguration; add TransactDataSource #403

Merged
merged 2 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions pkg/sqlutil/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ func (w *wrappedDataSource) ExecContext(ctx context.Context, query string, args
return
}

func (w *wrappedDataSource) NamedExecContext(ctx context.Context, query string, arg interface{}) (res sql.Result, err error) {
err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) {
res, err = w.db.NamedExecContext(ctx, query, arg)
return
}, query, arg)
return
}

func (w *wrappedDataSource) PrepareContext(ctx context.Context, query string) (stmt *sql.Stmt, err error) {
err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) {
stmt, err = w.db.PrepareContext(ctx, query) //nolint:sqlclosecheck
Expand All @@ -121,6 +129,14 @@ func (w *wrappedDataSource) PrepareContext(ctx context.Context, query string) (s
return
}

func (w *wrappedDataSource) PrepareNamedContext(ctx context.Context, query string) (stmt *sqlx.NamedStmt, err error) {
err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) {
stmt, err = w.db.PrepareNamedContext(ctx, query) //nolint:sqlclosecheck
return
}, query, nil)
return
}

func (w *wrappedDataSource) GetContext(ctx context.Context, dest interface{}, query string, args ...any) error {
return w.hook(ctx, w.lggr, func(ctx context.Context) error {
return w.db.GetContext(ctx, dest, query, args...)
Expand Down
30 changes: 22 additions & 8 deletions pkg/sqlutil/hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,20 @@ func TestWrapDataSource(t *testing.T) {
var ds DataSource = &dataSource{}
var sentinelErr = errors.New("intercepted error")
const fakeError = "fake warning"
ds = WrapDataSource(ds, lggr, TimeoutHook(selDur/2), noopHook, MonitorHook(func() bool { return true }), noopHook, func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error {
err := do(ctx)
if err != nil {
return err
}
lggr.Error(fakeError)
return sentinelErr
})
ds = WrapDataSource(ds, lggr,
TimeoutHook(func() time.Duration { return selDur / 2 }),
noopHook,
MonitorHook(func() bool { return true }),
noopHook,
func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error {
err := do(ctx)
if err != nil {
return err
}
lggr.Error(fakeError)
return sentinelErr
},
)
ctx := tests.Context(t)

// Error intercepted
Expand Down Expand Up @@ -125,10 +131,18 @@ func (q *dataSource) ExecContext(ctx context.Context, query string, args ...inte
return nil, nil
}

func (q *dataSource) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) {
return nil, nil
}

func (q *dataSource) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
return nil, nil
}

func (q *dataSource) PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) {
return nil, nil
}

func (q *dataSource) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
select {
case <-ctx.Done():
Expand Down
20 changes: 18 additions & 2 deletions pkg/sqlutil/sqlutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,30 @@ type DataSource interface {
sqlx.PreparerContext
GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error)
NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error)
}

type TxOptions struct {
sql.TxOptions
OnPanic func(recovered any, rollbackErr error)
}

// TransactDataSource is a helper for executing transactions.
// This useful for executing raw SQL queries, or when using more than one type which is not supported by Transact.
func TransactDataSource(ctx context.Context, ds DataSource, opts *TxOptions, fn func(tx DataSource) error) error {
return Transact(ctx, func(tx DataSource) DataSource { return tx }, ds, opts, fn)
}

// Transact is a helper for executing transactions with a domain specific type.
// A typical use looks like:
//
// func (d *MyD) Transaction(ctx context.Context, fn func(*MyD) error) (err error) {
// func (d *MyD) Transact(ctx context.Context, fn func(tx *MyD) error) (err error) {
// return sqlutil.Transact(ctx, d.new, d.db, nil, fn)
// }
func Transact[D any](ctx context.Context, newD func(DataSource) D, ds DataSource, opts *TxOptions, fn func(D) error) (err error) {
//
// If you need to combine multiple types in one transaction, you can declare a new type, or use TransactDataSource.
func Transact[D any](ctx context.Context, newD func(DataSource) D, ds DataSource, opts *TxOptions, fn func(tx D) error) (err error) {
txds, ok := ds.(transactional)
if !ok {
// Unsupported or already inside another transaction.
Expand Down Expand Up @@ -81,15 +91,21 @@ func Transact[D any](ctx context.Context, newD func(DataSource) D, ds DataSource
return
}

var _ transactional = (*sqlx.DB)(nil)

type transactional interface {
// BeginTxx is implemented by *sqlx.DB but not *sqlx.Tx.
BeginTxx(context.Context, *sql.TxOptions) (*sqlx.Tx, error)
}

var _ wrappedTransactional = (*wrappedTransactionalDataSource)(nil)

type wrappedTransactional interface {
BeginWrappedTxx(context.Context, *sql.TxOptions) (transaction, error)
}

var _ transaction = (*wrappedTx)(nil)

type transaction interface {
DataSource
Commit() error
Expand Down
4 changes: 2 additions & 2 deletions pkg/sqlutil/timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import (

// TimeoutHook returns a [QueryHook] which adds the defaultTimeout to each context.Context,
// unless [WithoutDefaultTimeout] has been applied to bypass intentionally.
func TimeoutHook(defaultTimeout time.Duration) QueryHook {
func TimeoutHook(defaultTimeout func() time.Duration) QueryHook {
return func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error {
if wo := ctx.Value(ctxKeyWithoutDefaultTimeout{}); wo == nil {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, defaultTimeout)
ctx, cancel = context.WithTimeout(ctx, defaultTimeout())
defer cancel()
}

Expand Down
Loading