diff --git a/.changeset/hot-dryers-flash.md b/.changeset/hot-dryers-flash.md new file mode 100644 index 00000000000..8423420589d --- /dev/null +++ b/.changeset/hot-dryers-flash.md @@ -0,0 +1,5 @@ +--- +"chainlink": patch +--- + +core/services: update llo & versioning to use sqlutil #internal diff --git a/core/cmd/shell.go b/core/cmd/shell.go index 4b1c32d279f..8b22525dfcd 100644 --- a/core/cmd/shell.go +++ b/core/cmd/shell.go @@ -248,7 +248,7 @@ func handleNodeVersioning(ctx context.Context, db *sqlx.DB, appLggr logger.Logge if static.Version != static.Unset { var appv, dbv *semver.Version - appv, dbv, err = versioning.CheckVersion(db, appLggr, static.Version) + appv, dbv, err = versioning.CheckVersion(ctx, db, appLggr, static.Version) if err != nil { // Exit immediately and don't touch the database if the app version is too old return fmt.Errorf("CheckVersion: %w", err) @@ -280,7 +280,7 @@ func handleNodeVersioning(ctx context.Context, db *sqlx.DB, appLggr logger.Logge // Update to latest version if static.Version != static.Unset { version := versioning.NewNodeVersion(static.Version) - if err = verORM.UpsertNodeVersion(version); err != nil { + if err = verORM.UpsertNodeVersion(ctx, version); err != nil { return fmt.Errorf("UpsertNodeVersion: %w", err) } } diff --git a/core/services/llo/orm.go b/core/services/llo/orm.go index e046d62ad89..6b14e543268 100644 --- a/core/services/llo/orm.go +++ b/core/services/llo/orm.go @@ -10,9 +10,8 @@ import ( "github.com/ethereum/go-ethereum/common" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" - - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) type ORM interface { @@ -22,12 +21,12 @@ type ORM interface { var _ ORM = &orm{} type orm struct { - q pg.Queryer + ds sqlutil.DataSource evmChainID *big.Int } -func NewORM(q pg.Queryer, evmChainID *big.Int) ORM { - return &orm{q, evmChainID} +func NewORM(ds sqlutil.DataSource, evmChainID *big.Int) ORM { + return &orm{ds, evmChainID} } func (o *orm) LoadChannelDefinitions(ctx context.Context, addr common.Address) (dfns llotypes.ChannelDefinitions, blockNum int64, err error) { @@ -36,7 +35,7 @@ func (o *orm) LoadChannelDefinitions(ctx context.Context, addr common.Address) ( BlockNum int64 `db:"block_num"` } var scanned scd - err = o.q.GetContext(ctx, &scanned, "SELECT definitions, block_num FROM channel_definitions WHERE evm_chain_id = $1 AND addr = $2", o.evmChainID.String(), addr) + err = o.ds.GetContext(ctx, &scanned, "SELECT definitions, block_num FROM channel_definitions WHERE evm_chain_id = $1 AND addr = $2", o.evmChainID.String(), addr) if errors.Is(err, sql.ErrNoRows) { return dfns, blockNum, nil } else if err != nil { @@ -53,7 +52,7 @@ func (o *orm) LoadChannelDefinitions(ctx context.Context, addr common.Address) ( // TODO: Test this method // https://smartcontract-it.atlassian.net/jira/software/c/projects/MERC/issues/MERC-3653 func (o *orm) StoreChannelDefinitions(ctx context.Context, addr common.Address, dfns llotypes.ChannelDefinitions, blockNum int64) error { - _, err := o.q.ExecContext(ctx, ` + _, err := o.ds.ExecContext(ctx, ` INSERT INTO channel_definitions (evm_chain_id, addr, definitions, block_num, updated_at) VALUES ($1, $2, $3, $4, NOW()) ON CONFLICT (evm_chain_id, addr) DO UPDATE diff --git a/core/services/relay/evm/evm.go b/core/services/relay/evm/evm.go index 95cf9efc944..c8fe1b868a7 100644 --- a/core/services/relay/evm/evm.go +++ b/core/services/relay/evm/evm.go @@ -131,7 +131,7 @@ func NewRelayer(lggr logger.Logger, chain legacyevm.Chain, opts RelayerOpts) (*R lggr = lggr.Named("Relayer") mercuryORM := mercury.NewORM(opts.DB, lggr, opts.QConfig) - lloORM := llo.NewORM(pg.NewQ(opts.DB, lggr, opts.QConfig), chain.ID()) + lloORM := llo.NewORM(opts.DS, chain.ID()) cdcFactory := llo.NewChannelDefinitionCacheFactory(lggr, lloORM, chain.LogPoller()) return &Relayer{ db: opts.DB, diff --git a/core/services/versioning/orm.go b/core/services/versioning/orm.go index 8ed745955dc..5a2472eee8e 100644 --- a/core/services/versioning/orm.go +++ b/core/services/versioning/orm.go @@ -7,11 +7,10 @@ import ( "github.com/Masterminds/semver/v3" "github.com/jackc/pgconn" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) // Version ORM manages the node_versions table @@ -19,19 +18,19 @@ import ( // The database version is ONLY useful for managing versioning specific to the database e.g. for backups or migrations type ORM interface { - FindLatestNodeVersion() (*NodeVersion, error) - UpsertNodeVersion(version NodeVersion) error + FindLatestNodeVersion(ctx context.Context) (*NodeVersion, error) + UpsertNodeVersion(ctx context.Context, version NodeVersion) error } type orm struct { - db *sqlx.DB + ds sqlutil.DataSource lggr logger.Logger timeout time.Duration } -func NewORM(db *sqlx.DB, lggr logger.Logger, timeout time.Duration) *orm { +func NewORM(ds sqlutil.DataSource, lggr logger.Logger, timeout time.Duration) *orm { return &orm{ - db: db, + ds: ds, lggr: lggr.Named("VersioningORM"), timeout: timeout, } @@ -41,17 +40,17 @@ func NewORM(db *sqlx.DB, lggr logger.Logger, timeout time.Duration) *orm { // version is newer than the current one // NOTE: If you just need the current application version, consider using static.Version instead // The database version is ONLY useful for managing versioning specific to the database e.g. for backups or migrations -func (o *orm) UpsertNodeVersion(version NodeVersion) error { +func (o *orm) UpsertNodeVersion(ctx context.Context, version NodeVersion) error { now := time.Now() if _, err := semver.NewVersion(version.Version); err != nil { return errors.Wrapf(err, "%q is not valid semver", version.Version) } - ctx, cancel := context.WithTimeout(context.Background(), o.timeout) + ctx, cancel := context.WithTimeout(ctx, o.timeout) defer cancel() - return pg.SqlxTransaction(ctx, o.db, o.lggr, func(tx pg.Queryer) error { - if _, _, err := CheckVersion(tx, logger.NullLogger, version.Version); err != nil { + return sqlutil.TransactDataSource(ctx, o.ds, nil, func(tx sqlutil.DataSource) error { + if _, _, err := CheckVersion(ctx, tx, logger.NullLogger, version.Version); err != nil { return err } @@ -63,17 +62,17 @@ version = EXCLUDED.version, created_at = EXCLUDED.created_at ` - _, err := tx.Exec(stmt, version.Version, now) + _, err := tx.ExecContext(ctx, stmt, version.Version, now) return err }) } // CheckVersion returns an error if there is a valid semver version in the // node_versions table that is higher than the current app version -func CheckVersion(q pg.Queryer, lggr logger.Logger, appVersion string) (appv, dbv *semver.Version, err error) { +func CheckVersion(ctx context.Context, ds sqlutil.DataSource, lggr logger.Logger, appVersion string) (appv, dbv *semver.Version, err error) { lggr = lggr.Named("Version") var dbVersion string - err = q.Get(&dbVersion, `SELECT version FROM node_versions ORDER BY created_at DESC LIMIT 1 FOR UPDATE`) + err = ds.GetContext(ctx, &dbVersion, `SELECT version FROM node_versions ORDER BY created_at DESC LIMIT 1 FOR UPDATE`) if errors.Is(err, sql.ErrNoRows) { lggr.Debugw("No previous version set", "appVersion", appVersion) return nil, nil, nil @@ -105,7 +104,7 @@ func CheckVersion(q pg.Queryer, lggr logger.Logger, appVersion string) (appv, db // FindLatestNodeVersion looks up the latest node version // NOTE: If you just need the current application version, consider using static.Version instead // The database version is ONLY useful for managing versioning specific to the database e.g. for backups or migrations -func (o *orm) FindLatestNodeVersion() (*NodeVersion, error) { +func (o *orm) FindLatestNodeVersion(ctx context.Context) (*NodeVersion, error) { stmt := ` SELECT version, created_at FROM node_versions @@ -113,7 +112,7 @@ ORDER BY created_at DESC ` var nodeVersion NodeVersion - err := o.db.Get(&nodeVersion, stmt) + err := o.ds.GetContext(ctx, &nodeVersion, stmt) if err != nil { return nil, err } diff --git a/core/services/versioning/orm_test.go b/core/services/versioning/orm_test.go index fe19a2dcd73..f655c9c47fe 100644 --- a/core/services/versioning/orm_test.go +++ b/core/services/versioning/orm_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/pg" @@ -14,13 +15,14 @@ import ( ) func TestORM_NodeVersion_UpsertNodeVersion(t *testing.T) { + ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) orm := NewORM(db, logger.TestLogger(t), pg.DefaultQueryTimeout) - err := orm.UpsertNodeVersion(NewNodeVersion("9.9.8")) + err := orm.UpsertNodeVersion(ctx, NewNodeVersion("9.9.8")) require.NoError(t, err) - ver, err := orm.FindLatestNodeVersion() + ver, err := orm.FindLatestNodeVersion(ctx) require.NoError(t, err) require.NotNil(t, ver) @@ -28,85 +30,87 @@ func TestORM_NodeVersion_UpsertNodeVersion(t *testing.T) { require.NotZero(t, ver.CreatedAt) // Testing Upsert - require.NoError(t, orm.UpsertNodeVersion(NewNodeVersion("9.9.8"))) + require.NoError(t, orm.UpsertNodeVersion(ctx, NewNodeVersion("9.9.8"))) - err = orm.UpsertNodeVersion(NewNodeVersion("9.9.7")) + err = orm.UpsertNodeVersion(ctx, NewNodeVersion("9.9.7")) require.Error(t, err) assert.Contains(t, err.Error(), "Application version (9.9.7) is lower than database version (9.9.8). Only Chainlink 9.9.8 or higher can be run on this database") - require.NoError(t, orm.UpsertNodeVersion(NewNodeVersion("9.9.9"))) + require.NoError(t, orm.UpsertNodeVersion(ctx, NewNodeVersion("9.9.9"))) var count int err = db.QueryRowx(`SELECT count(*) FROM node_versions`).Scan(&count) require.NoError(t, err) assert.Equal(t, 1, count) - ver, err = orm.FindLatestNodeVersion() + ver, err = orm.FindLatestNodeVersion(ctx) require.NoError(t, err) require.NotNil(t, ver) require.Equal(t, "9.9.9", ver.Version) // invalid semver returns error - err = orm.UpsertNodeVersion(NewNodeVersion("random_12345")) + err = orm.UpsertNodeVersion(ctx, NewNodeVersion("random_12345")) require.Error(t, err) assert.Contains(t, err.Error(), "\"random_12345\" is not valid semver: Invalid Semantic Version") - ver, err = orm.FindLatestNodeVersion() + ver, err = orm.FindLatestNodeVersion(ctx) require.NoError(t, err) require.NotNil(t, ver) require.Equal(t, "9.9.9", ver.Version) } func Test_Version_CheckVersion(t *testing.T) { + ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) lggr := logger.TestLogger(t) orm := NewORM(db, lggr, pg.DefaultQueryTimeout) - err := orm.UpsertNodeVersion(NewNodeVersion("9.9.8")) + err := orm.UpsertNodeVersion(ctx, NewNodeVersion("9.9.8")) require.NoError(t, err) // invalid app version semver returns error - _, _, err = CheckVersion(db, lggr, static.Unset) + _, _, err = CheckVersion(ctx, db, lggr, static.Unset) require.Error(t, err) assert.Contains(t, err.Error(), `Application version "unset" is not valid semver`) - _, _, err = CheckVersion(db, lggr, "some old bollocks") + _, _, err = CheckVersion(ctx, db, lggr, "some old bollocks") require.Error(t, err) assert.Contains(t, err.Error(), `Application version "some old bollocks" is not valid semver`) // lower version returns error - _, _, err = CheckVersion(db, lggr, "9.9.7") + _, _, err = CheckVersion(ctx, db, lggr, "9.9.7") require.Error(t, err) assert.Contains(t, err.Error(), "Application version (9.9.7) is lower than database version (9.9.8). Only Chainlink 9.9.8 or higher can be run on this database") // equal version is ok var appv, dbv *semver.Version - appv, dbv, err = CheckVersion(db, lggr, "9.9.8") + appv, dbv, err = CheckVersion(ctx, db, lggr, "9.9.8") require.NoError(t, err) assert.Equal(t, "9.9.8", appv.String()) assert.Equal(t, "9.9.8", dbv.String()) // greater version is ok - appv, dbv, err = CheckVersion(db, lggr, "9.9.9") + appv, dbv, err = CheckVersion(ctx, db, lggr, "9.9.9") require.NoError(t, err) assert.Equal(t, "9.9.9", appv.String()) assert.Equal(t, "9.9.8", dbv.String()) } func TestORM_NodeVersion_FindLatestNodeVersion(t *testing.T) { + ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) orm := NewORM(db, logger.TestLogger(t), pg.DefaultQueryTimeout) // Not Found - _, err := orm.FindLatestNodeVersion() + _, err := orm.FindLatestNodeVersion(ctx) require.Error(t, err) - err = orm.UpsertNodeVersion(NewNodeVersion("9.9.8")) + err = orm.UpsertNodeVersion(ctx, NewNodeVersion("9.9.8")) require.NoError(t, err) - ver, err := orm.FindLatestNodeVersion() + ver, err := orm.FindLatestNodeVersion(ctx) require.NoError(t, err) require.NotNil(t, ver)