From 75c70f730b08ce07bc4fbf2ec77d619b39d55968 Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Wed, 15 Nov 2023 06:01:48 -0600 Subject: [PATCH] core/services/pg: simplify API (#11296) --- core/services/pg/connection.go | 2 +- core/services/pg/q.go | 2 +- core/services/pg/sqlx.go | 2 +- core/services/pg/transaction.go | 54 +++++++++++---------------------- 4 files changed, 20 insertions(+), 40 deletions(-) diff --git a/core/services/pg/connection.go b/core/services/pg/connection.go index 0bafd5dcd0f..ee345bb6259 100644 --- a/core/services/pg/connection.go +++ b/core/services/pg/connection.go @@ -42,7 +42,7 @@ func NewConnection(uri string, dialect dialects.DialectName, config ConnectionCo lockTimeout := config.DefaultLockTimeout().Milliseconds() idleInTxSessionTimeout := config.DefaultIdleInTxSessionTimeout().Milliseconds() stmt := fmt.Sprintf(`SET TIME ZONE 'UTC'; SET lock_timeout = %d; SET idle_in_transaction_session_timeout = %d; SET default_transaction_isolation = %q`, - lockTimeout, idleInTxSessionTimeout, DefaultIsolation.String()) + lockTimeout, idleInTxSessionTimeout, defaultIsolation.String()) if _, err = db.Exec(stmt); err != nil { return nil, err } diff --git a/core/services/pg/q.go b/core/services/pg/q.go index 470d39c825c..9c9c15d9838 100644 --- a/core/services/pg/q.go +++ b/core/services/pg/q.go @@ -165,7 +165,7 @@ func (q Q) Context() (context.Context, context.CancelFunc) { return context.WithTimeout(q.ParentCtx, q.QueryTimeout) } -func (q Q) Transaction(fc func(q Queryer) error, txOpts ...TxOptions) error { +func (q Q) Transaction(fc func(q Queryer) error, txOpts ...TxOption) error { ctx, cancel := q.Context() defer cancel() return SqlxTransaction(ctx, q.Queryer, q.originalLogger(), fc, txOpts...) diff --git a/core/services/pg/sqlx.go b/core/services/pg/sqlx.go index 820cd51712e..c371c292138 100644 --- a/core/services/pg/sqlx.go +++ b/core/services/pg/sqlx.go @@ -35,7 +35,7 @@ func WrapDbWithSqlx(rdb *sql.DB) *sqlx.DB { return db } -func SqlxTransaction(ctx context.Context, q Queryer, lggr logger.Logger, fc func(q Queryer) error, txOpts ...TxOptions) (err error) { +func SqlxTransaction(ctx context.Context, q Queryer, lggr logger.Logger, fc func(q Queryer) error, txOpts ...TxOption) (err error) { switch db := q.(type) { case *sqlx.Tx: // nested transaction: just use the outer transaction diff --git a/core/services/pg/transaction.go b/core/services/pg/transaction.go index 92d72b3d81b..74841d010bf 100644 --- a/core/services/pg/transaction.go +++ b/core/services/pg/transaction.go @@ -15,44 +15,21 @@ import ( corelogger "github.com/smartcontractkit/chainlink/v2/core/logger" ) -type TxOptions struct { - sql.TxOptions -} +// NOTE: This is the default level in Postgres anyway, we just make it +// explicit here +const defaultIsolation = sql.LevelReadCommitted -// NOTE: In an ideal world the timeouts below would be set to something sane in -// the postgres configuration by the user. Since we do not live in an ideal -// world, it is necessary to override them here. -// -// They cannot easily be set at a session level due to how Go's connection -// pooling works. -const ( - // NOTE: This is the default level in Postgres anyway, we just make it - // explicit here - DefaultIsolation = sql.LevelReadCommitted -) +// TxOption is a functional option for SQL transactions. +type TxOption func(*sql.TxOptions) -func OptReadOnlyTx() TxOptions { - return TxOptions{TxOptions: sql.TxOptions{ReadOnly: true}} -} - -func applyDefaults(optss []TxOptions) (txOpts sql.TxOptions) { - readOnly := false - if len(optss) > 0 { - opts := optss[0] - readOnly = opts.ReadOnly - } - txOpts = sql.TxOptions{ - ReadOnly: readOnly, +func OptReadOnlyTx() TxOption { + return func(opts *sql.TxOptions) { + opts.ReadOnly = true } - return } -func SqlTransaction(ctx context.Context, rdb *sql.DB, lggr logger.Logger, fn func(tx *sqlx.Tx) error, optss ...TxOptions) (err error) { +func SqlTransaction(ctx context.Context, rdb *sql.DB, lggr logger.Logger, fn func(tx *sqlx.Tx) error, opts ...TxOption) (err error) { db := WrapDbWithSqlx(rdb) - return sqlxTransaction(ctx, db, lggr, fn, optss...) -} - -func sqlxTransaction(ctx context.Context, db *sqlx.DB, lggr logger.Logger, fn func(tx *sqlx.Tx) error, optss ...TxOptions) (err error) { wrapFn := func(q Queryer) error { tx, ok := q.(*sqlx.Tx) if !ok { @@ -60,16 +37,19 @@ func sqlxTransaction(ctx context.Context, db *sqlx.DB, lggr logger.Logger, fn fu } return fn(tx) } - return sqlxTransactionQ(ctx, db, lggr, wrapFn, optss...) + return sqlxTransactionQ(ctx, db, lggr, wrapFn, opts...) } -// TxBeginner can be a db or a conn, anything that implements BeginTxx -type TxBeginner interface { +// txBeginner can be a db or a conn, anything that implements BeginTxx +type txBeginner interface { BeginTxx(context.Context, *sql.TxOptions) (*sqlx.Tx, error) } -func sqlxTransactionQ(ctx context.Context, db TxBeginner, lggr logger.Logger, fn func(q Queryer) error, optss ...TxOptions) (err error) { - txOpts := applyDefaults(optss) +func sqlxTransactionQ(ctx context.Context, db txBeginner, lggr logger.Logger, fn func(q Queryer) error, opts ...TxOption) (err error) { + var txOpts sql.TxOptions + for _, o := range opts { + o(&txOpts) + } var tx *sqlx.Tx tx, err = db.BeginTxx(ctx, &txOpts)