From 049609b8e1f9bead712a9fb3a958eb70fbac5ed1 Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Thu, 14 Mar 2024 12:21:56 -0500 Subject: [PATCH] pkg/sqlutil: fix wrapped DB transactions (#400) --- pkg/sqlutil/hook.go | 112 ++++++++++++++++++++++++++++----------- pkg/sqlutil/hook_test.go | 62 +++++++++++++++------- pkg/sqlutil/sqlutil.go | 51 ++++++++++++++---- 3 files changed, 164 insertions(+), 61 deletions(-) diff --git a/pkg/sqlutil/hook.go b/pkg/sqlutil/hook.go index 7f32296de..fce0bd92e 100644 --- a/pkg/sqlutil/hook.go +++ b/pkg/sqlutil/hook.go @@ -9,11 +9,11 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" ) -var _ DB = &WrappedDB{} +var _ DataSource = &wrappedDataSource{} -// WrappedDB is a [DB] which invokes a [QueryHook] on each call. -type WrappedDB struct { - db DB +// wrappedDataSource is a [DataSource] which invokes a [QueryHook] on each call. +type wrappedDataSource struct { + db DataSource lggr logger.Logger hook QueryHook } @@ -25,35 +25,43 @@ type WrappedDB struct { // See [MonitorHook] and [TimeoutHook] for examples. type QueryHook func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error -// NewWrappedDB returns a new [WrappedDB] that calls each [QueryHook] in the provided order. -func NewWrappedDB(db DB, l logger.Logger, hs ...QueryHook) *WrappedDB { - iq := WrappedDB{db: db, +// WrapDataSource returns a new [DataSource] that calls each [QueryHook] in the provided order. +// If db implements [sqlx.BeginTxx], then the returned DataSource will also. +func WrapDataSource(db DataSource, l logger.Logger, hs ...QueryHook) DataSource { + iq := wrappedDataSource{db: db, lggr: logger.Helper(logger.Named(l, "WrappedDB"), 2), // skip our own wrapper and one interceptor hook: noopHook, } switch len(hs) { case 0: - return &iq case 1: iq.hook = hs[0] - return &iq + default: + // Nest the QueryHook calls so that they are wrapped from first to last. + // Example: + // [A, B, C] => A(B(C(do()))) + for i := len(hs) - 1; i >= 0; i-- { + next := hs[i] + prev := iq.hook + iq.hook = func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error { + // opt: cache the construction of these loggers + lggr = logger.Helper(lggr, 1) // skip one more for this wrapper + return next(ctx, lggr, func(ctx context.Context) error { + lggr = logger.Helper(lggr, 2) // skip two more for do() and this extra wrapper + return prev(ctx, lggr, do, query, args...) + }, query, args...) + } + } } - // Nest the QueryHook calls so that they are wrapped from first to last. - // Example: - // [A, B, C] => A(B(C(do()))) - for i := len(hs) - 1; i >= 0; i-- { - next := hs[i] - prev := iq.hook - iq.hook = func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error { - // opt: cache the construction of these loggers - lggr = logger.Helper(lggr, 1) // skip one more for this wrapper - return next(ctx, lggr, func(ctx context.Context) error { - lggr = logger.Helper(lggr, 2) // skip two more for do() and this extra wrapper - return prev(ctx, lggr, do, query, args...) - }, query, args...) + if txdb, ok := db.(transactional); ok { + // extra wrapper to make BeginTxx available + return &wrappedTransactionalDataSource{ + wrappedDataSource: iq, + txdb: txdb, } } + return &iq } @@ -61,19 +69,19 @@ func noopHook(ctx context.Context, lggr logger.Logger, do func(context.Context) return do(ctx) } -func (w *WrappedDB) DriverName() string { +func (w *wrappedDataSource) DriverName() string { return w.db.DriverName() } -func (w *WrappedDB) Rebind(s string) string { +func (w *wrappedDataSource) Rebind(s string) string { return w.db.Rebind(s) } -func (w *WrappedDB) BindNamed(s string, i interface{}) (string, []any, error) { +func (w *wrappedDataSource) BindNamed(s string, i interface{}) (string, []any, error) { return w.db.BindNamed(s, i) } -func (w *WrappedDB) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) { +func (w *wrappedDataSource) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) { err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) { rows, err = w.db.QueryContext(ctx, query, args...) //nolint return @@ -81,7 +89,7 @@ func (w *WrappedDB) QueryContext(ctx context.Context, query string, args ...any) return } -func (w *WrappedDB) QueryxContext(ctx context.Context, query string, args ...any) (rows *sqlx.Rows, err error) { +func (w *wrappedDataSource) QueryxContext(ctx context.Context, query string, args ...any) (rows *sqlx.Rows, err error) { err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) { rows, err = w.db.QueryxContext(ctx, query, args...) //nolint:sqlclosecheck return @@ -89,7 +97,7 @@ func (w *WrappedDB) QueryxContext(ctx context.Context, query string, args ...any return } -func (w *WrappedDB) QueryRowxContext(ctx context.Context, query string, args ...any) (row *sqlx.Row) { +func (w *wrappedDataSource) QueryRowxContext(ctx context.Context, query string, args ...any) (row *sqlx.Row) { _ = w.hook(ctx, w.lggr, func(ctx context.Context) error { row = w.db.QueryRowxContext(ctx, query, args...) return nil @@ -97,7 +105,7 @@ func (w *WrappedDB) QueryRowxContext(ctx context.Context, query string, args ... return } -func (w *WrappedDB) ExecContext(ctx context.Context, query string, args ...any) (res sql.Result, err error) { +func (w *wrappedDataSource) ExecContext(ctx context.Context, query string, args ...any) (res sql.Result, err error) { err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) { res, err = w.db.ExecContext(ctx, query, args...) return @@ -105,7 +113,7 @@ func (w *WrappedDB) ExecContext(ctx context.Context, query string, args ...any) return } -func (w *WrappedDB) PrepareContext(ctx context.Context, query string) (stmt *sql.Stmt, err error) { +func (w *wrappedDataSource) PrepareContext(ctx context.Context, query string) (stmt *sql.Stmt, err error) { err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) { stmt, err = w.db.PrepareContext(ctx, query) //nolint:sqlclosecheck return @@ -113,14 +121,54 @@ func (w *WrappedDB) PrepareContext(ctx context.Context, query string) (stmt *sql return } -func (w *WrappedDB) GetContext(ctx context.Context, dest interface{}, query string, args ...any) error { +func (w *wrappedDataSource) GetContext(ctx context.Context, dest interface{}, query string, args ...any) error { return w.hook(ctx, w.lggr, func(ctx context.Context) error { return w.db.GetContext(ctx, dest, query, args...) }, query, args...) } -func (w *WrappedDB) SelectContext(ctx context.Context, dest interface{}, query string, args ...any) error { +func (w *wrappedDataSource) SelectContext(ctx context.Context, dest interface{}, query string, args ...any) error { return w.hook(ctx, w.lggr, func(ctx context.Context) error { return w.db.SelectContext(ctx, dest, query, args...) }, query, args...) } + +// wrappedTransactionalDataSource extends [wrappedDataSource] with BeginTxx and BeginWrappedTxx for initiating transactions. +type wrappedTransactionalDataSource struct { + wrappedDataSource + txdb transactional +} + +func (w *wrappedTransactionalDataSource) BeginTxx(ctx context.Context, opts *sql.TxOptions) (tx *sqlx.Tx, err error) { + err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) { + tx, err = w.txdb.BeginTxx(ctx, opts) + return + }, "START TRANSACTION", nil) + return +} + +// BeginWrappedTxx is like BeginTxx, but wraps the returned tx with the same hook. +func (w *wrappedTransactionalDataSource) BeginWrappedTxx(ctx context.Context, opts *sql.TxOptions) (tx transaction, err error) { + tx, err = w.BeginTxx(ctx, opts) + if err != nil { + return nil, err + } + return &wrappedTx{ + wrappedDataSource: wrappedDataSource{ + db: tx, + lggr: w.lggr, + hook: w.hook, + }, + tx: tx, + }, nil +} + +// wrappedTx extends [wrappedDataSource] with Commit and Rollback for completing a transaction. +type wrappedTx struct { + wrappedDataSource + tx transaction +} + +func (w *wrappedTx) Commit() error { return w.tx.Commit() } + +func (w *wrappedTx) Rollback() error { return w.tx.Rollback() } diff --git a/pkg/sqlutil/hook_test.go b/pkg/sqlutil/hook_test.go index ee632e79d..5bdf9541b 100644 --- a/pkg/sqlutil/hook_test.go +++ b/pkg/sqlutil/hook_test.go @@ -23,12 +23,12 @@ const ( selDur = 200 * time.Millisecond ) -func TestNewInterceptedQueryer(t *testing.T) { +func TestWrapDataSource(t *testing.T) { lggr, ol := logger.TestObserved(t, zapcore.InfoLevel) - var db DB = &database{} + var ds DataSource = &dataSource{} var sentinelErr = errors.New("intercepted error") const fakeError = "fake warning" - db = NewWrappedDB(db, lggr, TimeoutHook(selDur/2), noopHook, MonitorHook(func() bool { return true }), noopHook, func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error { + ds = WrapDataSource(ds, lggr, TimeoutHook(selDur/2), noopHook, MonitorHook(func() bool { return true }), noopHook, func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error { err := do(ctx) if err != nil { return err @@ -39,7 +39,7 @@ func TestNewInterceptedQueryer(t *testing.T) { ctx := tests.Context(t) // Error intercepted - err := db.GetContext(ctx, "test", "foo", 42, "bar") + err := ds.GetContext(ctx, "test", "foo", 42, "bar") _, file, line, ok := runtime.Caller(0) require.True(t, ok) expCaller := fmt.Sprintf("%s:%d", file, line-1) @@ -55,7 +55,7 @@ func TestNewInterceptedQueryer(t *testing.T) { _ = ol.TakeAll() // Timeout applied - err = db.SelectContext(ctx, "test", "foo", 42, "bar") + err = ds.SelectContext(ctx, "test", "foo", 42, "bar") require.ErrorIs(t, err, context.DeadlineExceeded) logs = ol.FilterMessage(slowMsg).All() require.Len(t, logs, 1) @@ -63,49 +63,73 @@ func TestNewInterceptedQueryer(t *testing.T) { _ = ol.TakeAll() // Without default timeout - err = db.SelectContext(WithoutDefaultTimeout(ctx), "test", "foo", 42, "bar") + err = ds.SelectContext(WithoutDefaultTimeout(ctx), "test", "foo", 42, "bar") require.ErrorIs(t, err, sentinelErr) // W/o default, but with our own ctx2, cancel := context.WithTimeout(WithoutDefaultTimeout(ctx), selDur/100) t.Cleanup(cancel) - err = db.SelectContext(ctx2, "test", "foo", 42, "bar") + err = ds.SelectContext(ctx2, "test", "foo", 42, "bar") require.ErrorIs(t, err, context.DeadlineExceeded) } -var _ DB = &database{} +func TestWrapDataSource_transactional(t *testing.T) { + lggr := logger.Test(t) + + txional := (*transactional)(nil) + + var ds DataSource = (*sqlx.DB)(nil) + assert.Implements(t, txional, ds) + got := WrapDataSource(ds, lggr) + assert.Implements(t, txional, got) + got = WrapDataSource(ds, lggr, noopHook) + assert.Implements(t, txional, got) + got = WrapDataSource(ds, lggr, noopHook, noopHook) + assert.Implements(t, txional, got) + + ds = (*sqlx.Tx)(nil) + assert.NotImplements(t, txional, ds) + got = WrapDataSource(ds, lggr) + assert.NotImplements(t, txional, got) + got = WrapDataSource(ds, lggr, noopHook) + assert.NotImplements(t, txional, got) + got = WrapDataSource(ds, lggr, noopHook, noopHook) + assert.NotImplements(t, txional, got) +} + +var _ DataSource = &dataSource{} -type database struct{} +type dataSource struct{} -func (q *database) DriverName() string { return "" } +func (q *dataSource) DriverName() string { return "" } -func (q *database) Rebind(s string) string { return "" } +func (q *dataSource) Rebind(s string) string { return "" } -func (q *database) BindNamed(s string, i interface{}) (string, []interface{}, error) { +func (q *dataSource) BindNamed(s string, i interface{}) (string, []interface{}, error) { return "", nil, nil } -func (q *database) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { +func (q *dataSource) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { return nil, nil } -func (q *database) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { +func (q *dataSource) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { return nil, nil } -func (q *database) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row { +func (q *dataSource) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row { return nil } -func (q *database) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { +func (q *dataSource) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { return nil, nil } -func (q *database) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { +func (q *dataSource) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { return nil, nil } -func (q *database) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { +func (q *dataSource) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { select { case <-ctx.Done(): return ctx.Err() @@ -114,7 +138,7 @@ func (q *database) GetContext(ctx context.Context, dest interface{}, query strin return nil } -func (q *database) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { +func (q *dataSource) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { select { case <-ctx.Done(): return ctx.Err() diff --git a/pkg/sqlutil/sqlutil.go b/pkg/sqlutil/sqlutil.go index 2078845ea..2120324d7 100644 --- a/pkg/sqlutil/sqlutil.go +++ b/pkg/sqlutil/sqlutil.go @@ -8,10 +8,13 @@ import ( "github.com/jmoiron/sqlx" ) -type Queryer = DB +type Queryer = DataSource -// DB is implemented by [*sqlx.DB], [*sqlx.Tx], & [*sqlx.Conn]. -type DB interface { +var _ DataSource = (*sqlx.DB)(nil) +var _ DataSource = (*sqlx.Tx)(nil) + +// DataSource is implemented by [*sqlx.DB] & [*sqlx.Tx]. +type DataSource interface { sqlx.ExtContext sqlx.PreparerContext GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error @@ -29,19 +32,32 @@ type TxOptions struct { // func (d *MyD) Transaction(ctx context.Context, fn func(*MyD) error) (err error) { // return sqlutil.Transact(ctx, d.new, d.db, nil, fn) // } -func Transact[D any](ctx context.Context, newD func(DB) D, db DB, opts *TxOptions, fn func(D) error) (err error) { - txdb, ok := db.(interface { - // BeginTxx is implemented by *sqlx.DB & *sqlx.Conn, but not *sqlx.Tx. - BeginTxx(context.Context, *sql.TxOptions) (*sqlx.Tx, error) - }) +func Transact[D any](ctx context.Context, newD func(DataSource) D, ds DataSource, opts *TxOptions, fn func(D) error) (err error) { + txds, ok := ds.(transactional) if !ok { // Unsupported or already inside another transaction. - return fn(newD(db)) + return fn(newD(ds)) } if opts == nil { opts = &TxOptions{} } - tx, err := txdb.BeginTxx(ctx, &opts.TxOptions) + // Begin tx + tx, err := func() (transaction, error) { + // Support [DataSource]s wrapped via [WrapDataSource] + if wrapped, ok := ds.(wrappedTransactional); ok { + tx, terr := wrapped.BeginWrappedTxx(ctx, &opts.TxOptions) + if terr != nil { + return nil, terr + } + return tx, nil + } + + tx, terr := txds.BeginTxx(ctx, &opts.TxOptions) + if terr != nil { + return nil, terr + } + return tx, nil + }() if err != nil { return err } @@ -64,3 +80,18 @@ func Transact[D any](ctx context.Context, newD func(DB) D, db DB, opts *TxOption err = fn(newD(tx)) return } + +type transactional interface { + // BeginTxx is implemented by *sqlx.DB but not *sqlx.Tx. + BeginTxx(context.Context, *sql.TxOptions) (*sqlx.Tx, error) +} + +type wrappedTransactional interface { + BeginWrappedTxx(context.Context, *sql.TxOptions) (transaction, error) +} + +type transaction interface { + DataSource + Commit() error + Rollback() error +}