Skip to content

Commit

Permalink
pkg/sqlutil: fix wrapped DB transactions (#400)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmank88 authored Mar 14, 2024
1 parent 9bf02a1 commit 049609b
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 61 deletions.
112 changes: 80 additions & 32 deletions pkg/sqlutil/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import (
"github.com/smartcontractkit/chainlink-common/pkg/logger"
)

var _ DB = &WrappedDB{}
var _ DataSource = &wrappedDataSource{}

// WrappedDB is a [DB] which invokes a [QueryHook] on each call.
type WrappedDB struct {
db DB
// wrappedDataSource is a [DataSource] which invokes a [QueryHook] on each call.
type wrappedDataSource struct {
db DataSource
lggr logger.Logger
hook QueryHook
}
Expand All @@ -25,102 +25,150 @@ type WrappedDB struct {
// See [MonitorHook] and [TimeoutHook] for examples.
type QueryHook func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error

// NewWrappedDB returns a new [WrappedDB] that calls each [QueryHook] in the provided order.
func NewWrappedDB(db DB, l logger.Logger, hs ...QueryHook) *WrappedDB {
iq := WrappedDB{db: db,
// WrapDataSource returns a new [DataSource] that calls each [QueryHook] in the provided order.
// If db implements [sqlx.BeginTxx], then the returned DataSource will also.
func WrapDataSource(db DataSource, l logger.Logger, hs ...QueryHook) DataSource {
iq := wrappedDataSource{db: db,
lggr: logger.Helper(logger.Named(l, "WrappedDB"), 2), // skip our own wrapper and one interceptor
hook: noopHook,
}
switch len(hs) {
case 0:
return &iq
case 1:
iq.hook = hs[0]
return &iq
default:
// Nest the QueryHook calls so that they are wrapped from first to last.
// Example:
// [A, B, C] => A(B(C(do())))
for i := len(hs) - 1; i >= 0; i-- {
next := hs[i]
prev := iq.hook
iq.hook = func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error {
// opt: cache the construction of these loggers
lggr = logger.Helper(lggr, 1) // skip one more for this wrapper
return next(ctx, lggr, func(ctx context.Context) error {
lggr = logger.Helper(lggr, 2) // skip two more for do() and this extra wrapper
return prev(ctx, lggr, do, query, args...)
}, query, args...)
}
}
}

// Nest the QueryHook calls so that they are wrapped from first to last.
// Example:
// [A, B, C] => A(B(C(do())))
for i := len(hs) - 1; i >= 0; i-- {
next := hs[i]
prev := iq.hook
iq.hook = func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error {
// opt: cache the construction of these loggers
lggr = logger.Helper(lggr, 1) // skip one more for this wrapper
return next(ctx, lggr, func(ctx context.Context) error {
lggr = logger.Helper(lggr, 2) // skip two more for do() and this extra wrapper
return prev(ctx, lggr, do, query, args...)
}, query, args...)
if txdb, ok := db.(transactional); ok {
// extra wrapper to make BeginTxx available
return &wrappedTransactionalDataSource{
wrappedDataSource: iq,
txdb: txdb,
}
}

return &iq
}

func noopHook(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error {
return do(ctx)
}

func (w *WrappedDB) DriverName() string {
func (w *wrappedDataSource) DriverName() string {
return w.db.DriverName()
}

func (w *WrappedDB) Rebind(s string) string {
func (w *wrappedDataSource) Rebind(s string) string {
return w.db.Rebind(s)
}

func (w *WrappedDB) BindNamed(s string, i interface{}) (string, []any, error) {
func (w *wrappedDataSource) BindNamed(s string, i interface{}) (string, []any, error) {
return w.db.BindNamed(s, i)
}

func (w *WrappedDB) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) {
func (w *wrappedDataSource) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) {
err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) {
rows, err = w.db.QueryContext(ctx, query, args...) //nolint
return
}, query, args...)
return
}

func (w *WrappedDB) QueryxContext(ctx context.Context, query string, args ...any) (rows *sqlx.Rows, err error) {
func (w *wrappedDataSource) QueryxContext(ctx context.Context, query string, args ...any) (rows *sqlx.Rows, err error) {
err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) {
rows, err = w.db.QueryxContext(ctx, query, args...) //nolint:sqlclosecheck
return
}, query, args...)
return
}

func (w *WrappedDB) QueryRowxContext(ctx context.Context, query string, args ...any) (row *sqlx.Row) {
func (w *wrappedDataSource) QueryRowxContext(ctx context.Context, query string, args ...any) (row *sqlx.Row) {
_ = w.hook(ctx, w.lggr, func(ctx context.Context) error {
row = w.db.QueryRowxContext(ctx, query, args...)
return nil
}, query, args...)
return
}

func (w *WrappedDB) ExecContext(ctx context.Context, query string, args ...any) (res sql.Result, err error) {
func (w *wrappedDataSource) ExecContext(ctx context.Context, query string, args ...any) (res sql.Result, err error) {
err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) {
res, err = w.db.ExecContext(ctx, query, args...)
return
}, query, args...)
return
}

func (w *WrappedDB) PrepareContext(ctx context.Context, query string) (stmt *sql.Stmt, err error) {
func (w *wrappedDataSource) PrepareContext(ctx context.Context, query string) (stmt *sql.Stmt, err error) {
err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) {
stmt, err = w.db.PrepareContext(ctx, query) //nolint:sqlclosecheck
return
}, query, nil)
return
}

func (w *WrappedDB) GetContext(ctx context.Context, dest interface{}, query string, args ...any) error {
func (w *wrappedDataSource) GetContext(ctx context.Context, dest interface{}, query string, args ...any) error {
return w.hook(ctx, w.lggr, func(ctx context.Context) error {
return w.db.GetContext(ctx, dest, query, args...)
}, query, args...)
}

func (w *WrappedDB) SelectContext(ctx context.Context, dest interface{}, query string, args ...any) error {
func (w *wrappedDataSource) SelectContext(ctx context.Context, dest interface{}, query string, args ...any) error {
return w.hook(ctx, w.lggr, func(ctx context.Context) error {
return w.db.SelectContext(ctx, dest, query, args...)
}, query, args...)
}

// wrappedTransactionalDataSource extends [wrappedDataSource] with BeginTxx and BeginWrappedTxx for initiating transactions.
type wrappedTransactionalDataSource struct {
wrappedDataSource
txdb transactional
}

func (w *wrappedTransactionalDataSource) BeginTxx(ctx context.Context, opts *sql.TxOptions) (tx *sqlx.Tx, err error) {
err = w.hook(ctx, w.lggr, func(ctx context.Context) (err error) {
tx, err = w.txdb.BeginTxx(ctx, opts)
return
}, "START TRANSACTION", nil)
return
}

// BeginWrappedTxx is like BeginTxx, but wraps the returned tx with the same hook.
func (w *wrappedTransactionalDataSource) BeginWrappedTxx(ctx context.Context, opts *sql.TxOptions) (tx transaction, err error) {
tx, err = w.BeginTxx(ctx, opts)
if err != nil {
return nil, err
}
return &wrappedTx{
wrappedDataSource: wrappedDataSource{
db: tx,
lggr: w.lggr,
hook: w.hook,
},
tx: tx,
}, nil
}

// wrappedTx extends [wrappedDataSource] with Commit and Rollback for completing a transaction.
type wrappedTx struct {
wrappedDataSource
tx transaction
}

func (w *wrappedTx) Commit() error { return w.tx.Commit() }

func (w *wrappedTx) Rollback() error { return w.tx.Rollback() }
62 changes: 43 additions & 19 deletions pkg/sqlutil/hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ const (
selDur = 200 * time.Millisecond
)

func TestNewInterceptedQueryer(t *testing.T) {
func TestWrapDataSource(t *testing.T) {
lggr, ol := logger.TestObserved(t, zapcore.InfoLevel)
var db DB = &database{}
var ds DataSource = &dataSource{}
var sentinelErr = errors.New("intercepted error")
const fakeError = "fake warning"
db = NewWrappedDB(db, lggr, TimeoutHook(selDur/2), noopHook, MonitorHook(func() bool { return true }), noopHook, func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error {
ds = WrapDataSource(ds, lggr, TimeoutHook(selDur/2), noopHook, MonitorHook(func() bool { return true }), noopHook, func(ctx context.Context, lggr logger.Logger, do func(context.Context) error, query string, args ...any) error {
err := do(ctx)
if err != nil {
return err
Expand All @@ -39,7 +39,7 @@ func TestNewInterceptedQueryer(t *testing.T) {
ctx := tests.Context(t)

// Error intercepted
err := db.GetContext(ctx, "test", "foo", 42, "bar")
err := ds.GetContext(ctx, "test", "foo", 42, "bar")
_, file, line, ok := runtime.Caller(0)
require.True(t, ok)
expCaller := fmt.Sprintf("%s:%d", file, line-1)
Expand All @@ -55,57 +55,81 @@ func TestNewInterceptedQueryer(t *testing.T) {
_ = ol.TakeAll()

// Timeout applied
err = db.SelectContext(ctx, "test", "foo", 42, "bar")
err = ds.SelectContext(ctx, "test", "foo", 42, "bar")
require.ErrorIs(t, err, context.DeadlineExceeded)
logs = ol.FilterMessage(slowMsg).All()
require.Len(t, logs, 1)
assert.Equal(t, zapcore.DPanicLevel, logs[0].Level)
_ = ol.TakeAll()

// Without default timeout
err = db.SelectContext(WithoutDefaultTimeout(ctx), "test", "foo", 42, "bar")
err = ds.SelectContext(WithoutDefaultTimeout(ctx), "test", "foo", 42, "bar")
require.ErrorIs(t, err, sentinelErr)

// W/o default, but with our own
ctx2, cancel := context.WithTimeout(WithoutDefaultTimeout(ctx), selDur/100)
t.Cleanup(cancel)
err = db.SelectContext(ctx2, "test", "foo", 42, "bar")
err = ds.SelectContext(ctx2, "test", "foo", 42, "bar")
require.ErrorIs(t, err, context.DeadlineExceeded)
}

var _ DB = &database{}
func TestWrapDataSource_transactional(t *testing.T) {
lggr := logger.Test(t)

txional := (*transactional)(nil)

var ds DataSource = (*sqlx.DB)(nil)
assert.Implements(t, txional, ds)
got := WrapDataSource(ds, lggr)
assert.Implements(t, txional, got)
got = WrapDataSource(ds, lggr, noopHook)
assert.Implements(t, txional, got)
got = WrapDataSource(ds, lggr, noopHook, noopHook)
assert.Implements(t, txional, got)

ds = (*sqlx.Tx)(nil)
assert.NotImplements(t, txional, ds)
got = WrapDataSource(ds, lggr)
assert.NotImplements(t, txional, got)
got = WrapDataSource(ds, lggr, noopHook)
assert.NotImplements(t, txional, got)
got = WrapDataSource(ds, lggr, noopHook, noopHook)
assert.NotImplements(t, txional, got)
}

var _ DataSource = &dataSource{}

type database struct{}
type dataSource struct{}

func (q *database) DriverName() string { return "" }
func (q *dataSource) DriverName() string { return "" }

func (q *database) Rebind(s string) string { return "" }
func (q *dataSource) Rebind(s string) string { return "" }

func (q *database) BindNamed(s string, i interface{}) (string, []interface{}, error) {
func (q *dataSource) BindNamed(s string, i interface{}) (string, []interface{}, error) {
return "", nil, nil
}

func (q *database) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
func (q *dataSource) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
return nil, nil
}

func (q *database) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
func (q *dataSource) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
return nil, nil
}

func (q *database) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
func (q *dataSource) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
return nil
}

func (q *database) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
func (q *dataSource) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return nil, nil
}

func (q *database) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
func (q *dataSource) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
return nil, nil
}

func (q *database) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
func (q *dataSource) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
select {
case <-ctx.Done():
return ctx.Err()
Expand All @@ -114,7 +138,7 @@ func (q *database) GetContext(ctx context.Context, dest interface{}, query strin
return nil
}

func (q *database) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
func (q *dataSource) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
select {
case <-ctx.Done():
return ctx.Err()
Expand Down
Loading

0 comments on commit 049609b

Please sign in to comment.