From 051ee9fea53cba829fb91f57efdf1790f6bb62ee Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Thu, 14 Mar 2024 14:45:30 -0500 Subject: [PATCH] pkg/sqlutil: switch TimeoutHook argument to func --- pkg/sqlutil/hook.go | 8 ++++++++ pkg/sqlutil/hook_test.go | 26 ++++++++++++++++++-------- pkg/sqlutil/sqlutil.go | 1 + pkg/sqlutil/timeout.go | 4 ++-- 4 files changed, 29 insertions(+), 10 deletions(-) diff --git a/pkg/sqlutil/hook.go b/pkg/sqlutil/hook.go index fce0bd92e2..273a438c3b 100644 --- a/pkg/sqlutil/hook.go +++ b/pkg/sqlutil/hook.go @@ -133,6 +133,14 @@ func (w *wrappedDataSource) SelectContext(ctx context.Context, dest interface{}, }, query, args...) } +func (w *wrappedDataSource) PrepareNamedContext(ctx context.Context, query string) (stmt *sqlx.NamedStmt, err error) { + err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) { + stmt, err = w.db.PrepareNamedContext(ctx, query) + return + }, query, nil) + return +} + // wrappedTransactionalDataSource extends [wrappedDataSource] with BeginTxx and BeginWrappedTxx for initiating transactions. type wrappedTransactionalDataSource struct { wrappedDataSource diff --git a/pkg/sqlutil/hook_test.go b/pkg/sqlutil/hook_test.go index 5bdf9541bf..7da6829879 100644 --- a/pkg/sqlutil/hook_test.go +++ b/pkg/sqlutil/hook_test.go @@ -28,14 +28,20 @@ func TestWrapDataSource(t *testing.T) { var ds DataSource = &dataSource{} var sentinelErr = errors.New("intercepted error") const fakeError = "fake warning" - ds = WrapDataSource(ds, lggr, TimeoutHook(selDur/2), noopHook, MonitorHook(func() bool { return true }), noopHook, func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error { - err := do(ctx) - if err != nil { - return err - } - lggr.Error(fakeError) - return sentinelErr - }) + ds = WrapDataSource(ds, lggr, + TimeoutHook(func() time.Duration { return selDur / 2 }), + noopHook, + MonitorHook(func() bool { return true }), + noopHook, + func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error { + err := do(ctx) + if err != nil { + return err + } + lggr.Error(fakeError) + return sentinelErr + }, + ) ctx := tests.Context(t) // Error intercepted @@ -129,6 +135,10 @@ func (q *dataSource) PrepareContext(ctx context.Context, query string) (*sql.Stm return nil, nil } +func (q *dataSource) PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) { + return nil, nil +} + func (q *dataSource) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { select { case <-ctx.Done(): diff --git a/pkg/sqlutil/sqlutil.go b/pkg/sqlutil/sqlutil.go index 2120324d77..dd92151eec 100644 --- a/pkg/sqlutil/sqlutil.go +++ b/pkg/sqlutil/sqlutil.go @@ -19,6 +19,7 @@ type DataSource interface { sqlx.PreparerContext GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error + PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) } type TxOptions struct { diff --git a/pkg/sqlutil/timeout.go b/pkg/sqlutil/timeout.go index 4b38aece9f..800b3147bc 100644 --- a/pkg/sqlutil/timeout.go +++ b/pkg/sqlutil/timeout.go @@ -9,11 +9,11 @@ import ( // TimeoutHook returns a [QueryHook] which adds the defaultTimeout to each context.Context, // unless [WithoutDefaultTimeout] has been applied to bypass intentionally. -func TimeoutHook(defaultTimeout time.Duration) QueryHook { +func TimeoutHook(defaultTimeout func() time.Duration) QueryHook { return func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error { if wo := ctx.Value(ctxKeyWithoutDefaultTimeout{}); wo == nil { var cancel func() - ctx, cancel = context.WithTimeout(ctx, defaultTimeout) + ctx, cancel = context.WithTimeout(ctx, defaultTimeout()) defer cancel() }