Skip to content

Commit

Permalink
pkg/sqlutil: switch TimeoutHook argument to func
Browse files Browse the repository at this point in the history
  • Loading branch information
jmank88 committed Mar 15, 2024
1 parent 03b543a commit 91e05ff
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 10 deletions.
8 changes: 8 additions & 0 deletions pkg/sqlutil/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ func (w *wrappedDataSource) SelectContext(ctx context.Context, dest interface{},
}, query, args...)
}

func (w *wrappedDataSource) PrepareNamedContext(ctx context.Context, query string) (stmt *sqlx.NamedStmt, err error) {
err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) {
stmt, err = w.db.PrepareNamedContext(ctx, query) //nolint:sqlclosecheck
return
}, query, nil)
return
}

// wrappedTransactionalDataSource extends [wrappedDataSource] with BeginTxx and BeginWrappedTxx for initiating transactions.
type wrappedTransactionalDataSource struct {
wrappedDataSource
Expand Down
26 changes: 18 additions & 8 deletions pkg/sqlutil/hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,20 @@ func TestWrapDataSource(t *testing.T) {
var ds DataSource = &dataSource{}
var sentinelErr = errors.New("intercepted error")
const fakeError = "fake warning"
ds = WrapDataSource(ds, lggr, TimeoutHook(selDur/2), noopHook, MonitorHook(func() bool { return true }), noopHook, func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error {
err := do(ctx)
if err != nil {
return err
}
lggr.Error(fakeError)
return sentinelErr
})
ds = WrapDataSource(ds, lggr,
TimeoutHook(func() time.Duration { return selDur / 2 }),
noopHook,
MonitorHook(func() bool { return true }),
noopHook,
func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error {
err := do(ctx)
if err != nil {
return err
}
lggr.Error(fakeError)
return sentinelErr
},
)
ctx := tests.Context(t)

// Error intercepted
Expand Down Expand Up @@ -129,6 +135,10 @@ func (q *dataSource) PrepareContext(ctx context.Context, query string) (*sql.Stm
return nil, nil
}

func (q *dataSource) PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) {
return nil, nil
}

func (q *dataSource) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
select {
case <-ctx.Done():
Expand Down
1 change: 1 addition & 0 deletions pkg/sqlutil/sqlutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type DataSource interface {
sqlx.PreparerContext
GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error)
}

type TxOptions struct {
Expand Down
4 changes: 2 additions & 2 deletions pkg/sqlutil/timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import (

// TimeoutHook returns a [QueryHook] which adds the defaultTimeout to each context.Context,
// unless [WithoutDefaultTimeout] has been applied to bypass intentionally.
func TimeoutHook(defaultTimeout time.Duration) QueryHook {
func TimeoutHook(defaultTimeout func() time.Duration) QueryHook {
return func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error {
if wo := ctx.Value(ctxKeyWithoutDefaultTimeout{}); wo == nil {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, defaultTimeout)
ctx, cancel = context.WithTimeout(ctx, defaultTimeout())
defer cancel()
}

Expand Down

0 comments on commit 91e05ff

Please sign in to comment.