From 588f99a344e67b3749cea8dbdcdaa7de4d28b6ae Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Thu, 14 Mar 2024 12:06:37 -0500 Subject: [PATCH 1/2] use sqlutil instead of pg.QOpts --- core/bridges/orm_test.go | 2 +- core/chains/evm/log/helpers_test.go | 13 +- core/chains/evm/log/integration_test.go | 8 +- core/chains/evm/logpoller/observability.go | 4 +- core/chains/evm/logpoller/orm.go | 208 ++- core/chains/evm/logpoller/query.go | 2 +- core/chains/evm/txmgr/broadcaster_test.go | 2 +- core/chains/evm/txmgr/builder.go | 9 +- core/chains/evm/txmgr/evm_tx_store.go | 52 +- core/chains/evm/txmgr/txmgr_test.go | 1 - core/chains/legacyevm/chain.go | 20 +- core/chains/legacyevm/chain_test.go | 3 +- core/chains/legacyevm/evm_txm.go | 8 +- core/cmd/eth_keys_commands_test.go | 4 +- core/cmd/evm_transaction_commands_test.go | 10 +- core/cmd/jobs_commands_test.go | 7 +- core/cmd/shell.go | 24 +- core/cmd/shell_local.go | 14 +- core/cmd/shell_local_test.go | 12 +- core/cmd/shell_remote_test.go | 13 +- core/internal/cltest/cltest.go | 54 +- core/internal/cltest/factories.go | 16 +- core/internal/cltest/job_factories.go | 23 +- core/internal/features/features_test.go | 19 +- .../features/ocr2/features_ocr2_test.go | 10 +- core/internal/mocks/application.go | 22 - core/internal/testutils/evmtest/evmtest.go | 7 +- core/internal/testutils/pgtest/pgtest.go | 25 +- core/internal/testutils/testutils.go | 3 +- core/scripts/gateway/run_gateway.go | 2 +- core/scripts/go.mod | 4 +- core/scripts/go.sum | 8 +- core/services/chainlink/application.go | 73 +- .../relayer_chain_interoperators_test.go | 13 +- core/services/chainlink/relayer_factory.go | 19 +- core/services/cron/cron_test.go | 4 +- core/services/directrequest/delegate.go | 33 +- core/services/directrequest/delegate_test.go | 7 +- core/services/feeds/mocks/orm.go | 1323 +++++++---------- core/services/feeds/mocks/service.go | 221 ++- core/services/feeds/orm.go | 247 +-- core/services/feeds/orm_test.go | 380 ++--- core/services/feeds/service.go | 299 ++-- core/services/feeds/service_test.go | 1057 +++++++------ core/services/fluxmonitorv2/delegate.go | 13 +- core/services/fluxmonitorv2/flux_monitor.go | 18 +- .../fluxmonitorv2/flux_monitor_test.go | 1 + .../fluxmonitorv2/integrations_test.go | 30 +- core/services/fluxmonitorv2/orm_test.go | 4 +- core/services/functions/listener_test.go | 4 +- core/services/gateway/delegate.go | 13 +- core/services/gateway/gateway_test.go | 12 +- core/services/gateway/handler_factory.go | 13 +- .../handlers/functions/handler.functions.go | 9 +- .../functions/handler.functions_test.go | 4 +- .../gateway_integration_test.go | 2 +- core/services/job/common.go | 3 - core/services/job/helpers_test.go | 3 +- core/services/job/job_orm_test.go | 311 ++-- .../job/job_pipeline_orm_integration_test.go | 6 +- core/services/job/kv_orm.go | 14 +- core/services/job/kv_orm_test.go | 9 +- core/services/job/mocks/orm.go | 458 +++--- core/services/job/mocks/spawner.go | 53 +- core/services/job/orm.go | 787 +++++----- core/services/job/orm_test.go | 8 +- core/services/job/runner_integration_test.go | 84 +- core/services/job/spawner.go | 56 +- core/services/job/spawner_test.go | 56 +- core/services/keeper/delegate.go | 19 +- core/services/keeper/helpers_test.go | 4 + core/services/keeper/integration_test.go | 8 +- .../registry_synchronizer_helper_test.go | 2 +- .../registry_synchronizer_process_logs.go | 3 + .../keeper/registry_synchronizer_sync.go | 6 +- core/services/keeper/upkeep_executer_test.go | 2 +- core/services/keystore/eth_test.go | 6 +- core/services/llo/delegate.go | 16 +- core/services/ocr/contract_tracker.go | 5 +- core/services/ocr/contract_tracker_test.go | 10 +- core/services/ocr/database.go | 7 +- core/services/ocr/database_test.go | 4 +- core/services/ocr/delegate.go | 15 +- core/services/ocr/helpers_internal_test.go | 7 +- .../ocr/mocks/ocr_contract_tracker_db.go | 31 +- core/services/ocr2/database.go | 39 +- core/services/ocr2/database_test.go | 36 +- core/services/ocr2/delegate.go | 63 +- .../ocr2/plugins/dkg/persistence/db.go | 19 +- .../ocr2/plugins/dkg/persistence/db_test.go | 2 +- core/services/ocr2/plugins/dkg/plugin.go | 8 +- .../generic/pipeline_runner_adapter_test.go | 4 +- core/services/ocr2/plugins/ocr2keeper/util.go | 13 +- core/services/ocrbootstrap/database.go | 11 +- core/services/ocrbootstrap/database_test.go | 8 +- core/services/ocrbootstrap/delegate.go | 13 +- .../services/ocrcommon/discoverer_database.go | 13 +- .../ocrcommon/discoverer_database_test.go | 2 +- core/services/ocrcommon/peer_wrapper.go | 13 +- core/services/ocrcommon/peer_wrapper_test.go | 14 +- core/services/pg/connection.go | 5 + core/services/pg/lease_lock.go | 9 +- core/services/pg/q.go | 384 ----- core/services/pg/q_test.go | 85 -- core/services/pg/sqlx.go | 24 +- core/services/pg/transaction.go | 95 -- core/services/pg/utils.go | 50 - core/services/pipeline/helpers_test.go | 3 +- core/services/pipeline/mocks/orm.go | 18 +- core/services/pipeline/orm.go | 69 +- core/services/pipeline/orm_test.go | 49 +- core/services/pipeline/runner.go | 2 +- core/services/pipeline/runner_test.go | 19 +- core/services/pipeline/task.bridge_test.go | 54 +- core/services/pipeline/task.http_test.go | 2 +- core/services/pipeline/test_helpers_test.go | 3 +- core/services/promreporter/prom_reporter.go | 10 +- .../promreporter/prom_reporter_test.go | 7 +- core/services/relay/evm/evm.go | 18 +- core/services/relay/evm/evm_test.go | 23 +- .../evm/mercury/wsrpc/pb/mercury_wsrpc.pb.go | 1 - core/services/relay/evm/ocr2keeper.go | 5 +- core/services/relay/evm/ocr2vrf.go | 6 +- core/services/versioning/orm.go | 14 +- core/services/versioning/orm_test.go | 7 +- core/services/vrf/delegate_test.go | 10 +- core/services/vrf/v1/integration_test.go | 28 +- core/services/vrf/v1/listener_v1.go | 9 +- .../vrf/v2/integration_helpers_test.go | 6 +- .../vrf/v2/integration_v2_plus_test.go | 24 +- .../v2/integration_v2_reverted_txns_test.go | 2 +- core/services/vrf/v2/integration_v2_test.go | 22 +- core/services/vrf/vrftesthelpers/helpers.go | 6 +- core/services/webhook/authorizer.go | 12 +- core/services/webhook/authorizer_test.go | 3 +- .../webhook/external_initiator_manager.go | 44 +- .../external_initiator_manager_test.go | 17 +- .../mocks/external_initiator_manager.go | 18 +- core/services/webhook/validate.go | 6 +- core/services/webhook/validate_test.go | 15 +- core/sessions/ldapauth/sync.go | 38 +- core/sessions/localauth/reaper.go | 16 +- core/sessions/localauth/reaper_test.go | 2 +- core/store/migrate/migrate.go | 39 +- core/store/migrate/migrate_test.go | 27 +- .../migrations/0054_remove_legacy_pipeline.go | 10 +- core/web/bridge_types_controller.go | 2 +- core/web/eth_keys_controller_test.go | 3 +- core/web/evm_transactions_controller_test.go | 8 +- core/web/evm_transfer_controller_test.go | 15 +- core/web/evm_tx_attempts_controller_test.go | 2 +- .../external_initiators_controller_test.go | 2 +- core/web/jobs_controller.go | 10 +- core/web/jobs_controller_test.go | 46 +- core/web/loader/feeds_manager.go | 4 +- core/web/loader/feeds_manager_chain_config.go | 4 +- core/web/loader/job.go | 8 +- core/web/loader/job_proposal.go | 4 +- core/web/loader/job_proposal_spec.go | 4 +- core/web/loader/job_run.go | 4 +- core/web/loader/job_spec_errors.go | 3 +- core/web/loader/loader_test.go | 16 +- ...ipeline_job_spec_errors_controller_test.go | 3 +- core/web/pipeline_runs_controller.go | 11 +- core/web/pipeline_runs_controller_test.go | 12 +- core/web/resolver/bridge_test.go | 4 +- .../feeds_manager_chain_config_test.go | 14 +- core/web/resolver/feeds_manager_test.go | 16 +- core/web/resolver/job.go | 4 +- core/web/resolver/job_error_test.go | 14 +- core/web/resolver/job_proposal_spec_test.go | 20 +- core/web/resolver/job_proposal_test.go | 7 +- core/web/resolver/job_run_test.go | 12 +- core/web/resolver/job_test.go | 36 +- core/web/resolver/mutation.go | 26 +- core/web/resolver/query.go | 14 +- core/web/resolver/spec_test.go | 29 +- core/web/sessions_controller_test.go | 26 +- go.mod | 4 +- go.sum | 8 +- integration-tests/go.mod | 4 +- integration-tests/go.sum | 8 +- integration-tests/load/go.mod | 4 +- integration-tests/load/go.sum | 8 +- 184 files changed, 3802 insertions(+), 4592 deletions(-) delete mode 100644 core/services/pg/q.go delete mode 100644 core/services/pg/q_test.go delete mode 100644 core/services/pg/transaction.go delete mode 100644 core/services/pg/utils.go diff --git a/core/bridges/orm_test.go b/core/bridges/orm_test.go index 85e8b9ecdef..b85b6be00dd 100644 --- a/core/bridges/orm_test.go +++ b/core/bridges/orm_test.go @@ -144,7 +144,7 @@ func TestORM_TestCachedResponse(t *testing.T) { orm := bridges.NewORM(db) trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(ctx, nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) + specID, err := trORM.CreateSpec(ctx, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) _, err = orm.GetCachedResponse(ctx, "dot", specID, 1*time.Second) diff --git a/core/chains/evm/log/helpers_test.go b/core/chains/evm/log/helpers_test.go index 0d725b8594b..13aeb8d2338 100644 --- a/core/chains/evm/log/helpers_test.go +++ b/core/chains/evm/log/helpers_test.go @@ -110,7 +110,7 @@ func newBroadcasterHelperWithEthClient(t *testing.T, ethClient evmclient.Client, m[r.Chain().ID().String()] = r.Chain() } legacyChains := legacyevm.NewLegacyChains(m, cc.AppConfig().EVMConfigs()) - pipelineHelper := cltest.NewJobPipelineV2(t, globalConfig.WebServer(), globalConfig.JobPipeline(), globalConfig.Database(), legacyChains, db, kst, nil, nil) + pipelineHelper := cltest.NewJobPipelineV2(t, globalConfig.WebServer(), globalConfig.JobPipeline(), legacyChains, db, kst, nil, nil) return &broadcasterHelper{ t: t, @@ -263,7 +263,7 @@ func (helper *broadcasterHelper) newLogListenerWithJob(name string) *simpleLogLi PipelineSpec: &pipeline.Spec{}, ExternalJobID: uuid.New(), } - err := helper.pipelineHelper.Jrm.CreateJob(jb) + err := helper.pipelineHelper.Jrm.CreateJob(testutils.Context(t), jb) require.NoError(t, err) var rec received @@ -288,7 +288,7 @@ func (listener *simpleLogListener) HandleLog(ctx context.Context, lb log.Broadca listener.received.logs = append(listener.received.logs, lb.RawLog()) listener.received.broadcasts = append(listener.received.broadcasts, lb) - consumed := listener.handleLogBroadcast(lb) + consumed := listener.handleLogBroadcast(ctx, lb) if !consumed { listener.received.uniqueLogs = append(listener.received.uniqueLogs, lb.RawLog()) @@ -321,9 +321,8 @@ func (listener *simpleLogListener) requireAllReceived(t *testing.T, expectedStat }, testutils.WaitTimeout(t), time.Second, "len(received.uniqueLogs): %v is not equal len(expectedState.uniqueLogs): %v", len(received.getUniqueLogs()), len(expectedState.getUniqueLogs())) } -func (listener *simpleLogListener) handleLogBroadcast(lb log.Broadcast) bool { +func (listener *simpleLogListener) handleLogBroadcast(ctx context.Context, lb log.Broadcast) bool { t := listener.t - ctx := testutils.Context(t) consumed, err := listener.WasAlreadyConsumed(ctx, lb) if !assert.NoError(t, err) { return false @@ -354,8 +353,8 @@ type mockListener struct { jobID int32 } -func (l *mockListener) JobID() int32 { return l.jobID } -func (l *mockListener) HandleLog(log.Broadcast) {} +func (l *mockListener) JobID() int32 { return l.jobID } +func (l *mockListener) HandleLog(context.Context, log.Broadcast) {} type mockEthClientExpectedCalls struct { SubscribeFilterLogs int diff --git a/core/chains/evm/log/integration_test.go b/core/chains/evm/log/integration_test.go index fd6b375d80a..e34533b3cfb 100644 --- a/core/chains/evm/log/integration_test.go +++ b/core/chains/evm/log/integration_test.go @@ -250,7 +250,6 @@ func TestBroadcaster_ReplaysLogs(t *testing.T) { func TestBroadcaster_BackfillUnconsumedAfterCrash(t *testing.T) { contract1 := newMockContract(t) contract2 := newMockContract(t) - ctx := testutils.Context(t) blocks := cltest.NewBlocks(t, 10) const ( @@ -267,6 +266,7 @@ func TestBroadcaster_BackfillUnconsumedAfterCrash(t *testing.T) { helper := newBroadcasterHelper(t, 0, 1, logs, func(c *chainlink.Config, s *chainlink.Secrets) { c.EVM[0].FinalityDepth = ptr[uint32](confs) }) + ctx := testutils.Context(t) orm := log.NewORM(helper.db, cltest.FixtureChainID) listener := helper.newLogListenerWithJob("one") @@ -292,6 +292,7 @@ func TestBroadcaster_BackfillUnconsumedAfterCrash(t *testing.T) { helper := newBroadcasterHelper(t, 2, 1, logs, func(c *chainlink.Config, s *chainlink.Secrets) { c.EVM[0].FinalityDepth = ptr[uint32](confs) }) + ctx := testutils.Context(t) orm := log.NewORM(helper.db, cltest.FixtureChainID) contract1.On("ParseLog", log1).Return(flux_aggregator_wrapper.FluxAggregatorNewRound{}, nil) contract2.On("ParseLog", log2).Return(flux_aggregator_wrapper.FluxAggregatorAnswerUpdated{}, nil) @@ -318,6 +319,7 @@ func TestBroadcaster_BackfillUnconsumedAfterCrash(t *testing.T) { helper := newBroadcasterHelper(t, 4, 1, logs, func(c *chainlink.Config, s *chainlink.Secrets) { c.EVM[0].FinalityDepth = ptr[uint32](confs) }) + ctx := testutils.Context(t) orm := log.NewORM(helper.db, cltest.FixtureChainID) listener := helper.newLogListenerWithJob("one") @@ -342,6 +344,7 @@ func TestBroadcaster_BackfillUnconsumedAfterCrash(t *testing.T) { helper := newBroadcasterHelper(t, 7, 1, logs[1:], func(c *chainlink.Config, s *chainlink.Secrets) { c.EVM[0].FinalityDepth = ptr[uint32](confs) }) + ctx := testutils.Context(t) orm := log.NewORM(helper.db, cltest.FixtureChainID) listener := helper.newLogListenerWithJob("one") listener2 := helper.newLogListenerWithJob("two") @@ -377,8 +380,9 @@ func (helper *broadcasterHelper) simulateHeads(t *testing.T, listener, listener2 <-headsDone + ctx := testutils.Context(t) require.Eventually(t, func() bool { - blockNum, err := orm.GetPendingMinBlock(testutils.Context(t)) + blockNum, err := orm.GetPendingMinBlock(ctx) if !assert.NoError(t, err) { return false } diff --git a/core/chains/evm/logpoller/observability.go b/core/chains/evm/logpoller/observability.go index 07a0f58ce78..14dec5274ad 100644 --- a/core/chains/evm/logpoller/observability.go +++ b/core/chains/evm/logpoller/observability.go @@ -76,9 +76,9 @@ type ObservedORM struct { // NewObservedORM creates an observed version of log poller's ORM created by NewORM // Please see ObservedLogPoller for more details on how latencies are measured -func NewObservedORM(chainID *big.Int, db sqlutil.DataSource, lggr logger.Logger) *ObservedORM { +func NewObservedORM(chainID *big.Int, ds sqlutil.DataSource, lggr logger.Logger) *ObservedORM { return &ObservedORM{ - ORM: NewORM(chainID, db, lggr), + ORM: NewORM(chainID, ds, lggr), queryDuration: lpQueryDuration, datasetSize: lpQueryDataSets, logsInserted: lpLogsInserted, diff --git a/core/chains/evm/logpoller/orm.go b/core/chains/evm/logpoller/orm.go index ebba3cffc08..838a38c8ebb 100644 --- a/core/chains/evm/logpoller/orm.go +++ b/core/chains/evm/logpoller/orm.go @@ -59,32 +59,32 @@ type ORM interface { SelectLogsDataWordBetween(ctx context.Context, address common.Address, eventSig common.Hash, wordIndexMin int, wordIndexMax int, wordValue common.Hash, confs Confirmations) ([]Log, error) } -type DbORM struct { +type DSORM struct { chainID *big.Int - db sqlutil.DataSource + ds sqlutil.DataSource lggr logger.Logger } -var _ ORM = &DbORM{} +var _ ORM = &DSORM{} -// NewORM creates an DbORM scoped to chainID. -func NewORM(chainID *big.Int, db sqlutil.DataSource, lggr logger.Logger) *DbORM { - return &DbORM{ +// NewORM creates an DSORM scoped to chainID. +func NewORM(chainID *big.Int, ds sqlutil.DataSource, lggr logger.Logger) *DSORM { + return &DSORM{ chainID: chainID, - db: db, + ds: ds, lggr: lggr, } } -func (o *DbORM) Transaction(ctx context.Context, fn func(*DbORM) error) (err error) { - return sqlutil.Transact(ctx, o.new, o.db, nil, fn) +func (o *DSORM) Transact(ctx context.Context, fn func(*DSORM) error) (err error) { + return sqlutil.Transact(ctx, o.new, o.ds, nil, fn) } -// new returns a NewORM like o, but backed by q. -func (o *DbORM) new(q sqlutil.DataSource) *DbORM { return NewORM(o.chainID, q, o.lggr) } +// new returns a NewORM like o, but backed by ds. +func (o *DSORM) new(ds sqlutil.DataSource) *DSORM { return NewORM(o.chainID, ds, o.lggr) } // InsertBlock is idempotent to support replays. -func (o *DbORM) InsertBlock(ctx context.Context, blockHash common.Hash, blockNumber int64, blockTimestamp time.Time, finalizedBlock int64) error { +func (o *DSORM) InsertBlock(ctx context.Context, blockHash common.Hash, blockNumber int64, blockTimestamp time.Time, finalizedBlock int64) error { args, err := newQueryArgs(o.chainID). withCustomHashArg("block_hash", blockHash). withCustomArg("block_number", blockNumber). @@ -98,12 +98,7 @@ func (o *DbORM) InsertBlock(ctx context.Context, blockHash common.Hash, blockNum (evm_chain_id, block_hash, block_number, block_timestamp, finalized_block_number, created_at) VALUES (:evm_chain_id, :block_hash, :block_number, :block_timestamp, :finalized_block_number, NOW()) ON CONFLICT DO NOTHING` - query, sqlArgs, err := o.db.BindNamed(query, args) - if err != nil { - return err - } - - _, err = o.db.ExecContext(ctx, query, sqlArgs...) + _, err = o.ds.NamedExecContext(ctx, query, args) return err } @@ -111,7 +106,7 @@ func (o *DbORM) InsertBlock(ctx context.Context, blockHash common.Hash, blockNum // // Each address/event pair must have a unique job id, so it may be removed when the job is deleted. // If a second job tries to overwrite the same pair, this should fail. -func (o *DbORM) InsertFilter(ctx context.Context, filter Filter) (err error) { +func (o *DSORM) InsertFilter(ctx context.Context, filter Filter) (err error) { topicArrays := []types.HashArray{filter.Topic2, filter.Topic3, filter.Topic4} args, err := newQueryArgs(o.chainID). withCustomArg("name", filter.Name). @@ -148,18 +143,13 @@ func (o *DbORM) InsertFilter(ctx context.Context, filter Filter) (err error) { topicsColumns.String(), topicsSql.String()) - query, sqlArgs, err := o.db.BindNamed(query, args) - if err != nil { - return err - } - - _, err = o.db.ExecContext(ctx, query, sqlArgs...) + _, err = o.ds.NamedExecContext(ctx, query, args) return err } // DeleteFilter removes all events,address pairs associated with the Filter -func (o *DbORM) DeleteFilter(ctx context.Context, name string) error { - _, err := o.db.ExecContext(ctx, +func (o *DSORM) DeleteFilter(ctx context.Context, name string) error { + _, err := o.ds.ExecContext(ctx, `DELETE FROM evm.log_poller_filters WHERE name = $1 AND evm_chain_id = $2`, name, ubig.New(o.chainID)) return err @@ -167,7 +157,7 @@ func (o *DbORM) DeleteFilter(ctx context.Context, name string) error { } // LoadFilters returns all filters for this chain -func (o *DbORM) LoadFilters(ctx context.Context) (map[string]Filter, error) { +func (o *DSORM) LoadFilters(ctx context.Context) (map[string]Filter, error) { query := `SELECT name, ARRAY_AGG(DISTINCT address)::BYTEA[] AS addresses, ARRAY_AGG(DISTINCT event)::BYTEA[] AS event_sigs, @@ -180,7 +170,7 @@ func (o *DbORM) LoadFilters(ctx context.Context) (map[string]Filter, error) { FROM evm.log_poller_filters WHERE evm_chain_id = $1 GROUP BY name` var rows []Filter - err := o.db.SelectContext(ctx, &rows, query, ubig.New(o.chainID)) + err := o.ds.SelectContext(ctx, &rows, query, ubig.New(o.chainID)) filters := make(map[string]Filter) for _, filter := range rows { filters[filter.Name] = filter @@ -188,31 +178,31 @@ func (o *DbORM) LoadFilters(ctx context.Context) (map[string]Filter, error) { return filters, err } -func (o *DbORM) SelectBlockByHash(ctx context.Context, hash common.Hash) (*LogPollerBlock, error) { +func (o *DSORM) SelectBlockByHash(ctx context.Context, hash common.Hash) (*LogPollerBlock, error) { var b LogPollerBlock - if err := o.db.GetContext(ctx, &b, `SELECT * FROM evm.log_poller_blocks WHERE block_hash = $1 AND evm_chain_id = $2`, hash.Bytes(), ubig.New(o.chainID)); err != nil { + if err := o.ds.GetContext(ctx, &b, `SELECT * FROM evm.log_poller_blocks WHERE block_hash = $1 AND evm_chain_id = $2`, hash.Bytes(), ubig.New(o.chainID)); err != nil { return nil, err } return &b, nil } -func (o *DbORM) SelectBlockByNumber(ctx context.Context, n int64) (*LogPollerBlock, error) { +func (o *DSORM) SelectBlockByNumber(ctx context.Context, n int64) (*LogPollerBlock, error) { var b LogPollerBlock - if err := o.db.GetContext(ctx, &b, `SELECT * FROM evm.log_poller_blocks WHERE block_number = $1 AND evm_chain_id = $2`, n, ubig.New(o.chainID)); err != nil { + if err := o.ds.GetContext(ctx, &b, `SELECT * FROM evm.log_poller_blocks WHERE block_number = $1 AND evm_chain_id = $2`, n, ubig.New(o.chainID)); err != nil { return nil, err } return &b, nil } -func (o *DbORM) SelectLatestBlock(ctx context.Context) (*LogPollerBlock, error) { +func (o *DSORM) SelectLatestBlock(ctx context.Context) (*LogPollerBlock, error) { var b LogPollerBlock - if err := o.db.GetContext(ctx, &b, `SELECT * FROM evm.log_poller_blocks WHERE evm_chain_id = $1 ORDER BY block_number DESC LIMIT 1`, ubig.New(o.chainID)); err != nil { + if err := o.ds.GetContext(ctx, &b, `SELECT * FROM evm.log_poller_blocks WHERE evm_chain_id = $1 ORDER BY block_number DESC LIMIT 1`, ubig.New(o.chainID)); err != nil { return nil, err } return &b, nil } -func (o *DbORM) SelectLatestLogByEventSigWithConfs(ctx context.Context, eventSig common.Hash, address common.Address, confs Confirmations) (*Log, error) { +func (o *DSORM) SelectLatestLogByEventSigWithConfs(ctx context.Context, eventSig common.Hash, address common.Address, confs Confirmations) (*Log, error) { args, err := newQueryArgsForEvent(o.chainID, address, eventSig). withConfs(confs). toArgs() @@ -228,11 +218,11 @@ func (o *DbORM) SelectLatestLogByEventSigWithConfs(ctx context.Context, eventSig ORDER BY (block_number, log_index) DESC LIMIT 1`, nestedBlockNumberQuery(confs)) var l Log - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return nil, err } - if err = o.db.GetContext(ctx, &l, query, sqlArgs...); err != nil { + if err = o.ds.GetContext(ctx, &l, query, sqlArgs...); err != nil { return nil, err } return &l, nil @@ -240,9 +230,9 @@ func (o *DbORM) SelectLatestLogByEventSigWithConfs(ctx context.Context, eventSig // DeleteBlocksBefore delete blocks before and including end. When limit is set, it will delete at most limit blocks. // Otherwise, it will delete all blocks at once. -func (o *DbORM) DeleteBlocksBefore(ctx context.Context, end int64, limit int64) (int64, error) { +func (o *DSORM) DeleteBlocksBefore(ctx context.Context, end int64, limit int64) (int64, error) { if limit > 0 { - result, err := o.db.ExecContext(ctx, + result, err := o.ds.ExecContext(ctx, `DELETE FROM evm.log_poller_blocks WHERE block_number IN ( SELECT block_number FROM evm.log_poller_blocks @@ -257,7 +247,7 @@ func (o *DbORM) DeleteBlocksBefore(ctx context.Context, end int64, limit int64) } return result.RowsAffected() } - result, err := o.db.ExecContext(ctx, `DELETE FROM evm.log_poller_blocks + result, err := o.ds.ExecContext(ctx, `DELETE FROM evm.log_poller_blocks WHERE block_number <= $1 AND evm_chain_id = $2`, end, ubig.New(o.chainID)) if err != nil { return 0, err @@ -265,16 +255,16 @@ func (o *DbORM) DeleteBlocksBefore(ctx context.Context, end int64, limit int64) return result.RowsAffected() } -func (o *DbORM) DeleteLogsAndBlocksAfter(ctx context.Context, start int64) error { +func (o *DSORM) DeleteLogsAndBlocksAfter(ctx context.Context, start int64) error { // These deletes are bounded by reorg depth, so they are // fast and should not slow down the log readers. - return o.Transaction(ctx, func(orm *DbORM) error { + return o.Transact(ctx, func(orm *DSORM) error { // Applying upper bound filter is critical for Postgres performance (especially for evm.logs table) // because it allows the planner to properly estimate the number of rows to be scanned. // If not applied, these queries can become very slow. After some critical number // of logs, Postgres will try to scan all the logs in the index by block_number. // Latency without upper bound filter can be orders of magnitude higher for large number of logs. - _, err := o.db.ExecContext(ctx, `DELETE FROM evm.log_poller_blocks + _, err := o.ds.ExecContext(ctx, `DELETE FROM evm.log_poller_blocks WHERE evm_chain_id = $1 AND block_number >= $2 AND block_number <= (SELECT MAX(block_number) @@ -286,7 +276,7 @@ func (o *DbORM) DeleteLogsAndBlocksAfter(ctx context.Context, start int64) error return err } - _, err = o.db.ExecContext(ctx, `DELETE FROM evm.logs + _, err = o.ds.ExecContext(ctx, `DELETE FROM evm.logs WHERE evm_chain_id = $1 AND block_number >= $2 AND block_number <= (SELECT MAX(block_number) FROM evm.logs WHERE evm_chain_id = $1)`, @@ -307,11 +297,11 @@ type Exp struct { ShouldDelete bool } -func (o *DbORM) DeleteExpiredLogs(ctx context.Context, limit int64) (int64, error) { +func (o *DSORM) DeleteExpiredLogs(ctx context.Context, limit int64) (int64, error) { var err error var result sql.Result if limit > 0 { - result, err = o.db.ExecContext(ctx, ` + result, err = o.ds.ExecContext(ctx, ` DELETE FROM evm.logs WHERE (evm_chain_id, address, event_sig, block_number) IN ( SELECT l.evm_chain_id, l.address, l.event_sig, l.block_number @@ -327,7 +317,7 @@ func (o *DbORM) DeleteExpiredLogs(ctx context.Context, limit int64) (int64, erro LIMIT $2 )`, ubig.New(o.chainID), limit) } else { - result, err = o.db.ExecContext(ctx, `WITH r AS + result, err = o.ds.ExecContext(ctx, `WITH r AS ( SELECT address, event, MAX(retention) AS retention FROM evm.log_poller_filters WHERE evm_chain_id=$1 GROUP BY evm_chain_id,address, event HAVING NOT 0 = ANY(ARRAY_AGG(retention)) @@ -344,16 +334,16 @@ func (o *DbORM) DeleteExpiredLogs(ctx context.Context, limit int64) (int64, erro } // InsertLogs is idempotent to support replays. -func (o *DbORM) InsertLogs(ctx context.Context, logs []Log) error { +func (o *DSORM) InsertLogs(ctx context.Context, logs []Log) error { if err := o.validateLogs(logs); err != nil { return err } - return o.Transaction(ctx, func(orm *DbORM) error { - return orm.insertLogsWithinTx(ctx, logs, orm.db) + return o.Transact(ctx, func(orm *DSORM) error { + return orm.insertLogsWithinTx(ctx, logs, orm.ds) }) } -func (o *DbORM) InsertLogsWithBlock(ctx context.Context, logs []Log, block LogPollerBlock) error { +func (o *DSORM) InsertLogsWithBlock(ctx context.Context, logs []Log, block LogPollerBlock) error { // Optimization, don't open TX when there is only a block to be persisted if len(logs) == 0 { return o.InsertBlock(ctx, block.BlockHash, block.BlockNumber, block.BlockTimestamp, block.FinalizedBlockNumber) @@ -364,16 +354,16 @@ func (o *DbORM) InsertLogsWithBlock(ctx context.Context, logs []Log, block LogPo } // Block and logs goes with the same TX to ensure atomicity - return o.Transaction(ctx, func(orm *DbORM) error { + return o.Transact(ctx, func(orm *DSORM) error { err := orm.InsertBlock(ctx, block.BlockHash, block.BlockNumber, block.BlockTimestamp, block.FinalizedBlockNumber) if err != nil { return err } - return orm.insertLogsWithinTx(ctx, logs, orm.db) + return orm.insertLogsWithinTx(ctx, logs, orm.ds) }) } -func (o *DbORM) insertLogsWithinTx(ctx context.Context, logs []Log, tx sqlutil.DataSource) error { +func (o *DSORM) insertLogsWithinTx(ctx context.Context, logs []Log, tx sqlutil.DataSource) error { batchInsertSize := 4000 for i := 0; i < len(logs); i += batchInsertSize { start, end := i, i+batchInsertSize @@ -387,12 +377,10 @@ func (o *DbORM) insertLogsWithinTx(ctx context.Context, logs []Log, tx sqlutil.D (:evm_chain_id, :log_index, :block_hash, :block_number, :block_timestamp, :address, :event_sig, :topics, :tx_hash, :data, NOW()) ON CONFLICT DO NOTHING` - query, sqlArgs, err := o.db.BindNamed(query, logs[start:end]) + _, err := o.ds.NamedExecContext(ctx, query, logs[start:end]) if err != nil { return err } - - _, err = tx.ExecContext(ctx, query, sqlArgs...) if err != nil { if pkgerrors.Is(err, context.DeadlineExceeded) && batchInsertSize > 500 { // In case of DB timeouts, try to insert again with a smaller batch upto a limit @@ -406,7 +394,7 @@ func (o *DbORM) insertLogsWithinTx(ctx context.Context, logs []Log, tx sqlutil.D return nil } -func (o *DbORM) validateLogs(logs []Log) error { +func (o *DSORM) validateLogs(logs []Log) error { for _, log := range logs { if o.chainID.Cmp(log.EvmChainId.ToInt()) != 0 { return pkgerrors.Errorf("invalid chainID in log got %v want %v", log.EvmChainId.ToInt(), o.chainID) @@ -415,7 +403,7 @@ func (o *DbORM) validateLogs(logs []Log) error { return nil } -func (o *DbORM) SelectLogsByBlockRange(ctx context.Context, start, end int64) ([]Log, error) { +func (o *DSORM) SelectLogsByBlockRange(ctx context.Context, start, end int64) ([]Log, error) { args, err := newQueryArgs(o.chainID). withStartBlock(start). withEndBlock(end). @@ -431,12 +419,12 @@ func (o *DbORM) SelectLogsByBlockRange(ctx context.Context, start, end int64) ([ ORDER BY (block_number, log_index)` var logs []Log - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return nil, err } - err = o.db.SelectContext(ctx, &logs, query, sqlArgs...) + err = o.ds.SelectContext(ctx, &logs, query, sqlArgs...) if err != nil { return nil, err } @@ -444,7 +432,7 @@ func (o *DbORM) SelectLogsByBlockRange(ctx context.Context, start, end int64) ([ } // SelectLogs finds the logs in a given block range. -func (o *DbORM) SelectLogs(ctx context.Context, start, end int64, address common.Address, eventSig common.Hash) ([]Log, error) { +func (o *DSORM) SelectLogs(ctx context.Context, start, end int64, address common.Address, eventSig common.Hash) ([]Log, error) { args, err := newQueryArgsForEvent(o.chainID, address, eventSig). withStartBlock(start). withEndBlock(end). @@ -462,12 +450,12 @@ func (o *DbORM) SelectLogs(ctx context.Context, start, end int64, address common ORDER BY (block_number, log_index)` var logs []Log - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return nil, err } - err = o.db.SelectContext(ctx, &logs, query, sqlArgs...) + err = o.ds.SelectContext(ctx, &logs, query, sqlArgs...) if err != nil { return nil, err } @@ -475,7 +463,7 @@ func (o *DbORM) SelectLogs(ctx context.Context, start, end int64, address common } // SelectLogsCreatedAfter finds logs created after some timestamp. -func (o *DbORM) SelectLogsCreatedAfter(ctx context.Context, address common.Address, eventSig common.Hash, after time.Time, confs Confirmations) ([]Log, error) { +func (o *DSORM) SelectLogsCreatedAfter(ctx context.Context, address common.Address, eventSig common.Hash, after time.Time, confs Confirmations) ([]Log, error) { args, err := newQueryArgsForEvent(o.chainID, address, eventSig). withBlockTimestampAfter(after). withConfs(confs). @@ -494,12 +482,12 @@ func (o *DbORM) SelectLogsCreatedAfter(ctx context.Context, address common.Addre ORDER BY (block_number, log_index)`, nestedBlockNumberQuery(confs)) var logs []Log - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return nil, err } - if err = o.db.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { + if err = o.ds.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { return nil, err } return logs, nil @@ -507,7 +495,7 @@ func (o *DbORM) SelectLogsCreatedAfter(ctx context.Context, address common.Addre // SelectLogsWithSigs finds the logs in the given block range with the given event signatures // emitted from the given address. -func (o *DbORM) SelectLogsWithSigs(ctx context.Context, start, end int64, address common.Address, eventSigs []common.Hash) (logs []Log, err error) { +func (o *DSORM) SelectLogsWithSigs(ctx context.Context, start, end int64, address common.Address, eventSigs []common.Hash) (logs []Log, err error) { args, err := newQueryArgs(o.chainID). withAddress(address). withEventSigArray(eventSigs). @@ -525,19 +513,19 @@ func (o *DbORM) SelectLogsWithSigs(ctx context.Context, start, end int64, addres AND block_number BETWEEN :start_block AND :end_block ORDER BY (block_number, log_index)` - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return nil, err } - err = o.db.SelectContext(ctx, &logs, query, sqlArgs...) + err = o.ds.SelectContext(ctx, &logs, query, sqlArgs...) if pkgerrors.Is(err, sql.ErrNoRows) { return nil, nil } return logs, err } -func (o *DbORM) GetBlocksRange(ctx context.Context, start int64, end int64) ([]LogPollerBlock, error) { +func (o *DSORM) GetBlocksRange(ctx context.Context, start int64, end int64) ([]LogPollerBlock, error) { args, err := newQueryArgs(o.chainID). withStartBlock(start). withEndBlock(end). @@ -553,12 +541,12 @@ func (o *DbORM) GetBlocksRange(ctx context.Context, start int64, end int64) ([]L ORDER BY block_number ASC` var blocks []LogPollerBlock - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return nil, err } - err = o.db.SelectContext(ctx, &blocks, query, sqlArgs...) + err = o.ds.SelectContext(ctx, &blocks, query, sqlArgs...) if err != nil { return nil, err } @@ -566,7 +554,7 @@ func (o *DbORM) GetBlocksRange(ctx context.Context, start int64, end int64) ([]L } // SelectLatestLogEventSigsAddrsWithConfs finds the latest log by (address, event) combination that matches a list of Addresses and list of events -func (o *DbORM) SelectLatestLogEventSigsAddrsWithConfs(ctx context.Context, fromBlock int64, addresses []common.Address, eventSigs []common.Hash, confs Confirmations) ([]Log, error) { +func (o *DSORM) SelectLatestLogEventSigsAddrsWithConfs(ctx context.Context, fromBlock int64, addresses []common.Address, eventSigs []common.Hash, confs Confirmations) ([]Log, error) { args, err := newQueryArgs(o.chainID). withAddressArray(addresses). withEventSigArray(eventSigs). @@ -590,19 +578,19 @@ func (o *DbORM) SelectLatestLogEventSigsAddrsWithConfs(ctx context.Context, from ORDER BY block_number ASC`, nestedBlockNumberQuery(confs)) var logs []Log - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return nil, err } - if err = o.db.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { + if err = o.ds.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { return nil, pkgerrors.Wrap(err, "failed to execute query") } return logs, nil } // SelectLatestBlockByEventSigsAddrsWithConfs finds the latest block number that matches a list of Addresses and list of events. It returns 0 if there is no matching block -func (o *DbORM) SelectLatestBlockByEventSigsAddrsWithConfs(ctx context.Context, fromBlock int64, eventSigs []common.Hash, addresses []common.Address, confs Confirmations) (int64, error) { +func (o *DSORM) SelectLatestBlockByEventSigsAddrsWithConfs(ctx context.Context, fromBlock int64, eventSigs []common.Hash, addresses []common.Address, confs Confirmations) (int64, error) { args, err := newQueryArgs(o.chainID). withEventSigArray(eventSigs). withAddressArray(addresses). @@ -621,18 +609,18 @@ func (o *DbORM) SelectLatestBlockByEventSigsAddrsWithConfs(ctx context.Context, AND block_number <= %s`, nestedBlockNumberQuery(confs)) var blockNumber int64 - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return 0, err } - if err = o.db.GetContext(ctx, &blockNumber, query, sqlArgs...); err != nil { + if err = o.ds.GetContext(ctx, &blockNumber, query, sqlArgs...); err != nil { return 0, err } return blockNumber, nil } -func (o *DbORM) SelectLogsDataWordRange(ctx context.Context, address common.Address, eventSig common.Hash, wordIndex int, wordValueMin, wordValueMax common.Hash, confs Confirmations) ([]Log, error) { +func (o *DSORM) SelectLogsDataWordRange(ctx context.Context, address common.Address, eventSig common.Hash, wordIndex int, wordValueMin, wordValueMax common.Hash, confs Confirmations) ([]Log, error) { args, err := newQueryArgsForEvent(o.chainID, address, eventSig). withWordIndex(wordIndex). withWordValueMin(wordValueMin). @@ -653,18 +641,18 @@ func (o *DbORM) SelectLogsDataWordRange(ctx context.Context, address common.Addr ORDER BY (block_number, log_index)`, nestedBlockNumberQuery(confs)) var logs []Log - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return nil, err } - if err := o.db.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { + if err := o.ds.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { return nil, err } return logs, nil } -func (o *DbORM) SelectLogsDataWordGreaterThan(ctx context.Context, address common.Address, eventSig common.Hash, wordIndex int, wordValueMin common.Hash, confs Confirmations) ([]Log, error) { +func (o *DSORM) SelectLogsDataWordGreaterThan(ctx context.Context, address common.Address, eventSig common.Hash, wordIndex int, wordValueMin common.Hash, confs Confirmations) ([]Log, error) { args, err := newQueryArgsForEvent(o.chainID, address, eventSig). withWordIndex(wordIndex). withWordValueMin(wordValueMin). @@ -684,18 +672,18 @@ func (o *DbORM) SelectLogsDataWordGreaterThan(ctx context.Context, address commo ORDER BY (block_number, log_index)`, nestedBlockNumberQuery(confs)) var logs []Log - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return nil, err } - if err := o.db.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { + if err := o.ds.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { return nil, err } return logs, nil } -func (o *DbORM) SelectLogsDataWordBetween(ctx context.Context, address common.Address, eventSig common.Hash, wordIndexMin int, wordIndexMax int, wordValue common.Hash, confs Confirmations) ([]Log, error) { +func (o *DSORM) SelectLogsDataWordBetween(ctx context.Context, address common.Address, eventSig common.Hash, wordIndexMin int, wordIndexMax int, wordValue common.Hash, confs Confirmations) ([]Log, error) { args, err := newQueryArgsForEvent(o.chainID, address, eventSig). withWordIndexMin(wordIndexMin). withWordIndexMax(wordIndexMax). @@ -716,18 +704,18 @@ func (o *DbORM) SelectLogsDataWordBetween(ctx context.Context, address common.Ad ORDER BY (block_number, log_index)`, nestedBlockNumberQuery(confs)) var logs []Log - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return nil, err } - if err := o.db.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { + if err := o.ds.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { return nil, err } return logs, nil } -func (o *DbORM) SelectIndexedLogsTopicGreaterThan(ctx context.Context, address common.Address, eventSig common.Hash, topicIndex int, topicValueMin common.Hash, confs Confirmations) ([]Log, error) { +func (o *DSORM) SelectIndexedLogsTopicGreaterThan(ctx context.Context, address common.Address, eventSig common.Hash, topicIndex int, topicValueMin common.Hash, confs Confirmations) ([]Log, error) { args, err := newQueryArgsForEvent(o.chainID, address, eventSig). withTopicIndex(topicIndex). withTopicValueMin(topicValueMin). @@ -747,18 +735,18 @@ func (o *DbORM) SelectIndexedLogsTopicGreaterThan(ctx context.Context, address c ORDER BY (block_number, log_index)`, nestedBlockNumberQuery(confs)) var logs []Log - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return nil, err } - if err := o.db.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { + if err := o.ds.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { return nil, err } return logs, nil } -func (o *DbORM) SelectIndexedLogsTopicRange(ctx context.Context, address common.Address, eventSig common.Hash, topicIndex int, topicValueMin, topicValueMax common.Hash, confs Confirmations) ([]Log, error) { +func (o *DSORM) SelectIndexedLogsTopicRange(ctx context.Context, address common.Address, eventSig common.Hash, topicIndex int, topicValueMin, topicValueMax common.Hash, confs Confirmations) ([]Log, error) { args, err := newQueryArgsForEvent(o.chainID, address, eventSig). withTopicIndex(topicIndex). withTopicValueMin(topicValueMin). @@ -780,18 +768,18 @@ func (o *DbORM) SelectIndexedLogsTopicRange(ctx context.Context, address common. ORDER BY (evm.logs.block_number, evm.logs.log_index)`, nestedBlockNumberQuery(confs)) var logs []Log - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return nil, err } - if err := o.db.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { + if err := o.ds.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { return nil, err } return logs, nil } -func (o *DbORM) SelectIndexedLogs(ctx context.Context, address common.Address, eventSig common.Hash, topicIndex int, topicValues []common.Hash, confs Confirmations) ([]Log, error) { +func (o *DSORM) SelectIndexedLogs(ctx context.Context, address common.Address, eventSig common.Hash, topicIndex int, topicValues []common.Hash, confs Confirmations) ([]Log, error) { args, err := newQueryArgsForEvent(o.chainID, address, eventSig). withTopicIndex(topicIndex). withTopicValues(topicValues). @@ -811,19 +799,19 @@ func (o *DbORM) SelectIndexedLogs(ctx context.Context, address common.Address, e ORDER BY (block_number, log_index)`, nestedBlockNumberQuery(confs)) var logs []Log - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return nil, err } - if err := o.db.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { + if err := o.ds.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { return nil, err } return logs, nil } // SelectIndexedLogsByBlockRange finds the indexed logs in a given block range. -func (o *DbORM) SelectIndexedLogsByBlockRange(ctx context.Context, start, end int64, address common.Address, eventSig common.Hash, topicIndex int, topicValues []common.Hash) ([]Log, error) { +func (o *DSORM) SelectIndexedLogsByBlockRange(ctx context.Context, start, end int64, address common.Address, eventSig common.Hash, topicIndex int, topicValues []common.Hash) ([]Log, error) { args, err := newQueryArgsForEvent(o.chainID, address, eventSig). withTopicIndex(topicIndex). withTopicValues(topicValues). @@ -844,19 +832,19 @@ func (o *DbORM) SelectIndexedLogsByBlockRange(ctx context.Context, start, end in ORDER BY (block_number, log_index)` var logs []Log - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return nil, err } - err = o.db.SelectContext(ctx, &logs, query, sqlArgs...) + err = o.ds.SelectContext(ctx, &logs, query, sqlArgs...) if err != nil { return nil, err } return logs, nil } -func (o *DbORM) SelectIndexedLogsCreatedAfter(ctx context.Context, address common.Address, eventSig common.Hash, topicIndex int, topicValues []common.Hash, after time.Time, confs Confirmations) ([]Log, error) { +func (o *DSORM) SelectIndexedLogsCreatedAfter(ctx context.Context, address common.Address, eventSig common.Hash, topicIndex int, topicValues []common.Hash, after time.Time, confs Confirmations) ([]Log, error) { args, err := newQueryArgsForEvent(o.chainID, address, eventSig). withBlockTimestampAfter(after). withConfs(confs). @@ -878,18 +866,18 @@ func (o *DbORM) SelectIndexedLogsCreatedAfter(ctx context.Context, address commo ORDER BY (block_number, log_index)`, nestedBlockNumberQuery(confs)) var logs []Log - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return nil, err } - if err := o.db.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { + if err := o.ds.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { return nil, err } return logs, nil } -func (o *DbORM) SelectIndexedLogsByTxHash(ctx context.Context, address common.Address, eventSig common.Hash, txHash common.Hash) ([]Log, error) { +func (o *DSORM) SelectIndexedLogsByTxHash(ctx context.Context, address common.Address, eventSig common.Hash, txHash common.Hash) ([]Log, error) { args, err := newQueryArgs(o.chainID). withTxHash(txHash). withAddress(address). @@ -907,12 +895,12 @@ func (o *DbORM) SelectIndexedLogsByTxHash(ctx context.Context, address common.Ad ORDER BY (block_number, log_index)` var logs []Log - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return nil, err } - err = o.db.SelectContext(ctx, &logs, query, sqlArgs...) + err = o.ds.SelectContext(ctx, &logs, query, sqlArgs...) if err != nil { return nil, err } @@ -920,7 +908,7 @@ func (o *DbORM) SelectIndexedLogsByTxHash(ctx context.Context, address common.Ad } // SelectIndexedLogsWithSigsExcluding query's for logs that have signature A and exclude logs that have a corresponding signature B, matching is done based on the topic index both logs should be inside the block range and have the minimum number of confirmations -func (o *DbORM) SelectIndexedLogsWithSigsExcluding(ctx context.Context, sigA, sigB common.Hash, topicIndex int, address common.Address, startBlock, endBlock int64, confs Confirmations) ([]Log, error) { +func (o *DSORM) SelectIndexedLogsWithSigsExcluding(ctx context.Context, sigA, sigB common.Hash, topicIndex int, address common.Address, startBlock, endBlock int64, confs Confirmations) ([]Log, error) { args, err := newQueryArgs(o.chainID). withAddress(address). withTopicIndex(topicIndex). @@ -955,12 +943,12 @@ func (o *DbORM) SelectIndexedLogsWithSigsExcluding(ctx context.Context, sigA, si ORDER BY block_number,log_index ASC`, nestedQuery, nestedQuery) var logs []Log - query, sqlArgs, err := o.db.BindNamed(query, args) + query, sqlArgs, err := o.ds.BindNamed(query, args) if err != nil { return nil, err } - if err := o.db.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { + if err := o.ds.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { return nil, err } return logs, nil diff --git a/core/chains/evm/logpoller/query.go b/core/chains/evm/logpoller/query.go index 6aabe59045d..f9d2c45bce1 100644 --- a/core/chains/evm/logpoller/query.go +++ b/core/chains/evm/logpoller/query.go @@ -24,7 +24,7 @@ func concatBytes[T bytesProducer](byteSlice []T) [][]byte { return output } -// queryArgs is a helper for building the arguments to a postgres query created by DbORM +// queryArgs is a helper for building the arguments to a postgres query created by DSORM // Besides the convenience methods, it also keeps track of arguments validation and sanitization. type queryArgs struct { args map[string]interface{} diff --git a/core/chains/evm/txmgr/broadcaster_test.go b/core/chains/evm/txmgr/broadcaster_test.go index 4e19a2ec7da..20c069a46d6 100644 --- a/core/chains/evm/txmgr/broadcaster_test.go +++ b/core/chains/evm/txmgr/broadcaster_test.go @@ -1112,7 +1112,7 @@ func TestEthBroadcaster_ProcessUnstartedEthTxs_Errors(t *testing.T) { } t.Run("with erroring callback bails out", func(t *testing.T) { - require.NoError(t, txStore.InsertTx(ctx, &etx)) + require.NoError(t, txStore.InsertTx(testutils.Context(t), &etx)) fn := func(ctx context.Context, id uuid.UUID, result interface{}, err error) error { return errors.New("something exploded in the callback") } diff --git a/core/chains/evm/txmgr/builder.go b/core/chains/evm/txmgr/builder.go index f13efb2b258..0671f49bb74 100644 --- a/core/chains/evm/txmgr/builder.go +++ b/core/chains/evm/txmgr/builder.go @@ -4,8 +4,6 @@ import ( "math/big" "time" - "github.com/jmoiron/sqlx" - "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/common/txmgr" @@ -21,8 +19,7 @@ import ( // NewTxm constructs the necessary dependencies for the EvmTxm (broadcaster, confirmer, etc) and returns a new EvmTxManager func NewTxm( - sqlxDB *sqlx.DB, - db sqlutil.DataSource, + ds sqlutil.DataSource, chainConfig ChainConfig, fCfg FeeConfig, txConfig config.Transactions, @@ -40,14 +37,14 @@ func NewTxm( var fwdMgr FwdMgr if txConfig.ForwardersEnabled() { - fwdMgr = forwarders.NewFwdMgr(db, client, logPoller, lggr, chainConfig) + fwdMgr = forwarders.NewFwdMgr(ds, client, logPoller, lggr, chainConfig) } else { lggr.Info("EvmForwarderManager: Disabled") } checker := &CheckerFactory{Client: client} // create tx attempt builder txAttemptBuilder := NewEvmTxAttemptBuilder(*client.ConfiguredChainID(), fCfg, keyStore, estimator) - txStore := NewTxStore(sqlxDB, lggr) + txStore := NewTxStore(ds, lggr) txmCfg := NewEvmTxmConfig(chainConfig) // wrap Evm specific config feeCfg := NewEvmTxmFeeConfig(fCfg) // wrap Evm specific config txmClient := NewEvmTxmClient(client, clientErrors) // wrap Evm specific client diff --git a/core/chains/evm/txmgr/evm_tx_store.go b/core/chains/evm/txmgr/evm_tx_store.go index c8e664e8cfe..dedba07b594 100644 --- a/core/chains/evm/txmgr/evm_tx_store.go +++ b/core/chains/evm/txmgr/evm_tx_store.go @@ -115,7 +115,7 @@ func DbReceiptToEvmReceipt(receipt *dbReceipt) *evmtypes.Receipt { // Directly maps to onchain receipt schema. type rawOnchainReceipt = evmtypes.Receipt -func (o *evmTxStore) Transaction(ctx context.Context, readOnly bool, fn func(*evmTxStore) error) (err error) { +func (o *evmTxStore) Transact(ctx context.Context, readOnly bool, fn func(*evmTxStore) error) (err error) { opts := &sqlutil.TxOptions{TxOptions: sql.TxOptions{ReadOnly: readOnly}} return sqlutil.Transact(ctx, o.new, o.q, opts, fn) } @@ -509,7 +509,7 @@ func (o *evmTxStore) FindTxAttemptsByTxIDs(ctx context.Context, ids []int64) ([] func (o *evmTxStore) FindTxByHash(ctx context.Context, hash common.Hash) (*Tx, error) { var dbEtx DbEthTx - err := o.Transaction(ctx, true, func(orm *evmTxStore) error { + err := o.Transact(ctx, true, func(orm *evmTxStore) error { sql := `SELECT evm.txes.* FROM evm.txes WHERE id IN (SELECT DISTINCT eth_tx_id FROM evm.tx_attempts WHERE hash = $1)` if err := orm.q.GetContext(ctx, &dbEtx, sql, hash); err != nil { return pkgerrors.Wrapf(err, "failed to find eth_tx with hash %d", hash) @@ -575,7 +575,7 @@ func (o *evmTxStore) GetFatalTransactions(ctx context.Context) (txes []*Tx, err var cancel context.CancelFunc ctx, cancel = o.mergeContexts(ctx) defer cancel() - err = o.Transaction(ctx, true, func(orm *evmTxStore) error { + err = o.Transact(ctx, true, func(orm *evmTxStore) error { stmt := `SELECT * FROM evm.txes WHERE state = 'fatal_error'` var dbEtxs []DbEthTx if err = orm.q.SelectContext(ctx, &dbEtxs, stmt); err != nil { @@ -595,7 +595,7 @@ func (o *evmTxStore) GetFatalTransactions(ctx context.Context) (txes []*Tx, err // FindTxWithAttempts finds the Tx with its attempts and receipts preloaded func (o *evmTxStore) FindTxWithAttempts(ctx context.Context, etxID int64) (etx Tx, err error) { - err = o.Transaction(ctx, true, func(orm *evmTxStore) error { + err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbEtx DbEthTx if err = orm.q.GetContext(ctx, &dbEtx, `SELECT * FROM evm.txes WHERE id = $1 ORDER BY created_at ASC, id ASC`, etxID); err != nil { return pkgerrors.Wrapf(err, "failed to find evm.tx with id %d", etxID) @@ -614,7 +614,7 @@ func (o *evmTxStore) FindTxWithAttempts(ctx context.Context, etxID int64) (etx T func (o *evmTxStore) FindTxAttemptConfirmedByTxIDs(ctx context.Context, ids []int64) ([]TxAttempt, error) { var txAttempts []TxAttempt - err := o.Transaction(ctx, true, func(orm *evmTxStore) error { + err := o.Transact(ctx, true, func(orm *evmTxStore) error { var dbAttempts []DbEthTxAttempt if err := orm.q.SelectContext(ctx, &dbAttempts, `SELECT eta.* FROM evm.tx_attempts eta @@ -806,7 +806,7 @@ func (o *evmTxStore) FindTxAttemptsRequiringReceiptFetch(ctx context.Context, ch var cancel context.CancelFunc ctx, cancel = o.mergeContexts(ctx) defer cancel() - err = o.Transaction(ctx, true, func(orm *evmTxStore) error { + err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbAttempts []DbEthTxAttempt err = orm.q.SelectContext(ctx, &dbAttempts, ` SELECT evm.tx_attempts.* FROM evm.tx_attempts @@ -963,7 +963,7 @@ func (o *evmTxStore) GetInProgressTxAttempts(ctx context.Context, address common var cancel context.CancelFunc ctx, cancel = o.mergeContexts(ctx) defer cancel() - err = o.Transaction(ctx, true, func(orm *evmTxStore) error { + err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbAttempts []DbEthTxAttempt err = orm.q.SelectContext(ctx, &dbAttempts, ` SELECT evm.tx_attempts.* FROM evm.tx_attempts @@ -1046,7 +1046,7 @@ func (o *evmTxStore) FindTxWithSequence(ctx context.Context, fromAddress common. ctx, cancel = o.mergeContexts(ctx) defer cancel() etx = new(Tx) - err = o.Transaction(ctx, true, func(orm *evmTxStore) error { + err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbEtx DbEthTx err = orm.q.GetContext(ctx, &dbEtx, ` SELECT * FROM evm.txes WHERE from_address = $1 AND nonce = $2 AND state IN ('confirmed', 'confirmed_missing_receipt', 'unconfirmed') @@ -1094,7 +1094,7 @@ func (o *evmTxStore) UpdateTxForRebroadcast(ctx context.Context, etx Tx, etxAtte var cancel context.CancelFunc ctx, cancel = o.mergeContexts(ctx) defer cancel() - return o.Transaction(ctx, false, func(orm *evmTxStore) error { + return o.Transact(ctx, false, func(orm *evmTxStore) error { if err := deleteEthReceipts(ctx, orm, etx.ID); err != nil { return pkgerrors.Wrapf(err, "deleteEthReceipts failed for etx %v", etx.ID) } @@ -1109,7 +1109,7 @@ func (o *evmTxStore) FindTransactionsConfirmedInBlockRange(ctx context.Context, var cancel context.CancelFunc ctx, cancel = o.mergeContexts(ctx) defer cancel() - err = o.Transaction(ctx, true, func(orm *evmTxStore) error { + err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbEtxs []DbEthTx err = orm.q.SelectContext(ctx, &dbEtxs, ` SELECT DISTINCT evm.txes.* FROM evm.txes @@ -1136,7 +1136,7 @@ func (o *evmTxStore) FindEarliestUnconfirmedBroadcastTime(ctx context.Context, c var cancel context.CancelFunc ctx, cancel = o.mergeContexts(ctx) defer cancel() - err = o.Transaction(ctx, true, func(orm *evmTxStore) error { + err = o.Transact(ctx, true, func(orm *evmTxStore) error { if err = orm.q.QueryRowxContext(ctx, `SELECT min(initial_broadcast_at) FROM evm.txes WHERE state = 'unconfirmed' AND evm_chain_id = $1`, chainID.String()).Scan(&broadcastAt); err != nil { return fmt.Errorf("failed to query for unconfirmed eth_tx count: %w", err) } @@ -1149,7 +1149,7 @@ func (o *evmTxStore) FindEarliestUnconfirmedTxAttemptBlock(ctx context.Context, var cancel context.CancelFunc ctx, cancel = o.mergeContexts(ctx) defer cancel() - err = o.Transaction(ctx, true, func(orm *evmTxStore) error { + err = o.Transact(ctx, true, func(orm *evmTxStore) error { err = orm.q.QueryRowxContext(ctx, ` SELECT MIN(broadcast_before_block_num) FROM evm.tx_attempts JOIN evm.txes ON evm.txes.id = evm.tx_attempts.eth_tx_id @@ -1184,7 +1184,7 @@ func (o *evmTxStore) IsTxFinalized(ctx context.Context, blockHeight int64, txID func (o *evmTxStore) saveAttemptWithNewState(ctx context.Context, attempt TxAttempt, broadcastAt time.Time) error { var dbAttempt DbEthTxAttempt dbAttempt.FromTxAttempt(&attempt) - return o.Transaction(ctx, false, func(orm *evmTxStore) error { + return o.Transact(ctx, false, func(orm *evmTxStore) error { // In case of null broadcast_at (shouldn't happen) we don't want to // update anyway because it indicates a state where broadcast_at makes // no sense e.g. fatal_error @@ -1230,7 +1230,7 @@ func (o *evmTxStore) SaveConfirmedMissingReceiptAttempt(ctx context.Context, tim var cancel context.CancelFunc ctx, cancel = o.mergeContexts(ctx) defer cancel() - err := o.Transaction(ctx, false, func(orm *evmTxStore) error { + err := o.Transact(ctx, false, func(orm *evmTxStore) error { if err := orm.saveSentAttempt(ctx, timeout, attempt, broadcastAt); err != nil { return err } @@ -1321,7 +1321,7 @@ func (o *evmTxStore) GetTxByID(ctx context.Context, id int64) (txe *Tx, err erro ctx, cancel = o.mergeContexts(ctx) defer cancel() - err = o.Transaction(ctx, true, func(orm *evmTxStore) error { + err = o.Transact(ctx, true, func(orm *evmTxStore) error { stmt := `SELECT * FROM evm.txes WHERE id = $1` var dbEtxs []DbEthTx if err = orm.q.SelectContext(ctx, &dbEtxs, stmt, id); err != nil { @@ -1355,7 +1355,7 @@ func (o *evmTxStore) FindTxsRequiringGasBump(ctx context.Context, address common var cancel context.CancelFunc ctx, cancel = o.mergeContexts(ctx) defer cancel() - err = o.Transaction(ctx, true, func(orm *evmTxStore) error { + err = o.Transact(ctx, true, func(orm *evmTxStore) error { stmt := ` SELECT evm.txes.* FROM evm.txes LEFT JOIN evm.tx_attempts ON evm.txes.id = evm.tx_attempts.eth_tx_id AND (broadcast_before_block_num > $4 OR broadcast_before_block_num IS NULL OR evm.tx_attempts.state != 'broadcast') @@ -1382,7 +1382,7 @@ func (o *evmTxStore) FindTxsRequiringResubmissionDueToInsufficientFunds(ctx cont var cancel context.CancelFunc ctx, cancel = o.mergeContexts(ctx) defer cancel() - err = o.Transaction(ctx, true, func(orm *evmTxStore) error { + err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbEtxs []DbEthTx err = orm.q.SelectContext(ctx, &dbEtxs, ` SELECT DISTINCT evm.txes.* FROM evm.txes @@ -1423,7 +1423,7 @@ func (o *evmTxStore) MarkOldTxesMissingReceiptAsErrored(ctx context.Context, blo return nil } // note: if QOpt passes in a sql.Tx this will reuse it - return o.Transaction(ctx, false, func(orm *evmTxStore) error { + return o.Transact(ctx, false, func(orm *evmTxStore) error { type etx struct { ID int64 Nonce int64 @@ -1511,7 +1511,7 @@ func (o *evmTxStore) SaveReplacementInProgressAttempt(ctx context.Context, oldAt if oldAttempt.ID == 0 { return errors.New("expected oldAttempt to have an ID") } - return o.Transaction(ctx, false, func(orm *evmTxStore) error { + return o.Transact(ctx, false, func(orm *evmTxStore) error { if _, err := orm.q.ExecContext(ctx, `DELETE FROM evm.tx_attempts WHERE id=$1`, oldAttempt.ID); err != nil { return pkgerrors.Wrap(err, "saveReplacementInProgressAttempt failed to delete from evm.tx_attempts") } @@ -1557,7 +1557,7 @@ func (o *evmTxStore) UpdateTxFatalError(ctx context.Context, etx *Tx) error { etx.Sequence = nil etx.State = txmgr.TxFatalError - return o.Transaction(ctx, false, func(orm *evmTxStore) error { + return o.Transact(ctx, false, func(orm *evmTxStore) error { if _, err := orm.q.ExecContext(ctx, `DELETE FROM evm.tx_attempts WHERE eth_tx_id = $1`, etx.ID); err != nil { return pkgerrors.Wrapf(err, "saveFatallyErroredTransaction failed to delete eth_tx_attempt with eth_tx.ID %v", etx.ID) } @@ -1591,7 +1591,7 @@ func (o *evmTxStore) UpdateTxAttemptInProgressToBroadcast(ctx context.Context, e } etx.State = txmgr.TxUnconfirmed attempt.State = NewAttemptState - return o.Transaction(ctx, false, func(orm *evmTxStore) error { + return o.Transact(ctx, false, func(orm *evmTxStore) error { var dbEtx DbEthTx dbEtx.FromTx(etx) if err := orm.q.GetContext(ctx, &dbEtx, `UPDATE evm.txes SET state=$1, error=$2, broadcast_at=$3, initial_broadcast_at=$4 WHERE id = $5 RETURNING *`, dbEtx.State, dbEtx.Error, dbEtx.BroadcastAt, dbEtx.InitialBroadcastAt, dbEtx.ID); err != nil { @@ -1622,7 +1622,7 @@ func (o *evmTxStore) UpdateTxUnstartedToInProgress(ctx context.Context, etx *Tx, return errors.New("attempt state must be in_progress") } etx.State = txmgr.TxInProgress - return o.Transaction(ctx, false, func(orm *evmTxStore) error { + return o.Transact(ctx, false, func(orm *evmTxStore) error { // If a replay was triggered while unconfirmed transactions were pending, they will be marked as fatal_error => abandoned. // In this case, we must remove the abandoned attempt from evm.tx_attempts before replacing it with a new one. In any other // case, we uphold the constraint, leaving the original tx attempt as-is and returning the constraint violation error. @@ -1688,7 +1688,7 @@ func (o *evmTxStore) GetTxInProgress(ctx context.Context, fromAddress common.Add if err != nil { return etx, pkgerrors.Wrap(err, "getInProgressEthTx failed") } - err = o.Transaction(ctx, true, func(orm *evmTxStore) error { + err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbEtx DbEthTx err = orm.q.GetContext(ctx, &dbEtx, `SELECT * FROM evm.txes WHERE from_address = $1 and state = 'in_progress'`, fromAddress) if errors.Is(err, sql.ErrNoRows) { @@ -1776,7 +1776,7 @@ func (o *evmTxStore) CreateTransaction(ctx context.Context, txRequest TxRequest, ctx, cancel = o.mergeContexts(ctx) defer cancel() var dbEtx DbEthTx - err = o.Transaction(ctx, false, func(orm *evmTxStore) error { + err = o.Transact(ctx, false, func(orm *evmTxStore) error { if txRequest.PipelineTaskRunID != nil { err = orm.q.GetContext(ctx, &dbEtx, `SELECT * FROM evm.txes WHERE pipeline_task_run_id = $1 AND evm_chain_id = $2`, txRequest.PipelineTaskRunID, chainID.String()) @@ -1810,7 +1810,7 @@ func (o *evmTxStore) PruneUnstartedTxQueue(ctx context.Context, queueSize uint32 var cancel context.CancelFunc ctx, cancel = o.mergeContexts(ctx) defer cancel() - err = o.Transaction(ctx, false, func(orm *evmTxStore) error { + err = o.Transact(ctx, false, func(orm *evmTxStore) error { err := orm.q.SelectContext(ctx, &ids, ` DELETE FROM evm.txes WHERE state = 'unstarted' AND subject = $1 AND @@ -1945,7 +1945,7 @@ func (o *evmTxStore) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Co var cancel context.CancelFunc ctx, cancel = o.mergeContexts(ctx) defer cancel() - err = o.Transaction(ctx, true, func(orm *evmTxStore) error { + err = o.Transact(ctx, true, func(orm *evmTxStore) error { var dbEtxs []DbEthTx if err = orm.q.SelectContext(ctx, &dbEtxs, `SELECT * FROM evm.txes WHERE id = ANY($1) AND state = ANY($2) AND evm_chain_id = $3`, pq.Array(ids), pq.Array(states), chainID.String()); err != nil { return pkgerrors.Wrapf(err, "failed to find evm.txes") diff --git a/core/chains/evm/txmgr/txmgr_test.go b/core/chains/evm/txmgr/txmgr_test.go index aac9d89c490..85d25d8a70b 100644 --- a/core/chains/evm/txmgr/txmgr_test.go +++ b/core/chains/evm/txmgr/txmgr_test.go @@ -68,7 +68,6 @@ func makeTestEvmTxm( ) return txmgr.NewTxm( - db, db, ccfg, fcfg, diff --git a/core/chains/legacyevm/chain.go b/core/chains/legacyevm/chain.go index 920532518ab..27e0155da52 100644 --- a/core/chains/legacyevm/chain.go +++ b/core/chains/legacyevm/chain.go @@ -9,8 +9,6 @@ import ( gotoml "github.com/pelletier/go-toml/v2" "go.uber.org/multierr" - "github.com/jmoiron/sqlx" - common "github.com/smartcontractkit/chainlink-common/pkg/chains" "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" @@ -162,8 +160,7 @@ type ChainOpts struct { MailMon *mailbox.Monitor GasEstimator gas.EvmFeeEstimator - SqlxDB *sqlx.DB // Deprecated: use DB instead - DB sqlutil.DataSource + DS sqlutil.DataSource // TODO BCF-2513 remove test code from the API // Gen-functions are useful for dependency injection by tests @@ -184,11 +181,8 @@ func (o ChainOpts) Validate() error { if o.MailMon == nil { err = errors.Join(err, errors.New("nil MailMon")) } - if o.SqlxDB == nil { - err = errors.Join(err, errors.New("nil SqlxDB")) - } - if o.DB == nil { - err = errors.Join(err, errors.New("nil DB")) + if o.DS == nil { + err = errors.Join(err, errors.New("nil DS")) } if err != nil { err = fmt.Errorf("invalid ChainOpts: %w", err) @@ -229,7 +223,7 @@ func newChain(ctx context.Context, cfg *evmconfig.ChainScoped, nodes []*toml.Nod if !opts.AppConfig.EVMRPCEnabled() { headTracker = headtracker.NullTracker } else if opts.GenHeadTracker == nil { - orm := headtracker.NewORM(*chainID, opts.DB) + orm := headtracker.NewORM(*chainID, opts.DS) headSaver = headtracker.NewHeadSaver(l, orm, cfg.EVM(), cfg.EVM().HeadTracker()) headTracker = headtracker.NewHeadTracker(l, client, cfg.EVM(), cfg.EVM().HeadTracker(), headBroadcaster, headSaver, opts.MailMon) } else { @@ -251,12 +245,12 @@ func newChain(ctx context.Context, cfg *evmconfig.ChainScoped, nodes []*toml.Nod LogPrunePageSize: int64(cfg.EVM().LogPrunePageSize()), BackupPollerBlockDelay: int64(cfg.EVM().BackupLogPollerBlockDelay()), } - logPoller = logpoller.NewLogPoller(logpoller.NewObservedORM(chainID, opts.DB, l), client, l, lpOpts) + logPoller = logpoller.NewLogPoller(logpoller.NewObservedORM(chainID, opts.DS, l), client, l, lpOpts) } } // note: gas estimator is started as a part of the txm - txm, gasEstimator, err := newEvmTxm(opts.SqlxDB, opts.DB, cfg.EVM(), opts.AppConfig.EVMRPCEnabled(), opts.AppConfig.Database(), opts.AppConfig.Database().Listener(), client, l, logPoller, opts) + txm, gasEstimator, err := newEvmTxm(opts.DS, cfg.EVM(), opts.AppConfig.EVMRPCEnabled(), opts.AppConfig.Database(), opts.AppConfig.Database().Listener(), client, l, logPoller, opts) if err != nil { return nil, fmt.Errorf("failed to instantiate EvmTxm for chain with ID %s: %w", chainID.String(), err) } @@ -279,7 +273,7 @@ func newChain(ctx context.Context, cfg *evmconfig.ChainScoped, nodes []*toml.Nod if !opts.AppConfig.EVMRPCEnabled() { logBroadcaster = &log.NullBroadcaster{ErrMsg: fmt.Sprintf("Ethereum is disabled for chain %d", chainID)} } else if opts.GenLogBroadcaster == nil { - logORM := log.NewORM(opts.SqlxDB, *chainID) + logORM := log.NewORM(opts.DS, *chainID) logBroadcaster = log.NewBroadcaster(logORM, client, cfg.EVM(), l, highestSeenHead, opts.MailMon) } else { logBroadcaster = opts.GenLogBroadcaster(chainID) diff --git a/core/chains/legacyevm/chain_test.go b/core/chains/legacyevm/chain_test.go index 5dd7eb1c6ed..c10712d4b6b 100644 --- a/core/chains/legacyevm/chain_test.go +++ b/core/chains/legacyevm/chain_test.go @@ -65,8 +65,7 @@ func TestChainOpts_Validate(t *testing.T) { o := legacyevm.ChainOpts{ AppConfig: tt.fields.AppConfig, MailMon: tt.fields.MailMon, - SqlxDB: tt.fields.DB, - DB: tt.fields.DB, + DS: tt.fields.DB, } if err := o.Validate(); (err != nil) != tt.wantErr { t.Errorf("ChainOpts.Validate() error = %v, wantErr %v", err, tt.wantErr) diff --git a/core/chains/legacyevm/evm_txm.go b/core/chains/legacyevm/evm_txm.go index 6b2d1262ce8..df1f4248ce2 100644 --- a/core/chains/legacyevm/evm_txm.go +++ b/core/chains/legacyevm/evm_txm.go @@ -3,8 +3,6 @@ package legacyevm import ( "fmt" - "github.com/jmoiron/sqlx" - "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" evmconfig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" @@ -15,8 +13,7 @@ import ( ) func newEvmTxm( - sqlxDB *sqlx.DB, - db sqlutil.DataSource, + ds sqlutil.DataSource, cfg evmconfig.EVM, evmRPCEnabled bool, databaseConfig txmgr.DatabaseConfig, @@ -53,8 +50,7 @@ func newEvmTxm( if opts.GenTxManager == nil { txm, err = txmgr.NewTxm( - sqlxDB, - db, + ds, cfg, txmgr.NewEvmTxmFeeConfig(cfg.GasEstimator()), cfg.Transactions(), diff --git a/core/cmd/eth_keys_commands_test.go b/core/cmd/eth_keys_commands_test.go index 2f22cd1d3ae..64835c7f28b 100644 --- a/core/cmd/eth_keys_commands_test.go +++ b/core/cmd/eth_keys_commands_test.go @@ -182,7 +182,7 @@ func TestShell_CreateETHKey(t *testing.T) { withKey(), withMocks(ethClient), ) - db := app.GetSqlxDB() + db := app.GetDB() client, _ := app.NewShellAndRenderer() cltest.AssertCount(t, db, "evm.key_states", 1) // The initial funding key @@ -306,7 +306,7 @@ func TestShell_ImportExportETHKey_NoChains(t *testing.T) { _, err = ethKeyStore.Get(testutils.Context(t), address) require.Error(t, err) - cltest.AssertCount(t, app.GetSqlxDB(), "evm.key_states", 0) + cltest.AssertCount(t, app.GetDB(), "evm.key_states", 0) // Import the key set = flag.NewFlagSet("test", 0) diff --git a/core/cmd/evm_transaction_commands_test.go b/core/cmd/evm_transaction_commands_test.go index df5d066927a..5375abbacee 100644 --- a/core/cmd/evm_transaction_commands_test.go +++ b/core/cmd/evm_transaction_commands_test.go @@ -32,7 +32,7 @@ func TestShell_IndexTransactions(t *testing.T) { _, from := cltest.MustInsertRandomKey(t, app.KeyStore.Eth()) - txStore := cltest.NewTestTxStore(t, app.GetSqlxDB()) + txStore := cltest.NewTestTxStore(t, app.GetDB()) tx := cltest.MustInsertConfirmedEthTxWithLegacyAttempt(t, txStore, 0, 1, from) attempt := tx.TxAttempts[0] @@ -70,7 +70,7 @@ func TestShell_ShowTransaction(t *testing.T) { app := startNewApplicationV2(t, nil) client, r := app.NewShellAndRenderer() - db := app.GetSqlxDB() + db := app.GetDB() _, from := cltest.MustInsertRandomKey(t, app.KeyStore.Eth()) txStore := cltest.NewTestTxStore(t, db) @@ -97,7 +97,7 @@ func TestShell_IndexTxAttempts(t *testing.T) { _, from := cltest.MustInsertRandomKey(t, app.KeyStore.Eth()) - txStore := cltest.NewTestTxStore(t, app.GetSqlxDB()) + txStore := cltest.NewTestTxStore(t, app.GetDB()) tx := cltest.MustInsertConfirmedEthTxWithLegacyAttempt(t, txStore, 0, 1, from) // page 1 @@ -156,7 +156,7 @@ func TestShell_SendEther_From_Txm(t *testing.T) { withMocks(ethMock, key), ) client, r := app.NewShellAndRenderer() - db := app.GetSqlxDB() + db := app.GetDB() txStore := txmgr.NewTxStore(db, logger.TestLogger(t)) set := flag.NewFlagSet("sendether", 0) flagSetApplyFromAction(client.SendEther, set, "") @@ -221,7 +221,7 @@ func TestShell_SendEther_From_Txm_WEI(t *testing.T) { withMocks(ethMock, key), ) client, r := app.NewShellAndRenderer() - db := app.GetSqlxDB() + db := app.GetDB() txStore := txmgr.NewTxStore(db, logger.TestLogger(t)) set := flag.NewFlagSet("sendether", 0) diff --git a/core/cmd/jobs_commands_test.go b/core/cmd/jobs_commands_test.go index 75e95db84ca..77d2487509a 100644 --- a/core/cmd/jobs_commands_test.go +++ b/core/cmd/jobs_commands_test.go @@ -17,6 +17,7 @@ import ( commonconfig "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink/v2/core/cmd" "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/services/chainlink" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/store/models" @@ -428,7 +429,8 @@ func TestShell_DeleteJob(t *testing.T) { requireJobsCount(t, app.JobORM(), 1) - jobs, _, err := app.JobORM().FindJobs(0, 1000) + ctx := testutils.Context(t) + jobs, _, err := app.JobORM().FindJobs(ctx, 0, 1000) require.NoError(t, err) jobID := jobs[0].ID cltest.AwaitJobActive(t, app.JobSpawner(), jobID, 3*time.Second) @@ -451,7 +453,8 @@ func TestShell_DeleteJob(t *testing.T) { } func requireJobsCount(t *testing.T, orm job.ORM, expected int) { - jobs, _, err := orm.FindJobs(0, 1000) + ctx := testutils.Context(t) + jobs, _, err := orm.FindJobs(ctx, 0, 1000) require.NoError(t, err) require.Len(t, jobs, expected) } diff --git a/core/cmd/shell.go b/core/cmd/shell.go index bc58c5cab6d..0372148e742 100644 --- a/core/cmd/shell.go +++ b/core/cmd/shell.go @@ -136,7 +136,7 @@ type AppFactory interface { type ChainlinkAppFactory struct{} // NewApplication returns a new instance of the node with the given config. -func (n ChainlinkAppFactory) NewApplication(ctx context.Context, cfg chainlink.GeneralConfig, appLggr logger.Logger, sqlxDB *sqlx.DB) (app chainlink.Application, err error) { +func (n ChainlinkAppFactory) NewApplication(ctx context.Context, cfg chainlink.GeneralConfig, appLggr logger.Logger, db *sqlx.DB) (app chainlink.Application, err error) { err = initGlobals(cfg.Prometheus(), cfg.Tracing(), appLggr) if err != nil { appLggr.Errorf("Failed to initialize globals: %v", err) @@ -147,14 +147,14 @@ func (n ChainlinkAppFactory) NewApplication(ctx context.Context, cfg chainlink.G return nil, err } - db := sqlutil.WrapDataSource(sqlxDB, appLggr, sqlutil.TimeoutHook(cfg.Database().DefaultQueryTimeout), sqlutil.MonitorHook(cfg.Database().LogSQL)) - - err = handleNodeVersioning(ctx, sqlxDB, appLggr, cfg.RootDir(), cfg.Database(), cfg.WebServer().HTTPPort()) + err = handleNodeVersioning(ctx, db, appLggr, cfg.RootDir(), cfg.Database(), cfg.WebServer().HTTPPort()) if err != nil { return nil, err } - keyStore := keystore.New(db, utils.GetScryptParams(cfg), appLggr) + ds := sqlutil.WrapDataSource(db, appLggr, sqlutil.TimeoutHook(cfg.Database().DefaultQueryTimeout), sqlutil.MonitorHook(cfg.Database().LogSQL)) + + keyStore := keystore.New(ds, utils.GetScryptParams(cfg), appLggr) mailMon := mailbox.NewMonitor(cfg.AppID().String(), appLggr.Named("Mailbox")) loopRegistry := plugins.NewLoopRegistry(appLggr, cfg.Tracing()) @@ -175,7 +175,7 @@ func (n ChainlinkAppFactory) NewApplication(ctx context.Context, cfg chainlink.G evmFactoryCfg := chainlink.EVMFactoryConfig{ CSAETHKeystore: keyStore, - ChainOpts: legacyevm.ChainOpts{AppConfig: cfg, MailMon: mailMon, SqlxDB: sqlxDB, DB: sqlxDB}, + ChainOpts: legacyevm.ChainOpts{AppConfig: cfg, MailMon: mailMon, DS: ds}, } // evm always enabled for backward compatibility // TODO BCF-2510 this needs to change in order to clear the path for EVM extraction @@ -185,8 +185,7 @@ func (n ChainlinkAppFactory) NewApplication(ctx context.Context, cfg chainlink.G cosmosCfg := chainlink.CosmosFactoryConfig{ Keystore: keyStore.Cosmos(), TOMLConfigs: cfg.CosmosConfigs(), - DB: sqlxDB, - QConfig: cfg.Database(), + DS: ds, } initOps = append(initOps, chainlink.InitCosmos(ctx, relayerFactory, cosmosCfg)) } @@ -219,11 +218,10 @@ func (n ChainlinkAppFactory) NewApplication(ctx context.Context, cfg chainlink.G restrictedClient := clhttp.NewRestrictedHTTPClient(cfg.Database(), appLggr) unrestrictedClient := clhttp.NewUnrestrictedHTTPClient() - externalInitiatorManager := webhook.NewExternalInitiatorManager(sqlxDB, unrestrictedClient, appLggr, cfg.Database()) + externalInitiatorManager := webhook.NewExternalInitiatorManager(ds, unrestrictedClient) return chainlink.NewApplication(chainlink.ApplicationOpts{ Config: cfg, - SqlxDB: sqlxDB, - DB: db, + DS: ds, KeyStore: keyStore, RelayerChainInteroperators: relayChainInterops, MailMon: mailMon, @@ -244,7 +242,7 @@ func (n ChainlinkAppFactory) NewApplication(ctx context.Context, cfg chainlink.G func handleNodeVersioning(ctx context.Context, db *sqlx.DB, appLggr logger.Logger, rootDir string, cfg config.Database, healthReportPort uint16) error { var err error // Set up the versioning Configs - verORM := versioning.NewORM(db, appLggr, cfg.DefaultQueryTimeout()) + verORM := versioning.NewORM(db, appLggr) if static.Version != static.Unset { var appv, dbv *semver.Version @@ -272,7 +270,7 @@ func handleNodeVersioning(ctx context.Context, db *sqlx.DB, appLggr logger.Logge // Migrate the database if cfg.MigrateDatabase() { - if err = migrate.Migrate(ctx, db.DB, appLggr); err != nil { + if err = migrate.Migrate(ctx, db.DB); err != nil { return fmt.Errorf("initializeORM#Migrate: %w", err) } } diff --git a/core/cmd/shell_local.go b/core/cmd/shell_local.go index a61390a4886..6dbffbe404a 100644 --- a/core/cmd/shell_local.go +++ b/core/cmd/shell_local.go @@ -632,7 +632,7 @@ func (s *Shell) RebroadcastTransactions(c *cli.Context) (err error) { s.Logger.Infof("Rebroadcasting transactions from %v to %v", beginningNonce, endingNonce) - orm := txmgr.NewTxStore(app.GetSqlxDB(), lggr) + orm := txmgr.NewTxStore(app.GetDB(), lggr) txBuilder := txmgr.NewEvmTxAttemptBuilder(*ethClient.ConfiguredChainID(), chain.Config().EVM().GasEstimator(), keyStore.Eth(), nil) cfg := txmgr.NewEvmTxmConfig(chain.Config().EVM()) feeCfg := txmgr.NewEvmTxmFeeConfig(chain.Config().EVM().GasEstimator()) @@ -923,7 +923,7 @@ func (s *Shell) RollbackDatabase(c *cli.Context) error { return fmt.Errorf("failed to initialize orm: %v", err) } - if err := migrate.Rollback(ctx, db.DB, s.Logger, version); err != nil { + if err := migrate.Rollback(ctx, db.DB, version); err != nil { return fmt.Errorf("migrateDB failed: %v", err) } @@ -938,7 +938,7 @@ func (s *Shell) VersionDatabase(_ *cli.Context) error { return fmt.Errorf("failed to initialize orm: %v", err) } - version, err := migrate.Current(ctx, db.DB, s.Logger) + version, err := migrate.Current(ctx, db.DB) if err != nil { return fmt.Errorf("migrateDB failed: %v", err) } @@ -955,7 +955,7 @@ func (s *Shell) StatusDatabase(_ *cli.Context) error { return fmt.Errorf("failed to initialize orm: %v", err) } - if err = migrate.Status(ctx, db.DB, s.Logger); err != nil { + if err = migrate.Status(ctx, db.DB); err != nil { return fmt.Errorf("Status failed: %v", err) } return nil @@ -1099,7 +1099,7 @@ func migrateDB(ctx context.Context, config dbConfig, lggr logger.Logger) error { return fmt.Errorf("failed to initialize orm: %v", err) } - if err = migrate.Migrate(ctx, db.DB, lggr); err != nil { + if err = migrate.Migrate(ctx, db.DB); err != nil { return fmt.Errorf("migrateDB failed: %v", err) } return db.Close() @@ -1110,10 +1110,10 @@ func downAndUpDB(ctx context.Context, cfg dbConfig, lggr logger.Logger, baseVers if err != nil { return fmt.Errorf("failed to initialize orm: %v", err) } - if err = migrate.Rollback(ctx, db.DB, lggr, null.IntFrom(baseVersionID)); err != nil { + if err = migrate.Rollback(ctx, db.DB, null.IntFrom(baseVersionID)); err != nil { return fmt.Errorf("test rollback failed: %v", err) } - if err = migrate.Migrate(ctx, db.DB, lggr); err != nil { + if err = migrate.Migrate(ctx, db.DB); err != nil { return fmt.Errorf("second migrateDB failed: %v", err) } return db.Close() diff --git a/core/cmd/shell_local_test.go b/core/cmd/shell_local_test.go index d608d3931d0..7427e6caedb 100644 --- a/core/cmd/shell_local_test.go +++ b/core/cmd/shell_local_test.go @@ -91,8 +91,7 @@ func TestShell_RunNodeWithPasswords(t *testing.T) { ChainOpts: legacyevm.ChainOpts{ AppConfig: cfg, MailMon: &mailbox.Monitor{}, - SqlxDB: db, - DB: db, + DS: db, }, } testRelayers := genTestEVMRelayers(t, opts, keyStore) @@ -196,8 +195,7 @@ func TestShell_RunNodeWithAPICredentialsFile(t *testing.T) { ChainOpts: legacyevm.ChainOpts{ AppConfig: cfg, MailMon: &mailbox.Monitor{}, - SqlxDB: db, - DB: db, + DS: db, }, } testRelayers := genTestEVMRelayers(t, opts, keyStore) @@ -299,7 +297,7 @@ func TestShell_RebroadcastTransactions_Txm(t *testing.T) { lggr := logger.TestLogger(t) app := mocks.NewApplication(t) - app.On("GetSqlxDB").Return(sqlxDB) + app.On("GetDB").Return(sqlxDB) app.On("GetKeyStore").Return(keyStore) app.On("ID").Maybe().Return(uuid.New()) app.On("GetConfig").Return(config) @@ -381,7 +379,7 @@ func TestShell_RebroadcastTransactions_OutsideRange_Txm(t *testing.T) { lggr := logger.TestLogger(t) app := mocks.NewApplication(t) - app.On("GetSqlxDB").Return(sqlxDB) + app.On("GetDB").Return(sqlxDB) app.On("GetKeyStore").Return(keyStore) app.On("ID").Maybe().Return(uuid.New()) app.On("GetConfig").Return(config) @@ -460,7 +458,7 @@ func TestShell_RebroadcastTransactions_AddressCheck(t *testing.T) { lggr := logger.TestLogger(t) app := mocks.NewApplication(t) - app.On("GetSqlxDB").Maybe().Return(sqlxDB) + app.On("GetDB").Maybe().Return(sqlxDB) app.On("GetKeyStore").Return(keyStore) app.On("ID").Maybe().Return(uuid.New()) ethClient := evmtest.NewEthClientMockWithDefaultChain(t) diff --git a/core/cmd/shell_remote_test.go b/core/cmd/shell_remote_test.go index a8c054cd9be..cdbe12d66b4 100644 --- a/core/cmd/shell_remote_test.go +++ b/core/cmd/shell_remote_test.go @@ -151,6 +151,7 @@ func TestShell_CreateExternalInitiator(t *testing.T) { for _, tt := range tests { test := tt t.Run(test.name, func(t *testing.T) { + ctx := testutils.Context(t) app := startNewApplicationV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { c.JobPipeline.ExternalInitiatorsEnabled = ptr(true) }) @@ -165,7 +166,7 @@ func TestShell_CreateExternalInitiator(t *testing.T) { require.NoError(t, err) var exi bridges.ExternalInitiator - err = app.GetSqlxDB().Get(&exi, `SELECT * FROM external_initiators WHERE name = $1`, test.args[0]) + err = app.GetDB().GetContext(ctx, &exi, `SELECT * FROM external_initiators WHERE name = $1`, test.args[0]) require.NoError(t, err) if len(test.args) > 1 { @@ -195,7 +196,7 @@ func TestShell_CreateExternalInitiator_Errors(t *testing.T) { }) client, _ := app.NewShellAndRenderer() - initialExis := len(cltest.AllExternalInitiators(t, app.GetSqlxDB())) + initialExis := len(cltest.AllExternalInitiators(t, app.GetDB())) set := flag.NewFlagSet("create", 0) flagSetApplyFromAction(client.CreateExternalInitiator, set, "") @@ -206,7 +207,7 @@ func TestShell_CreateExternalInitiator_Errors(t *testing.T) { err := client.CreateExternalInitiator(c) assert.Error(t, err) - exis := cltest.AllExternalInitiators(t, app.GetSqlxDB()) + exis := cltest.AllExternalInitiators(t, app.GetDB()) assert.Len(t, exis, initialExis) }) } @@ -580,8 +581,8 @@ func TestShell_RunOCRJob_HappyPath(t *testing.T) { require.NoError(t, app.KeyStore.OCR().Add(ctx, cltest.DefaultOCRKey)) - _, bridge := cltest.MustCreateBridge(t, app.GetSqlxDB(), cltest.BridgeOpts{}) - _, bridge2 := cltest.MustCreateBridge(t, app.GetSqlxDB(), cltest.BridgeOpts{}) + _, bridge := cltest.MustCreateBridge(t, app.GetDB(), cltest.BridgeOpts{}) + _, bridge2 := cltest.MustCreateBridge(t, app.GetDB(), cltest.BridgeOpts{}) var jb job.Job ocrspec := testspecs.GenerateOCRSpec(testspecs.OCRSpecParams{DS1BridgeName: bridge.Name.String(), DS2BridgeName: bridge2.Name.String()}) @@ -669,7 +670,7 @@ func TestShell_AutoLogin(t *testing.T) { require.NoError(t, err) // Expire the session and then try again - pgtest.MustExec(t, app.GetSqlxDB(), "delete from sessions where email = $1", user.Email) + pgtest.MustExec(t, app.GetDB(), "delete from sessions where email = $1", user.Email) err = client.ListJobs(cli.NewContext(nil, fs, nil)) require.NoError(t, err) } diff --git a/core/internal/cltest/cltest.go b/core/internal/cltest/cltest.go index 8123439dafb..dc7079e44d9 100644 --- a/core/internal/cltest/cltest.go +++ b/core/internal/cltest/cltest.go @@ -181,11 +181,11 @@ type JobPipelineConfig interface { MaxSuccessfulRuns() uint64 } -func NewJobPipelineV2(t testing.TB, cfg pipeline.BridgeConfig, jpcfg JobPipelineConfig, dbCfg pg.QConfig, legacyChains legacyevm.LegacyChainContainer, db *sqlx.DB, keyStore keystore.Master, restrictedHTTPClient, unrestrictedHTTPClient *http.Client) JobPipelineV2TestHelper { +func NewJobPipelineV2(t testing.TB, cfg pipeline.BridgeConfig, jpcfg JobPipelineConfig, legacyChains legacyevm.LegacyChainContainer, db *sqlx.DB, keyStore keystore.Master, restrictedHTTPClient, unrestrictedHTTPClient *http.Client) JobPipelineV2TestHelper { lggr := logger.TestLogger(t) prm := pipeline.NewORM(db, lggr, jpcfg.MaxSuccessfulRuns()) btORM := bridges.NewORM(db) - jrm := job.NewORM(db, prm, btORM, keyStore, lggr, dbCfg) + jrm := job.NewORM(db, prm, btORM, keyStore, lggr) pr := pipeline.NewRunner(prm, btORM, jpcfg, cfg, legacyChains, keyStore.Eth(), keyStore.VRF(), lggr, restrictedHTTPClient, unrestrictedHTTPClient) return JobPipelineV2TestHelper{ prm, @@ -323,6 +323,8 @@ func NewApplicationWithConfig(t testing.TB, cfg chainlink.GeneralConfig, flagsAn require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) + ds := sqlutil.WrapDataSource(db, lggr, sqlutil.TimeoutHook(cfg.Database().DefaultQueryTimeout)) + var ethClient evmclient.Client var externalInitiatorManager webhook.ExternalInitiatorManager externalInitiatorManager = &webhook.NullExternalInitiatorManager{} @@ -337,13 +339,13 @@ func NewApplicationWithConfig(t testing.TB, cfg chainlink.GeneralConfig, flagsAn default: switch flag { case UseRealExternalInitiatorManager: - externalInitiatorManager = webhook.NewExternalInitiatorManager(db, clhttptest.NewTestLocalOnlyHTTPClient(), lggr, cfg.Database()) + externalInitiatorManager = webhook.NewExternalInitiatorManager(ds, clhttptest.NewTestLocalOnlyHTTPClient()) } } } - keyStore := keystore.NewInMemory(db, utils.FastScryptParams, lggr) + keyStore := keystore.NewInMemory(ds, utils.FastScryptParams, lggr) mailMon := mailbox.NewMonitor(cfg.AppID().String(), lggr.Named("Mailbox")) loopRegistry := plugins.NewLoopRegistry(lggr, nil) @@ -365,8 +367,7 @@ func NewApplicationWithConfig(t testing.TB, cfg chainlink.GeneralConfig, flagsAn ChainOpts: legacyevm.ChainOpts{ AppConfig: cfg, MailMon: mailMon, - SqlxDB: db, - DB: db, + DS: ds, }, CSAETHKeystore: keyStore, } @@ -392,8 +393,7 @@ func NewApplicationWithConfig(t testing.TB, cfg chainlink.GeneralConfig, flagsAn cosmosCfg := chainlink.CosmosFactoryConfig{ Keystore: keyStore.Cosmos(), TOMLConfigs: cfg.CosmosConfigs(), - DB: db, - QConfig: cfg.Database(), + DS: ds, } initOps = append(initOps, chainlink.InitCosmos(testCtx, relayerFactory, cosmosCfg)) } @@ -420,8 +420,7 @@ func NewApplicationWithConfig(t testing.TB, cfg chainlink.GeneralConfig, flagsAn appInstance, err := chainlink.NewApplication(chainlink.ApplicationOpts{ Config: cfg, MailMon: mailMon, - SqlxDB: db, - DB: db, + DS: ds, KeyStore: keyStore, RelayerChainInteroperators: relayChainInterops, Logger: lggr, @@ -569,9 +568,10 @@ func (ta *TestApplication) Stop() error { } func (ta *TestApplication) MustSeedNewSession(email string) (id string) { + ctx := testutils.Context(ta.t) session := NewSession() ta.Logger.Infof("TestApplication creating session (id: %s, email: %s, last used: %s)", session.ID, email, session.LastUsed.String()) - err := ta.GetSqlxDB().Get(&id, `INSERT INTO sessions (id, email, last_used, created_at) VALUES ($1, $2, $3, NOW()) RETURNING id`, session.ID, email, session.LastUsed) + err := ta.GetDB().GetContext(ctx, &id, `INSERT INTO sessions (id, email, last_used, created_at) VALUES ($1, $2, $3, NOW()) RETURNING id`, session.ID, email, session.LastUsed) require.NoError(ta.t, err) return id } @@ -903,13 +903,14 @@ const ( // WaitForSpecErrorV2 polls until the passed in jobID has count number // of job spec errors. -func WaitForSpecErrorV2(t *testing.T, db *sqlx.DB, jobID int32, count int) []job.SpecError { +func WaitForSpecErrorV2(t *testing.T, ds sqlutil.DataSource, jobID int32, count int) []job.SpecError { t.Helper() + ctx := testutils.Context(t) g := gomega.NewWithT(t) var jse []job.SpecError g.Eventually(func() []job.SpecError { - err := db.Select(&jse, `SELECT * FROM job_spec_errors WHERE job_id = $1`, jobID) + err := ds.SelectContext(ctx, &jse, `SELECT * FROM job_spec_errors WHERE job_id = $1`, jobID) assert.NoError(t, err) return jse }, testutils.WaitTimeout(t), DBPollingInterval).Should(gomega.HaveLen(count)) @@ -930,7 +931,7 @@ func WaitForPipeline(t testing.TB, nodeID int, jobID int32, expectedPipelineRuns var pr []pipeline.Run gomega.NewWithT(t).Eventually(func() bool { - prs, _, err := jo.PipelineRuns(&jobID, 0, 1000) + prs, _, err := jo.PipelineRuns(testutils.Context(t), &jobID, 0, 1000) require.NoError(t, err) var matched []pipeline.Run @@ -964,13 +965,14 @@ func WaitForPipeline(t testing.TB, nodeID int, jobID int32, expectedPipelineRuns } // AssertPipelineRunsStays asserts that the number of pipeline runs for a particular job remains at the provided values -func AssertPipelineRunsStays(t testing.TB, pipelineSpecID int32, db *sqlx.DB, want int) []pipeline.Run { +func AssertPipelineRunsStays(t testing.TB, pipelineSpecID int32, db sqlutil.DataSource, want int) []pipeline.Run { t.Helper() + ctx := testutils.Context(t) g := gomega.NewWithT(t) var prs []pipeline.Run g.Consistently(func() []pipeline.Run { - err := db.Select(&prs, `SELECT * FROM pipeline_runs WHERE pipeline_spec_id = $1`, pipelineSpecID) + err := db.SelectContext(ctx, &prs, `SELECT * FROM pipeline_runs WHERE pipeline_spec_id = $1`, pipelineSpecID) assert.NoError(t, err) return prs }, AssertNoActionTimeout, DBPollingInterval).Should(gomega.HaveLen(want)) @@ -1161,11 +1163,12 @@ func NewSession(optionalSessionID ...string) clsessions.Session { return session } -func AllExternalInitiators(t testing.TB, db *sqlx.DB) []bridges.ExternalInitiator { +func AllExternalInitiators(t testing.TB, ds sqlutil.DataSource) []bridges.ExternalInitiator { t.Helper() + ctx := testutils.Context(t) var all []bridges.ExternalInitiator - err := db.Select(&all, `SELECT * FROM external_initiators`) + err := ds.SelectContext(ctx, &all, `SELECT * FROM external_initiators`) require.NoError(t, err) return all } @@ -1518,35 +1521,38 @@ func AssertCount(t *testing.T, ds sqlutil.DataSource, tableName string, expected testutils.AssertCount(t, ds, tableName, expected) } -func WaitForCount(t *testing.T, db *sqlx.DB, tableName string, want int64) { +func WaitForCount(t *testing.T, ds sqlutil.DataSource, tableName string, want int64) { t.Helper() + ctx := testutils.Context(t) g := gomega.NewWithT(t) var count int64 var err error g.Eventually(func() int64 { - err = db.Get(&count, fmt.Sprintf(`SELECT count(*) FROM %s;`, tableName)) + err = ds.GetContext(ctx, &count, fmt.Sprintf(`SELECT count(*) FROM %s;`, tableName)) assert.NoError(t, err) return count }, testutils.WaitTimeout(t), DBPollingInterval).Should(gomega.Equal(want)) } -func AssertCountStays(t testing.TB, db *sqlx.DB, tableName string, want int64) { +func AssertCountStays(t testing.TB, ds sqlutil.DataSource, tableName string, want int64) { t.Helper() + ctx := testutils.Context(t) g := gomega.NewWithT(t) var count int64 var err error g.Consistently(func() int64 { - err = db.Get(&count, fmt.Sprintf(`SELECT count(*) FROM %s`, tableName)) + err = ds.GetContext(ctx, &count, fmt.Sprintf(`SELECT count(*) FROM %s`, tableName)) assert.NoError(t, err) return count }, AssertNoActionTimeout, DBPollingInterval).Should(gomega.Equal(want)) } -func AssertRecordEventually(t *testing.T, db *sqlx.DB, model interface{}, stmt string, check func() bool) { +func AssertRecordEventually(t *testing.T, ds sqlutil.DataSource, model interface{}, stmt string, check func() bool) { t.Helper() + ctx := testutils.Context(t) g := gomega.NewWithT(t) g.Eventually(func() bool { - err := db.Get(model, stmt) + err := ds.GetContext(ctx, model, stmt) require.NoError(t, err, "unable to find record in DB") return check() }, testutils.WaitTimeout(t), DBPollingInterval).Should(gomega.BeTrue()) diff --git a/core/internal/cltest/factories.go b/core/internal/cltest/factories.go index 43cf902ca8a..cd2fa9d9f63 100644 --- a/core/internal/cltest/factories.go +++ b/core/internal/cltest/factories.go @@ -98,9 +98,9 @@ func NewBridgeType(t testing.TB, opts BridgeOpts) (*bridges.BridgeTypeAuthentica // MustCreateBridge creates a bridge // Be careful not to specify a name here unless you ABSOLUTELY need to // This is because name is a unique index and identical names used across transactional tests will lock/deadlock -func MustCreateBridge(t testing.TB, db *sqlx.DB, opts BridgeOpts) (bta *bridges.BridgeTypeAuthentication, bt *bridges.BridgeType) { +func MustCreateBridge(t testing.TB, ds sqlutil.DataSource, opts BridgeOpts) (bta *bridges.BridgeTypeAuthentication, bt *bridges.BridgeType) { bta, bt = NewBridgeType(t, opts) - orm := bridges.NewORM(db) + orm := bridges.NewORM(ds) err := orm.CreateBridgeType(testutils.Context(t), bt) require.NoError(t, err) return bta, bt @@ -317,9 +317,9 @@ func MustGenerateRandomKeyState(_ testing.TB) ethkey.State { return ethkey.State{Address: NewEIP55Address()} } -func MustInsertHead(t *testing.T, db sqlutil.DataSource, number int64) evmtypes.Head { +func MustInsertHead(t *testing.T, ds sqlutil.DataSource, number int64) evmtypes.Head { h := evmtypes.NewHead(big.NewInt(number), evmutils.NewHash(), evmutils.NewHash(), 0, ubig.New(&FixtureChainID)) - horm := headtracker.NewORM(FixtureChainID, db) + horm := headtracker.NewORM(FixtureChainID, ds) err := horm.IdempotentInsertHead(testutils.Context(t), &h) require.NoError(t, err) @@ -347,8 +347,8 @@ func MustInsertV2JobSpec(t *testing.T, db *sqlx.DB, transmitterAddress common.Ad PipelineSpecID: pipelineSpec.ID, } - jorm := job.NewORM(db, nil, nil, nil, logger.TestLogger(t), configtest.NewTestGeneralConfig(t).Database()) - err = jorm.InsertJob(&jb) + jorm := job.NewORM(db, nil, nil, nil, logger.TestLogger(t)) + err = jorm.InsertJob(testutils.Context(t), &jb) require.NoError(t, err) return jb } @@ -404,8 +404,8 @@ func MustInsertKeeperJob(t *testing.T, db *sqlx.DB, korm *keeper.ORM, from evmty tlg := logger.TestLogger(t) prm := pipeline.NewORM(db, tlg, cfg.JobPipeline().MaxSuccessfulRuns()) btORM := bridges.NewORM(db) - jrm := job.NewORM(db, prm, btORM, nil, tlg, cfg.Database()) - err = jrm.InsertJob(&jb) + jrm := job.NewORM(db, prm, btORM, nil, tlg) + err = jrm.InsertJob(testutils.Context(t), &jb) require.NoError(t, err) jb.PipelineSpec.JobID = jb.ID return jb diff --git a/core/internal/cltest/job_factories.go b/core/internal/cltest/job_factories.go index d78440838b2..6ba13f97726 100644 --- a/core/internal/cltest/job_factories.go +++ b/core/internal/cltest/job_factories.go @@ -7,8 +7,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" - "github.com/jmoiron/sqlx" - + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/bridges" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" @@ -43,30 +42,30 @@ func MinimalOCRNonBootstrapSpec(contractAddress, transmitterAddress types.EIP55A return fmt.Sprintf(minimalOCRNonBootstrapTemplate, contractAddress, peerID, transmitterAddress.Hex(), keyBundleID) } -func MustInsertWebhookSpec(t *testing.T, db *sqlx.DB) (job.Job, job.WebhookSpec) { +func MustInsertWebhookSpec(t *testing.T, ds sqlutil.DataSource) (job.Job, job.WebhookSpec) { ctx := testutils.Context(t) - jobORM, pipelineORM := getORMs(t, db) + jobORM, pipelineORM := getORMs(t, ds) webhookSpec := job.WebhookSpec{} - require.NoError(t, jobORM.InsertWebhookSpec(&webhookSpec)) + require.NoError(t, jobORM.InsertWebhookSpec(ctx, &webhookSpec)) pSpec := pipeline.Pipeline{} - pipelineSpecID, err := pipelineORM.CreateSpec(ctx, nil, pSpec, 0) + pipelineSpecID, err := pipelineORM.CreateSpec(ctx, pSpec, 0) require.NoError(t, err) createdJob := job.Job{WebhookSpecID: &webhookSpec.ID, WebhookSpec: &webhookSpec, SchemaVersion: 1, Type: "webhook", ExternalJobID: uuid.New(), PipelineSpecID: pipelineSpecID} - require.NoError(t, jobORM.InsertJob(&createdJob)) + require.NoError(t, jobORM.InsertJob(ctx, &createdJob)) return createdJob, webhookSpec } -func getORMs(t *testing.T, db *sqlx.DB) (jobORM job.ORM, pipelineORM pipeline.ORM) { +func getORMs(t *testing.T, ds sqlutil.DataSource) (jobORM job.ORM, pipelineORM pipeline.ORM) { config := configtest.NewTestGeneralConfig(t) - keyStore := NewKeyStore(t, db) + keyStore := NewKeyStore(t, ds) lggr := logger.TestLogger(t) - pipelineORM = pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) - bridgeORM := bridges.NewORM(db) - jobORM = job.NewORM(db, pipelineORM, bridgeORM, keyStore, lggr, config.Database()) + pipelineORM = pipeline.NewORM(ds, lggr, config.JobPipeline().MaxSuccessfulRuns()) + bridgeORM := bridges.NewORM(ds) + jobORM = job.NewORM(ds, pipelineORM, bridgeORM, keyStore, lggr) t.Cleanup(func() { jobORM.Close() }) return } diff --git a/core/internal/features/features_test.go b/core/internal/features/features_test.go index 75ff98d05be..26e7d5eae56 100644 --- a/core/internal/features/features_test.go +++ b/core/internal/features/features_test.go @@ -82,6 +82,7 @@ var oneETH = assets.Eth(*big.NewInt(1000000000000000000)) func TestIntegration_ExternalInitiatorV2(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) ethClient := cltest.NewEthMocksWithStartupAssertions(t) cfg := configtest.NewGeneralConfig(t, func(c *chainlink.Config, s *chainlink.Secrets) { @@ -175,7 +176,7 @@ func TestIntegration_ExternalInitiatorV2(t *testing.T) { require.NoError(t, err) })) u, _ := url.Parse(bridgeServer.URL) - err := app.BridgeORM().CreateBridgeType(testutils.Context(t), &bridges.BridgeType{ + err := app.BridgeORM().CreateBridgeType(ctx, &bridges.BridgeType{ Name: bridges.BridgeName("substrate-adapter1"), URL: models.WebURL(*u), }) @@ -205,7 +206,7 @@ observationSource = """ """ `, jobUUID, eiName, cltest.MustJSONMarshal(t, eiSpec)) - _, err := webhook.ValidatedWebhookSpec(tomlSpec, app.GetExternalInitiatorManager()) + _, err := webhook.ValidatedWebhookSpec(ctx, tomlSpec, app.GetExternalInitiatorManager()) require.NoError(t, err) job := cltest.CreateJobViaWeb(t, app, []byte(cltest.MustJSONMarshal(t, web.CreateJobRequest{TOML: tomlSpec}))) jobID = job.ID @@ -227,7 +228,7 @@ observationSource = """ defer cleanup() cltest.AssertServerResponse(t, resp, 401) - cltest.AssertCountStays(t, app.GetSqlxDB(), "pipeline_runs", 0) + cltest.AssertCountStays(t, app.GetDB(), "pipeline_runs", 0) }) t.Run("calling webhook_spec with matching external_initiator_id works", func(t *testing.T) { @@ -236,9 +237,9 @@ observationSource = """ _ = cltest.CreateJobRunViaExternalInitiatorV2(t, app, jobUUID, *eia, cltest.MustJSONMarshal(t, eiRequest)) - pipelineORM := pipeline.NewORM(app.GetSqlxDB(), logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - bridgeORM := bridges.NewORM(app.GetSqlxDB()) - jobORM := job.NewORM(app.GetSqlxDB(), pipelineORM, bridgeORM, app.KeyStore, logger.TestLogger(t), cfg.Database()) + pipelineORM := pipeline.NewORM(app.GetDB(), logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + bridgeORM := bridges.NewORM(app.GetDB()) + jobORM := job.NewORM(app.GetDB(), pipelineORM, bridgeORM, app.KeyStore, logger.TestLogger(t)) runs := cltest.WaitForPipelineComplete(t, 0, jobID, 1, 2, jobORM, 5*time.Second, 300*time.Millisecond) require.Len(t, runs, 1) @@ -992,8 +993,9 @@ observationSource = """ return answer.String() }, testutils.WaitTimeout(t), cltest.DBPollingInterval).Should(gomega.Equal("20")) + ctx := testutils.Context(t) for _, app := range apps { - jobs, _, err := app.JobORM().FindJobs(0, 1000) + jobs, _, err := app.JobORM().FindJobs(ctx, 0, 1000) require.NoError(t, err) // No spec errors for _, j := range jobs { @@ -1220,8 +1222,9 @@ observationSource = """ return answer.String() }, testutils.WaitTimeout(t), cltest.DBPollingInterval).Should(gomega.Equal("20")) + ctx := testutils.Context(t) for _, app := range apps { - jobs, _, err := app.JobORM().FindJobs(0, 1000) + jobs, _, err := app.JobORM().FindJobs(ctx, 0, 1000) require.NoError(t, err) // No spec errors for _, j := range jobs { diff --git a/core/internal/features/ocr2/features_ocr2_test.go b/core/internal/features/ocr2/features_ocr2_test.go index 07e0fc21d9a..fab9d34b4b1 100644 --- a/core/internal/features/ocr2/features_ocr2_test.go +++ b/core/internal/features/ocr2/features_ocr2_test.go @@ -519,6 +519,7 @@ updateInterval = "1m" } }() + ctx := testutils.Context(t) for trial := 0; trial < 2; trial++ { var retVal int @@ -537,7 +538,7 @@ updateInterval = "1m" wg.Add(1) go func() { defer wg.Done() - completedRuns, err2 := apps[ic].JobORM().FindPipelineRunIDsByJobID(jids[ic], 0, 1000) + completedRuns, err2 := apps[ic].JobORM().FindPipelineRunIDsByJobID(ctx, jids[ic], 0, 1000) require.NoError(t, err2) // Want at least 2 runs so we see all the metadata. pr := cltest.WaitForPipelineComplete(t, ic, jids[ic], len(completedRuns)+2, 7, apps[ic].JobORM(), 2*time.Minute, 5*time.Second) @@ -558,7 +559,7 @@ updateInterval = "1m" }, 1*time.Minute, 200*time.Millisecond).Should(gomega.Equal(fmt.Sprintf("%d", 2*retVal))) for _, app := range apps { - jobs, _, err2 := app.JobORM().FindJobs(0, 1000) + jobs, _, err2 := app.JobORM().FindJobs(ctx, 0, 1000) require.NoError(t, err2) // No spec errors for _, j := range jobs { @@ -758,6 +759,7 @@ chainID = 1337 expectedMeta := map[string]struct{}{ "0": {}, "10": {}, "20": {}, "30": {}, } + ctx := testutils.Context(t) for i := 0; i < 4; i++ { s := i require.NoError(t, apps[i].Start(testutils.Context(t))) @@ -790,7 +792,7 @@ chainID = 1337 servers[s].Close() }) u, _ := url.Parse(servers[i].URL) - require.NoError(t, apps[i].BridgeORM().CreateBridgeType(testutils.Context(t), &bridges.BridgeType{ + require.NoError(t, apps[i].BridgeORM().CreateBridgeType(ctx, &bridges.BridgeType{ Name: bridges.BridgeName(fmt.Sprintf("bridge%d", i)), URL: models.WebURL(*u), })) @@ -882,7 +884,7 @@ updateInterval = "1m" }, 1*time.Minute, 200*time.Millisecond).Should(gomega.Equal("20")) for _, app := range apps { - jobs, _, err := app.JobORM().FindJobs(0, 1000) + jobs, _, err := app.JobORM().FindJobs(ctx, 0, 1000) require.NoError(t, err) // No spec errors for _, j := range jobs { diff --git a/core/internal/mocks/application.go b/core/internal/mocks/application.go index 2438eb302c0..c83b37a0e5d 100644 --- a/core/internal/mocks/application.go +++ b/core/internal/mocks/application.go @@ -35,8 +35,6 @@ import ( sqlutil "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" - sqlx "github.com/jmoiron/sqlx" - txmgr "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" types "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" @@ -389,26 +387,6 @@ func (_m *Application) GetRelayers() chainlink.RelayerChainInteroperators { return r0 } -// GetSqlxDB provides a mock function with given fields: -func (_m *Application) GetSqlxDB() *sqlx.DB { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for GetSqlxDB") - } - - var r0 *sqlx.DB - if rf, ok := ret.Get(0).(func() *sqlx.DB); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*sqlx.DB) - } - } - - return r0 -} - // GetWebAuthnConfiguration provides a mock function with given fields: func (_m *Application) GetWebAuthnConfiguration() sessions.WebAuthnConfiguration { ret := _m.Called() diff --git a/core/internal/testutils/evmtest/evmtest.go b/core/internal/testutils/evmtest/evmtest.go index eedc6275928..276dea2ac5d 100644 --- a/core/internal/testutils/evmtest/evmtest.go +++ b/core/internal/testutils/evmtest/evmtest.go @@ -9,13 +9,13 @@ import ( "testing" "github.com/ethereum/go-ethereum" - "github.com/jmoiron/sqlx" "github.com/pelletier/go-toml/v2" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "gopkg.in/guregu/null.v4" "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox/mailboxtest" @@ -62,7 +62,7 @@ type TestChainOpts struct { LogPoller logpoller.LogPoller GeneralConfig legacyevm.AppConfig HeadTracker httypes.HeadTracker - DB *sqlx.DB + DB sqlutil.DataSource TxManager txmgr.TxManager KeyStore keystore.Eth MailMon *mailbox.Monitor @@ -88,8 +88,7 @@ func NewChainRelayExtOpts(t testing.TB, testopts TestChainOpts) legacyevm.ChainR AppConfig: testopts.GeneralConfig, MailMon: testopts.MailMon, GasEstimator: testopts.GasEstimator, - SqlxDB: testopts.DB, - DB: testopts.DB, + DS: testopts.DB, }, } opts.GenEthClient = func(*big.Int) evmclient.Client { diff --git a/core/internal/testutils/pgtest/pgtest.go b/core/internal/testutils/pgtest/pgtest.go index 686483f2d41..8464604b667 100644 --- a/core/internal/testutils/pgtest/pgtest.go +++ b/core/internal/testutils/pgtest/pgtest.go @@ -1,7 +1,6 @@ package pgtest import ( - "database/sql" "testing" "github.com/google/uuid" @@ -10,25 +9,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/store/dialects" ) -func NewQConfig(logSQL bool) pg.QConfig { - return pg.NewQConfig(logSQL) -} - -func NewSqlDB(t *testing.T) *sql.DB { - testutils.SkipShortDB(t) - db, err := sql.Open(string(dialects.TransactionWrappedPostgres), uuid.New().String()) - require.NoError(t, err) - t.Cleanup(func() { assert.NoError(t, db.Close()) }) - - return db -} - func NewSqlxDB(t testing.TB) *sqlx.DB { testutils.SkipShortDB(t) db, err := sqlx.Open(string(dialects.TransactionWrappedPostgres), uuid.New().String()) @@ -39,12 +25,9 @@ func NewSqlxDB(t testing.TB) *sqlx.DB { return db } -func MustExec(t *testing.T, db *sqlx.DB, stmt string, args ...interface{}) { - require.NoError(t, utils.JustError(db.Exec(stmt, args...))) -} - -func MustSelect(t *testing.T, db *sqlx.DB, dest interface{}, stmt string, args ...interface{}) { - require.NoError(t, db.Select(dest, stmt, args...)) +func MustExec(t *testing.T, ds sqlutil.DataSource, stmt string, args ...interface{}) { + ctx := testutils.Context(t) + require.NoError(t, utils.JustError(ds.ExecContext(ctx, stmt, args...))) } func MustCount(t *testing.T, db *sqlx.DB, stmt string, args ...interface{}) (cnt int) { diff --git a/core/internal/testutils/testutils.go b/core/internal/testutils/testutils.go index ba7e697fb62..f4867eda69a 100644 --- a/core/internal/testutils/testutils.go +++ b/core/internal/testutils/testutils.go @@ -420,8 +420,9 @@ func SkipShortDB(tb testing.TB) { func AssertCount(t *testing.T, ds sqlutil.DataSource, tableName string, expected int64) { t.Helper() + ctx := Context(t) var count int64 - err := ds.GetContext(Context(t), &count, fmt.Sprintf(`SELECT count(*) FROM %s;`, tableName)) + err := ds.GetContext(ctx, &count, fmt.Sprintf(`SELECT count(*) FROM %s;`, tableName)) require.NoError(t, err) require.Equal(t, expected, count) } diff --git a/core/scripts/gateway/run_gateway.go b/core/scripts/gateway/run_gateway.go index 5dbcd02bf56..2daca5190a5 100644 --- a/core/scripts/gateway/run_gateway.go +++ b/core/scripts/gateway/run_gateway.go @@ -48,7 +48,7 @@ func main() { lggr, _ := logger.NewLogger() - handlerFactory := gateway.NewHandlerFactory(nil, nil, nil, lggr) + handlerFactory := gateway.NewHandlerFactory(nil, nil, lggr) gw, err := gateway.NewGatewayFromConfig(&cfg, handlerFactory, lggr) if err != nil { fmt.Println("error creating Gateway object:", err) diff --git a/core/scripts/go.mod b/core/scripts/go.mod index 5be5d137696..fff795e8af6 100644 --- a/core/scripts/go.mod +++ b/core/scripts/go.mod @@ -21,7 +21,7 @@ require ( github.com/prometheus/client_golang v1.17.0 github.com/shopspring/decimal v1.3.1 github.com/smartcontractkit/chainlink-automation v1.0.3 - github.com/smartcontractkit/chainlink-common v0.1.7-0.20240419205832-845fa69af8d9 + github.com/smartcontractkit/chainlink-common v0.1.7-0.20240424104752-ed1756cf454c github.com/smartcontractkit/chainlink-vrf v0.0.0-20240222010609-cd67d123c772 github.com/smartcontractkit/chainlink/v2 v2.0.0-00010101000000-000000000000 github.com/smartcontractkit/libocr v0.0.0-20240419185742-fd3cab206b2c @@ -256,7 +256,7 @@ require ( github.com/shirou/gopsutil/v3 v3.24.3 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/smartcontractkit/chain-selectors v1.0.10 // indirect - github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419131812-73d148593d92 // indirect + github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419213354-ea34a948e2ee // indirect github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540 // indirect github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab // indirect github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240422172640-59d47c73ba58 // indirect diff --git a/core/scripts/go.sum b/core/scripts/go.sum index b26bad1299e..f66cfce2e07 100644 --- a/core/scripts/go.sum +++ b/core/scripts/go.sum @@ -1185,10 +1185,10 @@ github.com/smartcontractkit/chain-selectors v1.0.10 h1:t9kJeE6B6G+hKD0GYR4kGJSCq github.com/smartcontractkit/chain-selectors v1.0.10/go.mod h1:d4Hi+E1zqjy9HqMkjBE5q1vcG9VGgxf5VxiRHfzi2kE= github.com/smartcontractkit/chainlink-automation v1.0.3 h1:h/ijT0NiyV06VxYVgcNfsE3+8OEzT3Q0Z9au0z1BPWs= github.com/smartcontractkit/chainlink-automation v1.0.3/go.mod h1:RjboV0Qd7YP+To+OrzHGXaxUxoSONveCoAK2TQ1INLU= -github.com/smartcontractkit/chainlink-common v0.1.7-0.20240419205832-845fa69af8d9 h1:elDIBChe7ByPNvCyrSjMLTPKrgY+sKgzzlWe2p3wokY= -github.com/smartcontractkit/chainlink-common v0.1.7-0.20240419205832-845fa69af8d9/go.mod h1:GTDBbovHUSAUk+fuGIySF2A/whhdtHGaWmU61BoERks= -github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419131812-73d148593d92 h1:MvaNzuaQh1vX4CAYLM8qFd99cf0ZF1JNwtDZtLU7WvU= -github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419131812-73d148593d92/go.mod h1:uATrrJ8IsuBkOBJ46USuf73gz9gZy5k5bzGE5/ji/rc= +github.com/smartcontractkit/chainlink-common v0.1.7-0.20240424104752-ed1756cf454c h1:nk3g1il/cG0raV2ymNlytAPvjfYNSvwHP7Gfy6ItmSI= +github.com/smartcontractkit/chainlink-common v0.1.7-0.20240424104752-ed1756cf454c/go.mod h1:GTDBbovHUSAUk+fuGIySF2A/whhdtHGaWmU61BoERks= +github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419213354-ea34a948e2ee h1:eFuBKyEbL2b+eyfgV/Eu9+8HuCEev+IcBi+K9l1dG7g= +github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419213354-ea34a948e2ee/go.mod h1:uATrrJ8IsuBkOBJ46USuf73gz9gZy5k5bzGE5/ji/rc= github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540 h1:xFSv8561jsLtF6gYZr/zW2z5qUUAkcFkApin2mnbYTo= github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540/go.mod h1:sjAmX8K2kbQhvDarZE1ZZgDgmHJ50s0BBc/66vKY2ek= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab h1:Ct1oUlyn03HDUVdFHJqtRGRUujMqdoMzvf/Cjhe30Ag= diff --git a/core/services/chainlink/application.go b/core/services/chainlink/application.go index 88a6fadf345..0ac6555aecc 100644 --- a/core/services/chainlink/application.go +++ b/core/services/chainlink/application.go @@ -16,8 +16,6 @@ import ( "go.uber.org/multierr" "go.uber.org/zap/zapcore" - "github.com/jmoiron/sqlx" - "github.com/smartcontractkit/chainlink-common/pkg/loop" commonservices "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" @@ -53,7 +51,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/ocrcommon" externalp2p "github.com/smartcontractkit/chainlink/v2/core/services/p2p/wrapper" "github.com/smartcontractkit/chainlink/v2/core/services/periodicbackup" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/services/promreporter" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury" @@ -78,7 +75,6 @@ type Application interface { GetLogger() logger.SugaredLogger GetAuditLogger() audit.AuditLogger GetHealthChecker() services.Checker - GetSqlxDB() *sqlx.DB // Deprecated: use GetDB GetDB() sqlutil.DataSource GetConfig() GeneralConfig SetLogLevel(lvl zapcore.Level) error @@ -146,8 +142,7 @@ type ChainlinkApplication struct { logger logger.SugaredLogger AuditLogger audit.AuditLogger closeLogger func() error - sqlxDB *sqlx.DB // Deprecated: use db instead - db sqlutil.DataSource + ds sqlutil.DataSource secretGenerator SecretGenerator profiler *pyroscope.Profiler loopRegistry *plugins.LoopRegistry @@ -161,8 +156,7 @@ type ApplicationOpts struct { Config GeneralConfig Logger logger.Logger MailMon *mailbox.Monitor - SqlxDB *sqlx.DB // Deprecated: use DB instead - DB sqlutil.DataSource + DS sqlutil.DataSource KeyStore keystore.Master RelayerChainInteroperators *CoreRelayerChainInteroperators AuditLogger audit.AuditLogger @@ -185,7 +179,6 @@ type ApplicationOpts struct { func NewApplication(opts ApplicationOpts) (Application, error) { var srvcs []services.ServiceCtx auditLogger := opts.AuditLogger - sqlxDB := opts.SqlxDB cfg := opts.Config relayerChainInterops := opts.RelayerChainInteroperators mailMon := opts.MailMon @@ -277,12 +270,12 @@ func NewApplication(opts ApplicationOpts) (Application, error) { srvcs = append(srvcs, mailMon) srvcs = append(srvcs, relayerChainInterops.Services()...) - promReporter := promreporter.NewPromReporter(sqlxDB.DB, legacyEVMChains, globalLogger) + promReporter := promreporter.NewPromReporter(opts.DS, legacyEVMChains, globalLogger) srvcs = append(srvcs, promReporter) // Initialize Local Users ORM and Authentication Provider specified in config // BasicAdminUsersORM is initialized and required regardless of separate Authentication Provider - localAdminUsersORM := localauth.NewORM(opts.DB, cfg.WebServer().SessionTimeout().Duration(), globalLogger, auditLogger) + localAdminUsersORM := localauth.NewORM(opts.DS, cfg.WebServer().SessionTimeout().Duration(), globalLogger, auditLogger) // Initialize Sessions ORM based on environment configured authenticator // localDB auth or remote LDAP auth @@ -294,26 +287,26 @@ func NewApplication(opts ApplicationOpts) (Application, error) { case sessions.LDAPAuth: var err error authenticationProvider, err = ldapauth.NewLDAPAuthenticator( - opts.DB, cfg.WebServer().LDAP(), cfg.Insecure().DevWebServer(), globalLogger, auditLogger, + opts.DS, cfg.WebServer().LDAP(), cfg.Insecure().DevWebServer(), globalLogger, auditLogger, ) if err != nil { return nil, errors.Wrap(err, "NewApplication: failed to initialize LDAP Authentication module") } - sessionReaper = ldapauth.NewLDAPServerStateSync(sqlxDB, cfg.Database(), cfg.WebServer().LDAP(), globalLogger) + sessionReaper = ldapauth.NewLDAPServerStateSync(opts.DS, cfg.WebServer().LDAP(), globalLogger) case sessions.LocalAuth: - authenticationProvider = localauth.NewORM(opts.DB, cfg.WebServer().SessionTimeout().Duration(), globalLogger, auditLogger) - sessionReaper = localauth.NewSessionReaper(sqlxDB.DB, cfg.WebServer(), globalLogger) + authenticationProvider = localauth.NewORM(opts.DS, cfg.WebServer().SessionTimeout().Duration(), globalLogger, auditLogger) + sessionReaper = localauth.NewSessionReaper(opts.DS, cfg.WebServer(), globalLogger) default: return nil, errors.Errorf("NewApplication: Unexpected 'AuthenticationMethod': %s supported values: %s, %s", authMethod, sessions.LocalAuth, sessions.LDAPAuth) } var ( - pipelineORM = pipeline.NewORM(sqlxDB, globalLogger, cfg.JobPipeline().MaxSuccessfulRuns()) - bridgeORM = bridges.NewORM(sqlxDB) - mercuryORM = mercury.NewORM(opts.DB) + pipelineORM = pipeline.NewORM(opts.DS, globalLogger, cfg.JobPipeline().MaxSuccessfulRuns()) + bridgeORM = bridges.NewORM(opts.DS) + mercuryORM = mercury.NewORM(opts.DS) pipelineRunner = pipeline.NewRunner(pipelineORM, bridgeORM, cfg.JobPipeline(), cfg.WebServer(), legacyEVMChains, keyStore.Eth(), keyStore.VRF(), globalLogger, restrictedHTTPClient, unrestrictedHTTPClient) - jobORM = job.NewORM(sqlxDB, pipelineORM, bridgeORM, keyStore, globalLogger, cfg.Database()) - txmORM = txmgr.NewTxStore(opts.DB, globalLogger) + jobORM = job.NewORM(opts.DS, pipelineORM, bridgeORM, keyStore, globalLogger) + txmORM = txmgr.NewTxStore(opts.DS, globalLogger) streamRegistry = streams.NewRegistry(globalLogger, pipelineRunner) ) @@ -334,14 +327,14 @@ func NewApplication(opts ApplicationOpts) (Application, error) { mailMon), job.Keeper: keeper.NewDelegate( cfg, - sqlxDB, + opts.DS, jobORM, pipelineRunner, globalLogger, legacyEVMChains, mailMon), job.VRF: vrf.NewDelegate( - sqlxDB, + opts.DS, keyStore, pipelineRunner, pipelineORM, @@ -368,8 +361,7 @@ func NewApplication(opts ApplicationOpts) (Application, error) { job.Gateway: gateway.NewDelegate( legacyEVMChains, keyStore.Eth(), - sqlxDB, - cfg.Database(), + opts.DS, globalLogger), job.Stream: streams.NewDelegate( globalLogger, @@ -395,7 +387,7 @@ func NewApplication(opts ApplicationOpts) (Application, error) { jobORM, pipelineORM, pipelineRunner, - sqlxDB, + opts.DS, legacyEVMChains, globalLogger, ) @@ -408,7 +400,7 @@ func NewApplication(opts ApplicationOpts) (Application, error) { if err := ocrcommon.ValidatePeerWrapperConfig(cfg.P2P()); err != nil { return nil, err } - peerWrapper = ocrcommon.NewSingletonPeerWrapper(keyStore, cfg.P2P(), cfg.OCR(), cfg.Database(), sqlxDB, globalLogger) + peerWrapper = ocrcommon.NewSingletonPeerWrapper(keyStore, cfg.P2P(), cfg.OCR(), opts.DS, globalLogger) srvcs = append(srvcs, peerWrapper) } else { return nil, fmt.Errorf("P2P stack required for OCR or OCR2") @@ -416,7 +408,7 @@ func NewApplication(opts ApplicationOpts) (Application, error) { if cfg.OCR().Enabled() { delegates[job.OffchainReporting] = ocr.NewDelegate( - sqlxDB, + opts.DS, jobORM, keyStore, pipelineRunner, @@ -436,11 +428,10 @@ func NewApplication(opts ApplicationOpts) (Application, error) { if cfg.OCR2().Enabled() { globalLogger.Debug("Off-chain reporting v2 enabled") - ocr2DelegateConfig := ocr2.NewDelegateConfig(cfg.OCR2(), cfg.Mercury(), cfg.Threshold(), cfg.Insecure(), cfg.JobPipeline(), cfg.Database(), loopRegistrarConfig) + ocr2DelegateConfig := ocr2.NewDelegateConfig(cfg.OCR2(), cfg.Mercury(), cfg.Threshold(), cfg.Insecure(), cfg.JobPipeline(), loopRegistrarConfig) delegates[job.OffchainReporting2] = ocr2.NewDelegate( - sqlxDB, - opts.DB, + opts.DS, jobORM, bridgeORM, mercuryORM, @@ -460,7 +451,7 @@ func NewApplication(opts ApplicationOpts) (Application, error) { registry, ) delegates[job.Bootstrap] = ocrbootstrap.NewDelegateBootstrap( - sqlxDB, + opts.DS, jobORM, peerWrapper, globalLogger, @@ -478,7 +469,7 @@ func NewApplication(opts ApplicationOpts) (Application, error) { for _, c := range legacyEVMChains.Slice() { lbs = append(lbs, c.LogBroadcaster()) } - jobSpawner := job.NewSpawner(jobORM, cfg.Database(), healthChecker, delegates, sqlxDB, globalLogger, lbs) + jobSpawner := job.NewSpawner(jobORM, cfg.Database(), healthChecker, delegates, globalLogger, lbs) srvcs = append(srvcs, jobSpawner, pipelineRunner) // We start the log poller after the job spawner @@ -491,11 +482,11 @@ func NewApplication(opts ApplicationOpts) (Application, error) { var feedsService feeds.Service if cfg.Feature().FeedsManager() { - feedsORM := feeds.NewORM(sqlxDB, opts.Logger, cfg.Database()) + feedsORM := feeds.NewORM(opts.DS) feedsService = feeds.NewService( feedsORM, jobORM, - sqlxDB, + opts.DS, jobSpawner, keyStore, cfg, @@ -503,7 +494,6 @@ func NewApplication(opts ApplicationOpts) (Application, error) { cfg.JobPipeline(), cfg.OCR(), cfg.OCR2(), - cfg.Database(), legacyEVMChains, globalLogger, opts.Version, @@ -548,8 +538,7 @@ func NewApplication(opts ApplicationOpts) (Application, error) { loopRegistry: loopRegistry, loopRegistrarConfig: loopRegistrarConfig, - sqlxDB: opts.SqlxDB, - db: opts.DB, + ds: opts.DS, // NOTE: Can keep things clean by putting more things in srvcs instead of manually start/closing srvcs: srvcs, @@ -742,7 +731,7 @@ func (app *ChainlinkApplication) WakeSessionReaper() { } func (app *ChainlinkApplication) AddJobV2(ctx context.Context, j *job.Job) error { - return app.jobSpawner.CreateJob(j, pg.WithParentCtx(ctx)) + return app.jobSpawner.CreateJob(ctx, nil, j) } func (app *ChainlinkApplication) DeleteJob(ctx context.Context, jobID int32) error { @@ -756,7 +745,7 @@ func (app *ChainlinkApplication) DeleteJob(ctx context.Context, jobID int32) err return errors.New("job must be deleted in the feeds manager") } - return app.jobSpawner.DeleteJob(jobID, pg.WithParentCtx(ctx)) + return app.jobSpawner.DeleteJob(ctx, nil, jobID) } func (app *ChainlinkApplication) RunWebhookJobV2(ctx context.Context, jobUUID uuid.UUID, requestBody string, meta jsonserializable.JSONSerializable) (int64, error) { @@ -857,12 +846,8 @@ func (app *ChainlinkApplication) GetRelayers() RelayerChainInteroperators { return app.relayers } -func (app *ChainlinkApplication) GetSqlxDB() *sqlx.DB { - return app.sqlxDB -} - func (app *ChainlinkApplication) GetDB() sqlutil.DataSource { - return app.db + return app.ds } // Returns the configuration to use for creating and authenticating diff --git a/core/services/chainlink/relayer_chain_interoperators_test.go b/core/services/chainlink/relayer_chain_interoperators_test.go index 7126c73927c..8111c1f61b4 100644 --- a/core/services/chainlink/relayer_chain_interoperators_test.go +++ b/core/services/chainlink/relayer_chain_interoperators_test.go @@ -212,8 +212,7 @@ func TestCoreRelayerChainInteroperators(t *testing.T) { ChainOpts: legacyevm.ChainOpts{ AppConfig: cfg, MailMon: &mailbox.Monitor{}, - SqlxDB: db, - DB: db, + DS: db, }, CSAETHKeystore: keyStore, }), @@ -265,8 +264,8 @@ func TestCoreRelayerChainInteroperators(t *testing.T) { chainlink.InitCosmos(testctx, factory, chainlink.CosmosFactoryConfig{ Keystore: keyStore.Cosmos(), TOMLConfigs: cfg.CosmosConfigs(), - DB: db, - QConfig: cfg.Database()}), + DS: db, + }), }, expectedCosmosChainCnt: 2, expectedCosmosNodeCnt: 2, @@ -287,8 +286,7 @@ func TestCoreRelayerChainInteroperators(t *testing.T) { AppConfig: cfg, MailMon: &mailbox.Monitor{}, - SqlxDB: db, - DB: db, + DS: db, }, CSAETHKeystore: keyStore, }), @@ -298,8 +296,7 @@ func TestCoreRelayerChainInteroperators(t *testing.T) { chainlink.InitCosmos(testctx, factory, chainlink.CosmosFactoryConfig{ Keystore: keyStore.Cosmos(), TOMLConfigs: cfg.CosmosConfigs(), - DB: db, - QConfig: cfg.Database(), + DS: db, }), }, expectedEVMChainCnt: 2, diff --git a/core/services/chainlink/relayer_factory.go b/core/services/chainlink/relayer_factory.go index 5902555f79c..00db81cce37 100644 --- a/core/services/chainlink/relayer_factory.go +++ b/core/services/chainlink/relayer_factory.go @@ -5,10 +5,10 @@ import ( "errors" "fmt" - "github.com/jmoiron/sqlx" "github.com/pelletier/go-toml/v2" "github.com/smartcontractkit/chainlink-common/pkg/loop" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-cosmos/pkg/cosmos" coscfg "github.com/smartcontractkit/chainlink-cosmos/pkg/cosmos/config" @@ -22,7 +22,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/config/env" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/relay" evmrelay "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury/wsrpc" @@ -68,9 +67,7 @@ func (r *RelayerFactory) NewEVM(ctx context.Context, config EVMFactoryConfig) (m } relayerOpts := evmrelay.RelayerOpts{ - DB: ccOpts.SqlxDB, - DS: ccOpts.DB, - QConfig: ccOpts.AppConfig.Database(), + DS: ccOpts.DS, CSAETHKeystore: config.CSAETHKeystore, MercuryPool: r.MercuryPool, } @@ -239,8 +236,7 @@ func (r *RelayerFactory) NewStarkNet(ks keystore.StarkNet, chainCfgs config.TOML type CosmosFactoryConfig struct { Keystore keystore.Cosmos coscfg.TOMLConfigs - *sqlx.DB - pg.QConfig + DS sqlutil.DataSource } func (c CosmosFactoryConfig) Validate() error { @@ -251,11 +247,8 @@ func (c CosmosFactoryConfig) Validate() error { if len(c.TOMLConfigs) == 0 { err = errors.Join(err, fmt.Errorf("no CosmosConfigs provided")) } - if c.DB == nil { - err = errors.Join(err, fmt.Errorf("nil DB")) - } - if c.QConfig == nil { - err = errors.Join(err, fmt.Errorf("nil QConfig")) + if c.DS == nil { + err = errors.Join(err, fmt.Errorf("nil DataStore")) } if err != nil { @@ -284,7 +277,7 @@ func (r *RelayerFactory) NewCosmos(config CosmosFactoryConfig) (map[types.RelayI opts := cosmos.ChainOpts{ Logger: lggr, - DB: config.DB, + DS: config.DS, KeyStore: loopKs, } diff --git a/core/services/cron/cron_test.go b/core/services/cron/cron_test.go index b31a06a9591..38684e982d9 100644 --- a/core/services/cron/cron_test.go +++ b/core/services/cron/cron_test.go @@ -29,7 +29,7 @@ func TestCronV2Pipeline(t *testing.T) { lggr := logger.TestLogger(t) orm := pipeline.NewORM(db, lggr, cfg.JobPipeline().MaxSuccessfulRuns()) btORM := bridges.NewORM(db) - jobORM := job.NewORM(db, orm, btORM, keyStore, lggr, cfg.Database()) + jobORM := job.NewORM(db, orm, btORM, keyStore, lggr) jb := &job.Job{ Type: job.Cron, @@ -40,7 +40,7 @@ func TestCronV2Pipeline(t *testing.T) { } delegate := cron.NewDelegate(runner, lggr) - require.NoError(t, jobORM.CreateJob(jb)) + require.NoError(t, jobORM.CreateJob(testutils.Context(t), jb)) serviceArray, err := delegate.ServicesForSpec(testutils.Context(t), *jb) require.NoError(t, err) assert.Len(t, serviceArray, 1) diff --git a/core/services/directrequest/delegate.go b/core/services/directrequest/delegate.go index 33a0a7e73da..26f2c5f9c84 100644 --- a/core/services/directrequest/delegate.go +++ b/core/services/directrequest/delegate.go @@ -215,36 +215,37 @@ func (l *listener) HandleLog(ctx context.Context, lb log.Broadcast) { } func (l *listener) processOracleRequests() { + ctx, cancel := l.chStop.NewCtx() + defer cancel() for { select { case <-l.chStop: l.shutdownWaitGroup.Done() return case <-l.mbOracleRequests.Notify(): - l.handleReceivedLogs(l.mbOracleRequests) + l.handleReceivedLogs(ctx, l.mbOracleRequests) } } } func (l *listener) processCancelOracleRequests() { + ctx, cancel := l.chStop.NewCtx() + defer cancel() for { select { case <-l.chStop: l.shutdownWaitGroup.Done() return case <-l.mbOracleCancelRequests.Notify(): - l.handleReceivedLogs(l.mbOracleCancelRequests) + l.handleReceivedLogs(ctx, l.mbOracleCancelRequests) } } } -func (l *listener) handleReceivedLogs(mailbox *mailbox.Mailbox[log.Broadcast]) { - ctx, cancel := l.chStop.NewCtx() - defer cancel() - +func (l *listener) handleReceivedLogs(ctx context.Context, mailbox *mailbox.Mailbox[log.Broadcast]) { for { select { - case <-l.chStop: + case <-ctx.Done(): return default: } @@ -263,7 +264,7 @@ func (l *listener) handleReceivedLogs(mailbox *mailbox.Mailbox[log.Broadcast]) { logJobSpecID := lb.RawLog().Topics[1] if logJobSpecID == (common.Hash{}) || (logJobSpecID != l.job.ExternalIDEncodeStringToTopic() && logJobSpecID != l.job.ExternalIDEncodeBytesToTopic()) { l.logger.Debugw("Skipping Run for Log with wrong Job ID", "logJobSpecID", logJobSpecID) - l.markLogConsumed(ctx, lb) + l.markLogConsumed(ctx, nil, lb) continue } @@ -277,7 +278,7 @@ func (l *listener) handleReceivedLogs(mailbox *mailbox.Mailbox[log.Broadcast]) { case *operator_wrapper.OperatorOracleRequest: l.handleOracleRequest(ctx, log, lb) case *operator_wrapper.OperatorCancelOracleRequest: - l.handleCancelOracleRequest(ctx, log, lb) + l.handleCancelOracleRequest(ctx, nil, log, lb) default: l.logger.Warnf("Unexpected log type %T", log) } @@ -316,7 +317,7 @@ func (l *listener) handleOracleRequest(ctx context.Context, request *operator_wr "requester", request.Requester, "allowedRequesters", l.requesters.ToStrings(), ) - l.markLogConsumed(ctx, lb) + l.markLogConsumed(ctx, nil, lb) return } @@ -333,7 +334,7 @@ func (l *listener) handleOracleRequest(ctx context.Context, request *operator_wr "minContractPayment", minContractPayment.String(), "requestPayment", requestPayment.String(), ) - l.markLogConsumed(ctx, lb) + l.markLogConsumed(ctx, nil, lb) return } } @@ -375,7 +376,7 @@ func (l *listener) handleOracleRequest(ctx context.Context, request *operator_wr }) run := pipeline.NewRun(*l.job.PipelineSpec, vars) _, err := l.pipelineRunner.Run(ctx, run, l.logger, true, func(tx sqlutil.DataSource) error { - l.markLogConsumed(ctx, lb) + l.markLogConsumed(ctx, tx, lb) return nil }) if ctx.Err() != nil { @@ -398,16 +399,16 @@ func (l *listener) allowRequester(requester common.Address) bool { } // Cancels runs that haven't been started yet, with the given request ID -func (l *listener) handleCancelOracleRequest(ctx context.Context, request *operator_wrapper.OperatorCancelOracleRequest, lb log.Broadcast) { +func (l *listener) handleCancelOracleRequest(ctx context.Context, ds sqlutil.DataSource, request *operator_wrapper.OperatorCancelOracleRequest, lb log.Broadcast) { runCloserChannelIf, loaded := l.runs.LoadAndDelete(formatRequestId(request.RequestId)) if loaded { close(runCloserChannelIf.(services.StopChan)) } - l.markLogConsumed(ctx, lb) + l.markLogConsumed(ctx, ds, lb) } -func (l *listener) markLogConsumed(ctx context.Context, lb log.Broadcast) { - if err := l.logBroadcaster.MarkConsumed(ctx, nil, lb); err != nil { +func (l *listener) markLogConsumed(ctx context.Context, ds sqlutil.DataSource, lb log.Broadcast) { + if err := l.logBroadcaster.MarkConsumed(ctx, ds, lb); err != nil { l.logger.Errorw("Unable to mark log consumed", "err", err, "log", lb.String()) } } diff --git a/core/services/directrequest/delegate_test.go b/core/services/directrequest/delegate_test.go index 08f38865bec..e754713b010 100644 --- a/core/services/directrequest/delegate_test.go +++ b/core/services/directrequest/delegate_test.go @@ -90,7 +90,7 @@ func NewDirectRequestUniverseWithConfig(t *testing.T, cfg chainlink.GeneralConfi lggr := logger.TestLogger(t) orm := pipeline.NewORM(db, lggr, cfg.JobPipeline().MaxSuccessfulRuns()) btORM := bridges.NewORM(db) - jobORM := job.NewORM(db, orm, btORM, keyStore, lggr, cfg.Database()) + jobORM := job.NewORM(db, orm, btORM, keyStore, lggr) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) delegate := directrequest.NewDelegate(lggr, runner, orm, legacyChains, mailMon) @@ -99,8 +99,9 @@ func NewDirectRequestUniverseWithConfig(t *testing.T, cfg chainlink.GeneralConfi if specF != nil { specF(jb) } - require.NoError(t, jobORM.CreateJob(jb)) - serviceArray, err := delegate.ServicesForSpec(testutils.Context(t), *jb) + ctx := testutils.Context(t) + require.NoError(t, jobORM.CreateJob(ctx, jb)) + serviceArray, err := delegate.ServicesForSpec(ctx, *jb) require.NoError(t, err) assert.Len(t, serviceArray, 1) service := serviceArray[0] diff --git a/core/services/feeds/mocks/orm.go b/core/services/feeds/mocks/orm.go index f84d80a6eb1..625a0b41d9a 100644 --- a/core/services/feeds/mocks/orm.go +++ b/core/services/feeds/mocks/orm.go @@ -3,10 +3,12 @@ package mocks import ( + context "context" + feeds "github.com/smartcontractkit/chainlink/v2/core/services/feeds" mock "github.com/stretchr/testify/mock" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" + sqlutil "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" uuid "github.com/google/uuid" ) @@ -24,24 +26,17 @@ func (_m *ORM) EXPECT() *ORM_Expecter { return &ORM_Expecter{mock: &_m.Mock} } -// ApproveSpec provides a mock function with given fields: id, externalJobID, qopts -func (_m *ORM) ApproveSpec(id int64, externalJobID uuid.UUID, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, id, externalJobID) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// ApproveSpec provides a mock function with given fields: ctx, id, externalJobID +func (_m *ORM) ApproveSpec(ctx context.Context, id int64, externalJobID uuid.UUID) error { + ret := _m.Called(ctx, id, externalJobID) if len(ret) == 0 { panic("no return value specified for ApproveSpec") } var r0 error - if rf, ok := ret.Get(0).(func(int64, uuid.UUID, ...pg.QOpt) error); ok { - r0 = rf(id, externalJobID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64, uuid.UUID) error); ok { + r0 = rf(ctx, id, externalJobID) } else { r0 = ret.Error(0) } @@ -55,23 +50,16 @@ type ORM_ApproveSpec_Call struct { } // ApproveSpec is a helper method to define mock.On call +// - ctx context.Context // - id int64 // - externalJobID uuid.UUID -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) ApproveSpec(id interface{}, externalJobID interface{}, qopts ...interface{}) *ORM_ApproveSpec_Call { - return &ORM_ApproveSpec_Call{Call: _e.mock.On("ApproveSpec", - append([]interface{}{id, externalJobID}, qopts...)...)} +func (_e *ORM_Expecter) ApproveSpec(ctx interface{}, id interface{}, externalJobID interface{}) *ORM_ApproveSpec_Call { + return &ORM_ApproveSpec_Call{Call: _e.mock.On("ApproveSpec", ctx, id, externalJobID)} } -func (_c *ORM_ApproveSpec_Call) Run(run func(id int64, externalJobID uuid.UUID, qopts ...pg.QOpt)) *ORM_ApproveSpec_Call { +func (_c *ORM_ApproveSpec_Call) Run(run func(ctx context.Context, id int64, externalJobID uuid.UUID)) *ORM_ApproveSpec_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-2) - for i, a := range args[2:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].(int64), args[1].(uuid.UUID), variadicArgs...) + run(args[0].(context.Context), args[1].(int64), args[2].(uuid.UUID)) }) return _c } @@ -81,29 +69,22 @@ func (_c *ORM_ApproveSpec_Call) Return(_a0 error) *ORM_ApproveSpec_Call { return _c } -func (_c *ORM_ApproveSpec_Call) RunAndReturn(run func(int64, uuid.UUID, ...pg.QOpt) error) *ORM_ApproveSpec_Call { +func (_c *ORM_ApproveSpec_Call) RunAndReturn(run func(context.Context, int64, uuid.UUID) error) *ORM_ApproveSpec_Call { _c.Call.Return(run) return _c } -// CancelSpec provides a mock function with given fields: id, qopts -func (_m *ORM) CancelSpec(id int64, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, id) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// CancelSpec provides a mock function with given fields: ctx, id +func (_m *ORM) CancelSpec(ctx context.Context, id int64) error { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for CancelSpec") } var r0 error - if rf, ok := ret.Get(0).(func(int64, ...pg.QOpt) error); ok { - r0 = rf(id, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, id) } else { r0 = ret.Error(0) } @@ -117,22 +98,15 @@ type ORM_CancelSpec_Call struct { } // CancelSpec is a helper method to define mock.On call +// - ctx context.Context // - id int64 -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) CancelSpec(id interface{}, qopts ...interface{}) *ORM_CancelSpec_Call { - return &ORM_CancelSpec_Call{Call: _e.mock.On("CancelSpec", - append([]interface{}{id}, qopts...)...)} +func (_e *ORM_Expecter) CancelSpec(ctx interface{}, id interface{}) *ORM_CancelSpec_Call { + return &ORM_CancelSpec_Call{Call: _e.mock.On("CancelSpec", ctx, id)} } -func (_c *ORM_CancelSpec_Call) Run(run func(id int64, qopts ...pg.QOpt)) *ORM_CancelSpec_Call { +func (_c *ORM_CancelSpec_Call) Run(run func(ctx context.Context, id int64)) *ORM_CancelSpec_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-1) - for i, a := range args[1:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].(int64), variadicArgs...) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -142,14 +116,14 @@ func (_c *ORM_CancelSpec_Call) Return(_a0 error) *ORM_CancelSpec_Call { return _c } -func (_c *ORM_CancelSpec_Call) RunAndReturn(run func(int64, ...pg.QOpt) error) *ORM_CancelSpec_Call { +func (_c *ORM_CancelSpec_Call) RunAndReturn(run func(context.Context, int64) error) *ORM_CancelSpec_Call { _c.Call.Return(run) return _c } -// CountJobProposals provides a mock function with given fields: -func (_m *ORM) CountJobProposals() (int64, error) { - ret := _m.Called() +// CountJobProposals provides a mock function with given fields: ctx +func (_m *ORM) CountJobProposals(ctx context.Context) (int64, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for CountJobProposals") @@ -157,17 +131,17 @@ func (_m *ORM) CountJobProposals() (int64, error) { var r0 int64 var r1 error - if rf, ok := ret.Get(0).(func() (int64, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (int64, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() int64); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) int64); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(int64) } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -181,13 +155,14 @@ type ORM_CountJobProposals_Call struct { } // CountJobProposals is a helper method to define mock.On call -func (_e *ORM_Expecter) CountJobProposals() *ORM_CountJobProposals_Call { - return &ORM_CountJobProposals_Call{Call: _e.mock.On("CountJobProposals")} +// - ctx context.Context +func (_e *ORM_Expecter) CountJobProposals(ctx interface{}) *ORM_CountJobProposals_Call { + return &ORM_CountJobProposals_Call{Call: _e.mock.On("CountJobProposals", ctx)} } -func (_c *ORM_CountJobProposals_Call) Run(run func()) *ORM_CountJobProposals_Call { +func (_c *ORM_CountJobProposals_Call) Run(run func(ctx context.Context)) *ORM_CountJobProposals_Call { _c.Call.Run(func(args mock.Arguments) { - run() + run(args[0].(context.Context)) }) return _c } @@ -197,14 +172,14 @@ func (_c *ORM_CountJobProposals_Call) Return(_a0 int64, _a1 error) *ORM_CountJob return _c } -func (_c *ORM_CountJobProposals_Call) RunAndReturn(run func() (int64, error)) *ORM_CountJobProposals_Call { +func (_c *ORM_CountJobProposals_Call) RunAndReturn(run func(context.Context) (int64, error)) *ORM_CountJobProposals_Call { _c.Call.Return(run) return _c } -// CountJobProposalsByStatus provides a mock function with given fields: -func (_m *ORM) CountJobProposalsByStatus() (*feeds.JobProposalCounts, error) { - ret := _m.Called() +// CountJobProposalsByStatus provides a mock function with given fields: ctx +func (_m *ORM) CountJobProposalsByStatus(ctx context.Context) (*feeds.JobProposalCounts, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for CountJobProposalsByStatus") @@ -212,19 +187,19 @@ func (_m *ORM) CountJobProposalsByStatus() (*feeds.JobProposalCounts, error) { var r0 *feeds.JobProposalCounts var r1 error - if rf, ok := ret.Get(0).(func() (*feeds.JobProposalCounts, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (*feeds.JobProposalCounts, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() *feeds.JobProposalCounts); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) *feeds.JobProposalCounts); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*feeds.JobProposalCounts) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -238,13 +213,14 @@ type ORM_CountJobProposalsByStatus_Call struct { } // CountJobProposalsByStatus is a helper method to define mock.On call -func (_e *ORM_Expecter) CountJobProposalsByStatus() *ORM_CountJobProposalsByStatus_Call { - return &ORM_CountJobProposalsByStatus_Call{Call: _e.mock.On("CountJobProposalsByStatus")} +// - ctx context.Context +func (_e *ORM_Expecter) CountJobProposalsByStatus(ctx interface{}) *ORM_CountJobProposalsByStatus_Call { + return &ORM_CountJobProposalsByStatus_Call{Call: _e.mock.On("CountJobProposalsByStatus", ctx)} } -func (_c *ORM_CountJobProposalsByStatus_Call) Run(run func()) *ORM_CountJobProposalsByStatus_Call { +func (_c *ORM_CountJobProposalsByStatus_Call) Run(run func(ctx context.Context)) *ORM_CountJobProposalsByStatus_Call { _c.Call.Run(func(args mock.Arguments) { - run() + run(args[0].(context.Context)) }) return _c } @@ -254,14 +230,14 @@ func (_c *ORM_CountJobProposalsByStatus_Call) Return(counts *feeds.JobProposalCo return _c } -func (_c *ORM_CountJobProposalsByStatus_Call) RunAndReturn(run func() (*feeds.JobProposalCounts, error)) *ORM_CountJobProposalsByStatus_Call { +func (_c *ORM_CountJobProposalsByStatus_Call) RunAndReturn(run func(context.Context) (*feeds.JobProposalCounts, error)) *ORM_CountJobProposalsByStatus_Call { _c.Call.Return(run) return _c } -// CountManagers provides a mock function with given fields: -func (_m *ORM) CountManagers() (int64, error) { - ret := _m.Called() +// CountManagers provides a mock function with given fields: ctx +func (_m *ORM) CountManagers(ctx context.Context) (int64, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for CountManagers") @@ -269,17 +245,17 @@ func (_m *ORM) CountManagers() (int64, error) { var r0 int64 var r1 error - if rf, ok := ret.Get(0).(func() (int64, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (int64, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() int64); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) int64); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(int64) } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -293,13 +269,14 @@ type ORM_CountManagers_Call struct { } // CountManagers is a helper method to define mock.On call -func (_e *ORM_Expecter) CountManagers() *ORM_CountManagers_Call { - return &ORM_CountManagers_Call{Call: _e.mock.On("CountManagers")} +// - ctx context.Context +func (_e *ORM_Expecter) CountManagers(ctx interface{}) *ORM_CountManagers_Call { + return &ORM_CountManagers_Call{Call: _e.mock.On("CountManagers", ctx)} } -func (_c *ORM_CountManagers_Call) Run(run func()) *ORM_CountManagers_Call { +func (_c *ORM_CountManagers_Call) Run(run func(ctx context.Context)) *ORM_CountManagers_Call { _c.Call.Run(func(args mock.Arguments) { - run() + run(args[0].(context.Context)) }) return _c } @@ -309,21 +286,14 @@ func (_c *ORM_CountManagers_Call) Return(_a0 int64, _a1 error) *ORM_CountManager return _c } -func (_c *ORM_CountManagers_Call) RunAndReturn(run func() (int64, error)) *ORM_CountManagers_Call { +func (_c *ORM_CountManagers_Call) RunAndReturn(run func(context.Context) (int64, error)) *ORM_CountManagers_Call { _c.Call.Return(run) return _c } -// CreateBatchChainConfig provides a mock function with given fields: cfgs, qopts -func (_m *ORM) CreateBatchChainConfig(cfgs []feeds.ChainConfig, qopts ...pg.QOpt) ([]int64, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, cfgs) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// CreateBatchChainConfig provides a mock function with given fields: ctx, cfgs +func (_m *ORM) CreateBatchChainConfig(ctx context.Context, cfgs []feeds.ChainConfig) ([]int64, error) { + ret := _m.Called(ctx, cfgs) if len(ret) == 0 { panic("no return value specified for CreateBatchChainConfig") @@ -331,19 +301,19 @@ func (_m *ORM) CreateBatchChainConfig(cfgs []feeds.ChainConfig, qopts ...pg.QOpt var r0 []int64 var r1 error - if rf, ok := ret.Get(0).(func([]feeds.ChainConfig, ...pg.QOpt) ([]int64, error)); ok { - return rf(cfgs, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, []feeds.ChainConfig) ([]int64, error)); ok { + return rf(ctx, cfgs) } - if rf, ok := ret.Get(0).(func([]feeds.ChainConfig, ...pg.QOpt) []int64); ok { - r0 = rf(cfgs, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, []feeds.ChainConfig) []int64); ok { + r0 = rf(ctx, cfgs) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]int64) } } - if rf, ok := ret.Get(1).(func([]feeds.ChainConfig, ...pg.QOpt) error); ok { - r1 = rf(cfgs, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, []feeds.ChainConfig) error); ok { + r1 = rf(ctx, cfgs) } else { r1 = ret.Error(1) } @@ -357,22 +327,15 @@ type ORM_CreateBatchChainConfig_Call struct { } // CreateBatchChainConfig is a helper method to define mock.On call +// - ctx context.Context // - cfgs []feeds.ChainConfig -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) CreateBatchChainConfig(cfgs interface{}, qopts ...interface{}) *ORM_CreateBatchChainConfig_Call { - return &ORM_CreateBatchChainConfig_Call{Call: _e.mock.On("CreateBatchChainConfig", - append([]interface{}{cfgs}, qopts...)...)} +func (_e *ORM_Expecter) CreateBatchChainConfig(ctx interface{}, cfgs interface{}) *ORM_CreateBatchChainConfig_Call { + return &ORM_CreateBatchChainConfig_Call{Call: _e.mock.On("CreateBatchChainConfig", ctx, cfgs)} } -func (_c *ORM_CreateBatchChainConfig_Call) Run(run func(cfgs []feeds.ChainConfig, qopts ...pg.QOpt)) *ORM_CreateBatchChainConfig_Call { +func (_c *ORM_CreateBatchChainConfig_Call) Run(run func(ctx context.Context, cfgs []feeds.ChainConfig)) *ORM_CreateBatchChainConfig_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-1) - for i, a := range args[1:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].([]feeds.ChainConfig), variadicArgs...) + run(args[0].(context.Context), args[1].([]feeds.ChainConfig)) }) return _c } @@ -382,21 +345,14 @@ func (_c *ORM_CreateBatchChainConfig_Call) Return(_a0 []int64, _a1 error) *ORM_C return _c } -func (_c *ORM_CreateBatchChainConfig_Call) RunAndReturn(run func([]feeds.ChainConfig, ...pg.QOpt) ([]int64, error)) *ORM_CreateBatchChainConfig_Call { +func (_c *ORM_CreateBatchChainConfig_Call) RunAndReturn(run func(context.Context, []feeds.ChainConfig) ([]int64, error)) *ORM_CreateBatchChainConfig_Call { _c.Call.Return(run) return _c } -// CreateChainConfig provides a mock function with given fields: cfg, qopts -func (_m *ORM) CreateChainConfig(cfg feeds.ChainConfig, qopts ...pg.QOpt) (int64, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, cfg) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// CreateChainConfig provides a mock function with given fields: ctx, cfg +func (_m *ORM) CreateChainConfig(ctx context.Context, cfg feeds.ChainConfig) (int64, error) { + ret := _m.Called(ctx, cfg) if len(ret) == 0 { panic("no return value specified for CreateChainConfig") @@ -404,17 +360,17 @@ func (_m *ORM) CreateChainConfig(cfg feeds.ChainConfig, qopts ...pg.QOpt) (int64 var r0 int64 var r1 error - if rf, ok := ret.Get(0).(func(feeds.ChainConfig, ...pg.QOpt) (int64, error)); ok { - return rf(cfg, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, feeds.ChainConfig) (int64, error)); ok { + return rf(ctx, cfg) } - if rf, ok := ret.Get(0).(func(feeds.ChainConfig, ...pg.QOpt) int64); ok { - r0 = rf(cfg, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, feeds.ChainConfig) int64); ok { + r0 = rf(ctx, cfg) } else { r0 = ret.Get(0).(int64) } - if rf, ok := ret.Get(1).(func(feeds.ChainConfig, ...pg.QOpt) error); ok { - r1 = rf(cfg, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, feeds.ChainConfig) error); ok { + r1 = rf(ctx, cfg) } else { r1 = ret.Error(1) } @@ -428,22 +384,15 @@ type ORM_CreateChainConfig_Call struct { } // CreateChainConfig is a helper method to define mock.On call +// - ctx context.Context // - cfg feeds.ChainConfig -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) CreateChainConfig(cfg interface{}, qopts ...interface{}) *ORM_CreateChainConfig_Call { - return &ORM_CreateChainConfig_Call{Call: _e.mock.On("CreateChainConfig", - append([]interface{}{cfg}, qopts...)...)} +func (_e *ORM_Expecter) CreateChainConfig(ctx interface{}, cfg interface{}) *ORM_CreateChainConfig_Call { + return &ORM_CreateChainConfig_Call{Call: _e.mock.On("CreateChainConfig", ctx, cfg)} } -func (_c *ORM_CreateChainConfig_Call) Run(run func(cfg feeds.ChainConfig, qopts ...pg.QOpt)) *ORM_CreateChainConfig_Call { +func (_c *ORM_CreateChainConfig_Call) Run(run func(ctx context.Context, cfg feeds.ChainConfig)) *ORM_CreateChainConfig_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-1) - for i, a := range args[1:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].(feeds.ChainConfig), variadicArgs...) + run(args[0].(context.Context), args[1].(feeds.ChainConfig)) }) return _c } @@ -453,14 +402,14 @@ func (_c *ORM_CreateChainConfig_Call) Return(_a0 int64, _a1 error) *ORM_CreateCh return _c } -func (_c *ORM_CreateChainConfig_Call) RunAndReturn(run func(feeds.ChainConfig, ...pg.QOpt) (int64, error)) *ORM_CreateChainConfig_Call { +func (_c *ORM_CreateChainConfig_Call) RunAndReturn(run func(context.Context, feeds.ChainConfig) (int64, error)) *ORM_CreateChainConfig_Call { _c.Call.Return(run) return _c } -// CreateJobProposal provides a mock function with given fields: jp -func (_m *ORM) CreateJobProposal(jp *feeds.JobProposal) (int64, error) { - ret := _m.Called(jp) +// CreateJobProposal provides a mock function with given fields: ctx, jp +func (_m *ORM) CreateJobProposal(ctx context.Context, jp *feeds.JobProposal) (int64, error) { + ret := _m.Called(ctx, jp) if len(ret) == 0 { panic("no return value specified for CreateJobProposal") @@ -468,17 +417,17 @@ func (_m *ORM) CreateJobProposal(jp *feeds.JobProposal) (int64, error) { var r0 int64 var r1 error - if rf, ok := ret.Get(0).(func(*feeds.JobProposal) (int64, error)); ok { - return rf(jp) + if rf, ok := ret.Get(0).(func(context.Context, *feeds.JobProposal) (int64, error)); ok { + return rf(ctx, jp) } - if rf, ok := ret.Get(0).(func(*feeds.JobProposal) int64); ok { - r0 = rf(jp) + if rf, ok := ret.Get(0).(func(context.Context, *feeds.JobProposal) int64); ok { + r0 = rf(ctx, jp) } else { r0 = ret.Get(0).(int64) } - if rf, ok := ret.Get(1).(func(*feeds.JobProposal) error); ok { - r1 = rf(jp) + if rf, ok := ret.Get(1).(func(context.Context, *feeds.JobProposal) error); ok { + r1 = rf(ctx, jp) } else { r1 = ret.Error(1) } @@ -492,14 +441,15 @@ type ORM_CreateJobProposal_Call struct { } // CreateJobProposal is a helper method to define mock.On call +// - ctx context.Context // - jp *feeds.JobProposal -func (_e *ORM_Expecter) CreateJobProposal(jp interface{}) *ORM_CreateJobProposal_Call { - return &ORM_CreateJobProposal_Call{Call: _e.mock.On("CreateJobProposal", jp)} +func (_e *ORM_Expecter) CreateJobProposal(ctx interface{}, jp interface{}) *ORM_CreateJobProposal_Call { + return &ORM_CreateJobProposal_Call{Call: _e.mock.On("CreateJobProposal", ctx, jp)} } -func (_c *ORM_CreateJobProposal_Call) Run(run func(jp *feeds.JobProposal)) *ORM_CreateJobProposal_Call { +func (_c *ORM_CreateJobProposal_Call) Run(run func(ctx context.Context, jp *feeds.JobProposal)) *ORM_CreateJobProposal_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(*feeds.JobProposal)) + run(args[0].(context.Context), args[1].(*feeds.JobProposal)) }) return _c } @@ -509,21 +459,14 @@ func (_c *ORM_CreateJobProposal_Call) Return(_a0 int64, _a1 error) *ORM_CreateJo return _c } -func (_c *ORM_CreateJobProposal_Call) RunAndReturn(run func(*feeds.JobProposal) (int64, error)) *ORM_CreateJobProposal_Call { +func (_c *ORM_CreateJobProposal_Call) RunAndReturn(run func(context.Context, *feeds.JobProposal) (int64, error)) *ORM_CreateJobProposal_Call { _c.Call.Return(run) return _c } -// CreateManager provides a mock function with given fields: ms, qopts -func (_m *ORM) CreateManager(ms *feeds.FeedsManager, qopts ...pg.QOpt) (int64, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, ms) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// CreateManager provides a mock function with given fields: ctx, ms +func (_m *ORM) CreateManager(ctx context.Context, ms *feeds.FeedsManager) (int64, error) { + ret := _m.Called(ctx, ms) if len(ret) == 0 { panic("no return value specified for CreateManager") @@ -531,17 +474,17 @@ func (_m *ORM) CreateManager(ms *feeds.FeedsManager, qopts ...pg.QOpt) (int64, e var r0 int64 var r1 error - if rf, ok := ret.Get(0).(func(*feeds.FeedsManager, ...pg.QOpt) (int64, error)); ok { - return rf(ms, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *feeds.FeedsManager) (int64, error)); ok { + return rf(ctx, ms) } - if rf, ok := ret.Get(0).(func(*feeds.FeedsManager, ...pg.QOpt) int64); ok { - r0 = rf(ms, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *feeds.FeedsManager) int64); ok { + r0 = rf(ctx, ms) } else { r0 = ret.Get(0).(int64) } - if rf, ok := ret.Get(1).(func(*feeds.FeedsManager, ...pg.QOpt) error); ok { - r1 = rf(ms, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, *feeds.FeedsManager) error); ok { + r1 = rf(ctx, ms) } else { r1 = ret.Error(1) } @@ -555,22 +498,15 @@ type ORM_CreateManager_Call struct { } // CreateManager is a helper method to define mock.On call +// - ctx context.Context // - ms *feeds.FeedsManager -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) CreateManager(ms interface{}, qopts ...interface{}) *ORM_CreateManager_Call { - return &ORM_CreateManager_Call{Call: _e.mock.On("CreateManager", - append([]interface{}{ms}, qopts...)...)} +func (_e *ORM_Expecter) CreateManager(ctx interface{}, ms interface{}) *ORM_CreateManager_Call { + return &ORM_CreateManager_Call{Call: _e.mock.On("CreateManager", ctx, ms)} } -func (_c *ORM_CreateManager_Call) Run(run func(ms *feeds.FeedsManager, qopts ...pg.QOpt)) *ORM_CreateManager_Call { +func (_c *ORM_CreateManager_Call) Run(run func(ctx context.Context, ms *feeds.FeedsManager)) *ORM_CreateManager_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-1) - for i, a := range args[1:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].(*feeds.FeedsManager), variadicArgs...) + run(args[0].(context.Context), args[1].(*feeds.FeedsManager)) }) return _c } @@ -580,21 +516,14 @@ func (_c *ORM_CreateManager_Call) Return(_a0 int64, _a1 error) *ORM_CreateManage return _c } -func (_c *ORM_CreateManager_Call) RunAndReturn(run func(*feeds.FeedsManager, ...pg.QOpt) (int64, error)) *ORM_CreateManager_Call { +func (_c *ORM_CreateManager_Call) RunAndReturn(run func(context.Context, *feeds.FeedsManager) (int64, error)) *ORM_CreateManager_Call { _c.Call.Return(run) return _c } -// CreateSpec provides a mock function with given fields: spec, qopts -func (_m *ORM) CreateSpec(spec feeds.JobProposalSpec, qopts ...pg.QOpt) (int64, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, spec) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// CreateSpec provides a mock function with given fields: ctx, spec +func (_m *ORM) CreateSpec(ctx context.Context, spec feeds.JobProposalSpec) (int64, error) { + ret := _m.Called(ctx, spec) if len(ret) == 0 { panic("no return value specified for CreateSpec") @@ -602,17 +531,17 @@ func (_m *ORM) CreateSpec(spec feeds.JobProposalSpec, qopts ...pg.QOpt) (int64, var r0 int64 var r1 error - if rf, ok := ret.Get(0).(func(feeds.JobProposalSpec, ...pg.QOpt) (int64, error)); ok { - return rf(spec, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, feeds.JobProposalSpec) (int64, error)); ok { + return rf(ctx, spec) } - if rf, ok := ret.Get(0).(func(feeds.JobProposalSpec, ...pg.QOpt) int64); ok { - r0 = rf(spec, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, feeds.JobProposalSpec) int64); ok { + r0 = rf(ctx, spec) } else { r0 = ret.Get(0).(int64) } - if rf, ok := ret.Get(1).(func(feeds.JobProposalSpec, ...pg.QOpt) error); ok { - r1 = rf(spec, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, feeds.JobProposalSpec) error); ok { + r1 = rf(ctx, spec) } else { r1 = ret.Error(1) } @@ -626,22 +555,15 @@ type ORM_CreateSpec_Call struct { } // CreateSpec is a helper method to define mock.On call +// - ctx context.Context // - spec feeds.JobProposalSpec -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) CreateSpec(spec interface{}, qopts ...interface{}) *ORM_CreateSpec_Call { - return &ORM_CreateSpec_Call{Call: _e.mock.On("CreateSpec", - append([]interface{}{spec}, qopts...)...)} +func (_e *ORM_Expecter) CreateSpec(ctx interface{}, spec interface{}) *ORM_CreateSpec_Call { + return &ORM_CreateSpec_Call{Call: _e.mock.On("CreateSpec", ctx, spec)} } -func (_c *ORM_CreateSpec_Call) Run(run func(spec feeds.JobProposalSpec, qopts ...pg.QOpt)) *ORM_CreateSpec_Call { +func (_c *ORM_CreateSpec_Call) Run(run func(ctx context.Context, spec feeds.JobProposalSpec)) *ORM_CreateSpec_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-1) - for i, a := range args[1:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].(feeds.JobProposalSpec), variadicArgs...) + run(args[0].(context.Context), args[1].(feeds.JobProposalSpec)) }) return _c } @@ -651,14 +573,14 @@ func (_c *ORM_CreateSpec_Call) Return(_a0 int64, _a1 error) *ORM_CreateSpec_Call return _c } -func (_c *ORM_CreateSpec_Call) RunAndReturn(run func(feeds.JobProposalSpec, ...pg.QOpt) (int64, error)) *ORM_CreateSpec_Call { +func (_c *ORM_CreateSpec_Call) RunAndReturn(run func(context.Context, feeds.JobProposalSpec) (int64, error)) *ORM_CreateSpec_Call { _c.Call.Return(run) return _c } -// DeleteChainConfig provides a mock function with given fields: id -func (_m *ORM) DeleteChainConfig(id int64) (int64, error) { - ret := _m.Called(id) +// DeleteChainConfig provides a mock function with given fields: ctx, id +func (_m *ORM) DeleteChainConfig(ctx context.Context, id int64) (int64, error) { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for DeleteChainConfig") @@ -666,17 +588,17 @@ func (_m *ORM) DeleteChainConfig(id int64) (int64, error) { var r0 int64 var r1 error - if rf, ok := ret.Get(0).(func(int64) (int64, error)); ok { - return rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) (int64, error)); ok { + return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(int64) int64); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) int64); ok { + r0 = rf(ctx, id) } else { r0 = ret.Get(0).(int64) } - if rf, ok := ret.Get(1).(func(int64) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -690,14 +612,15 @@ type ORM_DeleteChainConfig_Call struct { } // DeleteChainConfig is a helper method to define mock.On call +// - ctx context.Context // - id int64 -func (_e *ORM_Expecter) DeleteChainConfig(id interface{}) *ORM_DeleteChainConfig_Call { - return &ORM_DeleteChainConfig_Call{Call: _e.mock.On("DeleteChainConfig", id)} +func (_e *ORM_Expecter) DeleteChainConfig(ctx interface{}, id interface{}) *ORM_DeleteChainConfig_Call { + return &ORM_DeleteChainConfig_Call{Call: _e.mock.On("DeleteChainConfig", ctx, id)} } -func (_c *ORM_DeleteChainConfig_Call) Run(run func(id int64)) *ORM_DeleteChainConfig_Call { +func (_c *ORM_DeleteChainConfig_Call) Run(run func(ctx context.Context, id int64)) *ORM_DeleteChainConfig_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64)) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -707,29 +630,22 @@ func (_c *ORM_DeleteChainConfig_Call) Return(_a0 int64, _a1 error) *ORM_DeleteCh return _c } -func (_c *ORM_DeleteChainConfig_Call) RunAndReturn(run func(int64) (int64, error)) *ORM_DeleteChainConfig_Call { +func (_c *ORM_DeleteChainConfig_Call) RunAndReturn(run func(context.Context, int64) (int64, error)) *ORM_DeleteChainConfig_Call { _c.Call.Return(run) return _c } -// DeleteProposal provides a mock function with given fields: id, qopts -func (_m *ORM) DeleteProposal(id int64, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, id) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// DeleteProposal provides a mock function with given fields: ctx, id +func (_m *ORM) DeleteProposal(ctx context.Context, id int64) error { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for DeleteProposal") } var r0 error - if rf, ok := ret.Get(0).(func(int64, ...pg.QOpt) error); ok { - r0 = rf(id, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, id) } else { r0 = ret.Error(0) } @@ -743,22 +659,15 @@ type ORM_DeleteProposal_Call struct { } // DeleteProposal is a helper method to define mock.On call +// - ctx context.Context // - id int64 -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) DeleteProposal(id interface{}, qopts ...interface{}) *ORM_DeleteProposal_Call { - return &ORM_DeleteProposal_Call{Call: _e.mock.On("DeleteProposal", - append([]interface{}{id}, qopts...)...)} +func (_e *ORM_Expecter) DeleteProposal(ctx interface{}, id interface{}) *ORM_DeleteProposal_Call { + return &ORM_DeleteProposal_Call{Call: _e.mock.On("DeleteProposal", ctx, id)} } -func (_c *ORM_DeleteProposal_Call) Run(run func(id int64, qopts ...pg.QOpt)) *ORM_DeleteProposal_Call { +func (_c *ORM_DeleteProposal_Call) Run(run func(ctx context.Context, id int64)) *ORM_DeleteProposal_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-1) - for i, a := range args[1:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].(int64), variadicArgs...) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -768,21 +677,14 @@ func (_c *ORM_DeleteProposal_Call) Return(_a0 error) *ORM_DeleteProposal_Call { return _c } -func (_c *ORM_DeleteProposal_Call) RunAndReturn(run func(int64, ...pg.QOpt) error) *ORM_DeleteProposal_Call { +func (_c *ORM_DeleteProposal_Call) RunAndReturn(run func(context.Context, int64) error) *ORM_DeleteProposal_Call { _c.Call.Return(run) return _c } -// ExistsSpecByJobProposalIDAndVersion provides a mock function with given fields: jpID, version, qopts -func (_m *ORM) ExistsSpecByJobProposalIDAndVersion(jpID int64, version int32, qopts ...pg.QOpt) (bool, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, jpID, version) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// ExistsSpecByJobProposalIDAndVersion provides a mock function with given fields: ctx, jpID, version +func (_m *ORM) ExistsSpecByJobProposalIDAndVersion(ctx context.Context, jpID int64, version int32) (bool, error) { + ret := _m.Called(ctx, jpID, version) if len(ret) == 0 { panic("no return value specified for ExistsSpecByJobProposalIDAndVersion") @@ -790,17 +692,17 @@ func (_m *ORM) ExistsSpecByJobProposalIDAndVersion(jpID int64, version int32, qo var r0 bool var r1 error - if rf, ok := ret.Get(0).(func(int64, int32, ...pg.QOpt) (bool, error)); ok { - return rf(jpID, version, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64, int32) (bool, error)); ok { + return rf(ctx, jpID, version) } - if rf, ok := ret.Get(0).(func(int64, int32, ...pg.QOpt) bool); ok { - r0 = rf(jpID, version, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64, int32) bool); ok { + r0 = rf(ctx, jpID, version) } else { r0 = ret.Get(0).(bool) } - if rf, ok := ret.Get(1).(func(int64, int32, ...pg.QOpt) error); ok { - r1 = rf(jpID, version, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, int64, int32) error); ok { + r1 = rf(ctx, jpID, version) } else { r1 = ret.Error(1) } @@ -814,23 +716,16 @@ type ORM_ExistsSpecByJobProposalIDAndVersion_Call struct { } // ExistsSpecByJobProposalIDAndVersion is a helper method to define mock.On call +// - ctx context.Context // - jpID int64 // - version int32 -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) ExistsSpecByJobProposalIDAndVersion(jpID interface{}, version interface{}, qopts ...interface{}) *ORM_ExistsSpecByJobProposalIDAndVersion_Call { - return &ORM_ExistsSpecByJobProposalIDAndVersion_Call{Call: _e.mock.On("ExistsSpecByJobProposalIDAndVersion", - append([]interface{}{jpID, version}, qopts...)...)} +func (_e *ORM_Expecter) ExistsSpecByJobProposalIDAndVersion(ctx interface{}, jpID interface{}, version interface{}) *ORM_ExistsSpecByJobProposalIDAndVersion_Call { + return &ORM_ExistsSpecByJobProposalIDAndVersion_Call{Call: _e.mock.On("ExistsSpecByJobProposalIDAndVersion", ctx, jpID, version)} } -func (_c *ORM_ExistsSpecByJobProposalIDAndVersion_Call) Run(run func(jpID int64, version int32, qopts ...pg.QOpt)) *ORM_ExistsSpecByJobProposalIDAndVersion_Call { +func (_c *ORM_ExistsSpecByJobProposalIDAndVersion_Call) Run(run func(ctx context.Context, jpID int64, version int32)) *ORM_ExistsSpecByJobProposalIDAndVersion_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-2) - for i, a := range args[2:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].(int64), args[1].(int32), variadicArgs...) + run(args[0].(context.Context), args[1].(int64), args[2].(int32)) }) return _c } @@ -840,21 +735,14 @@ func (_c *ORM_ExistsSpecByJobProposalIDAndVersion_Call) Return(exists bool, err return _c } -func (_c *ORM_ExistsSpecByJobProposalIDAndVersion_Call) RunAndReturn(run func(int64, int32, ...pg.QOpt) (bool, error)) *ORM_ExistsSpecByJobProposalIDAndVersion_Call { +func (_c *ORM_ExistsSpecByJobProposalIDAndVersion_Call) RunAndReturn(run func(context.Context, int64, int32) (bool, error)) *ORM_ExistsSpecByJobProposalIDAndVersion_Call { _c.Call.Return(run) return _c } -// GetApprovedSpec provides a mock function with given fields: jpID, qopts -func (_m *ORM) GetApprovedSpec(jpID int64, qopts ...pg.QOpt) (*feeds.JobProposalSpec, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, jpID) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// GetApprovedSpec provides a mock function with given fields: ctx, jpID +func (_m *ORM) GetApprovedSpec(ctx context.Context, jpID int64) (*feeds.JobProposalSpec, error) { + ret := _m.Called(ctx, jpID) if len(ret) == 0 { panic("no return value specified for GetApprovedSpec") @@ -862,19 +750,19 @@ func (_m *ORM) GetApprovedSpec(jpID int64, qopts ...pg.QOpt) (*feeds.JobProposal var r0 *feeds.JobProposalSpec var r1 error - if rf, ok := ret.Get(0).(func(int64, ...pg.QOpt) (*feeds.JobProposalSpec, error)); ok { - return rf(jpID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64) (*feeds.JobProposalSpec, error)); ok { + return rf(ctx, jpID) } - if rf, ok := ret.Get(0).(func(int64, ...pg.QOpt) *feeds.JobProposalSpec); ok { - r0 = rf(jpID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64) *feeds.JobProposalSpec); ok { + r0 = rf(ctx, jpID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*feeds.JobProposalSpec) } } - if rf, ok := ret.Get(1).(func(int64, ...pg.QOpt) error); ok { - r1 = rf(jpID, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, jpID) } else { r1 = ret.Error(1) } @@ -888,22 +776,15 @@ type ORM_GetApprovedSpec_Call struct { } // GetApprovedSpec is a helper method to define mock.On call +// - ctx context.Context // - jpID int64 -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) GetApprovedSpec(jpID interface{}, qopts ...interface{}) *ORM_GetApprovedSpec_Call { - return &ORM_GetApprovedSpec_Call{Call: _e.mock.On("GetApprovedSpec", - append([]interface{}{jpID}, qopts...)...)} +func (_e *ORM_Expecter) GetApprovedSpec(ctx interface{}, jpID interface{}) *ORM_GetApprovedSpec_Call { + return &ORM_GetApprovedSpec_Call{Call: _e.mock.On("GetApprovedSpec", ctx, jpID)} } -func (_c *ORM_GetApprovedSpec_Call) Run(run func(jpID int64, qopts ...pg.QOpt)) *ORM_GetApprovedSpec_Call { +func (_c *ORM_GetApprovedSpec_Call) Run(run func(ctx context.Context, jpID int64)) *ORM_GetApprovedSpec_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-1) - for i, a := range args[1:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].(int64), variadicArgs...) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -913,14 +794,14 @@ func (_c *ORM_GetApprovedSpec_Call) Return(_a0 *feeds.JobProposalSpec, _a1 error return _c } -func (_c *ORM_GetApprovedSpec_Call) RunAndReturn(run func(int64, ...pg.QOpt) (*feeds.JobProposalSpec, error)) *ORM_GetApprovedSpec_Call { +func (_c *ORM_GetApprovedSpec_Call) RunAndReturn(run func(context.Context, int64) (*feeds.JobProposalSpec, error)) *ORM_GetApprovedSpec_Call { _c.Call.Return(run) return _c } -// GetChainConfig provides a mock function with given fields: id -func (_m *ORM) GetChainConfig(id int64) (*feeds.ChainConfig, error) { - ret := _m.Called(id) +// GetChainConfig provides a mock function with given fields: ctx, id +func (_m *ORM) GetChainConfig(ctx context.Context, id int64) (*feeds.ChainConfig, error) { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for GetChainConfig") @@ -928,19 +809,19 @@ func (_m *ORM) GetChainConfig(id int64) (*feeds.ChainConfig, error) { var r0 *feeds.ChainConfig var r1 error - if rf, ok := ret.Get(0).(func(int64) (*feeds.ChainConfig, error)); ok { - return rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) (*feeds.ChainConfig, error)); ok { + return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(int64) *feeds.ChainConfig); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) *feeds.ChainConfig); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*feeds.ChainConfig) } } - if rf, ok := ret.Get(1).(func(int64) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -954,14 +835,15 @@ type ORM_GetChainConfig_Call struct { } // GetChainConfig is a helper method to define mock.On call +// - ctx context.Context // - id int64 -func (_e *ORM_Expecter) GetChainConfig(id interface{}) *ORM_GetChainConfig_Call { - return &ORM_GetChainConfig_Call{Call: _e.mock.On("GetChainConfig", id)} +func (_e *ORM_Expecter) GetChainConfig(ctx interface{}, id interface{}) *ORM_GetChainConfig_Call { + return &ORM_GetChainConfig_Call{Call: _e.mock.On("GetChainConfig", ctx, id)} } -func (_c *ORM_GetChainConfig_Call) Run(run func(id int64)) *ORM_GetChainConfig_Call { +func (_c *ORM_GetChainConfig_Call) Run(run func(ctx context.Context, id int64)) *ORM_GetChainConfig_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64)) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -971,21 +853,14 @@ func (_c *ORM_GetChainConfig_Call) Return(_a0 *feeds.ChainConfig, _a1 error) *OR return _c } -func (_c *ORM_GetChainConfig_Call) RunAndReturn(run func(int64) (*feeds.ChainConfig, error)) *ORM_GetChainConfig_Call { +func (_c *ORM_GetChainConfig_Call) RunAndReturn(run func(context.Context, int64) (*feeds.ChainConfig, error)) *ORM_GetChainConfig_Call { _c.Call.Return(run) return _c } -// GetJobProposal provides a mock function with given fields: id, qopts -func (_m *ORM) GetJobProposal(id int64, qopts ...pg.QOpt) (*feeds.JobProposal, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, id) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// GetJobProposal provides a mock function with given fields: ctx, id +func (_m *ORM) GetJobProposal(ctx context.Context, id int64) (*feeds.JobProposal, error) { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for GetJobProposal") @@ -993,19 +868,19 @@ func (_m *ORM) GetJobProposal(id int64, qopts ...pg.QOpt) (*feeds.JobProposal, e var r0 *feeds.JobProposal var r1 error - if rf, ok := ret.Get(0).(func(int64, ...pg.QOpt) (*feeds.JobProposal, error)); ok { - return rf(id, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64) (*feeds.JobProposal, error)); ok { + return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(int64, ...pg.QOpt) *feeds.JobProposal); ok { - r0 = rf(id, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64) *feeds.JobProposal); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*feeds.JobProposal) } } - if rf, ok := ret.Get(1).(func(int64, ...pg.QOpt) error); ok { - r1 = rf(id, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -1019,22 +894,15 @@ type ORM_GetJobProposal_Call struct { } // GetJobProposal is a helper method to define mock.On call +// - ctx context.Context // - id int64 -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) GetJobProposal(id interface{}, qopts ...interface{}) *ORM_GetJobProposal_Call { - return &ORM_GetJobProposal_Call{Call: _e.mock.On("GetJobProposal", - append([]interface{}{id}, qopts...)...)} +func (_e *ORM_Expecter) GetJobProposal(ctx interface{}, id interface{}) *ORM_GetJobProposal_Call { + return &ORM_GetJobProposal_Call{Call: _e.mock.On("GetJobProposal", ctx, id)} } -func (_c *ORM_GetJobProposal_Call) Run(run func(id int64, qopts ...pg.QOpt)) *ORM_GetJobProposal_Call { +func (_c *ORM_GetJobProposal_Call) Run(run func(ctx context.Context, id int64)) *ORM_GetJobProposal_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-1) - for i, a := range args[1:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].(int64), variadicArgs...) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -1044,14 +912,14 @@ func (_c *ORM_GetJobProposal_Call) Return(_a0 *feeds.JobProposal, _a1 error) *OR return _c } -func (_c *ORM_GetJobProposal_Call) RunAndReturn(run func(int64, ...pg.QOpt) (*feeds.JobProposal, error)) *ORM_GetJobProposal_Call { +func (_c *ORM_GetJobProposal_Call) RunAndReturn(run func(context.Context, int64) (*feeds.JobProposal, error)) *ORM_GetJobProposal_Call { _c.Call.Return(run) return _c } -// GetJobProposalByRemoteUUID provides a mock function with given fields: _a0 -func (_m *ORM) GetJobProposalByRemoteUUID(_a0 uuid.UUID) (*feeds.JobProposal, error) { - ret := _m.Called(_a0) +// GetJobProposalByRemoteUUID provides a mock function with given fields: ctx, _a1 +func (_m *ORM) GetJobProposalByRemoteUUID(ctx context.Context, _a1 uuid.UUID) (*feeds.JobProposal, error) { + ret := _m.Called(ctx, _a1) if len(ret) == 0 { panic("no return value specified for GetJobProposalByRemoteUUID") @@ -1059,19 +927,19 @@ func (_m *ORM) GetJobProposalByRemoteUUID(_a0 uuid.UUID) (*feeds.JobProposal, er var r0 *feeds.JobProposal var r1 error - if rf, ok := ret.Get(0).(func(uuid.UUID) (*feeds.JobProposal, error)); ok { - return rf(_a0) + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) (*feeds.JobProposal, error)); ok { + return rf(ctx, _a1) } - if rf, ok := ret.Get(0).(func(uuid.UUID) *feeds.JobProposal); ok { - r0 = rf(_a0) + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) *feeds.JobProposal); ok { + r0 = rf(ctx, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*feeds.JobProposal) } } - if rf, ok := ret.Get(1).(func(uuid.UUID) error); ok { - r1 = rf(_a0) + if rf, ok := ret.Get(1).(func(context.Context, uuid.UUID) error); ok { + r1 = rf(ctx, _a1) } else { r1 = ret.Error(1) } @@ -1085,14 +953,15 @@ type ORM_GetJobProposalByRemoteUUID_Call struct { } // GetJobProposalByRemoteUUID is a helper method to define mock.On call -// - _a0 uuid.UUID -func (_e *ORM_Expecter) GetJobProposalByRemoteUUID(_a0 interface{}) *ORM_GetJobProposalByRemoteUUID_Call { - return &ORM_GetJobProposalByRemoteUUID_Call{Call: _e.mock.On("GetJobProposalByRemoteUUID", _a0)} +// - ctx context.Context +// - _a1 uuid.UUID +func (_e *ORM_Expecter) GetJobProposalByRemoteUUID(ctx interface{}, _a1 interface{}) *ORM_GetJobProposalByRemoteUUID_Call { + return &ORM_GetJobProposalByRemoteUUID_Call{Call: _e.mock.On("GetJobProposalByRemoteUUID", ctx, _a1)} } -func (_c *ORM_GetJobProposalByRemoteUUID_Call) Run(run func(_a0 uuid.UUID)) *ORM_GetJobProposalByRemoteUUID_Call { +func (_c *ORM_GetJobProposalByRemoteUUID_Call) Run(run func(ctx context.Context, _a1 uuid.UUID)) *ORM_GetJobProposalByRemoteUUID_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(uuid.UUID)) + run(args[0].(context.Context), args[1].(uuid.UUID)) }) return _c } @@ -1102,14 +971,14 @@ func (_c *ORM_GetJobProposalByRemoteUUID_Call) Return(_a0 *feeds.JobProposal, _a return _c } -func (_c *ORM_GetJobProposalByRemoteUUID_Call) RunAndReturn(run func(uuid.UUID) (*feeds.JobProposal, error)) *ORM_GetJobProposalByRemoteUUID_Call { +func (_c *ORM_GetJobProposalByRemoteUUID_Call) RunAndReturn(run func(context.Context, uuid.UUID) (*feeds.JobProposal, error)) *ORM_GetJobProposalByRemoteUUID_Call { _c.Call.Return(run) return _c } -// GetLatestSpec provides a mock function with given fields: jpID -func (_m *ORM) GetLatestSpec(jpID int64) (*feeds.JobProposalSpec, error) { - ret := _m.Called(jpID) +// GetLatestSpec provides a mock function with given fields: ctx, jpID +func (_m *ORM) GetLatestSpec(ctx context.Context, jpID int64) (*feeds.JobProposalSpec, error) { + ret := _m.Called(ctx, jpID) if len(ret) == 0 { panic("no return value specified for GetLatestSpec") @@ -1117,19 +986,19 @@ func (_m *ORM) GetLatestSpec(jpID int64) (*feeds.JobProposalSpec, error) { var r0 *feeds.JobProposalSpec var r1 error - if rf, ok := ret.Get(0).(func(int64) (*feeds.JobProposalSpec, error)); ok { - return rf(jpID) + if rf, ok := ret.Get(0).(func(context.Context, int64) (*feeds.JobProposalSpec, error)); ok { + return rf(ctx, jpID) } - if rf, ok := ret.Get(0).(func(int64) *feeds.JobProposalSpec); ok { - r0 = rf(jpID) + if rf, ok := ret.Get(0).(func(context.Context, int64) *feeds.JobProposalSpec); ok { + r0 = rf(ctx, jpID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*feeds.JobProposalSpec) } } - if rf, ok := ret.Get(1).(func(int64) error); ok { - r1 = rf(jpID) + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, jpID) } else { r1 = ret.Error(1) } @@ -1143,14 +1012,15 @@ type ORM_GetLatestSpec_Call struct { } // GetLatestSpec is a helper method to define mock.On call +// - ctx context.Context // - jpID int64 -func (_e *ORM_Expecter) GetLatestSpec(jpID interface{}) *ORM_GetLatestSpec_Call { - return &ORM_GetLatestSpec_Call{Call: _e.mock.On("GetLatestSpec", jpID)} +func (_e *ORM_Expecter) GetLatestSpec(ctx interface{}, jpID interface{}) *ORM_GetLatestSpec_Call { + return &ORM_GetLatestSpec_Call{Call: _e.mock.On("GetLatestSpec", ctx, jpID)} } -func (_c *ORM_GetLatestSpec_Call) Run(run func(jpID int64)) *ORM_GetLatestSpec_Call { +func (_c *ORM_GetLatestSpec_Call) Run(run func(ctx context.Context, jpID int64)) *ORM_GetLatestSpec_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64)) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -1160,14 +1030,14 @@ func (_c *ORM_GetLatestSpec_Call) Return(_a0 *feeds.JobProposalSpec, _a1 error) return _c } -func (_c *ORM_GetLatestSpec_Call) RunAndReturn(run func(int64) (*feeds.JobProposalSpec, error)) *ORM_GetLatestSpec_Call { +func (_c *ORM_GetLatestSpec_Call) RunAndReturn(run func(context.Context, int64) (*feeds.JobProposalSpec, error)) *ORM_GetLatestSpec_Call { _c.Call.Return(run) return _c } -// GetManager provides a mock function with given fields: id -func (_m *ORM) GetManager(id int64) (*feeds.FeedsManager, error) { - ret := _m.Called(id) +// GetManager provides a mock function with given fields: ctx, id +func (_m *ORM) GetManager(ctx context.Context, id int64) (*feeds.FeedsManager, error) { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for GetManager") @@ -1175,19 +1045,19 @@ func (_m *ORM) GetManager(id int64) (*feeds.FeedsManager, error) { var r0 *feeds.FeedsManager var r1 error - if rf, ok := ret.Get(0).(func(int64) (*feeds.FeedsManager, error)); ok { - return rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) (*feeds.FeedsManager, error)); ok { + return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(int64) *feeds.FeedsManager); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) *feeds.FeedsManager); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*feeds.FeedsManager) } } - if rf, ok := ret.Get(1).(func(int64) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -1201,14 +1071,15 @@ type ORM_GetManager_Call struct { } // GetManager is a helper method to define mock.On call +// - ctx context.Context // - id int64 -func (_e *ORM_Expecter) GetManager(id interface{}) *ORM_GetManager_Call { - return &ORM_GetManager_Call{Call: _e.mock.On("GetManager", id)} +func (_e *ORM_Expecter) GetManager(ctx interface{}, id interface{}) *ORM_GetManager_Call { + return &ORM_GetManager_Call{Call: _e.mock.On("GetManager", ctx, id)} } -func (_c *ORM_GetManager_Call) Run(run func(id int64)) *ORM_GetManager_Call { +func (_c *ORM_GetManager_Call) Run(run func(ctx context.Context, id int64)) *ORM_GetManager_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64)) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -1218,21 +1089,14 @@ func (_c *ORM_GetManager_Call) Return(_a0 *feeds.FeedsManager, _a1 error) *ORM_G return _c } -func (_c *ORM_GetManager_Call) RunAndReturn(run func(int64) (*feeds.FeedsManager, error)) *ORM_GetManager_Call { +func (_c *ORM_GetManager_Call) RunAndReturn(run func(context.Context, int64) (*feeds.FeedsManager, error)) *ORM_GetManager_Call { _c.Call.Return(run) return _c } -// GetSpec provides a mock function with given fields: id, qopts -func (_m *ORM) GetSpec(id int64, qopts ...pg.QOpt) (*feeds.JobProposalSpec, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, id) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// GetSpec provides a mock function with given fields: ctx, id +func (_m *ORM) GetSpec(ctx context.Context, id int64) (*feeds.JobProposalSpec, error) { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for GetSpec") @@ -1240,19 +1104,19 @@ func (_m *ORM) GetSpec(id int64, qopts ...pg.QOpt) (*feeds.JobProposalSpec, erro var r0 *feeds.JobProposalSpec var r1 error - if rf, ok := ret.Get(0).(func(int64, ...pg.QOpt) (*feeds.JobProposalSpec, error)); ok { - return rf(id, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64) (*feeds.JobProposalSpec, error)); ok { + return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(int64, ...pg.QOpt) *feeds.JobProposalSpec); ok { - r0 = rf(id, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64) *feeds.JobProposalSpec); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*feeds.JobProposalSpec) } } - if rf, ok := ret.Get(1).(func(int64, ...pg.QOpt) error); ok { - r1 = rf(id, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -1266,22 +1130,15 @@ type ORM_GetSpec_Call struct { } // GetSpec is a helper method to define mock.On call +// - ctx context.Context // - id int64 -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) GetSpec(id interface{}, qopts ...interface{}) *ORM_GetSpec_Call { - return &ORM_GetSpec_Call{Call: _e.mock.On("GetSpec", - append([]interface{}{id}, qopts...)...)} +func (_e *ORM_Expecter) GetSpec(ctx interface{}, id interface{}) *ORM_GetSpec_Call { + return &ORM_GetSpec_Call{Call: _e.mock.On("GetSpec", ctx, id)} } -func (_c *ORM_GetSpec_Call) Run(run func(id int64, qopts ...pg.QOpt)) *ORM_GetSpec_Call { +func (_c *ORM_GetSpec_Call) Run(run func(ctx context.Context, id int64)) *ORM_GetSpec_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-1) - for i, a := range args[1:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].(int64), variadicArgs...) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -1291,21 +1148,14 @@ func (_c *ORM_GetSpec_Call) Return(_a0 *feeds.JobProposalSpec, _a1 error) *ORM_G return _c } -func (_c *ORM_GetSpec_Call) RunAndReturn(run func(int64, ...pg.QOpt) (*feeds.JobProposalSpec, error)) *ORM_GetSpec_Call { +func (_c *ORM_GetSpec_Call) RunAndReturn(run func(context.Context, int64) (*feeds.JobProposalSpec, error)) *ORM_GetSpec_Call { _c.Call.Return(run) return _c } -// IsJobManaged provides a mock function with given fields: jobID, qopts -func (_m *ORM) IsJobManaged(jobID int64, qopts ...pg.QOpt) (bool, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, jobID) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// IsJobManaged provides a mock function with given fields: ctx, jobID +func (_m *ORM) IsJobManaged(ctx context.Context, jobID int64) (bool, error) { + ret := _m.Called(ctx, jobID) if len(ret) == 0 { panic("no return value specified for IsJobManaged") @@ -1313,17 +1163,17 @@ func (_m *ORM) IsJobManaged(jobID int64, qopts ...pg.QOpt) (bool, error) { var r0 bool var r1 error - if rf, ok := ret.Get(0).(func(int64, ...pg.QOpt) (bool, error)); ok { - return rf(jobID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64) (bool, error)); ok { + return rf(ctx, jobID) } - if rf, ok := ret.Get(0).(func(int64, ...pg.QOpt) bool); ok { - r0 = rf(jobID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64) bool); ok { + r0 = rf(ctx, jobID) } else { r0 = ret.Get(0).(bool) } - if rf, ok := ret.Get(1).(func(int64, ...pg.QOpt) error); ok { - r1 = rf(jobID, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, jobID) } else { r1 = ret.Error(1) } @@ -1337,22 +1187,15 @@ type ORM_IsJobManaged_Call struct { } // IsJobManaged is a helper method to define mock.On call +// - ctx context.Context // - jobID int64 -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) IsJobManaged(jobID interface{}, qopts ...interface{}) *ORM_IsJobManaged_Call { - return &ORM_IsJobManaged_Call{Call: _e.mock.On("IsJobManaged", - append([]interface{}{jobID}, qopts...)...)} +func (_e *ORM_Expecter) IsJobManaged(ctx interface{}, jobID interface{}) *ORM_IsJobManaged_Call { + return &ORM_IsJobManaged_Call{Call: _e.mock.On("IsJobManaged", ctx, jobID)} } -func (_c *ORM_IsJobManaged_Call) Run(run func(jobID int64, qopts ...pg.QOpt)) *ORM_IsJobManaged_Call { +func (_c *ORM_IsJobManaged_Call) Run(run func(ctx context.Context, jobID int64)) *ORM_IsJobManaged_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-1) - for i, a := range args[1:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].(int64), variadicArgs...) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -1362,14 +1205,14 @@ func (_c *ORM_IsJobManaged_Call) Return(_a0 bool, _a1 error) *ORM_IsJobManaged_C return _c } -func (_c *ORM_IsJobManaged_Call) RunAndReturn(run func(int64, ...pg.QOpt) (bool, error)) *ORM_IsJobManaged_Call { +func (_c *ORM_IsJobManaged_Call) RunAndReturn(run func(context.Context, int64) (bool, error)) *ORM_IsJobManaged_Call { _c.Call.Return(run) return _c } -// ListChainConfigsByManagerIDs provides a mock function with given fields: mgrIDs -func (_m *ORM) ListChainConfigsByManagerIDs(mgrIDs []int64) ([]feeds.ChainConfig, error) { - ret := _m.Called(mgrIDs) +// ListChainConfigsByManagerIDs provides a mock function with given fields: ctx, mgrIDs +func (_m *ORM) ListChainConfigsByManagerIDs(ctx context.Context, mgrIDs []int64) ([]feeds.ChainConfig, error) { + ret := _m.Called(ctx, mgrIDs) if len(ret) == 0 { panic("no return value specified for ListChainConfigsByManagerIDs") @@ -1377,19 +1220,19 @@ func (_m *ORM) ListChainConfigsByManagerIDs(mgrIDs []int64) ([]feeds.ChainConfig var r0 []feeds.ChainConfig var r1 error - if rf, ok := ret.Get(0).(func([]int64) ([]feeds.ChainConfig, error)); ok { - return rf(mgrIDs) + if rf, ok := ret.Get(0).(func(context.Context, []int64) ([]feeds.ChainConfig, error)); ok { + return rf(ctx, mgrIDs) } - if rf, ok := ret.Get(0).(func([]int64) []feeds.ChainConfig); ok { - r0 = rf(mgrIDs) + if rf, ok := ret.Get(0).(func(context.Context, []int64) []feeds.ChainConfig); ok { + r0 = rf(ctx, mgrIDs) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]feeds.ChainConfig) } } - if rf, ok := ret.Get(1).(func([]int64) error); ok { - r1 = rf(mgrIDs) + if rf, ok := ret.Get(1).(func(context.Context, []int64) error); ok { + r1 = rf(ctx, mgrIDs) } else { r1 = ret.Error(1) } @@ -1403,14 +1246,15 @@ type ORM_ListChainConfigsByManagerIDs_Call struct { } // ListChainConfigsByManagerIDs is a helper method to define mock.On call +// - ctx context.Context // - mgrIDs []int64 -func (_e *ORM_Expecter) ListChainConfigsByManagerIDs(mgrIDs interface{}) *ORM_ListChainConfigsByManagerIDs_Call { - return &ORM_ListChainConfigsByManagerIDs_Call{Call: _e.mock.On("ListChainConfigsByManagerIDs", mgrIDs)} +func (_e *ORM_Expecter) ListChainConfigsByManagerIDs(ctx interface{}, mgrIDs interface{}) *ORM_ListChainConfigsByManagerIDs_Call { + return &ORM_ListChainConfigsByManagerIDs_Call{Call: _e.mock.On("ListChainConfigsByManagerIDs", ctx, mgrIDs)} } -func (_c *ORM_ListChainConfigsByManagerIDs_Call) Run(run func(mgrIDs []int64)) *ORM_ListChainConfigsByManagerIDs_Call { +func (_c *ORM_ListChainConfigsByManagerIDs_Call) Run(run func(ctx context.Context, mgrIDs []int64)) *ORM_ListChainConfigsByManagerIDs_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]int64)) + run(args[0].(context.Context), args[1].([]int64)) }) return _c } @@ -1420,14 +1264,14 @@ func (_c *ORM_ListChainConfigsByManagerIDs_Call) Return(_a0 []feeds.ChainConfig, return _c } -func (_c *ORM_ListChainConfigsByManagerIDs_Call) RunAndReturn(run func([]int64) ([]feeds.ChainConfig, error)) *ORM_ListChainConfigsByManagerIDs_Call { +func (_c *ORM_ListChainConfigsByManagerIDs_Call) RunAndReturn(run func(context.Context, []int64) ([]feeds.ChainConfig, error)) *ORM_ListChainConfigsByManagerIDs_Call { _c.Call.Return(run) return _c } -// ListJobProposals provides a mock function with given fields: -func (_m *ORM) ListJobProposals() ([]feeds.JobProposal, error) { - ret := _m.Called() +// ListJobProposals provides a mock function with given fields: ctx +func (_m *ORM) ListJobProposals(ctx context.Context) ([]feeds.JobProposal, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for ListJobProposals") @@ -1435,19 +1279,19 @@ func (_m *ORM) ListJobProposals() ([]feeds.JobProposal, error) { var r0 []feeds.JobProposal var r1 error - if rf, ok := ret.Get(0).(func() ([]feeds.JobProposal, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) ([]feeds.JobProposal, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() []feeds.JobProposal); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []feeds.JobProposal); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]feeds.JobProposal) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -1461,13 +1305,14 @@ type ORM_ListJobProposals_Call struct { } // ListJobProposals is a helper method to define mock.On call -func (_e *ORM_Expecter) ListJobProposals() *ORM_ListJobProposals_Call { - return &ORM_ListJobProposals_Call{Call: _e.mock.On("ListJobProposals")} +// - ctx context.Context +func (_e *ORM_Expecter) ListJobProposals(ctx interface{}) *ORM_ListJobProposals_Call { + return &ORM_ListJobProposals_Call{Call: _e.mock.On("ListJobProposals", ctx)} } -func (_c *ORM_ListJobProposals_Call) Run(run func()) *ORM_ListJobProposals_Call { +func (_c *ORM_ListJobProposals_Call) Run(run func(ctx context.Context)) *ORM_ListJobProposals_Call { _c.Call.Run(func(args mock.Arguments) { - run() + run(args[0].(context.Context)) }) return _c } @@ -1477,21 +1322,14 @@ func (_c *ORM_ListJobProposals_Call) Return(jps []feeds.JobProposal, err error) return _c } -func (_c *ORM_ListJobProposals_Call) RunAndReturn(run func() ([]feeds.JobProposal, error)) *ORM_ListJobProposals_Call { +func (_c *ORM_ListJobProposals_Call) RunAndReturn(run func(context.Context) ([]feeds.JobProposal, error)) *ORM_ListJobProposals_Call { _c.Call.Return(run) return _c } -// ListJobProposalsByManagersIDs provides a mock function with given fields: ids, qopts -func (_m *ORM) ListJobProposalsByManagersIDs(ids []int64, qopts ...pg.QOpt) ([]feeds.JobProposal, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, ids) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// ListJobProposalsByManagersIDs provides a mock function with given fields: ctx, ids +func (_m *ORM) ListJobProposalsByManagersIDs(ctx context.Context, ids []int64) ([]feeds.JobProposal, error) { + ret := _m.Called(ctx, ids) if len(ret) == 0 { panic("no return value specified for ListJobProposalsByManagersIDs") @@ -1499,19 +1337,19 @@ func (_m *ORM) ListJobProposalsByManagersIDs(ids []int64, qopts ...pg.QOpt) ([]f var r0 []feeds.JobProposal var r1 error - if rf, ok := ret.Get(0).(func([]int64, ...pg.QOpt) ([]feeds.JobProposal, error)); ok { - return rf(ids, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, []int64) ([]feeds.JobProposal, error)); ok { + return rf(ctx, ids) } - if rf, ok := ret.Get(0).(func([]int64, ...pg.QOpt) []feeds.JobProposal); ok { - r0 = rf(ids, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, []int64) []feeds.JobProposal); ok { + r0 = rf(ctx, ids) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]feeds.JobProposal) } } - if rf, ok := ret.Get(1).(func([]int64, ...pg.QOpt) error); ok { - r1 = rf(ids, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, []int64) error); ok { + r1 = rf(ctx, ids) } else { r1 = ret.Error(1) } @@ -1525,22 +1363,15 @@ type ORM_ListJobProposalsByManagersIDs_Call struct { } // ListJobProposalsByManagersIDs is a helper method to define mock.On call +// - ctx context.Context // - ids []int64 -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) ListJobProposalsByManagersIDs(ids interface{}, qopts ...interface{}) *ORM_ListJobProposalsByManagersIDs_Call { - return &ORM_ListJobProposalsByManagersIDs_Call{Call: _e.mock.On("ListJobProposalsByManagersIDs", - append([]interface{}{ids}, qopts...)...)} +func (_e *ORM_Expecter) ListJobProposalsByManagersIDs(ctx interface{}, ids interface{}) *ORM_ListJobProposalsByManagersIDs_Call { + return &ORM_ListJobProposalsByManagersIDs_Call{Call: _e.mock.On("ListJobProposalsByManagersIDs", ctx, ids)} } -func (_c *ORM_ListJobProposalsByManagersIDs_Call) Run(run func(ids []int64, qopts ...pg.QOpt)) *ORM_ListJobProposalsByManagersIDs_Call { +func (_c *ORM_ListJobProposalsByManagersIDs_Call) Run(run func(ctx context.Context, ids []int64)) *ORM_ListJobProposalsByManagersIDs_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-1) - for i, a := range args[1:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].([]int64), variadicArgs...) + run(args[0].(context.Context), args[1].([]int64)) }) return _c } @@ -1550,14 +1381,14 @@ func (_c *ORM_ListJobProposalsByManagersIDs_Call) Return(_a0 []feeds.JobProposal return _c } -func (_c *ORM_ListJobProposalsByManagersIDs_Call) RunAndReturn(run func([]int64, ...pg.QOpt) ([]feeds.JobProposal, error)) *ORM_ListJobProposalsByManagersIDs_Call { +func (_c *ORM_ListJobProposalsByManagersIDs_Call) RunAndReturn(run func(context.Context, []int64) ([]feeds.JobProposal, error)) *ORM_ListJobProposalsByManagersIDs_Call { _c.Call.Return(run) return _c } -// ListManagers provides a mock function with given fields: -func (_m *ORM) ListManagers() ([]feeds.FeedsManager, error) { - ret := _m.Called() +// ListManagers provides a mock function with given fields: ctx +func (_m *ORM) ListManagers(ctx context.Context) ([]feeds.FeedsManager, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for ListManagers") @@ -1565,19 +1396,19 @@ func (_m *ORM) ListManagers() ([]feeds.FeedsManager, error) { var r0 []feeds.FeedsManager var r1 error - if rf, ok := ret.Get(0).(func() ([]feeds.FeedsManager, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) ([]feeds.FeedsManager, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() []feeds.FeedsManager); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []feeds.FeedsManager); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]feeds.FeedsManager) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -1591,13 +1422,14 @@ type ORM_ListManagers_Call struct { } // ListManagers is a helper method to define mock.On call -func (_e *ORM_Expecter) ListManagers() *ORM_ListManagers_Call { - return &ORM_ListManagers_Call{Call: _e.mock.On("ListManagers")} +// - ctx context.Context +func (_e *ORM_Expecter) ListManagers(ctx interface{}) *ORM_ListManagers_Call { + return &ORM_ListManagers_Call{Call: _e.mock.On("ListManagers", ctx)} } -func (_c *ORM_ListManagers_Call) Run(run func()) *ORM_ListManagers_Call { +func (_c *ORM_ListManagers_Call) Run(run func(ctx context.Context)) *ORM_ListManagers_Call { _c.Call.Run(func(args mock.Arguments) { - run() + run(args[0].(context.Context)) }) return _c } @@ -1607,14 +1439,14 @@ func (_c *ORM_ListManagers_Call) Return(mgrs []feeds.FeedsManager, err error) *O return _c } -func (_c *ORM_ListManagers_Call) RunAndReturn(run func() ([]feeds.FeedsManager, error)) *ORM_ListManagers_Call { +func (_c *ORM_ListManagers_Call) RunAndReturn(run func(context.Context) ([]feeds.FeedsManager, error)) *ORM_ListManagers_Call { _c.Call.Return(run) return _c } -// ListManagersByIDs provides a mock function with given fields: ids -func (_m *ORM) ListManagersByIDs(ids []int64) ([]feeds.FeedsManager, error) { - ret := _m.Called(ids) +// ListManagersByIDs provides a mock function with given fields: ctx, ids +func (_m *ORM) ListManagersByIDs(ctx context.Context, ids []int64) ([]feeds.FeedsManager, error) { + ret := _m.Called(ctx, ids) if len(ret) == 0 { panic("no return value specified for ListManagersByIDs") @@ -1622,19 +1454,19 @@ func (_m *ORM) ListManagersByIDs(ids []int64) ([]feeds.FeedsManager, error) { var r0 []feeds.FeedsManager var r1 error - if rf, ok := ret.Get(0).(func([]int64) ([]feeds.FeedsManager, error)); ok { - return rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int64) ([]feeds.FeedsManager, error)); ok { + return rf(ctx, ids) } - if rf, ok := ret.Get(0).(func([]int64) []feeds.FeedsManager); ok { - r0 = rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int64) []feeds.FeedsManager); ok { + r0 = rf(ctx, ids) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]feeds.FeedsManager) } } - if rf, ok := ret.Get(1).(func([]int64) error); ok { - r1 = rf(ids) + if rf, ok := ret.Get(1).(func(context.Context, []int64) error); ok { + r1 = rf(ctx, ids) } else { r1 = ret.Error(1) } @@ -1648,14 +1480,15 @@ type ORM_ListManagersByIDs_Call struct { } // ListManagersByIDs is a helper method to define mock.On call +// - ctx context.Context // - ids []int64 -func (_e *ORM_Expecter) ListManagersByIDs(ids interface{}) *ORM_ListManagersByIDs_Call { - return &ORM_ListManagersByIDs_Call{Call: _e.mock.On("ListManagersByIDs", ids)} +func (_e *ORM_Expecter) ListManagersByIDs(ctx interface{}, ids interface{}) *ORM_ListManagersByIDs_Call { + return &ORM_ListManagersByIDs_Call{Call: _e.mock.On("ListManagersByIDs", ctx, ids)} } -func (_c *ORM_ListManagersByIDs_Call) Run(run func(ids []int64)) *ORM_ListManagersByIDs_Call { +func (_c *ORM_ListManagersByIDs_Call) Run(run func(ctx context.Context, ids []int64)) *ORM_ListManagersByIDs_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]int64)) + run(args[0].(context.Context), args[1].([]int64)) }) return _c } @@ -1665,21 +1498,14 @@ func (_c *ORM_ListManagersByIDs_Call) Return(_a0 []feeds.FeedsManager, _a1 error return _c } -func (_c *ORM_ListManagersByIDs_Call) RunAndReturn(run func([]int64) ([]feeds.FeedsManager, error)) *ORM_ListManagersByIDs_Call { +func (_c *ORM_ListManagersByIDs_Call) RunAndReturn(run func(context.Context, []int64) ([]feeds.FeedsManager, error)) *ORM_ListManagersByIDs_Call { _c.Call.Return(run) return _c } -// ListSpecsByJobProposalIDs provides a mock function with given fields: ids, qopts -func (_m *ORM) ListSpecsByJobProposalIDs(ids []int64, qopts ...pg.QOpt) ([]feeds.JobProposalSpec, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, ids) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// ListSpecsByJobProposalIDs provides a mock function with given fields: ctx, ids +func (_m *ORM) ListSpecsByJobProposalIDs(ctx context.Context, ids []int64) ([]feeds.JobProposalSpec, error) { + ret := _m.Called(ctx, ids) if len(ret) == 0 { panic("no return value specified for ListSpecsByJobProposalIDs") @@ -1687,19 +1513,19 @@ func (_m *ORM) ListSpecsByJobProposalIDs(ids []int64, qopts ...pg.QOpt) ([]feeds var r0 []feeds.JobProposalSpec var r1 error - if rf, ok := ret.Get(0).(func([]int64, ...pg.QOpt) ([]feeds.JobProposalSpec, error)); ok { - return rf(ids, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, []int64) ([]feeds.JobProposalSpec, error)); ok { + return rf(ctx, ids) } - if rf, ok := ret.Get(0).(func([]int64, ...pg.QOpt) []feeds.JobProposalSpec); ok { - r0 = rf(ids, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, []int64) []feeds.JobProposalSpec); ok { + r0 = rf(ctx, ids) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]feeds.JobProposalSpec) } } - if rf, ok := ret.Get(1).(func([]int64, ...pg.QOpt) error); ok { - r1 = rf(ids, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, []int64) error); ok { + r1 = rf(ctx, ids) } else { r1 = ret.Error(1) } @@ -1713,22 +1539,15 @@ type ORM_ListSpecsByJobProposalIDs_Call struct { } // ListSpecsByJobProposalIDs is a helper method to define mock.On call +// - ctx context.Context // - ids []int64 -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) ListSpecsByJobProposalIDs(ids interface{}, qopts ...interface{}) *ORM_ListSpecsByJobProposalIDs_Call { - return &ORM_ListSpecsByJobProposalIDs_Call{Call: _e.mock.On("ListSpecsByJobProposalIDs", - append([]interface{}{ids}, qopts...)...)} +func (_e *ORM_Expecter) ListSpecsByJobProposalIDs(ctx interface{}, ids interface{}) *ORM_ListSpecsByJobProposalIDs_Call { + return &ORM_ListSpecsByJobProposalIDs_Call{Call: _e.mock.On("ListSpecsByJobProposalIDs", ctx, ids)} } -func (_c *ORM_ListSpecsByJobProposalIDs_Call) Run(run func(ids []int64, qopts ...pg.QOpt)) *ORM_ListSpecsByJobProposalIDs_Call { +func (_c *ORM_ListSpecsByJobProposalIDs_Call) Run(run func(ctx context.Context, ids []int64)) *ORM_ListSpecsByJobProposalIDs_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-1) - for i, a := range args[1:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].([]int64), variadicArgs...) + run(args[0].(context.Context), args[1].([]int64)) }) return _c } @@ -1738,29 +1557,22 @@ func (_c *ORM_ListSpecsByJobProposalIDs_Call) Return(_a0 []feeds.JobProposalSpec return _c } -func (_c *ORM_ListSpecsByJobProposalIDs_Call) RunAndReturn(run func([]int64, ...pg.QOpt) ([]feeds.JobProposalSpec, error)) *ORM_ListSpecsByJobProposalIDs_Call { +func (_c *ORM_ListSpecsByJobProposalIDs_Call) RunAndReturn(run func(context.Context, []int64) ([]feeds.JobProposalSpec, error)) *ORM_ListSpecsByJobProposalIDs_Call { _c.Call.Return(run) return _c } -// RejectSpec provides a mock function with given fields: id, qopts -func (_m *ORM) RejectSpec(id int64, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, id) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// RejectSpec provides a mock function with given fields: ctx, id +func (_m *ORM) RejectSpec(ctx context.Context, id int64) error { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for RejectSpec") } var r0 error - if rf, ok := ret.Get(0).(func(int64, ...pg.QOpt) error); ok { - r0 = rf(id, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, id) } else { r0 = ret.Error(0) } @@ -1774,22 +1586,15 @@ type ORM_RejectSpec_Call struct { } // RejectSpec is a helper method to define mock.On call +// - ctx context.Context // - id int64 -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) RejectSpec(id interface{}, qopts ...interface{}) *ORM_RejectSpec_Call { - return &ORM_RejectSpec_Call{Call: _e.mock.On("RejectSpec", - append([]interface{}{id}, qopts...)...)} +func (_e *ORM_Expecter) RejectSpec(ctx interface{}, id interface{}) *ORM_RejectSpec_Call { + return &ORM_RejectSpec_Call{Call: _e.mock.On("RejectSpec", ctx, id)} } -func (_c *ORM_RejectSpec_Call) Run(run func(id int64, qopts ...pg.QOpt)) *ORM_RejectSpec_Call { +func (_c *ORM_RejectSpec_Call) Run(run func(ctx context.Context, id int64)) *ORM_RejectSpec_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-1) - for i, a := range args[1:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].(int64), variadicArgs...) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -1799,29 +1604,22 @@ func (_c *ORM_RejectSpec_Call) Return(_a0 error) *ORM_RejectSpec_Call { return _c } -func (_c *ORM_RejectSpec_Call) RunAndReturn(run func(int64, ...pg.QOpt) error) *ORM_RejectSpec_Call { +func (_c *ORM_RejectSpec_Call) RunAndReturn(run func(context.Context, int64) error) *ORM_RejectSpec_Call { _c.Call.Return(run) return _c } -// RevokeSpec provides a mock function with given fields: id, qopts -func (_m *ORM) RevokeSpec(id int64, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, id) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// RevokeSpec provides a mock function with given fields: ctx, id +func (_m *ORM) RevokeSpec(ctx context.Context, id int64) error { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for RevokeSpec") } var r0 error - if rf, ok := ret.Get(0).(func(int64, ...pg.QOpt) error); ok { - r0 = rf(id, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, id) } else { r0 = ret.Error(0) } @@ -1835,22 +1633,15 @@ type ORM_RevokeSpec_Call struct { } // RevokeSpec is a helper method to define mock.On call +// - ctx context.Context // - id int64 -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) RevokeSpec(id interface{}, qopts ...interface{}) *ORM_RevokeSpec_Call { - return &ORM_RevokeSpec_Call{Call: _e.mock.On("RevokeSpec", - append([]interface{}{id}, qopts...)...)} +func (_e *ORM_Expecter) RevokeSpec(ctx interface{}, id interface{}) *ORM_RevokeSpec_Call { + return &ORM_RevokeSpec_Call{Call: _e.mock.On("RevokeSpec", ctx, id)} } -func (_c *ORM_RevokeSpec_Call) Run(run func(id int64, qopts ...pg.QOpt)) *ORM_RevokeSpec_Call { +func (_c *ORM_RevokeSpec_Call) Run(run func(ctx context.Context, id int64)) *ORM_RevokeSpec_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-1) - for i, a := range args[1:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].(int64), variadicArgs...) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -1860,14 +1651,61 @@ func (_c *ORM_RevokeSpec_Call) Return(_a0 error) *ORM_RevokeSpec_Call { return _c } -func (_c *ORM_RevokeSpec_Call) RunAndReturn(run func(int64, ...pg.QOpt) error) *ORM_RevokeSpec_Call { +func (_c *ORM_RevokeSpec_Call) RunAndReturn(run func(context.Context, int64) error) *ORM_RevokeSpec_Call { _c.Call.Return(run) return _c } -// UpdateChainConfig provides a mock function with given fields: cfg -func (_m *ORM) UpdateChainConfig(cfg feeds.ChainConfig) (int64, error) { - ret := _m.Called(cfg) +// Transact provides a mock function with given fields: _a0, _a1 +func (_m *ORM) Transact(_a0 context.Context, _a1 func(feeds.ORM) error) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Transact") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, func(feeds.ORM) error) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ORM_Transact_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Transact' +type ORM_Transact_Call struct { + *mock.Call +} + +// Transact is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 func(feeds.ORM) error +func (_e *ORM_Expecter) Transact(_a0 interface{}, _a1 interface{}) *ORM_Transact_Call { + return &ORM_Transact_Call{Call: _e.mock.On("Transact", _a0, _a1)} +} + +func (_c *ORM_Transact_Call) Run(run func(_a0 context.Context, _a1 func(feeds.ORM) error)) *ORM_Transact_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(func(feeds.ORM) error)) + }) + return _c +} + +func (_c *ORM_Transact_Call) Return(_a0 error) *ORM_Transact_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *ORM_Transact_Call) RunAndReturn(run func(context.Context, func(feeds.ORM) error) error) *ORM_Transact_Call { + _c.Call.Return(run) + return _c +} + +// UpdateChainConfig provides a mock function with given fields: ctx, cfg +func (_m *ORM) UpdateChainConfig(ctx context.Context, cfg feeds.ChainConfig) (int64, error) { + ret := _m.Called(ctx, cfg) if len(ret) == 0 { panic("no return value specified for UpdateChainConfig") @@ -1875,17 +1713,17 @@ func (_m *ORM) UpdateChainConfig(cfg feeds.ChainConfig) (int64, error) { var r0 int64 var r1 error - if rf, ok := ret.Get(0).(func(feeds.ChainConfig) (int64, error)); ok { - return rf(cfg) + if rf, ok := ret.Get(0).(func(context.Context, feeds.ChainConfig) (int64, error)); ok { + return rf(ctx, cfg) } - if rf, ok := ret.Get(0).(func(feeds.ChainConfig) int64); ok { - r0 = rf(cfg) + if rf, ok := ret.Get(0).(func(context.Context, feeds.ChainConfig) int64); ok { + r0 = rf(ctx, cfg) } else { r0 = ret.Get(0).(int64) } - if rf, ok := ret.Get(1).(func(feeds.ChainConfig) error); ok { - r1 = rf(cfg) + if rf, ok := ret.Get(1).(func(context.Context, feeds.ChainConfig) error); ok { + r1 = rf(ctx, cfg) } else { r1 = ret.Error(1) } @@ -1899,14 +1737,15 @@ type ORM_UpdateChainConfig_Call struct { } // UpdateChainConfig is a helper method to define mock.On call +// - ctx context.Context // - cfg feeds.ChainConfig -func (_e *ORM_Expecter) UpdateChainConfig(cfg interface{}) *ORM_UpdateChainConfig_Call { - return &ORM_UpdateChainConfig_Call{Call: _e.mock.On("UpdateChainConfig", cfg)} +func (_e *ORM_Expecter) UpdateChainConfig(ctx interface{}, cfg interface{}) *ORM_UpdateChainConfig_Call { + return &ORM_UpdateChainConfig_Call{Call: _e.mock.On("UpdateChainConfig", ctx, cfg)} } -func (_c *ORM_UpdateChainConfig_Call) Run(run func(cfg feeds.ChainConfig)) *ORM_UpdateChainConfig_Call { +func (_c *ORM_UpdateChainConfig_Call) Run(run func(ctx context.Context, cfg feeds.ChainConfig)) *ORM_UpdateChainConfig_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(feeds.ChainConfig)) + run(args[0].(context.Context), args[1].(feeds.ChainConfig)) }) return _c } @@ -1916,29 +1755,22 @@ func (_c *ORM_UpdateChainConfig_Call) Return(_a0 int64, _a1 error) *ORM_UpdateCh return _c } -func (_c *ORM_UpdateChainConfig_Call) RunAndReturn(run func(feeds.ChainConfig) (int64, error)) *ORM_UpdateChainConfig_Call { +func (_c *ORM_UpdateChainConfig_Call) RunAndReturn(run func(context.Context, feeds.ChainConfig) (int64, error)) *ORM_UpdateChainConfig_Call { _c.Call.Return(run) return _c } -// UpdateJobProposalStatus provides a mock function with given fields: id, status, qopts -func (_m *ORM) UpdateJobProposalStatus(id int64, status feeds.JobProposalStatus, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, id, status) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// UpdateJobProposalStatus provides a mock function with given fields: ctx, id, status +func (_m *ORM) UpdateJobProposalStatus(ctx context.Context, id int64, status feeds.JobProposalStatus) error { + ret := _m.Called(ctx, id, status) if len(ret) == 0 { panic("no return value specified for UpdateJobProposalStatus") } var r0 error - if rf, ok := ret.Get(0).(func(int64, feeds.JobProposalStatus, ...pg.QOpt) error); ok { - r0 = rf(id, status, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64, feeds.JobProposalStatus) error); ok { + r0 = rf(ctx, id, status) } else { r0 = ret.Error(0) } @@ -1952,23 +1784,16 @@ type ORM_UpdateJobProposalStatus_Call struct { } // UpdateJobProposalStatus is a helper method to define mock.On call +// - ctx context.Context // - id int64 // - status feeds.JobProposalStatus -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) UpdateJobProposalStatus(id interface{}, status interface{}, qopts ...interface{}) *ORM_UpdateJobProposalStatus_Call { - return &ORM_UpdateJobProposalStatus_Call{Call: _e.mock.On("UpdateJobProposalStatus", - append([]interface{}{id, status}, qopts...)...)} +func (_e *ORM_Expecter) UpdateJobProposalStatus(ctx interface{}, id interface{}, status interface{}) *ORM_UpdateJobProposalStatus_Call { + return &ORM_UpdateJobProposalStatus_Call{Call: _e.mock.On("UpdateJobProposalStatus", ctx, id, status)} } -func (_c *ORM_UpdateJobProposalStatus_Call) Run(run func(id int64, status feeds.JobProposalStatus, qopts ...pg.QOpt)) *ORM_UpdateJobProposalStatus_Call { +func (_c *ORM_UpdateJobProposalStatus_Call) Run(run func(ctx context.Context, id int64, status feeds.JobProposalStatus)) *ORM_UpdateJobProposalStatus_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-2) - for i, a := range args[2:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].(int64), args[1].(feeds.JobProposalStatus), variadicArgs...) + run(args[0].(context.Context), args[1].(int64), args[2].(feeds.JobProposalStatus)) }) return _c } @@ -1978,29 +1803,22 @@ func (_c *ORM_UpdateJobProposalStatus_Call) Return(_a0 error) *ORM_UpdateJobProp return _c } -func (_c *ORM_UpdateJobProposalStatus_Call) RunAndReturn(run func(int64, feeds.JobProposalStatus, ...pg.QOpt) error) *ORM_UpdateJobProposalStatus_Call { +func (_c *ORM_UpdateJobProposalStatus_Call) RunAndReturn(run func(context.Context, int64, feeds.JobProposalStatus) error) *ORM_UpdateJobProposalStatus_Call { _c.Call.Return(run) return _c } -// UpdateManager provides a mock function with given fields: mgr, qopts -func (_m *ORM) UpdateManager(mgr feeds.FeedsManager, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, mgr) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// UpdateManager provides a mock function with given fields: ctx, mgr +func (_m *ORM) UpdateManager(ctx context.Context, mgr feeds.FeedsManager) error { + ret := _m.Called(ctx, mgr) if len(ret) == 0 { panic("no return value specified for UpdateManager") } var r0 error - if rf, ok := ret.Get(0).(func(feeds.FeedsManager, ...pg.QOpt) error); ok { - r0 = rf(mgr, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, feeds.FeedsManager) error); ok { + r0 = rf(ctx, mgr) } else { r0 = ret.Error(0) } @@ -2014,22 +1832,15 @@ type ORM_UpdateManager_Call struct { } // UpdateManager is a helper method to define mock.On call +// - ctx context.Context // - mgr feeds.FeedsManager -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) UpdateManager(mgr interface{}, qopts ...interface{}) *ORM_UpdateManager_Call { - return &ORM_UpdateManager_Call{Call: _e.mock.On("UpdateManager", - append([]interface{}{mgr}, qopts...)...)} +func (_e *ORM_Expecter) UpdateManager(ctx interface{}, mgr interface{}) *ORM_UpdateManager_Call { + return &ORM_UpdateManager_Call{Call: _e.mock.On("UpdateManager", ctx, mgr)} } -func (_c *ORM_UpdateManager_Call) Run(run func(mgr feeds.FeedsManager, qopts ...pg.QOpt)) *ORM_UpdateManager_Call { +func (_c *ORM_UpdateManager_Call) Run(run func(ctx context.Context, mgr feeds.FeedsManager)) *ORM_UpdateManager_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-1) - for i, a := range args[1:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].(feeds.FeedsManager), variadicArgs...) + run(args[0].(context.Context), args[1].(feeds.FeedsManager)) }) return _c } @@ -2039,29 +1850,22 @@ func (_c *ORM_UpdateManager_Call) Return(_a0 error) *ORM_UpdateManager_Call { return _c } -func (_c *ORM_UpdateManager_Call) RunAndReturn(run func(feeds.FeedsManager, ...pg.QOpt) error) *ORM_UpdateManager_Call { +func (_c *ORM_UpdateManager_Call) RunAndReturn(run func(context.Context, feeds.FeedsManager) error) *ORM_UpdateManager_Call { _c.Call.Return(run) return _c } -// UpdateSpecDefinition provides a mock function with given fields: id, spec, qopts -func (_m *ORM) UpdateSpecDefinition(id int64, spec string, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, id, spec) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// UpdateSpecDefinition provides a mock function with given fields: ctx, id, spec +func (_m *ORM) UpdateSpecDefinition(ctx context.Context, id int64, spec string) error { + ret := _m.Called(ctx, id, spec) if len(ret) == 0 { panic("no return value specified for UpdateSpecDefinition") } var r0 error - if rf, ok := ret.Get(0).(func(int64, string, ...pg.QOpt) error); ok { - r0 = rf(id, spec, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64, string) error); ok { + r0 = rf(ctx, id, spec) } else { r0 = ret.Error(0) } @@ -2075,23 +1879,16 @@ type ORM_UpdateSpecDefinition_Call struct { } // UpdateSpecDefinition is a helper method to define mock.On call +// - ctx context.Context // - id int64 // - spec string -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) UpdateSpecDefinition(id interface{}, spec interface{}, qopts ...interface{}) *ORM_UpdateSpecDefinition_Call { - return &ORM_UpdateSpecDefinition_Call{Call: _e.mock.On("UpdateSpecDefinition", - append([]interface{}{id, spec}, qopts...)...)} +func (_e *ORM_Expecter) UpdateSpecDefinition(ctx interface{}, id interface{}, spec interface{}) *ORM_UpdateSpecDefinition_Call { + return &ORM_UpdateSpecDefinition_Call{Call: _e.mock.On("UpdateSpecDefinition", ctx, id, spec)} } -func (_c *ORM_UpdateSpecDefinition_Call) Run(run func(id int64, spec string, qopts ...pg.QOpt)) *ORM_UpdateSpecDefinition_Call { +func (_c *ORM_UpdateSpecDefinition_Call) Run(run func(ctx context.Context, id int64, spec string)) *ORM_UpdateSpecDefinition_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-2) - for i, a := range args[2:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].(int64), args[1].(string), variadicArgs...) + run(args[0].(context.Context), args[1].(int64), args[2].(string)) }) return _c } @@ -2101,21 +1898,14 @@ func (_c *ORM_UpdateSpecDefinition_Call) Return(_a0 error) *ORM_UpdateSpecDefini return _c } -func (_c *ORM_UpdateSpecDefinition_Call) RunAndReturn(run func(int64, string, ...pg.QOpt) error) *ORM_UpdateSpecDefinition_Call { +func (_c *ORM_UpdateSpecDefinition_Call) RunAndReturn(run func(context.Context, int64, string) error) *ORM_UpdateSpecDefinition_Call { _c.Call.Return(run) return _c } -// UpsertJobProposal provides a mock function with given fields: jp, qopts -func (_m *ORM) UpsertJobProposal(jp *feeds.JobProposal, qopts ...pg.QOpt) (int64, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, jp) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// UpsertJobProposal provides a mock function with given fields: ctx, jp +func (_m *ORM) UpsertJobProposal(ctx context.Context, jp *feeds.JobProposal) (int64, error) { + ret := _m.Called(ctx, jp) if len(ret) == 0 { panic("no return value specified for UpsertJobProposal") @@ -2123,17 +1913,17 @@ func (_m *ORM) UpsertJobProposal(jp *feeds.JobProposal, qopts ...pg.QOpt) (int64 var r0 int64 var r1 error - if rf, ok := ret.Get(0).(func(*feeds.JobProposal, ...pg.QOpt) (int64, error)); ok { - return rf(jp, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *feeds.JobProposal) (int64, error)); ok { + return rf(ctx, jp) } - if rf, ok := ret.Get(0).(func(*feeds.JobProposal, ...pg.QOpt) int64); ok { - r0 = rf(jp, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *feeds.JobProposal) int64); ok { + r0 = rf(ctx, jp) } else { r0 = ret.Get(0).(int64) } - if rf, ok := ret.Get(1).(func(*feeds.JobProposal, ...pg.QOpt) error); ok { - r1 = rf(jp, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, *feeds.JobProposal) error); ok { + r1 = rf(ctx, jp) } else { r1 = ret.Error(1) } @@ -2147,22 +1937,15 @@ type ORM_UpsertJobProposal_Call struct { } // UpsertJobProposal is a helper method to define mock.On call +// - ctx context.Context // - jp *feeds.JobProposal -// - qopts ...pg.QOpt -func (_e *ORM_Expecter) UpsertJobProposal(jp interface{}, qopts ...interface{}) *ORM_UpsertJobProposal_Call { - return &ORM_UpsertJobProposal_Call{Call: _e.mock.On("UpsertJobProposal", - append([]interface{}{jp}, qopts...)...)} +func (_e *ORM_Expecter) UpsertJobProposal(ctx interface{}, jp interface{}) *ORM_UpsertJobProposal_Call { + return &ORM_UpsertJobProposal_Call{Call: _e.mock.On("UpsertJobProposal", ctx, jp)} } -func (_c *ORM_UpsertJobProposal_Call) Run(run func(jp *feeds.JobProposal, qopts ...pg.QOpt)) *ORM_UpsertJobProposal_Call { +func (_c *ORM_UpsertJobProposal_Call) Run(run func(ctx context.Context, jp *feeds.JobProposal)) *ORM_UpsertJobProposal_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]pg.QOpt, len(args)-1) - for i, a := range args[1:] { - if a != nil { - variadicArgs[i] = a.(pg.QOpt) - } - } - run(args[0].(*feeds.JobProposal), variadicArgs...) + run(args[0].(context.Context), args[1].(*feeds.JobProposal)) }) return _c } @@ -2172,7 +1955,55 @@ func (_c *ORM_UpsertJobProposal_Call) Return(_a0 int64, _a1 error) *ORM_UpsertJo return _c } -func (_c *ORM_UpsertJobProposal_Call) RunAndReturn(run func(*feeds.JobProposal, ...pg.QOpt) (int64, error)) *ORM_UpsertJobProposal_Call { +func (_c *ORM_UpsertJobProposal_Call) RunAndReturn(run func(context.Context, *feeds.JobProposal) (int64, error)) *ORM_UpsertJobProposal_Call { + _c.Call.Return(run) + return _c +} + +// WithDataSource provides a mock function with given fields: _a0 +func (_m *ORM) WithDataSource(_a0 sqlutil.DataSource) feeds.ORM { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for WithDataSource") + } + + var r0 feeds.ORM + if rf, ok := ret.Get(0).(func(sqlutil.DataSource) feeds.ORM); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(feeds.ORM) + } + } + + return r0 +} + +// ORM_WithDataSource_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WithDataSource' +type ORM_WithDataSource_Call struct { + *mock.Call +} + +// WithDataSource is a helper method to define mock.On call +// - _a0 sqlutil.DataSource +func (_e *ORM_Expecter) WithDataSource(_a0 interface{}) *ORM_WithDataSource_Call { + return &ORM_WithDataSource_Call{Call: _e.mock.On("WithDataSource", _a0)} +} + +func (_c *ORM_WithDataSource_Call) Run(run func(_a0 sqlutil.DataSource)) *ORM_WithDataSource_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(sqlutil.DataSource)) + }) + return _c +} + +func (_c *ORM_WithDataSource_Call) Return(_a0 feeds.ORM) *ORM_WithDataSource_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *ORM_WithDataSource_Call) RunAndReturn(run func(sqlutil.DataSource) feeds.ORM) *ORM_WithDataSource_Call { _c.Call.Return(run) return _c } diff --git a/core/services/feeds/mocks/service.go b/core/services/feeds/mocks/service.go index 05ede181f44..1e2e6393276 100644 --- a/core/services/feeds/mocks/service.go +++ b/core/services/feeds/mocks/service.go @@ -68,9 +68,9 @@ func (_m *Service) Close() error { return r0 } -// CountJobProposalsByStatus provides a mock function with given fields: -func (_m *Service) CountJobProposalsByStatus() (*feeds.JobProposalCounts, error) { - ret := _m.Called() +// CountJobProposalsByStatus provides a mock function with given fields: ctx +func (_m *Service) CountJobProposalsByStatus(ctx context.Context) (*feeds.JobProposalCounts, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for CountJobProposalsByStatus") @@ -78,19 +78,19 @@ func (_m *Service) CountJobProposalsByStatus() (*feeds.JobProposalCounts, error) var r0 *feeds.JobProposalCounts var r1 error - if rf, ok := ret.Get(0).(func() (*feeds.JobProposalCounts, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (*feeds.JobProposalCounts, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() *feeds.JobProposalCounts); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) *feeds.JobProposalCounts); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*feeds.JobProposalCounts) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -98,9 +98,9 @@ func (_m *Service) CountJobProposalsByStatus() (*feeds.JobProposalCounts, error) return r0, r1 } -// CountManagers provides a mock function with given fields: -func (_m *Service) CountManagers() (int64, error) { - ret := _m.Called() +// CountManagers provides a mock function with given fields: ctx +func (_m *Service) CountManagers(ctx context.Context) (int64, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for CountManagers") @@ -108,17 +108,17 @@ func (_m *Service) CountManagers() (int64, error) { var r0 int64 var r1 error - if rf, ok := ret.Get(0).(func() (int64, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (int64, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() int64); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) int64); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(int64) } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -210,9 +210,9 @@ func (_m *Service) DeleteJob(ctx context.Context, args *feeds.DeleteJobArgs) (in return r0, r1 } -// GetChainConfig provides a mock function with given fields: id -func (_m *Service) GetChainConfig(id int64) (*feeds.ChainConfig, error) { - ret := _m.Called(id) +// GetChainConfig provides a mock function with given fields: ctx, id +func (_m *Service) GetChainConfig(ctx context.Context, id int64) (*feeds.ChainConfig, error) { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for GetChainConfig") @@ -220,19 +220,19 @@ func (_m *Service) GetChainConfig(id int64) (*feeds.ChainConfig, error) { var r0 *feeds.ChainConfig var r1 error - if rf, ok := ret.Get(0).(func(int64) (*feeds.ChainConfig, error)); ok { - return rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) (*feeds.ChainConfig, error)); ok { + return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(int64) *feeds.ChainConfig); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) *feeds.ChainConfig); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*feeds.ChainConfig) } } - if rf, ok := ret.Get(1).(func(int64) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -240,9 +240,9 @@ func (_m *Service) GetChainConfig(id int64) (*feeds.ChainConfig, error) { return r0, r1 } -// GetJobProposal provides a mock function with given fields: id -func (_m *Service) GetJobProposal(id int64) (*feeds.JobProposal, error) { - ret := _m.Called(id) +// GetJobProposal provides a mock function with given fields: ctx, id +func (_m *Service) GetJobProposal(ctx context.Context, id int64) (*feeds.JobProposal, error) { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for GetJobProposal") @@ -250,19 +250,19 @@ func (_m *Service) GetJobProposal(id int64) (*feeds.JobProposal, error) { var r0 *feeds.JobProposal var r1 error - if rf, ok := ret.Get(0).(func(int64) (*feeds.JobProposal, error)); ok { - return rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) (*feeds.JobProposal, error)); ok { + return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(int64) *feeds.JobProposal); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) *feeds.JobProposal); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*feeds.JobProposal) } } - if rf, ok := ret.Get(1).(func(int64) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -270,9 +270,9 @@ func (_m *Service) GetJobProposal(id int64) (*feeds.JobProposal, error) { return r0, r1 } -// GetManager provides a mock function with given fields: id -func (_m *Service) GetManager(id int64) (*feeds.FeedsManager, error) { - ret := _m.Called(id) +// GetManager provides a mock function with given fields: ctx, id +func (_m *Service) GetManager(ctx context.Context, id int64) (*feeds.FeedsManager, error) { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for GetManager") @@ -280,19 +280,19 @@ func (_m *Service) GetManager(id int64) (*feeds.FeedsManager, error) { var r0 *feeds.FeedsManager var r1 error - if rf, ok := ret.Get(0).(func(int64) (*feeds.FeedsManager, error)); ok { - return rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) (*feeds.FeedsManager, error)); ok { + return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(int64) *feeds.FeedsManager); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) *feeds.FeedsManager); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*feeds.FeedsManager) } } - if rf, ok := ret.Get(1).(func(int64) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -300,9 +300,9 @@ func (_m *Service) GetManager(id int64) (*feeds.FeedsManager, error) { return r0, r1 } -// GetSpec provides a mock function with given fields: id -func (_m *Service) GetSpec(id int64) (*feeds.JobProposalSpec, error) { - ret := _m.Called(id) +// GetSpec provides a mock function with given fields: ctx, id +func (_m *Service) GetSpec(ctx context.Context, id int64) (*feeds.JobProposalSpec, error) { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for GetSpec") @@ -310,19 +310,19 @@ func (_m *Service) GetSpec(id int64) (*feeds.JobProposalSpec, error) { var r0 *feeds.JobProposalSpec var r1 error - if rf, ok := ret.Get(0).(func(int64) (*feeds.JobProposalSpec, error)); ok { - return rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) (*feeds.JobProposalSpec, error)); ok { + return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(int64) *feeds.JobProposalSpec); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) *feeds.JobProposalSpec); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*feeds.JobProposalSpec) } } - if rf, ok := ret.Get(1).(func(int64) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -358,9 +358,9 @@ func (_m *Service) IsJobManaged(ctx context.Context, jobID int64) (bool, error) return r0, r1 } -// ListChainConfigsByManagerIDs provides a mock function with given fields: mgrIDs -func (_m *Service) ListChainConfigsByManagerIDs(mgrIDs []int64) ([]feeds.ChainConfig, error) { - ret := _m.Called(mgrIDs) +// ListChainConfigsByManagerIDs provides a mock function with given fields: ctx, mgrIDs +func (_m *Service) ListChainConfigsByManagerIDs(ctx context.Context, mgrIDs []int64) ([]feeds.ChainConfig, error) { + ret := _m.Called(ctx, mgrIDs) if len(ret) == 0 { panic("no return value specified for ListChainConfigsByManagerIDs") @@ -368,19 +368,19 @@ func (_m *Service) ListChainConfigsByManagerIDs(mgrIDs []int64) ([]feeds.ChainCo var r0 []feeds.ChainConfig var r1 error - if rf, ok := ret.Get(0).(func([]int64) ([]feeds.ChainConfig, error)); ok { - return rf(mgrIDs) + if rf, ok := ret.Get(0).(func(context.Context, []int64) ([]feeds.ChainConfig, error)); ok { + return rf(ctx, mgrIDs) } - if rf, ok := ret.Get(0).(func([]int64) []feeds.ChainConfig); ok { - r0 = rf(mgrIDs) + if rf, ok := ret.Get(0).(func(context.Context, []int64) []feeds.ChainConfig); ok { + r0 = rf(ctx, mgrIDs) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]feeds.ChainConfig) } } - if rf, ok := ret.Get(1).(func([]int64) error); ok { - r1 = rf(mgrIDs) + if rf, ok := ret.Get(1).(func(context.Context, []int64) error); ok { + r1 = rf(ctx, mgrIDs) } else { r1 = ret.Error(1) } @@ -388,9 +388,9 @@ func (_m *Service) ListChainConfigsByManagerIDs(mgrIDs []int64) ([]feeds.ChainCo return r0, r1 } -// ListJobProposals provides a mock function with given fields: -func (_m *Service) ListJobProposals() ([]feeds.JobProposal, error) { - ret := _m.Called() +// ListJobProposals provides a mock function with given fields: ctx +func (_m *Service) ListJobProposals(ctx context.Context) ([]feeds.JobProposal, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for ListJobProposals") @@ -398,19 +398,19 @@ func (_m *Service) ListJobProposals() ([]feeds.JobProposal, error) { var r0 []feeds.JobProposal var r1 error - if rf, ok := ret.Get(0).(func() ([]feeds.JobProposal, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) ([]feeds.JobProposal, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() []feeds.JobProposal); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []feeds.JobProposal); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]feeds.JobProposal) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -418,9 +418,9 @@ func (_m *Service) ListJobProposals() ([]feeds.JobProposal, error) { return r0, r1 } -// ListJobProposalsByManagersIDs provides a mock function with given fields: ids -func (_m *Service) ListJobProposalsByManagersIDs(ids []int64) ([]feeds.JobProposal, error) { - ret := _m.Called(ids) +// ListJobProposalsByManagersIDs provides a mock function with given fields: ctx, ids +func (_m *Service) ListJobProposalsByManagersIDs(ctx context.Context, ids []int64) ([]feeds.JobProposal, error) { + ret := _m.Called(ctx, ids) if len(ret) == 0 { panic("no return value specified for ListJobProposalsByManagersIDs") @@ -428,19 +428,19 @@ func (_m *Service) ListJobProposalsByManagersIDs(ids []int64) ([]feeds.JobPropos var r0 []feeds.JobProposal var r1 error - if rf, ok := ret.Get(0).(func([]int64) ([]feeds.JobProposal, error)); ok { - return rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int64) ([]feeds.JobProposal, error)); ok { + return rf(ctx, ids) } - if rf, ok := ret.Get(0).(func([]int64) []feeds.JobProposal); ok { - r0 = rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int64) []feeds.JobProposal); ok { + r0 = rf(ctx, ids) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]feeds.JobProposal) } } - if rf, ok := ret.Get(1).(func([]int64) error); ok { - r1 = rf(ids) + if rf, ok := ret.Get(1).(func(context.Context, []int64) error); ok { + r1 = rf(ctx, ids) } else { r1 = ret.Error(1) } @@ -448,9 +448,9 @@ func (_m *Service) ListJobProposalsByManagersIDs(ids []int64) ([]feeds.JobPropos return r0, r1 } -// ListManagers provides a mock function with given fields: -func (_m *Service) ListManagers() ([]feeds.FeedsManager, error) { - ret := _m.Called() +// ListManagers provides a mock function with given fields: ctx +func (_m *Service) ListManagers(ctx context.Context) ([]feeds.FeedsManager, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for ListManagers") @@ -458,19 +458,19 @@ func (_m *Service) ListManagers() ([]feeds.FeedsManager, error) { var r0 []feeds.FeedsManager var r1 error - if rf, ok := ret.Get(0).(func() ([]feeds.FeedsManager, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) ([]feeds.FeedsManager, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() []feeds.FeedsManager); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []feeds.FeedsManager); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]feeds.FeedsManager) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -478,9 +478,9 @@ func (_m *Service) ListManagers() ([]feeds.FeedsManager, error) { return r0, r1 } -// ListManagersByIDs provides a mock function with given fields: ids -func (_m *Service) ListManagersByIDs(ids []int64) ([]feeds.FeedsManager, error) { - ret := _m.Called(ids) +// ListManagersByIDs provides a mock function with given fields: ctx, ids +func (_m *Service) ListManagersByIDs(ctx context.Context, ids []int64) ([]feeds.FeedsManager, error) { + ret := _m.Called(ctx, ids) if len(ret) == 0 { panic("no return value specified for ListManagersByIDs") @@ -488,19 +488,19 @@ func (_m *Service) ListManagersByIDs(ids []int64) ([]feeds.FeedsManager, error) var r0 []feeds.FeedsManager var r1 error - if rf, ok := ret.Get(0).(func([]int64) ([]feeds.FeedsManager, error)); ok { - return rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int64) ([]feeds.FeedsManager, error)); ok { + return rf(ctx, ids) } - if rf, ok := ret.Get(0).(func([]int64) []feeds.FeedsManager); ok { - r0 = rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int64) []feeds.FeedsManager); ok { + r0 = rf(ctx, ids) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]feeds.FeedsManager) } } - if rf, ok := ret.Get(1).(func([]int64) error); ok { - r1 = rf(ids) + if rf, ok := ret.Get(1).(func(context.Context, []int64) error); ok { + r1 = rf(ctx, ids) } else { r1 = ret.Error(1) } @@ -508,9 +508,9 @@ func (_m *Service) ListManagersByIDs(ids []int64) ([]feeds.FeedsManager, error) return r0, r1 } -// ListSpecsByJobProposalIDs provides a mock function with given fields: ids -func (_m *Service) ListSpecsByJobProposalIDs(ids []int64) ([]feeds.JobProposalSpec, error) { - ret := _m.Called(ids) +// ListSpecsByJobProposalIDs provides a mock function with given fields: ctx, ids +func (_m *Service) ListSpecsByJobProposalIDs(ctx context.Context, ids []int64) ([]feeds.JobProposalSpec, error) { + ret := _m.Called(ctx, ids) if len(ret) == 0 { panic("no return value specified for ListSpecsByJobProposalIDs") @@ -518,19 +518,19 @@ func (_m *Service) ListSpecsByJobProposalIDs(ids []int64) ([]feeds.JobProposalSp var r0 []feeds.JobProposalSpec var r1 error - if rf, ok := ret.Get(0).(func([]int64) ([]feeds.JobProposalSpec, error)); ok { - return rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int64) ([]feeds.JobProposalSpec, error)); ok { + return rf(ctx, ids) } - if rf, ok := ret.Get(0).(func([]int64) []feeds.JobProposalSpec); ok { - r0 = rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int64) []feeds.JobProposalSpec); ok { + r0 = rf(ctx, ids) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]feeds.JobProposalSpec) } } - if rf, ok := ret.Get(1).(func([]int64) error); ok { - r1 = rf(ids) + if rf, ok := ret.Get(1).(func(context.Context, []int64) error); ok { + r1 = rf(ctx, ids) } else { r1 = ret.Error(1) } @@ -676,11 +676,6 @@ func (_m *Service) SyncNodeInfo(ctx context.Context, id int64) error { return r0 } -// Unsafe_SetConnectionsManager provides a mock function with given fields: _a0 -func (_m *Service) Unsafe_SetConnectionsManager(_a0 feeds.ConnectionsManager) { - _m.Called(_a0) -} - // UpdateChainConfig provides a mock function with given fields: ctx, cfg func (_m *Service) UpdateChainConfig(ctx context.Context, cfg feeds.ChainConfig) (int64, error) { ret := _m.Called(ctx, cfg) diff --git a/core/services/feeds/orm.go b/core/services/feeds/orm.go index 24ed7b8b369..bf77051dad7 100644 --- a/core/services/feeds/orm.go +++ b/core/services/feeds/orm.go @@ -1,6 +1,7 @@ package feeds import ( + "context" "database/sql" "fmt" "strings" @@ -9,99 +10,104 @@ import ( "github.com/lib/pq" "github.com/pkg/errors" - "github.com/jmoiron/sqlx" - - "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" ) //go:generate mockery --with-expecter=true --quiet --name ORM --output ./mocks/ --case=underscore type ORM interface { - CountManagers() (int64, error) - CreateManager(ms *FeedsManager, qopts ...pg.QOpt) (int64, error) - GetManager(id int64) (*FeedsManager, error) - ListManagers() (mgrs []FeedsManager, err error) - ListManagersByIDs(ids []int64) ([]FeedsManager, error) - UpdateManager(mgr FeedsManager, qopts ...pg.QOpt) error - - CreateBatchChainConfig(cfgs []ChainConfig, qopts ...pg.QOpt) ([]int64, error) - CreateChainConfig(cfg ChainConfig, qopts ...pg.QOpt) (int64, error) - DeleteChainConfig(id int64) (int64, error) - GetChainConfig(id int64) (*ChainConfig, error) - ListChainConfigsByManagerIDs(mgrIDs []int64) ([]ChainConfig, error) - UpdateChainConfig(cfg ChainConfig) (int64, error) - - CountJobProposals() (int64, error) - CountJobProposalsByStatus() (counts *JobProposalCounts, err error) - CreateJobProposal(jp *JobProposal) (int64, error) - DeleteProposal(id int64, qopts ...pg.QOpt) error - GetJobProposal(id int64, qopts ...pg.QOpt) (*JobProposal, error) - GetJobProposalByRemoteUUID(uuid uuid.UUID) (*JobProposal, error) - ListJobProposals() (jps []JobProposal, err error) - ListJobProposalsByManagersIDs(ids []int64, qopts ...pg.QOpt) ([]JobProposal, error) - UpdateJobProposalStatus(id int64, status JobProposalStatus, qopts ...pg.QOpt) error // NEEDED? - UpsertJobProposal(jp *JobProposal, qopts ...pg.QOpt) (int64, error) - - ApproveSpec(id int64, externalJobID uuid.UUID, qopts ...pg.QOpt) error - CancelSpec(id int64, qopts ...pg.QOpt) error - CreateSpec(spec JobProposalSpec, qopts ...pg.QOpt) (int64, error) - ExistsSpecByJobProposalIDAndVersion(jpID int64, version int32, qopts ...pg.QOpt) (exists bool, err error) - GetApprovedSpec(jpID int64, qopts ...pg.QOpt) (*JobProposalSpec, error) - GetLatestSpec(jpID int64) (*JobProposalSpec, error) - GetSpec(id int64, qopts ...pg.QOpt) (*JobProposalSpec, error) - ListSpecsByJobProposalIDs(ids []int64, qopts ...pg.QOpt) ([]JobProposalSpec, error) - RejectSpec(id int64, qopts ...pg.QOpt) error - RevokeSpec(id int64, qopts ...pg.QOpt) error - UpdateSpecDefinition(id int64, spec string, qopts ...pg.QOpt) error - - IsJobManaged(jobID int64, qopts ...pg.QOpt) (bool, error) + CountManagers(ctx context.Context) (int64, error) + CreateManager(ctx context.Context, ms *FeedsManager) (int64, error) + GetManager(ctx context.Context, id int64) (*FeedsManager, error) + ListManagers(ctx context.Context) (mgrs []FeedsManager, err error) + ListManagersByIDs(ctx context.Context, ids []int64) ([]FeedsManager, error) + UpdateManager(ctx context.Context, mgr FeedsManager) error + + CreateBatchChainConfig(ctx context.Context, cfgs []ChainConfig) ([]int64, error) + CreateChainConfig(ctx context.Context, cfg ChainConfig) (int64, error) + DeleteChainConfig(ctx context.Context, id int64) (int64, error) + GetChainConfig(ctx context.Context, id int64) (*ChainConfig, error) + ListChainConfigsByManagerIDs(ctx context.Context, mgrIDs []int64) ([]ChainConfig, error) + UpdateChainConfig(ctx context.Context, cfg ChainConfig) (int64, error) + + CountJobProposals(ctx context.Context) (int64, error) + CountJobProposalsByStatus(ctx context.Context) (counts *JobProposalCounts, err error) + CreateJobProposal(ctx context.Context, jp *JobProposal) (int64, error) + DeleteProposal(ctx context.Context, id int64) error + GetJobProposal(ctx context.Context, id int64) (*JobProposal, error) + GetJobProposalByRemoteUUID(ctx context.Context, uuid uuid.UUID) (*JobProposal, error) + ListJobProposals(ctx context.Context) (jps []JobProposal, err error) + ListJobProposalsByManagersIDs(ctx context.Context, ids []int64) ([]JobProposal, error) + UpdateJobProposalStatus(ctx context.Context, id int64, status JobProposalStatus) error // NEEDED? + UpsertJobProposal(ctx context.Context, jp *JobProposal) (int64, error) + + ApproveSpec(ctx context.Context, id int64, externalJobID uuid.UUID) error + CancelSpec(ctx context.Context, id int64) error + CreateSpec(ctx context.Context, spec JobProposalSpec) (int64, error) + ExistsSpecByJobProposalIDAndVersion(ctx context.Context, jpID int64, version int32) (exists bool, err error) + GetApprovedSpec(ctx context.Context, jpID int64) (*JobProposalSpec, error) + GetLatestSpec(ctx context.Context, jpID int64) (*JobProposalSpec, error) + GetSpec(ctx context.Context, id int64) (*JobProposalSpec, error) + ListSpecsByJobProposalIDs(ctx context.Context, ids []int64) ([]JobProposalSpec, error) + RejectSpec(ctx context.Context, id int64) error + RevokeSpec(ctx context.Context, id int64) error + UpdateSpecDefinition(ctx context.Context, id int64, spec string) error + + IsJobManaged(ctx context.Context, jobID int64) (bool, error) + + Transact(context.Context, func(ORM) error) error + WithDataSource(sqlutil.DataSource) ORM } var _ ORM = &orm{} type orm struct { - q pg.Q + ds sqlutil.DataSource } -func NewORM(db *sqlx.DB, lggr logger.Logger, cfg pg.QConfig) *orm { - return &orm{ - q: pg.NewQ(db, lggr, cfg), - } +func NewORM(ds sqlutil.DataSource) *orm { + return &orm{ds: ds} } +func (o *orm) Transact(ctx context.Context, fn func(ORM) error) error { + return sqlutil.Transact(ctx, o.WithDataSource, o.ds, nil, fn) +} + +func (o *orm) WithDataSource(ds sqlutil.DataSource) ORM { return &orm{ds} } + // Count counts the number of feeds manager records. -func (o *orm) CountManagers() (count int64, err error) { +func (o *orm) CountManagers(ctx context.Context) (count int64, err error) { stmt := ` SELECT COUNT(*) FROM feeds_managers ` - err = o.q.Get(&count, stmt) + err = o.ds.GetContext(ctx, &count, stmt) return count, errors.Wrap(err, "CountManagers failed") } // CreateManager creates a feeds manager. -func (o *orm) CreateManager(ms *FeedsManager, qopts ...pg.QOpt) (id int64, err error) { +func (o *orm) CreateManager(ctx context.Context, ms *FeedsManager) (id int64, err error) { stmt := ` INSERT INTO feeds_managers (name, uri, public_key, created_at, updated_at) VALUES ($1,$2,$3,NOW(),NOW()) RETURNING id; ` - err = o.q.WithOpts(qopts...).Get(&id, stmt, ms.Name, ms.URI, ms.PublicKey) + err = o.ds.GetContext(ctx, &id, stmt, ms.Name, ms.URI, ms.PublicKey) return id, errors.Wrap(err, "CreateManager failed") } // CreateChainConfig creates a new chain config. -func (o *orm) CreateChainConfig(cfg ChainConfig, qopts ...pg.QOpt) (id int64, err error) { +func (o *orm) CreateChainConfig(ctx context.Context, cfg ChainConfig) (id int64, err error) { stmt := ` INSERT INTO feeds_manager_chain_configs (feeds_manager_id, chain_id, chain_type, account_address, admin_address, flux_monitor_config, ocr1_config, ocr2_config, created_at, updated_at) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,NOW(),NOW()) RETURNING id; ` - err = o.q.WithOpts(qopts...).Get(&id, + err = o.ds.GetContext(ctx, + &id, stmt, cfg.FeedsManagerID, cfg.ChainID, @@ -117,7 +123,7 @@ RETURNING id; } // CreateBatchChainConfig creates multiple chain configs. -func (o *orm) CreateBatchChainConfig(cfgs []ChainConfig, qopts ...pg.QOpt) (ids []int64, err error) { +func (o *orm) CreateBatchChainConfig(ctx context.Context, cfgs []ChainConfig) (ids []int64, err error) { if len(cfgs) == 0 { return } @@ -160,7 +166,8 @@ RETURNING id; ) } - err = o.q.WithOpts(qopts...).Select(&ids, + err = o.ds.SelectContext(ctx, + &ids, fmt.Sprintf(stmt, strings.Join(vStrs, ",")), vArgs..., ) @@ -169,7 +176,7 @@ RETURNING id; } // DeleteChainConfig deletes a chain config. -func (o *orm) DeleteChainConfig(id int64) (int64, error) { +func (o *orm) DeleteChainConfig(ctx context.Context, id int64) (int64, error) { stmt := ` DELETE FROM feeds_manager_chain_configs WHERE id = $1 @@ -177,13 +184,13 @@ RETURNING id; ` var ccid int64 - err := o.q.Get(&ccid, stmt, id) + err := o.ds.GetContext(ctx, &ccid, stmt, id) return ccid, errors.Wrap(err, "DeleteChainConfig failed") } // GetChainConfig fetches a chain config. -func (o *orm) GetChainConfig(id int64) (*ChainConfig, error) { +func (o *orm) GetChainConfig(ctx context.Context, id int64) (*ChainConfig, error) { stmt := ` SELECT id, feeds_manager_id, chain_id, chain_type, account_address, admin_address, flux_monitor_config, ocr1_config, ocr2_config, created_at, updated_at FROM feeds_manager_chain_configs @@ -191,14 +198,14 @@ WHERE id = $1; ` var cfg ChainConfig - err := o.q.Get(&cfg, stmt, id) + err := o.ds.GetContext(ctx, &cfg, stmt, id) return &cfg, errors.Wrap(err, "GetChainConfig failed") } // ListChainConfigsByManagerIDs fetches the chain configs matching all manager // ids. -func (o *orm) ListChainConfigsByManagerIDs(mgrIDs []int64) ([]ChainConfig, error) { +func (o *orm) ListChainConfigsByManagerIDs(ctx context.Context, mgrIDs []int64) ([]ChainConfig, error) { stmt := ` SELECT id, feeds_manager_id, chain_id, chain_type, account_address, admin_address, flux_monitor_config, ocr1_config, ocr2_config, created_at, updated_at FROM feeds_manager_chain_configs @@ -206,13 +213,13 @@ WHERE feeds_manager_id = ANY($1) ` var cfgs []ChainConfig - err := o.q.Select(&cfgs, stmt, mgrIDs) + err := o.ds.SelectContext(ctx, &cfgs, stmt, mgrIDs) return cfgs, errors.Wrap(err, "ListJobProposalsByManagersIDs failed") } // UpdateChainConfig updates a chain config. -func (o *orm) UpdateChainConfig(cfg ChainConfig) (int64, error) { +func (o *orm) UpdateChainConfig(ctx context.Context, cfg ChainConfig) (int64, error) { stmt := ` UPDATE feeds_manager_chain_configs SET account_address = $1, @@ -226,7 +233,7 @@ RETURNING id; ` var cfgID int64 - err := o.q.Get(&cfgID, stmt, + err := o.ds.GetContext(ctx, &cfgID, stmt, cfg.AccountAddress, cfg.AdminAddress, cfg.FluxMonitorConfig, @@ -239,7 +246,7 @@ RETURNING id; } // GetManager gets a feeds manager by id. -func (o *orm) GetManager(id int64) (mgr *FeedsManager, err error) { +func (o *orm) GetManager(ctx context.Context, id int64) (mgr *FeedsManager, err error) { stmt := ` SELECT id, name, uri, public_key, created_at, updated_at FROM feeds_managers @@ -247,23 +254,23 @@ WHERE id = $1 ` mgr = new(FeedsManager) - err = o.q.Get(mgr, stmt, id) + err = o.ds.GetContext(ctx, mgr, stmt, id) return mgr, errors.Wrap(err, "GetManager failed") } // ListManager lists all feeds managers. -func (o *orm) ListManagers() (mgrs []FeedsManager, err error) { +func (o *orm) ListManagers(ctx context.Context) (mgrs []FeedsManager, err error) { stmt := ` SELECT id, name, uri, public_key, created_at, updated_at FROM feeds_managers; ` - err = o.q.Select(&mgrs, stmt) + err = o.ds.SelectContext(ctx, &mgrs, stmt) return mgrs, errors.Wrap(err, "ListManagers failed") } // ListManagersByIDs gets feeds managers by ids. -func (o *orm) ListManagersByIDs(ids []int64) (managers []FeedsManager, err error) { +func (o *orm) ListManagersByIDs(ctx context.Context, ids []int64) (managers []FeedsManager, err error) { stmt := ` SELECT id, name, uri, public_key, created_at, updated_at FROM feeds_managers @@ -271,20 +278,20 @@ WHERE id = ANY($1) ORDER BY created_at, id;` mgrIds := pq.Array(ids) - err = o.q.Select(&managers, stmt, mgrIds) + err = o.ds.SelectContext(ctx, &managers, stmt, mgrIds) return managers, errors.Wrap(err, "GetManagers failed") } // UpdateManager updates the manager details. -func (o *orm) UpdateManager(mgr FeedsManager, qopts ...pg.QOpt) (err error) { +func (o *orm) UpdateManager(ctx context.Context, mgr FeedsManager) (err error) { stmt := ` UPDATE feeds_managers SET name = $1, uri = $2, public_key = $3, updated_at = NOW() WHERE id = $4; ` - res, err := o.q.WithOpts(qopts...).Exec(stmt, mgr.Name, mgr.URI, mgr.PublicKey, mgr.ID) + res, err := o.ds.ExecContext(ctx, stmt, mgr.Name, mgr.URI, mgr.PublicKey, mgr.ID) if err != nil { return errors.Wrap(err, "UpdateManager failed to update feeds_managers") } @@ -299,27 +306,27 @@ WHERE id = $4; } // CreateJobProposal creates a job proposal. -func (o *orm) CreateJobProposal(jp *JobProposal) (id int64, err error) { +func (o *orm) CreateJobProposal(ctx context.Context, jp *JobProposal) (id int64, err error) { stmt := ` INSERT INTO job_proposals (name, remote_uuid, status, feeds_manager_id, multiaddrs, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, NOW(), NOW()) RETURNING id; ` - err = o.q.Get(&id, stmt, jp.Name, jp.RemoteUUID, jp.Status, jp.FeedsManagerID, jp.Multiaddrs) + err = o.ds.GetContext(ctx, &id, stmt, jp.Name, jp.RemoteUUID, jp.Status, jp.FeedsManagerID, jp.Multiaddrs) return id, errors.Wrap(err, "CreateJobProposal failed") } // CountJobProposals counts the number of job proposal records. -func (o *orm) CountJobProposals() (count int64, err error) { +func (o *orm) CountJobProposals(ctx context.Context) (count int64, err error) { stmt := `SELECT COUNT(*) FROM job_proposals` - err = o.q.Get(&count, stmt) + err = o.ds.GetContext(ctx, &count, stmt) return count, errors.Wrap(err, "CountJobProposals failed") } // CountJobProposals counts the number of job proposal records. -func (o *orm) CountJobProposalsByStatus() (counts *JobProposalCounts, err error) { +func (o *orm) CountJobProposalsByStatus(ctx context.Context) (counts *JobProposalCounts, err error) { stmt := ` SELECT COUNT(*) filter (where job_proposals.status = 'pending' OR job_proposals.pending_update = TRUE) as pending, @@ -332,26 +339,26 @@ FROM job_proposals; ` counts = new(JobProposalCounts) - err = o.q.Get(counts, stmt) + err = o.ds.GetContext(ctx, counts, stmt) return counts, errors.Wrap(err, "CountJobProposalsByStatus failed") } // GetJobProposal gets a job proposal by id. -func (o *orm) GetJobProposal(id int64, qopts ...pg.QOpt) (jp *JobProposal, err error) { +func (o *orm) GetJobProposal(ctx context.Context, id int64) (jp *JobProposal, err error) { stmt := ` SELECT * FROM job_proposals WHERE id = $1 ` jp = new(JobProposal) - err = o.q.WithOpts(qopts...).Get(jp, stmt, id) + err = o.ds.GetContext(ctx, jp, stmt, id) return jp, errors.Wrap(err, "GetJobProposal failed") } // GetJobProposalByRemoteUUID gets a job proposal by the remote FMS uuid. This // method will filter out the deleted job proposals. To get all job proposals, // use the GetJobProposal get by id method. -func (o *orm) GetJobProposalByRemoteUUID(id uuid.UUID) (jp *JobProposal, err error) { +func (o *orm) GetJobProposalByRemoteUUID(ctx context.Context, id uuid.UUID) (jp *JobProposal, err error) { stmt := ` SELECT * FROM job_proposals @@ -360,35 +367,35 @@ AND status <> $2; ` jp = new(JobProposal) - err = o.q.Get(jp, stmt, id, JobProposalStatusDeleted) + err = o.ds.GetContext(ctx, jp, stmt, id, JobProposalStatusDeleted) return jp, errors.Wrap(err, "GetJobProposalByRemoteUUID failed") } // ListJobProposals lists all job proposals. -func (o *orm) ListJobProposals() (jps []JobProposal, err error) { +func (o *orm) ListJobProposals(ctx context.Context) (jps []JobProposal, err error) { stmt := ` SELECT * FROM job_proposals; ` - err = o.q.Select(&jps, stmt) + err = o.ds.SelectContext(ctx, &jps, stmt) return jps, errors.Wrap(err, "ListJobProposals failed") } // ListJobProposalsByManagersIDs gets job proposals by feeds managers IDs. -func (o *orm) ListJobProposalsByManagersIDs(ids []int64, qopts ...pg.QOpt) ([]JobProposal, error) { +func (o *orm) ListJobProposalsByManagersIDs(ctx context.Context, ids []int64) ([]JobProposal, error) { stmt := ` SELECT * FROM job_proposals WHERE feeds_manager_id = ANY($1) ` var jps []JobProposal - err := o.q.WithOpts(qopts...).Select(&jps, stmt, ids) + err := o.ds.SelectContext(ctx, &jps, stmt, ids) return jps, errors.Wrap(err, "ListJobProposalsByManagersIDs failed") } // UpdateJobProposalStatus updates the status of a job proposal by id. -func (o *orm) UpdateJobProposalStatus(id int64, status JobProposalStatus, qopts ...pg.QOpt) error { +func (o *orm) UpdateJobProposalStatus(ctx context.Context, id int64, status JobProposalStatus) error { stmt := ` UPDATE job_proposals SET status = $1, @@ -396,7 +403,7 @@ SET status = $1, WHERE id = $2; ` - result, err := o.q.WithOpts(qopts...).Exec(stmt, status, id) + result, err := o.ds.ExecContext(ctx, stmt, status, id) if err != nil { return err } @@ -415,7 +422,7 @@ WHERE id = $2; // UpsertJobProposal creates a job proposal if it does not exist. If it does exist, // then we update the details of the existing job proposal only if the provided // feeds manager id exists. -func (o *orm) UpsertJobProposal(jp *JobProposal, qopts ...pg.QOpt) (id int64, err error) { +func (o *orm) UpsertJobProposal(ctx context.Context, jp *JobProposal) (id int64, err error) { stmt := ` INSERT INTO job_proposals (name, remote_uuid, status, feeds_manager_id, multiaddrs, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, NOW(), NOW()) @@ -436,13 +443,13 @@ DO RETURNING id; ` - err = o.q.WithOpts(qopts...).Get(&id, stmt, jp.Name, jp.RemoteUUID, jp.Status, jp.FeedsManagerID, jp.Multiaddrs) + err = o.ds.GetContext(ctx, &id, stmt, jp.Name, jp.RemoteUUID, jp.Status, jp.FeedsManagerID, jp.Multiaddrs) return id, errors.Wrap(err, "UpsertJobProposal") } // ApproveSpec approves the spec and sets the external job ID on the associated // job proposal. -func (o *orm) ApproveSpec(id int64, externalJobID uuid.UUID, qopts ...pg.QOpt) error { +func (o *orm) ApproveSpec(ctx context.Context, id int64, externalJobID uuid.UUID) error { // Update the status of the approval stmt := ` UPDATE job_proposal_specs @@ -454,7 +461,7 @@ RETURNING job_proposal_id; ` var jpID int64 - if err := o.q.WithOpts(qopts...).Get(&jpID, stmt, JobProposalStatusApproved, id); err != nil { + if err := o.ds.GetContext(ctx, &jpID, stmt, JobProposalStatusApproved, id); err != nil { return err } @@ -468,7 +475,7 @@ SET status = $1, WHERE id = $3; ` - result, err := o.q.WithOpts(qopts...).Exec(stmt, JobProposalStatusApproved, externalJobID, jpID) + result, err := o.ds.ExecContext(ctx, stmt, JobProposalStatusApproved, externalJobID, jpID) if err != nil { return err } @@ -487,7 +494,7 @@ WHERE id = $3; // CancelSpec cancels the spec and removes the external job id from the associated job proposal. It // sets the status of the spec and the proposal to cancelled, except in the case of deleted // proposals. -func (o *orm) CancelSpec(id int64, qopts ...pg.QOpt) error { +func (o *orm) CancelSpec(ctx context.Context, id int64) error { // Update the status of the approval stmt := ` UPDATE job_proposal_specs @@ -499,7 +506,7 @@ RETURNING job_proposal_id; ` var jpID int64 - if err := o.q.WithOpts(qopts...).Get(&jpID, stmt, SpecStatusCancelled, id); err != nil { + if err := o.ds.GetContext(ctx, &jpID, stmt, SpecStatusCancelled, id); err != nil { return err } @@ -516,7 +523,7 @@ SET status = ( updated_at = NOW() WHERE id = $1; ` - result, err := o.q.WithOpts(qopts...).Exec(stmt, jpID, nil) + result, err := o.ds.ExecContext(ctx, stmt, jpID, nil) if err != nil { return err } @@ -533,7 +540,7 @@ WHERE id = $1; } // CreateSpec creates a new job proposal spec -func (o *orm) CreateSpec(spec JobProposalSpec, qopts ...pg.QOpt) (int64, error) { +func (o *orm) CreateSpec(ctx context.Context, spec JobProposalSpec) (int64, error) { stmt := ` INSERT INTO job_proposal_specs (definition, version, status, job_proposal_id, status_updated_at, created_at, updated_at) VALUES ($1, $2, $3, $4, NOW(), NOW(), NOW()) @@ -541,14 +548,14 @@ RETURNING id; ` var id int64 - err := o.q.WithOpts(qopts...).Get(&id, stmt, spec.Definition, spec.Version, spec.Status, spec.JobProposalID) + err := o.ds.GetContext(ctx, &id, stmt, spec.Definition, spec.Version, spec.Status, spec.JobProposalID) return id, errors.Wrap(err, "CreateJobProposalSpec failed") } // ExistsSpecByJobProposalIDAndVersion checks if a job proposal spec exists for a specific job // proposal and version. -func (o *orm) ExistsSpecByJobProposalIDAndVersion(jpID int64, version int32, qopts ...pg.QOpt) (exists bool, err error) { +func (o *orm) ExistsSpecByJobProposalIDAndVersion(ctx context.Context, jpID int64, version int32) (exists bool, err error) { stmt := ` SELECT exists ( SELECT 1 @@ -557,12 +564,12 @@ SELECT exists ( ); ` - err = o.q.WithOpts(qopts...).Get(&exists, stmt, jpID, version) + err = o.ds.GetContext(ctx, &exists, stmt, jpID, version) return exists, errors.Wrap(err, "JobProposalSpecVersionExists failed") } // DeleteProposal performs a soft delete of the job proposal by setting the status to deleted -func (o *orm) DeleteProposal(id int64, qopts ...pg.QOpt) error { +func (o *orm) DeleteProposal(ctx context.Context, id int64) error { // Get the latest spec for the proposal. stmt := ` SELECT id, definition, version, status, job_proposal_id, status_updated_at, created_at, updated_at @@ -577,7 +584,7 @@ AND job_proposal_id = $1 ` var spec JobProposalSpec - err := o.q.WithOpts(qopts...).Get(&spec, stmt, id) + err := o.ds.GetContext(ctx, &spec, stmt, id) if err != nil { return err } @@ -593,7 +600,7 @@ SET status = $1, WHERE id = $2; ` - result, err := o.q.WithOpts(qopts...).Exec(stmt, JobProposalStatusDeleted, id, pendingUpdate) + result, err := o.ds.ExecContext(ctx, stmt, JobProposalStatusDeleted, id, pendingUpdate) if err != nil { return err } @@ -610,20 +617,20 @@ WHERE id = $2; } // GetSpec fetches the job proposal spec by id -func (o *orm) GetSpec(id int64, qopts ...pg.QOpt) (*JobProposalSpec, error) { +func (o *orm) GetSpec(ctx context.Context, id int64) (*JobProposalSpec, error) { stmt := ` SELECT id, definition, version, status, job_proposal_id, status_updated_at, created_at, updated_at FROM job_proposal_specs WHERE id = $1; ` var spec JobProposalSpec - err := o.q.WithOpts(qopts...).Get(&spec, stmt, id) + err := o.ds.GetContext(ctx, &spec, stmt, id) return &spec, errors.Wrap(err, "CreateJobProposalSpec failed") } // GetApprovedSpec gets the approved spec for a job proposal -func (o *orm) GetApprovedSpec(jpID int64, qopts ...pg.QOpt) (*JobProposalSpec, error) { +func (o *orm) GetApprovedSpec(ctx context.Context, jpID int64) (*JobProposalSpec, error) { stmt := ` SELECT id, definition, version, status, job_proposal_id, status_updated_at, created_at, updated_at FROM job_proposal_specs @@ -632,13 +639,13 @@ AND job_proposal_id = $2 ` var spec JobProposalSpec - err := o.q.WithOpts(qopts...).Get(&spec, stmt, SpecStatusApproved, jpID) + err := o.ds.GetContext(ctx, &spec, stmt, SpecStatusApproved, jpID) return &spec, errors.Wrap(err, "GetApprovedSpec failed") } // GetLatestSpec gets the latest spec for a job proposal. -func (o *orm) GetLatestSpec(jpID int64) (*JobProposalSpec, error) { +func (o *orm) GetLatestSpec(ctx context.Context, jpID int64) (*JobProposalSpec, error) { stmt := ` SELECT id, definition, version, status, job_proposal_id, status_updated_at, created_at, updated_at FROM job_proposal_specs @@ -652,26 +659,26 @@ AND job_proposal_id = $1 ` var spec JobProposalSpec - err := o.q.Get(&spec, stmt, jpID) + err := o.ds.GetContext(ctx, &spec, stmt, jpID) return &spec, errors.Wrap(err, "GetLatestSpec failed") } // ListSpecsByJobProposalIDs lists the specs which belong to any of job proposal // ids. -func (o *orm) ListSpecsByJobProposalIDs(ids []int64, qopts ...pg.QOpt) ([]JobProposalSpec, error) { +func (o *orm) ListSpecsByJobProposalIDs(ctx context.Context, ids []int64) ([]JobProposalSpec, error) { stmt := ` SELECT id, definition, version, status, job_proposal_id, status_updated_at, created_at, updated_at FROM job_proposal_specs WHERE job_proposal_id = ANY($1) ` var specs []JobProposalSpec - err := o.q.WithOpts(qopts...).Select(&specs, stmt, ids) + err := o.ds.SelectContext(ctx, &specs, stmt, ids) return specs, errors.Wrap(err, "GetJobProposalsByManagersIDs failed") } // RejectSpec rejects the spec and updates the job proposal -func (o *orm) RejectSpec(id int64, qopts ...pg.QOpt) error { +func (o *orm) RejectSpec(ctx context.Context, id int64) error { stmt := ` UPDATE job_proposal_specs SET status = $1, @@ -682,7 +689,7 @@ RETURNING job_proposal_id; ` var jpID int64 - if err := o.q.WithOpts(qopts...).Get(&jpID, stmt, SpecStatusRejected, id); err != nil { + if err := o.ds.GetContext(ctx, &jpID, stmt, SpecStatusRejected, id); err != nil { return err } @@ -700,7 +707,7 @@ SET status = ( WHERE id = $1 ` - result, err := o.q.WithOpts(qopts...).Exec(stmt, jpID) + result, err := o.ds.ExecContext(ctx, stmt, jpID) if err != nil { return err } @@ -719,7 +726,7 @@ WHERE id = $1 // RevokeSpec revokes a job proposal with a pending job spec. An approved // proposal cannot be revoked. A revoked proposal's job spec cannot be approved // or edited, but the job can be reproposed by FMS. -func (o *orm) RevokeSpec(id int64, qopts ...pg.QOpt) error { +func (o *orm) RevokeSpec(ctx context.Context, id int64) error { // Update the status of the spec stmt := ` UPDATE job_proposal_specs @@ -736,7 +743,7 @@ RETURNING job_proposal_id; ` var jpID int64 - if err := o.q.WithOpts(qopts...).Get(&jpID, stmt, id, SpecStatusRevoked); err != nil { + if err := o.ds.GetContext(ctx, &jpID, stmt, id, SpecStatusRevoked); err != nil { return err } @@ -760,7 +767,7 @@ SET status = ( WHERE id = $1 ` - result, err := o.q.WithOpts(qopts...).Exec(stmt, jpID, nil, JobProposalStatusRevoked) + result, err := o.ds.ExecContext(ctx, stmt, jpID, nil, JobProposalStatusRevoked) if err != nil { return err } @@ -777,7 +784,7 @@ WHERE id = $1 } // UpdateSpecDefinition updates the definition of a job proposal spec by id. -func (o *orm) UpdateSpecDefinition(id int64, spec string, qopts ...pg.QOpt) error { +func (o *orm) UpdateSpecDefinition(ctx context.Context, id int64, spec string) error { stmt := ` UPDATE job_proposal_specs SET definition = $1, @@ -785,7 +792,7 @@ SET definition = $1, WHERE id = $2; ` - res, err := o.q.WithOpts(qopts...).Exec(stmt, spec, id) + res, err := o.ds.ExecContext(ctx, stmt, spec, id) if err != nil { return errors.Wrap(err, "UpdateSpecDefinition failed to update definition") } @@ -803,7 +810,7 @@ WHERE id = $2; } // IsJobManaged determines if a job is managed by the feeds manager. -func (o *orm) IsJobManaged(jobID int64, qopts ...pg.QOpt) (exists bool, err error) { +func (o *orm) IsJobManaged(ctx context.Context, jobID int64) (exists bool, err error) { stmt := ` SELECT exists ( SELECT 1 @@ -813,6 +820,6 @@ SELECT exists ( ); ` - err = o.q.WithOpts(qopts...).Get(&exists, stmt, jobID) + err = o.ds.GetContext(ctx, &exists, stmt, jobID) return exists, errors.Wrap(err, "IsJobManaged failed") } diff --git a/core/services/feeds/orm_test.go b/core/services/feeds/orm_test.go index 3a0a17c99e0..df2624319f5 100644 --- a/core/services/feeds/orm_test.go +++ b/core/services/feeds/orm_test.go @@ -5,13 +5,12 @@ import ( "testing" "github.com/google/uuid" + "github.com/jmoiron/sqlx" "github.com/lib/pq" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/guregu/null.v4" - "github.com/jmoiron/sqlx" - "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink/v2/core/bridges" "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" @@ -45,9 +44,8 @@ func setupORM(t *testing.T) *TestORM { t.Helper() var ( - db = pgtest.NewSqlxDB(t) - lggr = logger.TestLogger(t) - orm = feeds.NewORM(db, lggr, pgtest.NewQConfig(true)) + db = pgtest.NewSqlxDB(t) + orm = feeds.NewORM(db) ) return &TestORM{ORM: orm, db: db} @@ -57,6 +55,7 @@ func setupORM(t *testing.T) *TestORM { func Test_ORM_CreateManager(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -67,14 +66,14 @@ func Test_ORM_CreateManager(t *testing.T) { } ) - count, err := orm.CountManagers() + count, err := orm.CountManagers(ctx) require.NoError(t, err) require.Equal(t, int64(0), count) - id, err := orm.CreateManager(mgr) + id, err := orm.CreateManager(ctx, mgr) require.NoError(t, err) - count, err = orm.CountManagers() + count, err = orm.CountManagers(ctx) require.NoError(t, err) require.Equal(t, int64(1), count) @@ -83,6 +82,7 @@ func Test_ORM_CreateManager(t *testing.T) { func Test_ORM_GetManager(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -93,10 +93,10 @@ func Test_ORM_GetManager(t *testing.T) { } ) - id, err := orm.CreateManager(mgr) + id, err := orm.CreateManager(ctx, mgr) require.NoError(t, err) - actual, err := orm.GetManager(id) + actual, err := orm.GetManager(ctx, id) require.NoError(t, err) assert.Equal(t, id, actual.ID) @@ -104,12 +104,13 @@ func Test_ORM_GetManager(t *testing.T) { assert.Equal(t, name, actual.Name) assert.Equal(t, publicKey, actual.PublicKey) - _, err = orm.GetManager(-1) + _, err = orm.GetManager(ctx, -1) require.Error(t, err) } func Test_ORM_ListManagers(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -120,10 +121,10 @@ func Test_ORM_ListManagers(t *testing.T) { } ) - id, err := orm.CreateManager(mgr) + id, err := orm.CreateManager(ctx, mgr) require.NoError(t, err) - mgrs, err := orm.ListManagers() + mgrs, err := orm.ListManagers(ctx) require.NoError(t, err) require.Len(t, mgrs, 1) @@ -136,6 +137,7 @@ func Test_ORM_ListManagers(t *testing.T) { func Test_ORM_ListManagersByIDs(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -146,10 +148,10 @@ func Test_ORM_ListManagersByIDs(t *testing.T) { } ) - id, err := orm.CreateManager(mgr) + id, err := orm.CreateManager(ctx, mgr) require.NoError(t, err) - mgrs, err := orm.ListManagersByIDs([]int64{id}) + mgrs, err := orm.ListManagersByIDs(ctx, []int64{id}) require.NoError(t, err) require.Equal(t, 1, len(mgrs)) @@ -162,6 +164,7 @@ func Test_ORM_ListManagersByIDs(t *testing.T) { func Test_ORM_UpdateManager(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -172,7 +175,7 @@ func Test_ORM_UpdateManager(t *testing.T) { } ) - id, err := orm.CreateManager(mgr) + id, err := orm.CreateManager(ctx, mgr) require.NoError(t, err) updatedMgr := feeds.FeedsManager{ @@ -182,10 +185,10 @@ func Test_ORM_UpdateManager(t *testing.T) { PublicKey: crypto.PublicKey([]byte("22222222222222222222222222222222")), } - err = orm.UpdateManager(updatedMgr) + err = orm.UpdateManager(ctx, updatedMgr) require.NoError(t, err) - actual, err := orm.GetManager(id) + actual, err := orm.GetManager(ctx, id) require.NoError(t, err) assert.Equal(t, updatedMgr.URI, actual.URI) @@ -197,6 +200,7 @@ func Test_ORM_UpdateManager(t *testing.T) { func Test_ORM_CreateChainConfig(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -224,10 +228,10 @@ func Test_ORM_CreateChainConfig(t *testing.T) { } ) - id, err := orm.CreateChainConfig(cfg1) + id, err := orm.CreateChainConfig(ctx, cfg1) require.NoError(t, err) - actual, err := orm.GetChainConfig(id) + actual, err := orm.GetChainConfig(ctx, id) require.NoError(t, err) assertChainConfigEqual(t, map[string]interface{}{ @@ -244,6 +248,7 @@ func Test_ORM_CreateChainConfig(t *testing.T) { func Test_ORM_CreateBatchChainConfig(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -264,12 +269,12 @@ func Test_ORM_CreateBatchChainConfig(t *testing.T) { } ) - ids, err := orm.CreateBatchChainConfig([]feeds.ChainConfig{cfg1, cfg2}) + ids, err := orm.CreateBatchChainConfig(ctx, []feeds.ChainConfig{cfg1, cfg2}) require.NoError(t, err) assert.Len(t, ids, 2) - actual, err := orm.GetChainConfig(ids[0]) + actual, err := orm.GetChainConfig(ctx, ids[0]) require.NoError(t, err) assertChainConfigEqual(t, map[string]interface{}{ @@ -283,7 +288,7 @@ func Test_ORM_CreateBatchChainConfig(t *testing.T) { "ocr2Config": cfg1.OCR2Config, }, *actual) - actual, err = orm.GetChainConfig(ids[1]) + actual, err = orm.GetChainConfig(ctx, ids[1]) require.NoError(t, err) assertChainConfigEqual(t, map[string]interface{}{ @@ -298,13 +303,14 @@ func Test_ORM_CreateBatchChainConfig(t *testing.T) { }, *actual) // Test empty configs - ids, err = orm.CreateBatchChainConfig([]feeds.ChainConfig{}) + ids, err = orm.CreateBatchChainConfig(ctx, []feeds.ChainConfig{}) require.NoError(t, err) require.Empty(t, ids) } func Test_ORM_DeleteChainConfig(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -318,22 +324,23 @@ func Test_ORM_DeleteChainConfig(t *testing.T) { } ) - id, err := orm.CreateChainConfig(cfg1) + id, err := orm.CreateChainConfig(ctx, cfg1) require.NoError(t, err) - _, err = orm.GetChainConfig(id) + _, err = orm.GetChainConfig(ctx, id) require.NoError(t, err) - actual, err := orm.DeleteChainConfig(id) + actual, err := orm.DeleteChainConfig(ctx, id) require.NoError(t, err) require.Equal(t, id, actual) - _, err = orm.GetChainConfig(id) + _, err = orm.GetChainConfig(ctx, id) require.Error(t, err) } func Test_ORM_ListChainConfigsByManagerIDs(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -361,10 +368,10 @@ func Test_ORM_ListChainConfigsByManagerIDs(t *testing.T) { } ) - _, err := orm.CreateChainConfig(cfg1) + _, err := orm.CreateChainConfig(ctx, cfg1) require.NoError(t, err) - actual, err := orm.ListChainConfigsByManagerIDs([]int64{fmID}) + actual, err := orm.ListChainConfigsByManagerIDs(ctx, []int64{fmID}) require.NoError(t, err) require.Len(t, actual, 1) @@ -382,6 +389,7 @@ func Test_ORM_ListChainConfigsByManagerIDs(t *testing.T) { func Test_ORM_UpdateChainConfig(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -414,15 +422,15 @@ func Test_ORM_UpdateChainConfig(t *testing.T) { } ) - id, err := orm.CreateChainConfig(cfg1) + id, err := orm.CreateChainConfig(ctx, cfg1) require.NoError(t, err) updateCfg.ID = id - id, err = orm.UpdateChainConfig(updateCfg) + id, err = orm.UpdateChainConfig(ctx, updateCfg) require.NoError(t, err) - actual, err := orm.GetChainConfig(id) + actual, err := orm.GetChainConfig(ctx, id) require.NoError(t, err) assertChainConfigEqual(t, map[string]interface{}{ @@ -441,6 +449,7 @@ func Test_ORM_UpdateChainConfig(t *testing.T) { func Test_ORM_CreateJobProposal(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) fmID := createFeedsManager(t, orm) @@ -452,14 +461,14 @@ func Test_ORM_CreateJobProposal(t *testing.T) { FeedsManagerID: fmID, } - count, err := orm.CountJobProposals() + count, err := orm.CountJobProposals(ctx) require.NoError(t, err) require.Equal(t, int64(0), count) - id, err := orm.CreateJobProposal(jp) + id, err := orm.CreateJobProposal(ctx, jp) require.NoError(t, err) - actual, err := orm.GetJobProposal(id) + actual, err := orm.GetJobProposal(ctx, id) require.NoError(t, err) require.Equal(t, jp.Name, actual.Name) require.Equal(t, jp.RemoteUUID, actual.RemoteUUID) @@ -474,6 +483,7 @@ func Test_ORM_CreateJobProposal(t *testing.T) { func Test_ORM_GetJobProposal(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) fmID := createFeedsManager(t, orm) @@ -495,10 +505,10 @@ func Test_ORM_GetJobProposal(t *testing.T) { FeedsManagerID: fmID, } - id, err := orm.CreateJobProposal(jp) + id, err := orm.CreateJobProposal(ctx, jp) require.NoError(t, err) - _, err = orm.CreateJobProposal(deletedJp) + _, err = orm.CreateJobProposal(ctx, deletedJp) require.NoError(t, err) assertJobEquals := func(actual *feeds.JobProposal) { @@ -512,32 +522,33 @@ func Test_ORM_GetJobProposal(t *testing.T) { } t.Run("by id", func(t *testing.T) { - actual, err := orm.GetJobProposal(id) + actual, err := orm.GetJobProposal(ctx, id) require.NoError(t, err) assert.Equal(t, id, actual.ID) assertJobEquals(actual) - _, err = orm.GetJobProposal(int64(0)) + _, err = orm.GetJobProposal(ctx, int64(0)) require.Error(t, err) }) t.Run("by remote uuid", func(t *testing.T) { - actual, err := orm.GetJobProposalByRemoteUUID(remoteUUID) + actual, err := orm.GetJobProposalByRemoteUUID(ctx, remoteUUID) require.NoError(t, err) assertJobEquals(actual) - _, err = orm.GetJobProposalByRemoteUUID(deletedUUID) + _, err = orm.GetJobProposalByRemoteUUID(ctx, deletedUUID) require.Error(t, err) - _, err = orm.GetJobProposalByRemoteUUID(uuid.New()) + _, err = orm.GetJobProposalByRemoteUUID(ctx, uuid.New()) require.Error(t, err) }) } func Test_ORM_ListJobProposals(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) fmID := createFeedsManager(t, orm) @@ -551,10 +562,10 @@ func Test_ORM_ListJobProposals(t *testing.T) { FeedsManagerID: fmID, } - id, err := orm.CreateJobProposal(jp) + id, err := orm.CreateJobProposal(ctx, jp) require.NoError(t, err) - jps, err := orm.ListJobProposals() + jps, err := orm.ListJobProposals(ctx) require.NoError(t, err) require.Len(t, jps, 1) @@ -573,13 +584,14 @@ func Test_ORM_CountJobProposalsByStatus(t *testing.T) { testCases := []struct { name string - before func(orm *TestORM) *feeds.JobProposalCounts + before func(t *testing.T, orm *TestORM) *feeds.JobProposalCounts wantApproved, wantRejected, wantDeleted, wantRevoked, wantPending, wantCancelled int64 }{ { name: "correctly counts when there are no job proposals", - before: func(orm *TestORM) *feeds.JobProposalCounts { - counts, err := orm.CountJobProposalsByStatus() + before: func(t *testing.T, orm *TestORM) *feeds.JobProposalCounts { + ctx := testutils.Context(t) + counts, err := orm.CountJobProposalsByStatus(ctx) require.NoError(t, err) return counts @@ -587,12 +599,13 @@ func Test_ORM_CountJobProposalsByStatus(t *testing.T) { }, { name: "correctly counts a pending and cancelled job proposal by status", - before: func(orm *TestORM) *feeds.JobProposalCounts { + before: func(t *testing.T, orm *TestORM) *feeds.JobProposalCounts { + ctx := testutils.Context(t) fmID := createFeedsManager(t, orm) createJobProposal(t, orm, feeds.JobProposalStatusPending, fmID) createJobProposal(t, orm, feeds.JobProposalStatusCancelled, fmID) - counts, err := orm.CountJobProposalsByStatus() + counts, err := orm.CountJobProposalsByStatus(ctx) require.NoError(t, err) return counts @@ -604,12 +617,13 @@ func Test_ORM_CountJobProposalsByStatus(t *testing.T) { // Verify that the counts are correct even if the proposal status is not pending. A // spec is considered pending if its status is pending OR pending_update is TRUE name: "correctly counts the pending specs when pending_update is true but the status itself is not pending", - before: func(orm *TestORM) *feeds.JobProposalCounts { + before: func(t *testing.T, orm *TestORM) *feeds.JobProposalCounts { + ctx := testutils.Context(t) fmID := createFeedsManager(t, orm) // Create a pending job proposal. jUUID := uuid.New() - jpID, err := orm.CreateJobProposal(&feeds.JobProposal{ + jpID, err := orm.CreateJobProposal(ctx, &feeds.JobProposal{ RemoteUUID: jUUID, Status: feeds.JobProposalStatusPending, FeedsManagerID: fmID, @@ -617,7 +631,7 @@ func Test_ORM_CountJobProposalsByStatus(t *testing.T) { require.NoError(t, err) // Upsert the proposal and change its status to rejected - _, err = orm.UpsertJobProposal(&feeds.JobProposal{ + _, err = orm.UpsertJobProposal(ctx, &feeds.JobProposal{ RemoteUUID: jUUID, Status: feeds.JobProposalStatusRejected, FeedsManagerID: fmID, @@ -625,11 +639,11 @@ func Test_ORM_CountJobProposalsByStatus(t *testing.T) { require.NoError(t, err) // Assert that the upserted job proposal is now pending update. - jp, err := orm.GetJobProposal(jpID) + jp, err := orm.GetJobProposal(ctx, jpID) require.NoError(t, err) assert.Equal(t, true, jp.PendingUpdate) - counts, err := orm.CountJobProposalsByStatus() + counts, err := orm.CountJobProposalsByStatus(ctx) require.NoError(t, err) return counts @@ -638,7 +652,8 @@ func Test_ORM_CountJobProposalsByStatus(t *testing.T) { }, { name: "correctly counts when approving a job proposal", - before: func(orm *TestORM) *feeds.JobProposalCounts { + before: func(t *testing.T, orm *TestORM) *feeds.JobProposalCounts { + ctx := testutils.Context(t) fmID := createFeedsManager(t, orm) // Create a pending job proposal. @@ -649,15 +664,15 @@ func Test_ORM_CountJobProposalsByStatus(t *testing.T) { specID := createJobSpec(t, orm, jpID) // Defer the FK requirement of an existing job for a job proposal to be approved - require.NoError(t, utils.JustError(orm.db.Exec( + require.NoError(t, utils.JustError(orm.db.ExecContext(ctx, `SET CONSTRAINTS job_proposals_job_id_fkey DEFERRED`, ))) // Approve the pending job proposal. - err := orm.ApproveSpec(specID, jUUID) + err := orm.ApproveSpec(ctx, specID, jUUID) require.NoError(t, err) - counts, err := orm.CountJobProposalsByStatus() + counts, err := orm.CountJobProposalsByStatus(ctx) require.NoError(t, err) return counts @@ -666,16 +681,17 @@ func Test_ORM_CountJobProposalsByStatus(t *testing.T) { }, { name: "correctly counts when revoking a job proposal", - before: func(orm *TestORM) *feeds.JobProposalCounts { + before: func(t *testing.T, orm *TestORM) *feeds.JobProposalCounts { + ctx := testutils.Context(t) fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusPending, fmID) specID := createJobSpec(t, orm, jpID) // Revoke the pending job proposal. - err := orm.RevokeSpec(specID) + err := orm.RevokeSpec(ctx, specID) require.NoError(t, err) - counts, err := orm.CountJobProposalsByStatus() + counts, err := orm.CountJobProposalsByStatus(ctx) require.NoError(t, err) return counts @@ -684,16 +700,17 @@ func Test_ORM_CountJobProposalsByStatus(t *testing.T) { }, { name: "correctly counts when deleting a job proposal", - before: func(orm *TestORM) *feeds.JobProposalCounts { + before: func(t *testing.T, orm *TestORM) *feeds.JobProposalCounts { + ctx := testutils.Context(t) fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusPending, fmID) createJobSpec(t, orm, jpID) // Delete the pending job proposal. - err := orm.DeleteProposal(jpID) + err := orm.DeleteProposal(ctx, jpID) require.NoError(t, err) - counts, err := orm.CountJobProposalsByStatus() + counts, err := orm.CountJobProposalsByStatus(ctx) require.NoError(t, err) return counts @@ -702,12 +719,13 @@ func Test_ORM_CountJobProposalsByStatus(t *testing.T) { }, { name: "correctly counts when deleting a job proposal with an approved spec", - before: func(orm *TestORM) *feeds.JobProposalCounts { + before: func(t *testing.T, orm *TestORM) *feeds.JobProposalCounts { + ctx := testutils.Context(t) fmID := createFeedsManager(t, orm) // Create a pending job proposal. jUUID := uuid.New() - jpID, err := orm.CreateJobProposal(&feeds.JobProposal{ + jpID, err := orm.CreateJobProposal(ctx, &feeds.JobProposal{ RemoteUUID: jUUID, Status: feeds.JobProposalStatusPending, FeedsManagerID: fmID, @@ -719,18 +737,18 @@ func Test_ORM_CountJobProposalsByStatus(t *testing.T) { specID := createJobSpec(t, orm, jpID) // Defer the FK requirement of an existing job for a job proposal to be approved - require.NoError(t, utils.JustError(orm.db.Exec( + require.NoError(t, utils.JustError(orm.db.ExecContext(ctx, `SET CONSTRAINTS job_proposals_job_id_fkey DEFERRED`, ))) - err = orm.ApproveSpec(specID, jUUID) + err = orm.ApproveSpec(ctx, specID, jUUID) require.NoError(t, err) // Delete the pending job proposal. - err = orm.DeleteProposal(jpID) + err = orm.DeleteProposal(ctx, jpID) require.NoError(t, err) - counts, err := orm.CountJobProposalsByStatus() + counts, err := orm.CountJobProposalsByStatus(ctx) require.NoError(t, err) return counts @@ -745,7 +763,7 @@ func Test_ORM_CountJobProposalsByStatus(t *testing.T) { t.Run(tc.name, func(t *testing.T) { orm := setupORM(t) - counts := tc.before(orm) + counts := tc.before(t, orm) assert.Equal(t, tc.wantPending, counts.Pending) assert.Equal(t, tc.wantApproved, counts.Approved) @@ -759,6 +777,7 @@ func Test_ORM_CountJobProposalsByStatus(t *testing.T) { func Test_ORM_ListJobProposalByManagersIDs(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) fmID := createFeedsManager(t, orm) @@ -772,10 +791,10 @@ func Test_ORM_ListJobProposalByManagersIDs(t *testing.T) { FeedsManagerID: fmID, } - id, err := orm.CreateJobProposal(jp) + id, err := orm.CreateJobProposal(ctx, jp) require.NoError(t, err) - jps, err := orm.ListJobProposalsByManagersIDs([]int64{fmID}) + jps, err := orm.ListJobProposalsByManagersIDs(ctx, []int64{fmID}) require.NoError(t, err) require.Len(t, jps, 1) @@ -791,18 +810,19 @@ func Test_ORM_ListJobProposalByManagersIDs(t *testing.T) { func Test_ORM_UpdateJobProposalStatus(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusPending, fmID) - actualCreated, err := orm.GetJobProposal(jpID) + actualCreated, err := orm.GetJobProposal(ctx, jpID) require.NoError(t, err) - err = orm.UpdateJobProposalStatus(jpID, feeds.JobProposalStatusRejected) + err = orm.UpdateJobProposalStatus(ctx, jpID, feeds.JobProposalStatusRejected) require.NoError(t, err) - actual, err := orm.GetJobProposal(jpID) + actual, err := orm.GetJobProposal(ctx, jpID) require.NoError(t, err) assert.Equal(t, jpID, actual.ID) @@ -812,6 +832,7 @@ func Test_ORM_UpdateJobProposalStatus(t *testing.T) { func Test_ORM_UpsertJobProposal(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -833,19 +854,19 @@ func Test_ORM_UpsertJobProposal(t *testing.T) { // from pending to approved, and then approved to pending, and pending to deleted and so forth. // Create - count, err := orm.CountJobProposals() + count, err := orm.CountJobProposals(ctx) require.NoError(t, err) require.Equal(t, int64(0), count) - jpID, err := orm.UpsertJobProposal(jp) + jpID, err := orm.UpsertJobProposal(ctx, jp) require.NoError(t, err) - createdActual, err := orm.GetJobProposal(jpID) + createdActual, err := orm.GetJobProposal(ctx, jpID) require.NoError(t, err) assert.False(t, createdActual.PendingUpdate) - count, err = orm.CountJobProposals() + count, err = orm.CountJobProposals(ctx) require.NoError(t, err) require.Equal(t, int64(1), count) @@ -855,10 +876,10 @@ func Test_ORM_UpsertJobProposal(t *testing.T) { jp.Multiaddrs = pq.StringArray{"dns/example.com"} jp.Name = null.StringFrom("jp1_updated") - jpID, err = orm.UpsertJobProposal(jp) + jpID, err = orm.UpsertJobProposal(ctx, jp) require.NoError(t, err) - actual, err := orm.GetJobProposal(jpID) + actual, err := orm.GetJobProposal(ctx, jpID) require.NoError(t, err) assert.Equal(t, jp.Name, actual.Name) assert.Equal(t, jp.Status, actual.Status) @@ -874,14 +895,14 @@ func Test_ORM_UpsertJobProposal(t *testing.T) { specID := createJobSpec(t, orm, jpID) // Defer the FK requirement of an existing job for a job proposal. - require.NoError(t, utils.JustError(orm.db.Exec( + require.NoError(t, utils.JustError(orm.db.ExecContext(ctx, `SET CONSTRAINTS job_proposals_job_id_fkey DEFERRED`, ))) - err = orm.ApproveSpec(specID, externalJobID.UUID) + err = orm.ApproveSpec(ctx, specID, externalJobID.UUID) require.NoError(t, err) - actual, err = orm.GetJobProposal(jpID) + actual, err = orm.GetJobProposal(ctx, jpID) require.NoError(t, err) // Assert that the job proposal is now approved. @@ -893,10 +914,10 @@ func Test_ORM_UpsertJobProposal(t *testing.T) { jp.Name = null.StringFrom("jp1_updated_again") jp.Status = feeds.JobProposalStatusPending - _, err = orm.UpsertJobProposal(jp) + _, err = orm.UpsertJobProposal(ctx, jp) require.NoError(t, err) - actual, err = orm.GetJobProposal(jpID) + actual, err = orm.GetJobProposal(ctx, jpID) require.NoError(t, err) assert.Equal(t, feeds.JobProposalStatusApproved, actual.Status) @@ -904,10 +925,10 @@ func Test_ORM_UpsertJobProposal(t *testing.T) { assert.True(t, actual.PendingUpdate) // Delete the proposal - err = orm.DeleteProposal(jpID) + err = orm.DeleteProposal(ctx, jpID) require.NoError(t, err) - actual, err = orm.GetJobProposal(jpID) + actual, err = orm.GetJobProposal(ctx, jpID) require.NoError(t, err) assert.Equal(t, feeds.JobProposalStatusDeleted, actual.Status) @@ -915,11 +936,11 @@ func Test_ORM_UpsertJobProposal(t *testing.T) { // Update deleted proposal jp.Status = feeds.JobProposalStatusRejected - jpID, err = orm.UpsertJobProposal(jp) + jpID, err = orm.UpsertJobProposal(ctx, jp) require.NoError(t, err) // Ensure the deleted proposal does not get updated - actual, err = orm.GetJobProposal(jpID) + actual, err = orm.GetJobProposal(ctx, jpID) require.NoError(t, err) assert.NotEqual(t, jp.Status, actual.Status) assert.Equal(t, feeds.JobProposalStatusDeleted, actual.Status) @@ -929,6 +950,7 @@ func Test_ORM_UpsertJobProposal(t *testing.T) { func Test_ORM_ApproveSpec(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -937,7 +959,7 @@ func Test_ORM_ApproveSpec(t *testing.T) { ) // Manually create the job proposal to set pending update - jpID, err := orm.CreateJobProposal(&feeds.JobProposal{ + jpID, err := orm.CreateJobProposal(ctx, &feeds.JobProposal{ RemoteUUID: uuid.New(), Status: feeds.JobProposalStatusPending, FeedsManagerID: fmID, @@ -947,20 +969,20 @@ func Test_ORM_ApproveSpec(t *testing.T) { specID := createJobSpec(t, orm, jpID) // Defer the FK requirement of an existing job for a job proposal. - require.NoError(t, utils.JustError(orm.db.Exec( + require.NoError(t, utils.JustError(orm.db.ExecContext(ctx, `SET CONSTRAINTS job_proposals_job_id_fkey DEFERRED`, ))) - err = orm.ApproveSpec(specID, externalJobID.UUID) + err = orm.ApproveSpec(ctx, specID, externalJobID.UUID) require.NoError(t, err) - actual, err := orm.GetSpec(specID) + actual, err := orm.GetSpec(ctx, specID) require.NoError(t, err) assert.Equal(t, specID, actual.ID) assert.Equal(t, feeds.SpecStatusApproved, actual.Status) - actualJP, err := orm.GetJobProposal(jpID) + actualJP, err := orm.GetJobProposal(ctx, jpID) require.NoError(t, err) assert.Equal(t, externalJobID, actualJP.ExternalJobID) @@ -973,14 +995,14 @@ func Test_ORM_CancelSpec(t *testing.T) { testCases := []struct { name string - before func(orm *TestORM) (int64, int64) + before func(t *testing.T, orm *TestORM) (int64, int64) wantSpecStatus feeds.SpecStatus wantProposalStatus feeds.JobProposalStatus wantErr string }{ { name: "pending proposal", - before: func(orm *TestORM) (int64, int64) { + before: func(t *testing.T, orm *TestORM) (int64, int64) { fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusPending, fmID) specID := createJobSpec(t, orm, jpID) @@ -992,7 +1014,7 @@ func Test_ORM_CancelSpec(t *testing.T) { }, { name: "deleted proposal", - before: func(orm *TestORM) (int64, int64) { + before: func(t *testing.T, orm *TestORM) (int64, int64) { fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusDeleted, fmID) specID := createJobSpec(t, orm, jpID) @@ -1004,7 +1026,7 @@ func Test_ORM_CancelSpec(t *testing.T) { }, { name: "not found", - before: func(orm *TestORM) (int64, int64) { + before: func(t *testing.T, orm *TestORM) (int64, int64) { return 0, 0 }, wantErr: "sql: no rows in result set", @@ -1015,24 +1037,25 @@ func Test_ORM_CancelSpec(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { + ctx := testutils.Context(t) orm := setupORM(t) - jpID, specID := tc.before(orm) + jpID, specID := tc.before(t, orm) - err := orm.CancelSpec(specID) + err := orm.CancelSpec(ctx, specID) if tc.wantErr != "" { require.EqualError(t, err, tc.wantErr) } else { require.NoError(t, err) - actual, err := orm.GetSpec(specID) + actual, err := orm.GetSpec(ctx, specID) require.NoError(t, err) assert.Equal(t, specID, actual.ID) assert.Equal(t, tc.wantSpecStatus, actual.Status) - actualJP, err := orm.GetJobProposal(jpID) + actualJP, err := orm.GetJobProposal(ctx, jpID) require.NoError(t, err) assert.Equal(t, tc.wantProposalStatus, actualJP.Status) @@ -1047,14 +1070,14 @@ func Test_ORM_DeleteProposal(t *testing.T) { testCases := []struct { name string - before func(orm *TestORM) int64 + before func(t *testing.T, orm *TestORM) int64 wantProposalStatus feeds.JobProposalStatus wantProposalPendingUpdate bool wantErr string }{ { name: "pending proposal", - before: func(orm *TestORM) int64 { + before: func(t *testing.T, orm *TestORM) int64 { fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusPending, fmID) createJobSpec(t, orm, jpID) @@ -1066,7 +1089,8 @@ func Test_ORM_DeleteProposal(t *testing.T) { }, { name: "approved proposal with approved spec", - before: func(orm *TestORM) int64 { + before: func(t *testing.T, orm *TestORM) int64 { + ctx := testutils.Context(t) fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusPending, fmID) specID := createJobSpec(t, orm, jpID) @@ -1074,11 +1098,11 @@ func Test_ORM_DeleteProposal(t *testing.T) { externalJobID := uuid.NullUUID{UUID: uuid.New(), Valid: true} // Defer the FK requirement of an existing job for a job proposal. - require.NoError(t, utils.JustError(orm.db.Exec( + require.NoError(t, utils.JustError(orm.db.ExecContext(ctx, `SET CONSTRAINTS job_proposals_job_id_fkey DEFERRED`, ))) - err := orm.ApproveSpec(specID, externalJobID.UUID) + err := orm.ApproveSpec(ctx, specID, externalJobID.UUID) require.NoError(t, err) return jpID @@ -1088,7 +1112,8 @@ func Test_ORM_DeleteProposal(t *testing.T) { }, { name: "approved proposal with pending spec", - before: func(orm *TestORM) int64 { + before: func(t *testing.T, orm *TestORM) int64 { + ctx := testutils.Context(t) fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusPending, fmID) specID := createJobSpec(t, orm, jpID) @@ -1096,25 +1121,25 @@ func Test_ORM_DeleteProposal(t *testing.T) { externalJobID := uuid.NullUUID{UUID: uuid.New(), Valid: true} // Defer the FK requirement of an existing job for a job proposal. - require.NoError(t, utils.JustError(orm.db.Exec( + require.NoError(t, utils.JustError(orm.db.ExecContext(ctx, `SET CONSTRAINTS job_proposals_job_id_fkey DEFERRED`, ))) - err := orm.ApproveSpec(specID, externalJobID.UUID) + err := orm.ApproveSpec(ctx, specID, externalJobID.UUID) require.NoError(t, err) - jp, err := orm.GetJobProposal(jpID) + jp, err := orm.GetJobProposal(ctx, jpID) require.NoError(t, err) // Update the proposal to pending and create a new pending spec - _, err = orm.UpsertJobProposal(&feeds.JobProposal{ + _, err = orm.UpsertJobProposal(ctx, &feeds.JobProposal{ RemoteUUID: jp.RemoteUUID, Status: feeds.JobProposalStatusPending, FeedsManagerID: fmID, }) require.NoError(t, err) - _, err = orm.CreateSpec(feeds.JobProposalSpec{ + _, err = orm.CreateSpec(ctx, feeds.JobProposalSpec{ Definition: "spec data", Version: 2, Status: feeds.SpecStatusPending, @@ -1129,7 +1154,7 @@ func Test_ORM_DeleteProposal(t *testing.T) { }, { name: "cancelled proposal", - before: func(orm *TestORM) int64 { + before: func(t *testing.T, orm *TestORM) int64 { fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusCancelled, fmID) createJobSpec(t, orm, jpID) @@ -1141,7 +1166,7 @@ func Test_ORM_DeleteProposal(t *testing.T) { }, { name: "rejected proposal", - before: func(orm *TestORM) int64 { + before: func(t *testing.T, orm *TestORM) int64 { fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusRejected, fmID) createJobSpec(t, orm, jpID) @@ -1153,7 +1178,7 @@ func Test_ORM_DeleteProposal(t *testing.T) { }, { name: "not found spec", - before: func(orm *TestORM) int64 { + before: func(t *testing.T, orm *TestORM) int64 { fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusRejected, fmID) @@ -1163,7 +1188,7 @@ func Test_ORM_DeleteProposal(t *testing.T) { }, { name: "not found proposal", - before: func(orm *TestORM) int64 { + before: func(t *testing.T, orm *TestORM) int64 { return 0 }, wantErr: "sql: no rows in result set", @@ -1174,18 +1199,19 @@ func Test_ORM_DeleteProposal(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { + ctx := testutils.Context(t) orm := setupORM(t) - jpID := tc.before(orm) + jpID := tc.before(t, orm) - err := orm.DeleteProposal(jpID) + err := orm.DeleteProposal(ctx, jpID) if tc.wantErr != "" { require.EqualError(t, err, tc.wantErr) } else { require.NoError(t, err) - actual, err := orm.GetJobProposal(jpID) + actual, err := orm.GetJobProposal(ctx, jpID) require.NoError(t, err) assert.Equal(t, jpID, actual.ID) @@ -1201,7 +1227,7 @@ func Test_ORM_RevokeSpec(t *testing.T) { testCases := []struct { name string - before func(orm *TestORM) (int64, int64) + before func(t *testing.T, orm *TestORM) (int64, int64) wantProposalStatus feeds.JobProposalStatus wantSpecStatus feeds.SpecStatus wantErr string @@ -1209,7 +1235,7 @@ func Test_ORM_RevokeSpec(t *testing.T) { }{ { name: "pending proposal", - before: func(orm *TestORM) (int64, int64) { + before: func(t *testing.T, orm *TestORM) (int64, int64) { fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusPending, fmID) specID := createJobSpec(t, orm, jpID) @@ -1222,7 +1248,8 @@ func Test_ORM_RevokeSpec(t *testing.T) { }, { name: "approved proposal", - before: func(orm *TestORM) (int64, int64) { + before: func(t *testing.T, orm *TestORM) (int64, int64) { + ctx := testutils.Context(t) fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusPending, fmID) specID := createJobSpec(t, orm, jpID) @@ -1230,11 +1257,11 @@ func Test_ORM_RevokeSpec(t *testing.T) { externalJobID := uuid.NullUUID{UUID: uuid.New(), Valid: true} // Defer the FK requirement of an existing job for a job proposal. - require.NoError(t, utils.JustError(orm.db.Exec( + require.NoError(t, utils.JustError(orm.db.ExecContext(ctx, `SET CONSTRAINTS job_proposals_job_id_fkey DEFERRED`, ))) - err := orm.ApproveSpec(specID, externalJobID.UUID) + err := orm.ApproveSpec(ctx, specID, externalJobID.UUID) require.NoError(t, err) return jpID, specID @@ -1244,7 +1271,7 @@ func Test_ORM_RevokeSpec(t *testing.T) { }, { name: "cancelled proposal", - before: func(orm *TestORM) (int64, int64) { + before: func(t *testing.T, orm *TestORM) (int64, int64) { fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusCancelled, fmID) specID := createJobSpec(t, orm, jpID) @@ -1256,7 +1283,7 @@ func Test_ORM_RevokeSpec(t *testing.T) { }, { name: "rejected proposal", - before: func(orm *TestORM) (int64, int64) { + before: func(t *testing.T, orm *TestORM) (int64, int64) { fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusRejected, fmID) specID := createJobSpec(t, orm, jpID) @@ -1268,7 +1295,7 @@ func Test_ORM_RevokeSpec(t *testing.T) { }, { name: "deleted proposal", - before: func(orm *TestORM) (int64, int64) { + before: func(t *testing.T, orm *TestORM) (int64, int64) { fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusDeleted, fmID) specID := createJobSpec(t, orm, jpID) @@ -1280,7 +1307,7 @@ func Test_ORM_RevokeSpec(t *testing.T) { }, { name: "not found", - before: func(orm *TestORM) (int64, int64) { + before: func(t *testing.T, orm *TestORM) (int64, int64) { return 0, 0 }, wantErr: "sql: no rows in result set", @@ -1291,18 +1318,19 @@ func Test_ORM_RevokeSpec(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { + ctx := testutils.Context(t) orm := setupORM(t) - jpID, specID := tc.before(orm) + jpID, specID := tc.before(t, orm) - err := orm.RevokeSpec(specID) + err := orm.RevokeSpec(ctx, specID) if tc.wantErr != "" { require.EqualError(t, err, tc.wantErr) } else { require.NoError(t, err) - actualJP, err := orm.GetJobProposal(jpID) + actualJP, err := orm.GetJobProposal(ctx, jpID) require.NoError(t, err) assert.Equal(t, tc.wantProposalStatus, actualJP.Status) @@ -1318,6 +1346,7 @@ func Test_ORM_RevokeSpec(t *testing.T) { func Test_ORM_ExistsSpecByJobProposalIDAndVersion(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -1327,17 +1356,18 @@ func Test_ORM_ExistsSpecByJobProposalIDAndVersion(t *testing.T) { createJobSpec(t, orm, jpID) - exists, err := orm.ExistsSpecByJobProposalIDAndVersion(jpID, 1) + exists, err := orm.ExistsSpecByJobProposalIDAndVersion(ctx, jpID, 1) require.NoError(t, err) require.True(t, exists) - exists, err = orm.ExistsSpecByJobProposalIDAndVersion(jpID, 2) + exists, err = orm.ExistsSpecByJobProposalIDAndVersion(ctx, jpID, 2) require.NoError(t, err) require.False(t, exists) } func Test_ORM_GetSpec(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -1346,7 +1376,7 @@ func Test_ORM_GetSpec(t *testing.T) { specID = createJobSpec(t, orm, jpID) ) - actual, err := orm.GetSpec(specID) + actual, err := orm.GetSpec(ctx, specID) require.NoError(t, err) assert.Equal(t, "spec data", actual.Definition) @@ -1357,6 +1387,7 @@ func Test_ORM_GetSpec(t *testing.T) { func Test_ORM_GetApprovedSpec(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -1368,23 +1399,23 @@ func Test_ORM_GetApprovedSpec(t *testing.T) { // Defer the FK requirement of a job proposal so we don't have to setup a // real job. - require.NoError(t, utils.JustError(orm.db.Exec( + require.NoError(t, utils.JustError(orm.db.ExecContext(ctx, `SET CONSTRAINTS job_proposals_job_id_fkey DEFERRED`, ))) - err := orm.ApproveSpec(specID, externalJobID.UUID) + err := orm.ApproveSpec(ctx, specID, externalJobID.UUID) require.NoError(t, err) - actual, err := orm.GetApprovedSpec(jpID) + actual, err := orm.GetApprovedSpec(ctx, jpID) require.NoError(t, err) assert.Equal(t, specID, actual.ID) assert.Equal(t, feeds.SpecStatusApproved, actual.Status) - err = orm.CancelSpec(specID) + err = orm.CancelSpec(ctx, specID) require.NoError(t, err) - _, err = orm.GetApprovedSpec(jpID) + _, err = orm.GetApprovedSpec(ctx, jpID) require.Error(t, err) assert.ErrorIs(t, err, sql.ErrNoRows) @@ -1392,6 +1423,7 @@ func Test_ORM_GetApprovedSpec(t *testing.T) { func Test_ORM_GetLatestSpec(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -1400,7 +1432,7 @@ func Test_ORM_GetLatestSpec(t *testing.T) { ) _ = createJobSpec(t, orm, jpID) - spec2ID, err := orm.CreateSpec(feeds.JobProposalSpec{ + spec2ID, err := orm.CreateSpec(ctx, feeds.JobProposalSpec{ Definition: "spec data", Version: 2, Status: feeds.SpecStatusPending, @@ -1408,7 +1440,7 @@ func Test_ORM_GetLatestSpec(t *testing.T) { }) require.NoError(t, err) - actual, err := orm.GetSpec(spec2ID) + actual, err := orm.GetSpec(ctx, spec2ID) require.NoError(t, err) assert.Equal(t, spec2ID, actual.ID) @@ -1420,6 +1452,7 @@ func Test_ORM_GetLatestSpec(t *testing.T) { func Test_ORM_ListSpecsByJobProposalIDs(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -1433,7 +1466,7 @@ func Test_ORM_ListSpecsByJobProposalIDs(t *testing.T) { createJobSpec(t, orm, jp1ID) createJobSpec(t, orm, jp2ID) - specs, err := orm.ListSpecsByJobProposalIDs([]int64{jp1ID, jp2ID}) + specs, err := orm.ListSpecsByJobProposalIDs(ctx, []int64{jp1ID, jp2ID}) require.NoError(t, err) require.Len(t, specs, 2) @@ -1457,14 +1490,14 @@ func Test_ORM_RejectSpec(t *testing.T) { testCases := []struct { name string - before func(orm *TestORM) (int64, int64) + before func(t *testing.T, orm *TestORM) (int64, int64) wantSpecStatus feeds.SpecStatus wantProposalStatus feeds.JobProposalStatus wantErr string }{ { name: "pending proposal", - before: func(orm *TestORM) (int64, int64) { + before: func(t *testing.T, orm *TestORM) (int64, int64) { fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusPending, fmID) specID := createJobSpec(t, orm, jpID) @@ -1476,7 +1509,8 @@ func Test_ORM_RejectSpec(t *testing.T) { }, { name: "approved proposal", - before: func(orm *TestORM) (int64, int64) { + before: func(t *testing.T, orm *TestORM) (int64, int64) { + ctx := testutils.Context(t) fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusPending, fmID) specID := createJobSpec(t, orm, jpID) @@ -1484,11 +1518,11 @@ func Test_ORM_RejectSpec(t *testing.T) { externalJobID := uuid.NullUUID{UUID: uuid.New(), Valid: true} // Defer the FK requirement of an existing job for a job proposal. - require.NoError(t, utils.JustError(orm.db.Exec( + require.NoError(t, utils.JustError(orm.db.ExecContext(ctx, `SET CONSTRAINTS job_proposals_job_id_fkey DEFERRED`, ))) - err := orm.ApproveSpec(specID, externalJobID.UUID) + err := orm.ApproveSpec(ctx, specID, externalJobID.UUID) require.NoError(t, err) return jpID, specID @@ -1498,7 +1532,7 @@ func Test_ORM_RejectSpec(t *testing.T) { }, { name: "cancelled proposal", - before: func(orm *TestORM) (int64, int64) { + before: func(t *testing.T, orm *TestORM) (int64, int64) { fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusCancelled, fmID) specID := createJobSpec(t, orm, jpID) @@ -1510,7 +1544,7 @@ func Test_ORM_RejectSpec(t *testing.T) { }, { name: "deleted proposal", - before: func(orm *TestORM) (int64, int64) { + before: func(t *testing.T, orm *TestORM) (int64, int64) { fmID := createFeedsManager(t, orm) jpID := createJobProposal(t, orm, feeds.JobProposalStatusDeleted, fmID) specID := createJobSpec(t, orm, jpID) @@ -1522,7 +1556,7 @@ func Test_ORM_RejectSpec(t *testing.T) { }, { name: "not found", - before: func(orm *TestORM) (int64, int64) { + before: func(t *testing.T, orm *TestORM) (int64, int64) { return 0, 0 }, wantErr: "sql: no rows in result set", @@ -1533,24 +1567,25 @@ func Test_ORM_RejectSpec(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { + ctx := testutils.Context(t) orm := setupORM(t) - jpID, specID := tc.before(orm) + jpID, specID := tc.before(t, orm) - err := orm.RejectSpec(specID) + err := orm.RejectSpec(ctx, specID) if tc.wantErr != "" { require.EqualError(t, err, tc.wantErr) } else { require.NoError(t, err) - actual, err := orm.GetSpec(specID) + actual, err := orm.GetSpec(ctx, specID) require.NoError(t, err) assert.Equal(t, specID, actual.ID) assert.Equal(t, tc.wantSpecStatus, actual.Status) - actualJP, err := orm.GetJobProposal(jpID) + actualJP, err := orm.GetJobProposal(ctx, jpID) require.NoError(t, err) assert.Equal(t, tc.wantProposalStatus, actualJP.Status) @@ -1562,6 +1597,7 @@ func Test_ORM_RejectSpec(t *testing.T) { func Test_ORM_UpdateSpecDefinition(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -1570,13 +1606,13 @@ func Test_ORM_UpdateSpecDefinition(t *testing.T) { specID = createJobSpec(t, orm, jpID) ) - prev, err := orm.GetSpec(specID) + prev, err := orm.GetSpec(ctx, specID) require.NoError(t, err) - err = orm.UpdateSpecDefinition(specID, "updated spec") + err = orm.UpdateSpecDefinition(ctx, specID, "updated spec") require.NoError(t, err) - actual, err := orm.GetSpec(specID) + actual, err := orm.GetSpec(ctx, specID) require.NoError(t, err) assert.Equal(t, specID, actual.ID) @@ -1584,7 +1620,7 @@ func Test_ORM_UpdateSpecDefinition(t *testing.T) { require.Equal(t, "updated spec", actual.Definition) // Not found - err = orm.UpdateSpecDefinition(-1, "updated spec") + err = orm.UpdateSpecDefinition(ctx, -1, "updated spec") require.Error(t, err) } @@ -1592,6 +1628,7 @@ func Test_ORM_UpdateSpecDefinition(t *testing.T) { func Test_ORM_IsJobManaged(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( orm = setupORM(t) @@ -1603,14 +1640,14 @@ func Test_ORM_IsJobManaged(t *testing.T) { j := createJob(t, orm.db, externalJobID.UUID) - isManaged, err := orm.IsJobManaged(int64(j.ID)) + isManaged, err := orm.IsJobManaged(ctx, int64(j.ID)) require.NoError(t, err) assert.False(t, isManaged) - err = orm.ApproveSpec(specID, externalJobID.UUID) + err = orm.ApproveSpec(ctx, specID, externalJobID.UUID) require.NoError(t, err) - isManaged, err = orm.IsJobManaged(int64(j.ID)) + isManaged, err = orm.IsJobManaged(ctx, int64(j.ID)) require.NoError(t, err) assert.True(t, isManaged) } @@ -1640,7 +1677,8 @@ func createFeedsManager(t *testing.T, orm feeds.ORM) int64 { PublicKey: publicKey, } - id, err := orm.CreateManager(mgr) + ctx := testutils.Context(t) + id, err := orm.CreateManager(ctx, mgr) require.NoError(t, err) return id @@ -1658,7 +1696,7 @@ func createJob(t *testing.T, db *sqlx.DB, externalJobID uuid.UUID) *job.Job { bridgeORM = bridges.NewORM(db) relayExtenders = evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) ) - orm := job.NewORM(db, pipelineORM, bridgeORM, keyStore, lggr, config.Database()) + orm := job.NewORM(db, pipelineORM, bridgeORM, keyStore, lggr) require.NoError(t, keyStore.OCR().Add(ctx, cltest.DefaultOCRKey)) require.NoError(t, keyStore.P2P().Add(ctx, cltest.DefaultP2PKey)) @@ -1679,7 +1717,7 @@ func createJob(t *testing.T, db *sqlx.DB, externalJobID uuid.UUID) *job.Job { ) require.NoError(t, err) - err = orm.CreateJob(&jb) + err = orm.CreateJob(ctx, &jb) require.NoError(t, err) return &jb @@ -1688,7 +1726,8 @@ func createJob(t *testing.T, db *sqlx.DB, externalJobID uuid.UUID) *job.Job { func createJobProposal(t *testing.T, orm feeds.ORM, status feeds.JobProposalStatus, fmID int64) int64 { t.Helper() - id, err := orm.CreateJobProposal(&feeds.JobProposal{ + ctx := testutils.Context(t) + id, err := orm.CreateJobProposal(ctx, &feeds.JobProposal{ RemoteUUID: uuid.New(), Status: status, FeedsManagerID: fmID, @@ -1702,7 +1741,8 @@ func createJobProposal(t *testing.T, orm feeds.ORM, status feeds.JobProposalStat func createJobSpec(t *testing.T, orm feeds.ORM, jpID int64) int64 { t.Helper() - id, err := orm.CreateSpec(feeds.JobProposalSpec{ + ctx := testutils.Context(t) + id, err := orm.CreateSpec(ctx, feeds.JobProposalSpec{ Definition: "spec data", Version: 1, Status: feeds.SpecStatusPending, diff --git a/core/services/feeds/service.go b/core/services/feeds/service.go index 8c4ea7a36bf..d6032befbdc 100644 --- a/core/services/feeds/service.go +++ b/core/services/feeds/service.go @@ -15,9 +15,8 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" "gopkg.in/guregu/null.v4" - "github.com/jmoiron/sqlx" - "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/plugins" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" @@ -33,7 +32,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/ocr" ocr2 "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/validate" "github.com/smartcontractkit/chainlink/v2/core/services/ocrbootstrap" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/utils/crypto" ) @@ -66,17 +64,17 @@ type Service interface { Start(ctx context.Context) error Close() error - CountManagers() (int64, error) - GetManager(id int64) (*FeedsManager, error) - ListManagers() ([]FeedsManager, error) - ListManagersByIDs(ids []int64) ([]FeedsManager, error) + CountManagers(ctx context.Context) (int64, error) + GetManager(ctx context.Context, id int64) (*FeedsManager, error) + ListManagers(ctx context.Context) ([]FeedsManager, error) + ListManagersByIDs(ctx context.Context, ids []int64) ([]FeedsManager, error) RegisterManager(ctx context.Context, params RegisterManagerParams) (int64, error) UpdateManager(ctx context.Context, mgr FeedsManager) error CreateChainConfig(ctx context.Context, cfg ChainConfig) (int64, error) DeleteChainConfig(ctx context.Context, id int64) (int64, error) - GetChainConfig(id int64) (*ChainConfig, error) - ListChainConfigsByManagerIDs(mgrIDs []int64) ([]ChainConfig, error) + GetChainConfig(ctx context.Context, id int64) (*ChainConfig, error) + ListChainConfigsByManagerIDs(ctx context.Context, mgrIDs []int64) ([]ChainConfig, error) UpdateChainConfig(ctx context.Context, cfg ChainConfig) (int64, error) DeleteJob(ctx context.Context, args *DeleteJobArgs) (int64, error) @@ -85,19 +83,17 @@ type Service interface { RevokeJob(ctx context.Context, args *RevokeJobArgs) (int64, error) SyncNodeInfo(ctx context.Context, id int64) error - CountJobProposalsByStatus() (*JobProposalCounts, error) - GetJobProposal(id int64) (*JobProposal, error) - ListJobProposals() ([]JobProposal, error) - ListJobProposalsByManagersIDs(ids []int64) ([]JobProposal, error) + CountJobProposalsByStatus(ctx context.Context) (*JobProposalCounts, error) + GetJobProposal(ctx context.Context, id int64) (*JobProposal, error) + ListJobProposals(ctx context.Context) ([]JobProposal, error) + ListJobProposalsByManagersIDs(ctx context.Context, ids []int64) ([]JobProposal, error) ApproveSpec(ctx context.Context, id int64, force bool) error CancelSpec(ctx context.Context, id int64) error - GetSpec(id int64) (*JobProposalSpec, error) - ListSpecsByJobProposalIDs(ids []int64) ([]JobProposalSpec, error) + GetSpec(ctx context.Context, id int64) (*JobProposalSpec, error) + ListSpecsByJobProposalIDs(ctx context.Context, ids []int64) ([]JobProposalSpec, error) RejectSpec(ctx context.Context, id int64) error UpdateSpecDefinition(ctx context.Context, id int64, spec string) error - - Unsafe_SetConnectionsManager(ConnectionsManager) } type service struct { @@ -105,7 +101,7 @@ type service struct { orm ORM jobORM job.ORM - q pg.Q + ds sqlutil.DataSource csaKeyStore keystore.CSA p2pKeyStore keystore.P2P ocr1KeyStore keystore.OCR @@ -127,7 +123,7 @@ type service struct { func NewService( orm ORM, jobORM job.ORM, - db *sqlx.DB, + ds sqlutil.DataSource, jobSpawner job.Spawner, keyStore keystore.Master, gCfg GeneralConfig, @@ -135,7 +131,6 @@ func NewService( jobCfg JobConfig, ocrCfg OCRConfig, ocr2Cfg OCR2Config, - dbCfg pg.QConfig, legacyChains legacyevm.LegacyChainContainer, lggr logger.Logger, version string, @@ -145,7 +140,7 @@ func NewService( svc := &service{ orm: orm, jobORM: jobORM, - q: pg.NewQ(db, lggr, dbCfg), + ds: ds, jobSpawner: jobSpawner, p2pKeyStore: keyStore.P2P(), csaKeyStore: keyStore.CSA(), @@ -178,7 +173,7 @@ type RegisterManagerParams struct { // // Only a single feeds manager is currently supported. func (s *service) RegisterManager(ctx context.Context, params RegisterManagerParams) (int64, error) { - count, err := s.CountManagers() + count, err := s.CountManagers(ctx) if err != nil { return 0, err } @@ -193,16 +188,16 @@ func (s *service) RegisterManager(ctx context.Context, params RegisterManagerPar } var id int64 - q := s.q.WithOpts(pg.WithParentCtx(ctx)) - err = q.Transaction(func(tx pg.Queryer) error { + + err = s.orm.Transact(ctx, func(tx ORM) error { var txerr error - id, txerr = s.orm.CreateManager(&mgr, pg.WithQueryer(tx)) + id, txerr = tx.CreateManager(ctx, &mgr) if err != nil { return txerr } - if _, txerr = s.orm.CreateBatchChainConfig(params.ChainConfigs, pg.WithQueryer(tx)); txerr != nil { + if _, txerr = tx.CreateBatchChainConfig(ctx, params.ChainConfigs); txerr != nil { return txerr } @@ -229,7 +224,7 @@ func (s *service) SyncNodeInfo(ctx context.Context, id int64) error { return errors.Wrap(err, "could not fetch client") } - cfgs, err := s.orm.ListChainConfigsByManagerIDs([]int64{id}) + cfgs, err := s.orm.ListChainConfigsByManagerIDs(ctx, []int64{id}) if err != nil { return errors.Wrap(err, "could not fetch chain configs") } @@ -259,17 +254,9 @@ func (s *service) SyncNodeInfo(ctx context.Context, id int64) error { // UpdateManager updates the feed manager details, takes down the // connection and reestablishes a new connection with the updated public key. func (s *service) UpdateManager(ctx context.Context, mgr FeedsManager) error { - q := s.q.WithOpts(pg.WithParentCtx(ctx)) - err := q.Transaction(func(tx pg.Queryer) error { - txerr := s.orm.UpdateManager(mgr, pg.WithQueryer(tx)) - if txerr != nil { - return errors.Wrap(txerr, "could not update manager") - } - - return nil - }) + err := s.orm.UpdateManager(ctx, mgr) if err != nil { - return err + return errors.Wrap(err, "could not update manager") } if err := s.restartConnection(ctx, mgr); err != nil { @@ -280,8 +267,8 @@ func (s *service) UpdateManager(ctx context.Context, mgr FeedsManager) error { } // ListManagerServices lists all the manager services. -func (s *service) ListManagers() ([]FeedsManager, error) { - managers, err := s.orm.ListManagers() +func (s *service) ListManagers(ctx context.Context) ([]FeedsManager, error) { + managers, err := s.orm.ListManagers(ctx) if err != nil { return nil, errors.Wrap(err, "failed to get a list of managers") } @@ -294,8 +281,8 @@ func (s *service) ListManagers() ([]FeedsManager, error) { } // GetManager gets a manager service by id. -func (s *service) GetManager(id int64) (*FeedsManager, error) { - manager, err := s.orm.GetManager(id) +func (s *service) GetManager(ctx context.Context, id int64) (*FeedsManager, error) { + manager, err := s.orm.GetManager(ctx, id) if err != nil { return nil, errors.Wrap(err, "failed to get manager by ID") } @@ -305,8 +292,8 @@ func (s *service) GetManager(id int64) (*FeedsManager, error) { } // ListManagersByIDs get managers services by ids. -func (s *service) ListManagersByIDs(ids []int64) ([]FeedsManager, error) { - managers, err := s.orm.ListManagersByIDs(ids) +func (s *service) ListManagersByIDs(ctx context.Context, ids []int64) ([]FeedsManager, error) { + managers, err := s.orm.ListManagersByIDs(ctx, ids) if err != nil { return nil, errors.Wrap(err, "failed to list managers by IDs") } @@ -319,8 +306,8 @@ func (s *service) ListManagersByIDs(ids []int64) ([]FeedsManager, error) { } // CountManagers gets the total number of manager services -func (s *service) CountManagers() (int64, error) { - return s.orm.CountManagers() +func (s *service) CountManagers(ctx context.Context) (int64, error) { + return s.orm.CountManagers(ctx) } // CreateChainConfig creates a chain config. @@ -333,12 +320,12 @@ func (s *service) CreateChainConfig(ctx context.Context, cfg ChainConfig) (int64 } } - id, err := s.orm.CreateChainConfig(cfg) + id, err := s.orm.CreateChainConfig(ctx, cfg) if err != nil { return 0, errors.Wrap(err, "CreateChainConfig failed") } - mgr, err := s.orm.GetManager(cfg.FeedsManagerID) + mgr, err := s.orm.GetManager(ctx, cfg.FeedsManagerID) if err != nil { return 0, errors.Wrap(err, "CreateChainConfig: failed to fetch manager") } @@ -352,17 +339,17 @@ func (s *service) CreateChainConfig(ctx context.Context, cfg ChainConfig) (int64 // DeleteChainConfig deletes the chain config by id. func (s *service) DeleteChainConfig(ctx context.Context, id int64) (int64, error) { - cfg, err := s.orm.GetChainConfig(id) + cfg, err := s.orm.GetChainConfig(ctx, id) if err != nil { return 0, errors.Wrap(err, "DeleteChainConfig failed: could not get chain config") } - _, err = s.orm.DeleteChainConfig(id) + _, err = s.orm.DeleteChainConfig(ctx, id) if err != nil { return 0, errors.Wrap(err, "DeleteChainConfig failed") } - mgr, err := s.orm.GetManager(cfg.FeedsManagerID) + mgr, err := s.orm.GetManager(ctx, cfg.FeedsManagerID) if err != nil { return 0, errors.Wrap(err, "DeleteChainConfig: failed to fetch manager") } @@ -374,8 +361,8 @@ func (s *service) DeleteChainConfig(ctx context.Context, id int64) (int64, error return id, nil } -func (s *service) GetChainConfig(id int64) (*ChainConfig, error) { - cfg, err := s.orm.GetChainConfig(id) +func (s *service) GetChainConfig(ctx context.Context, id int64) (*ChainConfig, error) { + cfg, err := s.orm.GetChainConfig(ctx, id) if err != nil { return nil, errors.Wrap(err, "GetChainConfig failed") } @@ -383,8 +370,8 @@ func (s *service) GetChainConfig(id int64) (*ChainConfig, error) { return cfg, nil } -func (s *service) ListChainConfigsByManagerIDs(mgrIDs []int64) ([]ChainConfig, error) { - cfgs, err := s.orm.ListChainConfigsByManagerIDs(mgrIDs) +func (s *service) ListChainConfigsByManagerIDs(ctx context.Context, mgrIDs []int64) ([]ChainConfig, error) { + cfgs, err := s.orm.ListChainConfigsByManagerIDs(ctx, mgrIDs) return cfgs, errors.Wrap(err, "ListChainConfigsByManagerIDs failed") } @@ -398,12 +385,12 @@ func (s *service) UpdateChainConfig(ctx context.Context, cfg ChainConfig) (int64 } } - id, err := s.orm.UpdateChainConfig(cfg) + id, err := s.orm.UpdateChainConfig(ctx, cfg) if err != nil { return 0, errors.Wrap(err, "UpdateChainConfig failed") } - ccfg, err := s.orm.GetChainConfig(cfg.ID) + ccfg, err := s.orm.GetChainConfig(ctx, cfg.ID) if err != nil { return 0, errors.Wrap(err, "UpdateChainConfig failed: could not get chain config") } @@ -419,13 +406,13 @@ func (s *service) UpdateChainConfig(ctx context.Context, cfg ChainConfig) (int64 // // When we support multiple feed managers, we will need to change this to filter // by feeds manager -func (s *service) ListJobProposals() ([]JobProposal, error) { - return s.orm.ListJobProposals() +func (s *service) ListJobProposals(ctx context.Context) ([]JobProposal, error) { + return s.orm.ListJobProposals(ctx) } // ListJobProposalsByManagersIDs gets job proposals by feeds managers IDs -func (s *service) ListJobProposalsByManagersIDs(ids []int64) ([]JobProposal, error) { - return s.orm.ListJobProposalsByManagersIDs(ids) +func (s *service) ListJobProposalsByManagersIDs(ctx context.Context, ids []int64) ([]JobProposal, error) { + return s.orm.ListJobProposalsByManagersIDs(ctx, ids) } // DeleteJobArgs are the arguments to provide to the DeleteJob method. @@ -437,7 +424,7 @@ type DeleteJobArgs struct { // DeleteJob deletes a job proposal if it exist. The feeds manager id check // ensures that only the intended feed manager can make this request. func (s *service) DeleteJob(ctx context.Context, args *DeleteJobArgs) (int64, error) { - proposal, err := s.orm.GetJobProposalByRemoteUUID(args.RemoteUUID) + proposal, err := s.orm.GetJobProposalByRemoteUUID(ctx, args.RemoteUUID) if err != nil { if !errors.Is(err, sql.ErrNoRows) { return 0, errors.Wrap(err, "GetJobProposalByRemoteUUID failed to check existence of job proposal") @@ -456,14 +443,13 @@ func (s *service) DeleteJob(ctx context.Context, args *DeleteJobArgs) (int64, er return 0, errors.New("cannot delete a job proposal belonging to another feeds manager") } - pctx := pg.WithParentCtx(ctx) - if err = s.orm.DeleteProposal(proposal.ID, pctx); err != nil { + if err = s.orm.DeleteProposal(ctx, proposal.ID); err != nil { s.lggr.Errorw("Failed to delete the proposal", "err", err) return 0, errors.Wrap(err, "DeleteProposal failed") } - if err = s.observeJobProposalCounts(); err != nil { + if err = s.observeJobProposalCounts(ctx); err != nil { logger.Errorw("Failed to push metrics for job proposal deletion", err) } @@ -479,7 +465,7 @@ type RevokeJobArgs struct { // RevokeJob revokes a pending job proposal if it exist. The feeds manager // id check ensures that only the intended feed manager can make this request. func (s *service) RevokeJob(ctx context.Context, args *RevokeJobArgs) (int64, error) { - proposal, err := s.orm.GetJobProposalByRemoteUUID(args.RemoteUUID) + proposal, err := s.orm.GetJobProposalByRemoteUUID(ctx, args.RemoteUUID) if err != nil { if !errors.Is(err, sql.ErrNoRows) { return 0, errors.Wrap(err, "GetJobProposalByRemoteUUID failed to check existence of job proposal") @@ -495,7 +481,7 @@ func (s *service) RevokeJob(ctx context.Context, args *RevokeJobArgs) (int64, er } // get the latest spec for the proposal - latest, err := s.orm.GetLatestSpec(proposal.ID) + latest, err := s.orm.GetLatestSpec(ctx, proposal.ID) if err != nil { return 0, errors.Wrap(err, "GetLatestSpec failed to get latest spec") } @@ -504,8 +490,7 @@ func (s *service) RevokeJob(ctx context.Context, args *RevokeJobArgs) (int64, er return 0, errors.New("only pending job specs can be revoked") } - pctx := pg.WithParentCtx(ctx) - if err = s.orm.RevokeSpec(latest.ID, pctx); err != nil { + if err = s.orm.RevokeSpec(ctx, latest.ID); err != nil { s.lggr.Errorw("Failed to revoke the proposal", "err", err) return 0, errors.Wrap(err, "RevokeSpec failed") @@ -516,7 +501,7 @@ func (s *service) RevokeJob(ctx context.Context, args *RevokeJobArgs) (int64, er "job_proposal_spec_id", latest.ID, ) - if err = s.observeJobProposalCounts(); err != nil { + if err = s.observeJobProposalCounts(ctx); err != nil { logger.Errorw("Failed to push metrics for revoke job", err) } @@ -545,7 +530,7 @@ func (s *service) ProposeJob(ctx context.Context, args *ProposeJobArgs) (int64, return 0, err } - existing, err := s.orm.GetJobProposalByRemoteUUID(args.RemoteUUID) + existing, err := s.orm.GetJobProposalByRemoteUUID(ctx, args.RemoteUUID) if err != nil { if !errors.Is(err, sql.ErrNoRows) { return 0, errors.Wrap(err, "failed to check existence of job proposal") @@ -562,7 +547,7 @@ func (s *service) ProposeJob(ctx context.Context, args *ProposeJobArgs) (int64, // Check the version being proposed has not been previously proposed. var exists bool - exists, err = s.orm.ExistsSpecByJobProposalIDAndVersion(existing.ID, args.Version) + exists, err = s.orm.ExistsSpecByJobProposalIDAndVersion(ctx, existing.ID, args.Version) if err != nil { return 0, errors.Wrap(err, "failed to check existence of spec") } @@ -577,32 +562,31 @@ func (s *service) ProposeJob(ctx context.Context, args *ProposeJobArgs) (int64, ) var id int64 - q := s.q.WithOpts(pg.WithParentCtx(ctx)) - err = q.Transaction(func(tx pg.Queryer) error { + err = s.orm.Transact(ctx, func(tx ORM) error { var txerr error // Parse the Job Spec TOML to extract the name name := extractName(args.Spec) // Upsert job proposal - id, txerr = s.orm.UpsertJobProposal(&JobProposal{ + id, txerr = tx.UpsertJobProposal(ctx, &JobProposal{ Name: name, RemoteUUID: args.RemoteUUID, Status: JobProposalStatusPending, FeedsManagerID: args.FeedsManagerID, Multiaddrs: args.Multiaddrs, - }, pg.WithQueryer(tx)) + }) if txerr != nil { return errors.Wrap(txerr, "failed to upsert job proposal") } // Create the spec version - _, txerr = s.orm.CreateSpec(JobProposalSpec{ + _, txerr = tx.CreateSpec(ctx, JobProposalSpec{ Definition: args.Spec, Status: SpecStatusPending, Version: args.Version, JobProposalID: id, - }, pg.WithQueryer(tx)) + }) if txerr != nil { return errors.Wrap(txerr, "failed to create spec") } @@ -616,7 +600,7 @@ func (s *service) ProposeJob(ctx context.Context, args *ProposeJobArgs) (int64, // Track the given job proposal request promJobProposalRequest.Inc() - if err = s.observeJobProposalCounts(); err != nil { + if err = s.observeJobProposalCounts(ctx); err != nil { logger.Errorw("Failed to push metrics for propose job", err) } @@ -624,20 +608,18 @@ func (s *service) ProposeJob(ctx context.Context, args *ProposeJobArgs) (int64, } // GetJobProposal gets a job proposal by id. -func (s *service) GetJobProposal(id int64) (*JobProposal, error) { - return s.orm.GetJobProposal(id) +func (s *service) GetJobProposal(ctx context.Context, id int64) (*JobProposal, error) { + return s.orm.GetJobProposal(ctx, id) } // CountJobProposalsByStatus returns the count of job proposals with a given status. -func (s *service) CountJobProposalsByStatus() (*JobProposalCounts, error) { - return s.orm.CountJobProposalsByStatus() +func (s *service) CountJobProposalsByStatus(ctx context.Context) (*JobProposalCounts, error) { + return s.orm.CountJobProposalsByStatus(ctx) } // RejectSpec rejects a spec. func (s *service) RejectSpec(ctx context.Context, id int64) error { - pctx := pg.WithParentCtx(ctx) - - spec, err := s.orm.GetSpec(id, pctx) + spec, err := s.orm.GetSpec(ctx, id) if err != nil { return errors.Wrap(err, "orm: job proposal spec") } @@ -647,7 +629,7 @@ func (s *service) RejectSpec(ctx context.Context, id int64) error { return errors.New("must be a pending job proposal spec") } - proposal, err := s.orm.GetJobProposal(spec.JobProposalID, pctx) + proposal, err := s.orm.GetJobProposal(ctx, spec.JobProposalID) if err != nil { return errors.Wrap(err, "orm: job proposal") } @@ -662,9 +644,8 @@ func (s *service) RejectSpec(ctx context.Context, id int64) error { "job_proposal_spec_id", id, ) - q := s.q.WithOpts(pctx) - err = q.Transaction(func(tx pg.Queryer) error { - if err = s.orm.RejectSpec(id, pg.WithQueryer(tx)); err != nil { + err = s.orm.Transact(ctx, func(tx ORM) error { + if err = tx.RejectSpec(ctx, id); err != nil { return err } @@ -681,7 +662,7 @@ func (s *service) RejectSpec(ctx context.Context, id int64) error { return errors.Wrap(err, "could not reject job proposal") } - if err = s.observeJobProposalCounts(); err != nil { + if err = s.observeJobProposalCounts(ctx); err != nil { logger.Errorw("Failed to push metrics for job rejection", err) } @@ -690,25 +671,23 @@ func (s *service) RejectSpec(ctx context.Context, id int64) error { // IsJobManaged determines is a job is managed by the Feeds Manager. func (s *service) IsJobManaged(ctx context.Context, jobID int64) (bool, error) { - return s.orm.IsJobManaged(jobID, pg.WithParentCtx(ctx)) + return s.orm.IsJobManaged(ctx, jobID) } // ApproveSpec approves a spec for a job proposal and creates a job with the // spec. func (s *service) ApproveSpec(ctx context.Context, id int64, force bool) error { - pctx := pg.WithParentCtx(ctx) - - spec, err := s.orm.GetSpec(id, pctx) + spec, err := s.orm.GetSpec(ctx, id) if err != nil { return errors.Wrap(err, "orm: job proposal spec") } - proposal, err := s.orm.GetJobProposal(spec.JobProposalID, pctx) + proposal, err := s.orm.GetJobProposal(ctx, spec.JobProposalID) if err != nil { return errors.Wrap(err, "orm: job proposal") } - if err = s.isApprovable(proposal.Status, proposal.ID, spec.Status, spec.ID); err != nil { + if err = s.isApprovable(ctx, proposal.Status, proposal.ID, spec.Status, spec.ID); err != nil { return err } @@ -741,17 +720,14 @@ func (s *service) ApproveSpec(ctx context.Context, id int64, force bool) error { return errors.Wrap(err, "failed to approve job spec due to bridge check") } - q := s.q.WithOpts(pctx) - err = q.Transaction(func(tx pg.Queryer) error { + err = s.transact(ctx, func(tx datasources) error { var ( txerr error existingJobID int32 - - pgOpts = pg.WithQueryer(tx) ) // Use the external job id to check if a job already exists - foundJob, txerr := s.jobORM.FindJobByExternalJobID(j.ExternalJobID, pgOpts) + foundJob, txerr := tx.jobORM.FindJobByExternalJobID(ctx, j.ExternalJobID) if txerr != nil { // Return an error if the repository errors. If there is a not found // error we want to continue with approving the job. @@ -768,7 +744,7 @@ func (s *service) ApproveSpec(ctx context.Context, id int64, force bool) error { if existingJobID == 0 { switch j.Type { case job.OffchainReporting, job.FluxMonitor: - existingJobID, txerr = s.findExistingJobForOCRFlux(j, pgOpts) + existingJobID, txerr = findExistingJobForOCRFlux(ctx, j, tx.jobORM) if txerr != nil { // Return an error if the repository errors. If there is a not found // error we want to continue with approving the job. @@ -777,7 +753,7 @@ func (s *service) ApproveSpec(ctx context.Context, id int64, force bool) error { } } case job.OffchainReporting2, job.Bootstrap: - existingJobID, txerr = s.findExistingJobForOCR2(j, pgOpts) + existingJobID, txerr = findExistingJobForOCR2(ctx, j, tx.jobORM) if txerr != nil { // Return an error if the repository errors. If there is a not found // error we want to continue with approving the job. @@ -798,7 +774,7 @@ func (s *service) ApproveSpec(ctx context.Context, id int64, force bool) error { } // Check if the job is managed by FMS - approvedSpec, serr := s.orm.GetApprovedSpec(proposal.ID, pgOpts) + approvedSpec, serr := tx.orm.GetApprovedSpec(ctx, proposal.ID) if serr != nil { if !errors.Is(serr, sql.ErrNoRows) { logger.Errorw("Failed to get approved spec", "err", serr) @@ -811,7 +787,7 @@ func (s *service) ApproveSpec(ctx context.Context, id int64, force bool) error { // If a spec is found, cancel the existing job spec if serr == nil { - if cerr := s.orm.CancelSpec(approvedSpec.ID, pgOpts); cerr != nil { + if cerr := tx.orm.CancelSpec(ctx, approvedSpec.ID); cerr != nil { logger.Errorw("Failed to delete the cancel the spec", "err", cerr) return cerr @@ -819,7 +795,7 @@ func (s *service) ApproveSpec(ctx context.Context, id int64, force bool) error { } // Delete the job - if serr = s.jobSpawner.DeleteJob(existingJobID, pgOpts); serr != nil { + if serr = s.jobSpawner.DeleteJob(ctx, tx.ds, existingJobID); serr != nil { logger.Errorw("Failed to delete the job", "err", serr) return errors.Wrap(serr, "DeleteJob failed") @@ -827,14 +803,14 @@ func (s *service) ApproveSpec(ctx context.Context, id int64, force bool) error { } // Create the job - if txerr = s.jobSpawner.CreateJob(j, pgOpts); txerr != nil { + if txerr = s.jobSpawner.CreateJob(ctx, tx.ds, j); txerr != nil { logger.Errorw("Failed to create job", "err", txerr) return txerr } // Approve the job proposal spec - if txerr = s.orm.ApproveSpec(id, j.ExternalJobID, pgOpts); txerr != nil { + if txerr = tx.orm.ApproveSpec(ctx, id, j.ExternalJobID); txerr != nil { logger.Errorw("Failed to approve spec", "err", txerr) return txerr @@ -856,18 +832,32 @@ func (s *service) ApproveSpec(ctx context.Context, id int64, force bool) error { return errors.Wrap(err, "could not approve job proposal") } - if err = s.observeJobProposalCounts(); err != nil { + if err = s.observeJobProposalCounts(ctx); err != nil { logger.Errorw("Failed to push metrics for job approval", err) } return nil } +type datasources struct { + ds sqlutil.DataSource + orm ORM + jobORM job.ORM +} + +func (s *service) transact(ctx context.Context, fn func(datasources) error) error { + return sqlutil.Transact(ctx, func(tx sqlutil.DataSource) datasources { + return datasources{ + ds: tx, + orm: s.orm.WithDataSource(tx), + jobORM: s.jobORM.WithDataSource(tx), + } + }, s.ds, nil, fn) +} + // CancelSpec cancels a spec for a job proposal. func (s *service) CancelSpec(ctx context.Context, id int64) error { - pctx := pg.WithParentCtx(ctx) - - spec, err := s.orm.GetSpec(id, pctx) + spec, err := s.orm.GetSpec(ctx, id) if err != nil { return errors.Wrap(err, "orm: job proposal spec") } @@ -876,7 +866,7 @@ func (s *service) CancelSpec(ctx context.Context, id int64) error { return errors.New("must be an approved job proposal spec") } - jp, err := s.orm.GetJobProposal(spec.JobProposalID, pg.WithParentCtx(ctx)) + jp, err := s.orm.GetJobProposal(ctx, spec.JobProposalID) if err != nil { return errors.Wrap(err, "orm: job proposal") } @@ -891,20 +881,18 @@ func (s *service) CancelSpec(ctx context.Context, id int64) error { "job_proposal_spec_id", id, ) - q := s.q.WithOpts(pctx) - err = q.Transaction(func(tx pg.Queryer) error { + err = s.transact(ctx, func(tx datasources) error { var ( - txerr error - pgOpts = pg.WithQueryer(tx) + txerr error ) - if txerr = s.orm.CancelSpec(id, pgOpts); txerr != nil { + if txerr = tx.orm.CancelSpec(ctx, id); txerr != nil { return txerr } // Delete the job if jp.ExternalJobID.Valid { - j, txerr := s.jobORM.FindJobByExternalJobID(jp.ExternalJobID.UUID, pgOpts) + j, txerr := tx.jobORM.FindJobByExternalJobID(ctx, jp.ExternalJobID.UUID) if txerr != nil { // Return an error if the repository errors. If there is a not found error we want // to continue with cancelling the spec but we won't have to cancel any jobs. @@ -914,7 +902,7 @@ func (s *service) CancelSpec(ctx context.Context, id int64) error { } if txerr == nil { - if serr := s.jobSpawner.DeleteJob(j.ID, pgOpts); serr != nil { + if serr := s.jobSpawner.DeleteJob(ctx, tx.ds, j.ID); serr != nil { return errors.Wrap(serr, "DeleteJob failed") } } @@ -934,7 +922,7 @@ func (s *service) CancelSpec(ctx context.Context, id int64) error { return err } - if err = s.observeJobProposalCounts(); err != nil { + if err = s.observeJobProposalCounts(ctx); err != nil { logger.Errorw("Failed to push metrics for job cancellation", err) } @@ -942,20 +930,18 @@ func (s *service) CancelSpec(ctx context.Context, id int64) error { } // ListSpecsByJobProposalIDs gets the specs which belong to the job proposal ids. -func (s *service) ListSpecsByJobProposalIDs(ids []int64) ([]JobProposalSpec, error) { - return s.orm.ListSpecsByJobProposalIDs(ids) +func (s *service) ListSpecsByJobProposalIDs(ctx context.Context, ids []int64) ([]JobProposalSpec, error) { + return s.orm.ListSpecsByJobProposalIDs(ctx, ids) } // GetSpec gets the spec details by id. -func (s *service) GetSpec(id int64) (*JobProposalSpec, error) { - return s.orm.GetSpec(id) +func (s *service) GetSpec(ctx context.Context, id int64) (*JobProposalSpec, error) { + return s.orm.GetSpec(ctx, id) } // UpdateSpecDefinition updates the spec's TOML definition. func (s *service) UpdateSpecDefinition(ctx context.Context, id int64, defn string) error { - pctx := pg.WithParentCtx(ctx) - - spec, err := s.orm.GetSpec(id, pctx) + spec, err := s.orm.GetSpec(ctx, id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return errors.Wrap(err, "job proposal spec does not exist") @@ -969,7 +955,7 @@ func (s *service) UpdateSpecDefinition(ctx context.Context, id int64, defn strin } // Update the spec definition - if err = s.orm.UpdateSpecDefinition(id, defn, pctx); err != nil { + if err = s.orm.UpdateSpecDefinition(ctx, id, defn); err != nil { return errors.Wrap(err, "could not update job proposal") } @@ -985,7 +971,7 @@ func (s *service) Start(ctx context.Context) error { } // We only support a single feeds manager right now - mgrs, err := s.ListManagers() + mgrs, err := s.ListManagers(ctx) if err != nil { return err } @@ -998,7 +984,7 @@ func (s *service) Start(ctx context.Context) error { mgr := mgrs[0] s.connectFeedManager(ctx, mgr, privkey) - if err = s.observeJobProposalCounts(); err != nil { + if err = s.observeJobProposalCounts(ctx); err != nil { s.lggr.Error("failed to observe job proposal count when starting service", err) } @@ -1052,8 +1038,8 @@ func (s *service) getCSAPrivateKey() (privkey []byte, err error) { // observeJobProposalCounts is a helper method that queries the repository for the count of // job proposals by status and then updates prometheus gauges. -func (s *service) observeJobProposalCounts() error { - counts, err := s.CountJobProposalsByStatus() +func (s *service) observeJobProposalCounts(ctx context.Context) error { + counts, err := s.CountJobProposalsByStatus(ctx) if err != nil { return errors.Wrap(err, "failed to fetch counts of job proposals") } @@ -1073,18 +1059,8 @@ func (s *service) observeJobProposalCounts() error { return nil } -// Unsafe_SetConnectionsManager sets the ConnectionsManager on the service. -// -// We need to be able to inject a mock for the client to facilitate integration -// tests. -// -// ONLY TO BE USED FOR TESTING. -func (s *service) Unsafe_SetConnectionsManager(connMgr ConnectionsManager) { - s.connMgr = connMgr -} - // findExistingJobForOCR2 looks for existing job for OCR2 -func (s *service) findExistingJobForOCR2(j *job.Job, qopts pg.QOpt) (int32, error) { +func findExistingJobForOCR2(ctx context.Context, j *job.Job, tx job.ORM) (int32, error) { var contractID string var feedID *common.Hash @@ -1103,11 +1079,11 @@ func (s *service) findExistingJobForOCR2(j *job.Job, qopts pg.QOpt) (int32, erro return 0, errors.Errorf("unsupported job type: %s", j.Type) } - return s.jobORM.FindOCR2JobIDByAddress(contractID, feedID, qopts) + return tx.FindOCR2JobIDByAddress(ctx, contractID, feedID) } // findExistingJobForOCRFlux looks for existing job for OCR or flux -func (s *service) findExistingJobForOCRFlux(j *job.Job, qopts pg.QOpt) (int32, error) { +func findExistingJobForOCRFlux(ctx context.Context, j *job.Job, tx job.ORM) (int32, error) { var address types.EIP55Address var evmChainID *big.Big @@ -1124,7 +1100,7 @@ func (s *service) findExistingJobForOCRFlux(j *job.Job, qopts pg.QOpt) (int32, e return 0, errors.Errorf("unsupported job type: %s", j.Type) } - return s.jobORM.FindJobIDByAddress(address, evmChainID, qopts) + return tx.FindJobIDByAddress(ctx, address, evmChainID) } // generateJob validates and generates a job from a spec. @@ -1354,7 +1330,7 @@ func extractName(defn string) null.String { // isApprovable returns nil if a spec can be approved based on the current // proposal and spec status, and if it can't be approved, the reason as an // error. -func (s *service) isApprovable(propStatus JobProposalStatus, proposalID int64, specStatus SpecStatus, specID int64) error { +func (s *service) isApprovable(ctx context.Context, propStatus JobProposalStatus, proposalID int64, specStatus SpecStatus, specID int64) error { if propStatus == JobProposalStatusDeleted { return errors.New("cannot approve spec for a deleted job proposal") } @@ -1372,7 +1348,7 @@ func (s *service) isApprovable(propStatus JobProposalStatus, proposalID int64, s return errors.New("cannot approve a revoked spec") case SpecStatusCancelled: // Allowed to approve a cancelled job if it is the latest job - latest, serr := s.orm.GetLatestSpec(proposalID) + latest, serr := s.orm.GetLatestSpec(ctx, proposalID) if serr != nil { return errors.Wrap(serr, "failed to get latest spec") } @@ -1405,46 +1381,46 @@ func (ns NullService) Close() error { return nil } func (ns NullService) ApproveSpec(ctx context.Context, id int64, force bool) error { return ErrFeedsManagerDisabled } -func (ns NullService) CountManagers() (int64, error) { return 0, nil } -func (ns NullService) CountJobProposalsByStatus() (*JobProposalCounts, error) { +func (ns NullService) CountManagers(ctx context.Context) (int64, error) { return 0, nil } +func (ns NullService) CountJobProposalsByStatus(ctx context.Context) (*JobProposalCounts, error) { return nil, ErrFeedsManagerDisabled } func (ns NullService) CancelSpec(ctx context.Context, id int64) error { return ErrFeedsManagerDisabled } -func (ns NullService) GetJobProposal(id int64) (*JobProposal, error) { +func (ns NullService) GetJobProposal(ctx context.Context, id int64) (*JobProposal, error) { return nil, ErrFeedsManagerDisabled } -func (ns NullService) ListSpecsByJobProposalIDs(ids []int64) ([]JobProposalSpec, error) { +func (ns NullService) ListSpecsByJobProposalIDs(ctx context.Context, ids []int64) ([]JobProposalSpec, error) { return nil, ErrFeedsManagerDisabled } -func (ns NullService) GetManager(id int64) (*FeedsManager, error) { +func (ns NullService) GetManager(ctx context.Context, id int64) (*FeedsManager, error) { return nil, ErrFeedsManagerDisabled } -func (ns NullService) ListManagersByIDs(ids []int64) ([]FeedsManager, error) { +func (ns NullService) ListManagersByIDs(ctx context.Context, ids []int64) ([]FeedsManager, error) { return nil, ErrFeedsManagerDisabled } -func (ns NullService) GetSpec(id int64) (*JobProposalSpec, error) { +func (ns NullService) GetSpec(ctx context.Context, id int64) (*JobProposalSpec, error) { return nil, ErrFeedsManagerDisabled } -func (ns NullService) ListManagers() ([]FeedsManager, error) { return nil, nil } +func (ns NullService) ListManagers(ctx context.Context) ([]FeedsManager, error) { return nil, nil } func (ns NullService) CreateChainConfig(ctx context.Context, cfg ChainConfig) (int64, error) { return 0, ErrFeedsManagerDisabled } -func (ns NullService) GetChainConfig(id int64) (*ChainConfig, error) { +func (ns NullService) GetChainConfig(ctx context.Context, id int64) (*ChainConfig, error) { return nil, ErrFeedsManagerDisabled } func (ns NullService) DeleteChainConfig(ctx context.Context, id int64) (int64, error) { return 0, ErrFeedsManagerDisabled } -func (ns NullService) ListChainConfigsByManagerIDs(mgrIDs []int64) ([]ChainConfig, error) { +func (ns NullService) ListChainConfigsByManagerIDs(ctx context.Context, mgrIDs []int64) ([]ChainConfig, error) { return nil, ErrFeedsManagerDisabled } func (ns NullService) UpdateChainConfig(ctx context.Context, cfg ChainConfig) (int64, error) { return 0, ErrFeedsManagerDisabled } -func (ns NullService) ListJobProposals() ([]JobProposal, error) { return nil, nil } -func (ns NullService) ListJobProposalsByManagersIDs(ids []int64) ([]JobProposal, error) { +func (ns NullService) ListJobProposals(ctx context.Context) ([]JobProposal, error) { return nil, nil } +func (ns NullService) ListJobProposalsByManagersIDs(ctx context.Context, ids []int64) ([]JobProposal, error) { return nil, ErrFeedsManagerDisabled } func (ns NullService) ProposeJob(ctx context.Context, args *ProposeJobArgs) (int64, error) { @@ -1472,6 +1448,5 @@ func (ns NullService) IsJobManaged(ctx context.Context, jobID int64) (bool, erro func (ns NullService) UpdateSpecDefinition(ctx context.Context, id int64, spec string) error { return ErrFeedsManagerDisabled } -func (ns NullService) Unsafe_SetConnectionsManager(_ ConnectionsManager) {} //revive:enable diff --git a/core/services/feeds/service_test.go b/core/services/feeds/service_test.go index f83a98986e2..af656618f78 100644 --- a/core/services/feeds/service_test.go +++ b/core/services/feeds/service_test.go @@ -186,7 +186,7 @@ func setupTestServiceCfg(t *testing.T, overrideCfg func(c *chainlink.Config, s * keyStore.On("P2P").Return(p2pKeystore) keyStore.On("OCR").Return(ocr1Keystore) keyStore.On("OCR2").Return(ocr2Keystore) - svc := feeds.NewService(orm, jobORM, db, spawner, keyStore, gcfg, gcfg.Insecure(), gcfg.JobPipeline(), gcfg.OCR(), gcfg.OCR2(), gcfg.Database(), legacyChains, lggr, "1.0.0", nil) + svc := feeds.NewService(orm, jobORM, db, spawner, keyStore, gcfg, gcfg.Insecure(), gcfg.JobPipeline(), gcfg.OCR(), gcfg.OCR2(), legacyChains, lggr, "1.0.0", nil) svc.SetConnectionsManager(connMgr) return &TestService{ @@ -233,14 +233,19 @@ func Test_Service_RegisterManager(t *testing.T) { svc := setupTestService(t) - svc.orm.On("CountManagers").Return(int64(0), nil) - svc.orm.On("CreateManager", &mgr, mock.Anything). + svc.orm.On("CountManagers", mock.Anything).Return(int64(0), nil) + svc.orm.On("CreateManager", mock.Anything, &mgr, mock.Anything). Return(id, nil) - svc.orm.On("CreateBatchChainConfig", params.ChainConfigs, mock.Anything). + svc.orm.On("CreateBatchChainConfig", mock.Anything, params.ChainConfigs, mock.Anything). Return([]int64{}, nil) svc.csaKeystore.On("GetAll").Return([]csakey.KeyV2{key}, nil) // ListManagers runs in a goroutine so it might be called. svc.orm.On("ListManagers", testutils.Context(t)).Return([]feeds.FeedsManager{mgr}, nil).Maybe() + transactCall := svc.orm.On("Transact", mock.Anything, mock.Anything) + transactCall.Run(func(args mock.Arguments) { + fn := args[1].(func(orm feeds.ORM) error) + transactCall.ReturnArguments = mock.Arguments{fn(svc.orm)} + }) svc.connMgr.On("Connect", mock.IsType(feeds.ConnectOpts{})) actual, err := svc.RegisterManager(testutils.Context(t), params) @@ -254,6 +259,7 @@ func Test_Service_RegisterManager(t *testing.T) { func Test_Service_ListManagers(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( mgr = feeds.FeedsManager{} @@ -261,10 +267,10 @@ func Test_Service_ListManagers(t *testing.T) { ) svc := setupTestService(t) - svc.orm.On("ListManagers").Return(mgrs, nil) + svc.orm.On("ListManagers", mock.Anything).Return(mgrs, nil) svc.connMgr.On("IsConnected", mgr.ID).Return(false) - actual, err := svc.ListManagers() + actual, err := svc.ListManagers(ctx) require.NoError(t, err) assert.Equal(t, mgrs, actual) @@ -272,6 +278,7 @@ func Test_Service_ListManagers(t *testing.T) { func Test_Service_GetManager(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( id = int64(1) @@ -279,11 +286,11 @@ func Test_Service_GetManager(t *testing.T) { ) svc := setupTestService(t) - svc.orm.On("GetManager", id). + svc.orm.On("GetManager", mock.Anything, id). Return(&mgr, nil) svc.connMgr.On("IsConnected", mgr.ID).Return(false) - actual, err := svc.GetManager(id) + actual, err := svc.GetManager(ctx, id) require.NoError(t, err) assert.Equal(t, actual, &mgr) @@ -298,7 +305,7 @@ func Test_Service_UpdateFeedsManager(t *testing.T) { svc := setupTestService(t) - svc.orm.On("UpdateManager", mgr, mock.Anything).Return(nil) + svc.orm.On("UpdateManager", mock.Anything, mgr, mock.Anything).Return(nil) svc.csaKeystore.On("GetAll").Return([]csakey.KeyV2{key}, nil) svc.connMgr.On("Disconnect", mgr.ID).Return(nil) svc.connMgr.On("Connect", mock.IsType(feeds.ConnectOpts{})).Return(nil) @@ -309,6 +316,7 @@ func Test_Service_UpdateFeedsManager(t *testing.T) { func Test_Service_ListManagersByIDs(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( mgr = feeds.FeedsManager{} @@ -316,11 +324,11 @@ func Test_Service_ListManagersByIDs(t *testing.T) { ) svc := setupTestService(t) - svc.orm.On("ListManagersByIDs", []int64{mgr.ID}). + svc.orm.On("ListManagersByIDs", mock.Anything, []int64{mgr.ID}). Return(mgrs, nil) svc.connMgr.On("IsConnected", mgr.ID).Return(false) - actual, err := svc.ListManagersByIDs([]int64{mgr.ID}) + actual, err := svc.ListManagersByIDs(ctx, []int64{mgr.ID}) require.NoError(t, err) assert.Equal(t, mgrs, actual) @@ -328,16 +336,17 @@ func Test_Service_ListManagersByIDs(t *testing.T) { func Test_Service_CountManagers(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( count = int64(1) ) svc := setupTestService(t) - svc.orm.On("CountManagers"). + svc.orm.On("CountManagers", mock.Anything). Return(count, nil) - actual, err := svc.CountManagers() + actual, err := svc.CountManagers(ctx) require.NoError(t, err) assert.Equal(t, count, actual) @@ -369,10 +378,10 @@ func Test_Service_CreateChainConfig(t *testing.T) { svc = setupTestService(t) ) - svc.orm.On("CreateChainConfig", cfg).Return(int64(1), nil) - svc.orm.On("GetManager", mgr.ID).Return(&mgr, nil) + svc.orm.On("CreateChainConfig", mock.Anything, cfg).Return(int64(1), nil) + svc.orm.On("GetManager", mock.Anything, mgr.ID).Return(&mgr, nil) svc.connMgr.On("GetClient", mgr.ID).Return(svc.fmsClient, nil) - svc.orm.On("ListChainConfigsByManagerIDs", []int64{mgr.ID}).Return([]feeds.ChainConfig{cfg}, nil) + svc.orm.On("ListChainConfigsByManagerIDs", mock.Anything, []int64{mgr.ID}).Return([]feeds.ChainConfig{cfg}, nil) svc.fmsClient.On("UpdateNode", mock.Anything, &proto.UpdateNodeRequest{ Version: nodeVersion.Version, ChainConfigs: []*proto.ChainConfig{ @@ -430,11 +439,11 @@ func Test_Service_DeleteChainConfig(t *testing.T) { svc = setupTestService(t) ) - svc.orm.On("GetChainConfig", cfg.ID).Return(&cfg, nil) - svc.orm.On("DeleteChainConfig", cfg.ID).Return(cfg.ID, nil) - svc.orm.On("GetManager", mgr.ID).Return(&mgr, nil) + svc.orm.On("GetChainConfig", mock.Anything, cfg.ID).Return(&cfg, nil) + svc.orm.On("DeleteChainConfig", mock.Anything, cfg.ID).Return(cfg.ID, nil) + svc.orm.On("GetManager", mock.Anything, mgr.ID).Return(&mgr, nil) svc.connMgr.On("GetClient", mgr.ID).Return(svc.fmsClient, nil) - svc.orm.On("ListChainConfigsByManagerIDs", []int64{mgr.ID}).Return([]feeds.ChainConfig{}, nil) + svc.orm.On("ListChainConfigsByManagerIDs", mock.Anything, []int64{mgr.ID}).Return([]feeds.ChainConfig{}, nil) svc.fmsClient.On("UpdateNode", mock.Anything, &proto.UpdateNodeRequest{ Version: nodeVersion.Version, ChainConfigs: []*proto.ChainConfig{}, @@ -446,6 +455,7 @@ func Test_Service_DeleteChainConfig(t *testing.T) { } func Test_Service_ListChainConfigsByManagerIDs(t *testing.T) { + ctx := testutils.Context(t) var ( mgr = feeds.FeedsManager{ID: 1} cfg = feeds.ChainConfig{ @@ -457,9 +467,9 @@ func Test_Service_ListChainConfigsByManagerIDs(t *testing.T) { svc = setupTestService(t) ) - svc.orm.On("ListChainConfigsByManagerIDs", ids).Return([]feeds.ChainConfig{cfg}, nil) + svc.orm.On("ListChainConfigsByManagerIDs", mock.Anything, ids).Return([]feeds.ChainConfig{cfg}, nil) - actual, err := svc.ListChainConfigsByManagerIDs(ids) + actual, err := svc.ListChainConfigsByManagerIDs(ctx, ids) require.NoError(t, err) assert.Equal(t, []feeds.ChainConfig{cfg}, actual) } @@ -484,10 +494,10 @@ func Test_Service_UpdateChainConfig(t *testing.T) { svc = setupTestService(t) ) - svc.orm.On("UpdateChainConfig", cfg).Return(int64(1), nil) - svc.orm.On("GetChainConfig", cfg.ID).Return(&cfg, nil) + svc.orm.On("UpdateChainConfig", mock.Anything, cfg).Return(int64(1), nil) + svc.orm.On("GetChainConfig", mock.Anything, cfg.ID).Return(&cfg, nil) svc.connMgr.On("GetClient", mgr.ID).Return(svc.fmsClient, nil) - svc.orm.On("ListChainConfigsByManagerIDs", []int64{mgr.ID}).Return([]feeds.ChainConfig{cfg}, nil) + svc.orm.On("ListChainConfigsByManagerIDs", mock.Anything, []int64{mgr.ID}).Return([]feeds.ChainConfig{cfg}, nil) svc.fmsClient.On("UpdateNode", mock.Anything, &proto.UpdateNodeRequest{ Version: nodeVersion.Version, ChainConfigs: []*proto.ChainConfig{ @@ -640,10 +650,15 @@ func Test_Service_ProposeJob(t *testing.T) { { name: "Create success (Flux Monitor)", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", jpFluxMonitor.RemoteUUID).Return(new(feeds.JobProposal), sql.ErrNoRows) - svc.orm.On("UpsertJobProposal", &jpFluxMonitor, mock.Anything).Return(idFluxMonitor, nil) - svc.orm.On("CreateSpec", specFluxMonitor, mock.Anything).Return(int64(100), nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, jpFluxMonitor.RemoteUUID).Return(new(feeds.JobProposal), sql.ErrNoRows) + svc.orm.On("UpsertJobProposal", mock.Anything, &jpFluxMonitor).Return(idFluxMonitor, nil) + svc.orm.On("CreateSpec", mock.Anything, specFluxMonitor).Return(int64(100), nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + transactCall := svc.orm.On("Transact", mock.Anything, mock.Anything) + transactCall.Run(func(args mock.Arguments) { + fn := args[1].(func(orm feeds.ORM) error) + transactCall.ReturnArguments = mock.Arguments{fn(svc.orm)} + }) }, args: argsFluxMonitor, wantID: idFluxMonitor, @@ -651,10 +666,15 @@ func Test_Service_ProposeJob(t *testing.T) { { name: "Create success (OCR1)", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", jpOCR1.RemoteUUID).Return(new(feeds.JobProposal), sql.ErrNoRows) - svc.orm.On("UpsertJobProposal", &jpOCR1, mock.Anything).Return(idOCR1, nil) - svc.orm.On("CreateSpec", specOCR1, mock.Anything).Return(int64(100), nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, jpOCR1.RemoteUUID).Return(new(feeds.JobProposal), sql.ErrNoRows) + svc.orm.On("UpsertJobProposal", mock.Anything, &jpOCR1).Return(idOCR1, nil) + svc.orm.On("CreateSpec", mock.Anything, specOCR1).Return(int64(100), nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + transactCall := svc.orm.On("Transact", mock.Anything, mock.Anything) + transactCall.Run(func(args mock.Arguments) { + fn := args[1].(func(orm feeds.ORM) error) + transactCall.ReturnArguments = mock.Arguments{fn(svc.orm)} + }) }, args: argsOCR1, wantID: idOCR1, @@ -662,10 +682,15 @@ func Test_Service_ProposeJob(t *testing.T) { { name: "Create success (OCR2)", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", jpOCR2.RemoteUUID).Return(new(feeds.JobProposal), sql.ErrNoRows) - svc.orm.On("UpsertJobProposal", &jpOCR2, mock.Anything).Return(idOCR2, nil) - svc.orm.On("CreateSpec", specOCR2, mock.Anything).Return(int64(100), nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, jpOCR2.RemoteUUID).Return(new(feeds.JobProposal), sql.ErrNoRows) + svc.orm.On("UpsertJobProposal", mock.Anything, &jpOCR2).Return(idOCR2, nil) + svc.orm.On("CreateSpec", mock.Anything, specOCR2).Return(int64(100), nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + transactCall := svc.orm.On("Transact", mock.Anything, mock.Anything) + transactCall.Run(func(args mock.Arguments) { + fn := args[1].(func(orm feeds.ORM) error) + transactCall.ReturnArguments = mock.Arguments{fn(svc.orm)} + }) }, args: argsOCR2, wantID: idOCR2, @@ -673,10 +698,15 @@ func Test_Service_ProposeJob(t *testing.T) { { name: "Create success (Bootstrap)", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", jpBootstrap.RemoteUUID).Return(new(feeds.JobProposal), sql.ErrNoRows) - svc.orm.On("UpsertJobProposal", &jpBootstrap, mock.Anything).Return(idBootstrap, nil) - svc.orm.On("CreateSpec", specBootstrap, mock.Anything).Return(int64(102), nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, jpBootstrap.RemoteUUID).Return(new(feeds.JobProposal), sql.ErrNoRows) + svc.orm.On("UpsertJobProposal", mock.Anything, &jpBootstrap).Return(idBootstrap, nil) + svc.orm.On("CreateSpec", mock.Anything, specBootstrap).Return(int64(102), nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + transactCall := svc.orm.On("Transact", mock.Anything, mock.Anything) + transactCall.Run(func(args mock.Arguments) { + fn := args[1].(func(orm feeds.ORM) error) + transactCall.ReturnArguments = mock.Arguments{fn(svc.orm)} + }) }, args: argsBootstrap, wantID: idBootstrap, @@ -685,16 +715,21 @@ func Test_Service_ProposeJob(t *testing.T) { name: "Update success", before: func(svc *TestService) { svc.orm. - On("GetJobProposalByRemoteUUID", jpFluxMonitor.RemoteUUID). + On("GetJobProposalByRemoteUUID", mock.Anything, jpFluxMonitor.RemoteUUID). Return(&feeds.JobProposal{ FeedsManagerID: jpFluxMonitor.FeedsManagerID, RemoteUUID: jpFluxMonitor.RemoteUUID, Status: feeds.JobProposalStatusPending, }, nil) - svc.orm.On("ExistsSpecByJobProposalIDAndVersion", jpFluxMonitor.ID, argsFluxMonitor.Version).Return(false, nil) - svc.orm.On("UpsertJobProposal", &jpFluxMonitor, mock.Anything).Return(idFluxMonitor, nil) - svc.orm.On("CreateSpec", specFluxMonitor, mock.Anything).Return(int64(100), nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("ExistsSpecByJobProposalIDAndVersion", mock.Anything, jpFluxMonitor.ID, argsFluxMonitor.Version).Return(false, nil) + svc.orm.On("UpsertJobProposal", mock.Anything, &jpFluxMonitor).Return(idFluxMonitor, nil) + svc.orm.On("CreateSpec", mock.Anything, specFluxMonitor).Return(int64(100), nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + transactCall := svc.orm.On("Transact", mock.Anything, mock.Anything) + transactCall.Run(func(args mock.Arguments) { + fn := args[1].(func(orm feeds.ORM) error) + transactCall.ReturnArguments = mock.Arguments{fn(svc.orm)} + }) }, args: argsFluxMonitor, wantID: idFluxMonitor, @@ -717,7 +752,7 @@ func Test_Service_ProposeJob(t *testing.T) { name: "ensure an upsert validates the job proposal belongs to the feeds manager", before: func(svc *TestService) { svc.orm. - On("GetJobProposalByRemoteUUID", jpFluxMonitor.RemoteUUID). + On("GetJobProposalByRemoteUUID", mock.Anything, jpFluxMonitor.RemoteUUID). Return(&feeds.JobProposal{ FeedsManagerID: 2, RemoteUUID: jpFluxMonitor.RemoteUUID, @@ -730,13 +765,13 @@ func Test_Service_ProposeJob(t *testing.T) { name: "spec version already exists", before: func(svc *TestService) { svc.orm. - On("GetJobProposalByRemoteUUID", jpFluxMonitor.RemoteUUID). + On("GetJobProposalByRemoteUUID", mock.Anything, jpFluxMonitor.RemoteUUID). Return(&feeds.JobProposal{ FeedsManagerID: jpFluxMonitor.FeedsManagerID, RemoteUUID: jpFluxMonitor.RemoteUUID, Status: feeds.JobProposalStatusPending, }, nil) - svc.orm.On("ExistsSpecByJobProposalIDAndVersion", jpFluxMonitor.ID, argsFluxMonitor.Version).Return(true, nil) + svc.orm.On("ExistsSpecByJobProposalIDAndVersion", mock.Anything, jpFluxMonitor.ID, argsFluxMonitor.Version).Return(true, nil) }, args: argsFluxMonitor, wantErr: "proposed job spec version already exists", @@ -744,8 +779,13 @@ func Test_Service_ProposeJob(t *testing.T) { { name: "upsert error", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", jpFluxMonitor.RemoteUUID).Return(new(feeds.JobProposal), sql.ErrNoRows) - svc.orm.On("UpsertJobProposal", &jpFluxMonitor, mock.Anything).Return(int64(0), errors.New("orm error")) + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, jpFluxMonitor.RemoteUUID).Return(new(feeds.JobProposal), sql.ErrNoRows) + svc.orm.On("UpsertJobProposal", mock.Anything, &jpFluxMonitor).Return(int64(0), errors.New("orm error")) + transactCall := svc.orm.On("Transact", mock.Anything, mock.Anything) + transactCall.Run(func(args mock.Arguments) { + fn := args[1].(func(orm feeds.ORM) error) + transactCall.ReturnArguments = mock.Arguments{fn(svc.orm)} + }) }, args: argsFluxMonitor, wantErr: "failed to upsert job proposal", @@ -753,9 +793,14 @@ func Test_Service_ProposeJob(t *testing.T) { { name: "Create spec error", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", jpFluxMonitor.RemoteUUID).Return(new(feeds.JobProposal), sql.ErrNoRows) - svc.orm.On("UpsertJobProposal", &jpFluxMonitor, mock.Anything).Return(idFluxMonitor, nil) - svc.orm.On("CreateSpec", specFluxMonitor, mock.Anything).Return(int64(0), errors.New("orm error")) + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, jpFluxMonitor.RemoteUUID).Return(new(feeds.JobProposal), sql.ErrNoRows) + svc.orm.On("UpsertJobProposal", mock.Anything, &jpFluxMonitor).Return(idFluxMonitor, nil) + svc.orm.On("CreateSpec", mock.Anything, specFluxMonitor).Return(int64(0), errors.New("orm error")) + transactCall := svc.orm.On("Transact", mock.Anything, mock.Anything) + transactCall.Run(func(args mock.Arguments) { + fn := args[1].(func(orm feeds.ORM) error) + transactCall.ReturnArguments = mock.Arguments{fn(svc.orm)} + }) }, args: argsFluxMonitor, wantErr: "failed to create spec", @@ -820,9 +865,9 @@ func Test_Service_DeleteJob(t *testing.T) { { name: "Delete success", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", approved.RemoteUUID).Return(&approved, nil) - svc.orm.On("DeleteProposal", approved.ID, mock.Anything).Return(nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, approved.RemoteUUID).Return(&approved, nil) + svc.orm.On("DeleteProposal", mock.Anything, approved.ID).Return(nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) }, args: args, wantID: approved.ID, @@ -831,7 +876,7 @@ func Test_Service_DeleteJob(t *testing.T) { name: "Job proposal being deleted belongs to the feeds manager", before: func(svc *TestService) { svc.orm. - On("GetJobProposalByRemoteUUID", approved.RemoteUUID). + On("GetJobProposalByRemoteUUID", mock.Anything, approved.RemoteUUID). Return(&feeds.JobProposal{ FeedsManagerID: 2, RemoteUUID: approved.RemoteUUID, @@ -844,7 +889,7 @@ func Test_Service_DeleteJob(t *testing.T) { { name: "Get proposal error", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", approved.RemoteUUID).Return(nil, errors.New("orm error")) + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, approved.RemoteUUID).Return(nil, errors.New("orm error")) }, args: args, wantErr: "GetJobProposalByRemoteUUID failed", @@ -852,7 +897,7 @@ func Test_Service_DeleteJob(t *testing.T) { { name: "No proposal error", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", approved.RemoteUUID).Return(nil, sql.ErrNoRows) + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, approved.RemoteUUID).Return(nil, sql.ErrNoRows) }, args: args, wantErr: "GetJobProposalByRemoteUUID did not find any proposals to delete", @@ -860,8 +905,8 @@ func Test_Service_DeleteJob(t *testing.T) { { name: "Delete proposal error", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", approved.RemoteUUID).Return(&approved, nil) - svc.orm.On("DeleteProposal", approved.ID, mock.Anything).Return(errors.New("orm error")) + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, approved.RemoteUUID).Return(&approved, nil) + svc.orm.On("DeleteProposal", mock.Anything, approved.ID).Return(errors.New("orm error")) }, args: args, wantErr: "DeleteProposal failed", @@ -960,10 +1005,10 @@ answer1 [type=median index=0]; { name: "Revoke success when latest spec status is pending", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", pendingProposal.RemoteUUID).Return(pendingProposal, nil) - svc.orm.On("GetLatestSpec", pendingSpec.JobProposalID).Return(pendingSpec, nil) - svc.orm.On("RevokeSpec", pendingSpec.ID, mock.Anything).Return(nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, pendingProposal.RemoteUUID).Return(pendingProposal, nil) + svc.orm.On("GetLatestSpec", mock.Anything, pendingSpec.JobProposalID).Return(pendingSpec, nil) + svc.orm.On("RevokeSpec", mock.Anything, pendingSpec.ID).Return(nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) }, args: args, wantID: pendingProposal.ID, @@ -971,16 +1016,16 @@ answer1 [type=median index=0]; { name: "Revoke success when latest spec status is cancelled", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", pendingProposal.RemoteUUID).Return(pendingProposal, nil) - svc.orm.On("GetLatestSpec", pendingSpec.JobProposalID).Return(&feeds.JobProposalSpec{ + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, pendingProposal.RemoteUUID).Return(pendingProposal, nil) + svc.orm.On("GetLatestSpec", mock.Anything, pendingSpec.JobProposalID).Return(&feeds.JobProposalSpec{ ID: 20, Status: feeds.SpecStatusCancelled, JobProposalID: pendingProposal.ID, Version: 1, Definition: defn, }, nil) - svc.orm.On("RevokeSpec", pendingSpec.ID, mock.Anything).Return(nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("RevokeSpec", mock.Anything, pendingSpec.ID).Return(nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) }, args: args, wantID: pendingProposal.ID, @@ -989,7 +1034,7 @@ answer1 [type=median index=0]; name: "Job proposal being revoked belongs to the feeds manager", before: func(svc *TestService) { svc.orm. - On("GetJobProposalByRemoteUUID", pendingProposal.RemoteUUID). + On("GetJobProposalByRemoteUUID", mock.Anything, pendingProposal.RemoteUUID). Return(&feeds.JobProposal{ FeedsManagerID: 2, RemoteUUID: pendingProposal.RemoteUUID, @@ -1002,7 +1047,7 @@ answer1 [type=median index=0]; { name: "Get proposal error", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", pendingProposal.RemoteUUID).Return(nil, errors.New("orm error")) + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, pendingProposal.RemoteUUID).Return(nil, errors.New("orm error")) }, args: args, wantErr: "GetJobProposalByRemoteUUID failed", @@ -1010,7 +1055,7 @@ answer1 [type=median index=0]; { name: "No proposal error", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", pendingProposal.RemoteUUID).Return(nil, sql.ErrNoRows) + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, pendingProposal.RemoteUUID).Return(nil, sql.ErrNoRows) }, args: args, wantErr: "GetJobProposalByRemoteUUID did not find any proposals to revoke", @@ -1018,8 +1063,8 @@ answer1 [type=median index=0]; { name: "Get latest spec error", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", pendingProposal.RemoteUUID).Return(pendingProposal, nil) - svc.orm.On("GetLatestSpec", pendingSpec.JobProposalID).Return(nil, sql.ErrNoRows) + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, pendingProposal.RemoteUUID).Return(pendingProposal, nil) + svc.orm.On("GetLatestSpec", mock.Anything, pendingSpec.JobProposalID).Return(nil, sql.ErrNoRows) }, args: args, wantErr: "GetLatestSpec failed to get latest spec", @@ -1027,8 +1072,8 @@ answer1 [type=median index=0]; { name: "Not revokable due to spec status approved", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", pendingProposal.RemoteUUID).Return(pendingProposal, nil) - svc.orm.On("GetLatestSpec", pendingSpec.JobProposalID).Return(&feeds.JobProposalSpec{ + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, pendingProposal.RemoteUUID).Return(pendingProposal, nil) + svc.orm.On("GetLatestSpec", mock.Anything, pendingSpec.JobProposalID).Return(&feeds.JobProposalSpec{ ID: 20, Status: feeds.SpecStatusApproved, JobProposalID: pendingProposal.ID, @@ -1042,8 +1087,8 @@ answer1 [type=median index=0]; { name: "Not revokable due to spec status rejected", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", pendingProposal.RemoteUUID).Return(pendingProposal, nil) - svc.orm.On("GetLatestSpec", pendingSpec.JobProposalID).Return(&feeds.JobProposalSpec{ + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, pendingProposal.RemoteUUID).Return(pendingProposal, nil) + svc.orm.On("GetLatestSpec", mock.Anything, pendingSpec.JobProposalID).Return(&feeds.JobProposalSpec{ ID: 20, Status: feeds.SpecStatusRejected, JobProposalID: pendingProposal.ID, @@ -1057,8 +1102,8 @@ answer1 [type=median index=0]; { name: "Not revokable due to spec status already revoked", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", pendingProposal.RemoteUUID).Return(pendingProposal, nil) - svc.orm.On("GetLatestSpec", pendingSpec.JobProposalID).Return(&feeds.JobProposalSpec{ + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, pendingProposal.RemoteUUID).Return(pendingProposal, nil) + svc.orm.On("GetLatestSpec", mock.Anything, pendingSpec.JobProposalID).Return(&feeds.JobProposalSpec{ ID: 20, Status: feeds.SpecStatusRevoked, JobProposalID: pendingProposal.ID, @@ -1072,13 +1117,13 @@ answer1 [type=median index=0]; { name: "Not revokable due to proposal status deleted", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", pendingProposal.RemoteUUID).Return(&feeds.JobProposal{ + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, pendingProposal.RemoteUUID).Return(&feeds.JobProposal{ ID: 1, FeedsManagerID: 1, RemoteUUID: remoteUUID, Status: feeds.JobProposalStatusDeleted, }, nil) - svc.orm.On("GetLatestSpec", pendingSpec.JobProposalID).Return(pendingSpec, nil) + svc.orm.On("GetLatestSpec", mock.Anything, pendingSpec.JobProposalID).Return(pendingSpec, nil) }, args: args, wantErr: "only pending job specs can be revoked", @@ -1086,9 +1131,9 @@ answer1 [type=median index=0]; { name: "Revoke proposal error", before: func(svc *TestService) { - svc.orm.On("GetJobProposalByRemoteUUID", pendingProposal.RemoteUUID).Return(pendingProposal, nil) - svc.orm.On("GetLatestSpec", pendingSpec.JobProposalID).Return(pendingSpec, nil) - svc.orm.On("RevokeSpec", pendingSpec.ID, mock.Anything).Return(errors.New("orm error")) + svc.orm.On("GetJobProposalByRemoteUUID", mock.Anything, pendingProposal.RemoteUUID).Return(pendingProposal, nil) + svc.orm.On("GetLatestSpec", mock.Anything, pendingSpec.JobProposalID).Return(pendingSpec, nil) + svc.orm.On("RevokeSpec", mock.Anything, pendingSpec.ID).Return(errors.New("orm error")) }, args: args, wantErr: "RevokeSpec failed", @@ -1168,7 +1213,7 @@ func Test_Service_SyncNodeInfo(t *testing.T) { svc := setupTestService(t) svc.connMgr.On("GetClient", mgr.ID).Return(svc.fmsClient, nil) - svc.orm.On("ListChainConfigsByManagerIDs", []int64{mgr.ID}).Return(chainConfigs, nil) + svc.orm.On("ListChainConfigsByManagerIDs", mock.Anything, []int64{mgr.ID}).Return(chainConfigs, nil) // OCR1 key fetching svc.p2pKeystore.On("Get", p2pKey.PeerID()).Return(p2pKey, nil) @@ -1227,7 +1272,7 @@ func Test_Service_IsJobManaged(t *testing.T) { ctx := testutils.Context(t) jobID := int64(1) - svc.orm.On("IsJobManaged", jobID, mock.Anything).Return(true, nil) + svc.orm.On("IsJobManaged", mock.Anything, jobID).Return(true, nil) isManaged, err := svc.IsJobManaged(ctx, jobID) require.NoError(t, err) @@ -1236,6 +1281,7 @@ func Test_Service_IsJobManaged(t *testing.T) { func Test_Service_ListJobProposals(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( jp = feeds.JobProposal{} @@ -1243,10 +1289,10 @@ func Test_Service_ListJobProposals(t *testing.T) { ) svc := setupTestService(t) - svc.orm.On("ListJobProposals"). + svc.orm.On("ListJobProposals", mock.Anything). Return(jps, nil) - actual, err := svc.ListJobProposals() + actual, err := svc.ListJobProposals(ctx) require.NoError(t, err) assert.Equal(t, actual, jps) @@ -1254,6 +1300,7 @@ func Test_Service_ListJobProposals(t *testing.T) { func Test_Service_ListJobProposalsByManagersIDs(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( jp = feeds.JobProposal{} @@ -1262,10 +1309,10 @@ func Test_Service_ListJobProposalsByManagersIDs(t *testing.T) { ) svc := setupTestService(t) - svc.orm.On("ListJobProposalsByManagersIDs", fmIDs). + svc.orm.On("ListJobProposalsByManagersIDs", mock.Anything, fmIDs). Return(jps, nil) - actual, err := svc.ListJobProposalsByManagersIDs(fmIDs) + actual, err := svc.ListJobProposalsByManagersIDs(ctx, fmIDs) require.NoError(t, err) assert.Equal(t, actual, jps) @@ -1273,6 +1320,7 @@ func Test_Service_ListJobProposalsByManagersIDs(t *testing.T) { func Test_Service_GetJobProposal(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( id = int64(1) @@ -1280,10 +1328,10 @@ func Test_Service_GetJobProposal(t *testing.T) { ) svc := setupTestService(t) - svc.orm.On("GetJobProposal", id). + svc.orm.On("GetJobProposal", mock.Anything, id). Return(&ms, nil) - actual, err := svc.GetJobProposal(id) + actual, err := svc.GetJobProposal(ctx, id) require.NoError(t, err) assert.Equal(t, actual, &ms) @@ -1320,12 +1368,12 @@ func Test_Service_CancelSpec(t *testing.T) { name: "success", before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) - svc.orm.On("CancelSpec", spec.ID, mock.Anything).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(j, nil) - svc.spawner.On("DeleteJob", j.ID, mock.Anything).Return(nil) + svc.orm.On("CancelSpec", mock.Anything, spec.ID).Return(nil) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(j, nil) + svc.spawner.On("DeleteJob", mock.Anything, mock.Anything, j.ID).Return(nil) svc.fmsClient.On("CancelledJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -1334,7 +1382,9 @@ func Test_Service_CancelSpec(t *testing.T) { Version: int64(spec.Version), }, ).Return(&proto.CancelledJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, specID: spec.ID, }, @@ -1342,14 +1392,14 @@ func Test_Service_CancelSpec(t *testing.T) { name: "success without external job id", before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(&feeds.JobProposal{ + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(&feeds.JobProposal{ ID: 1, RemoteUUID: externalJobID, FeedsManagerID: 100, }, nil) - svc.orm.On("CancelSpec", spec.ID, mock.Anything).Return(nil) + svc.orm.On("CancelSpec", mock.Anything, spec.ID).Return(nil) svc.fmsClient.On("CancelledJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), &proto.CancelledJobRequest{ @@ -1357,7 +1407,9 @@ func Test_Service_CancelSpec(t *testing.T) { Version: int64(spec.Version), }, ).Return(&proto.CancelledJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, specID: spec.ID, }, @@ -1365,11 +1417,11 @@ func Test_Service_CancelSpec(t *testing.T) { name: "success without jobs", before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) - svc.orm.On("CancelSpec", spec.ID, mock.Anything).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) + svc.orm.On("CancelSpec", mock.Anything, spec.ID).Return(nil) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) svc.fmsClient.On("CancelledJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), &proto.CancelledJobRequest{ @@ -1377,14 +1429,16 @@ func Test_Service_CancelSpec(t *testing.T) { Version: int64(spec.Version), }, ).Return(&proto.CancelledJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, specID: spec.ID, }, { name: "spec does not exist", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(nil, errors.New("Not Found")) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(nil, errors.New("Not Found")) }, specID: spec.ID, wantErr: "orm: job proposal spec: Not Found", @@ -1396,7 +1450,7 @@ func Test_Service_CancelSpec(t *testing.T) { ID: spec.ID, Status: feeds.SpecStatusPending, } - svc.orm.On("GetSpec", pspec.ID, mock.Anything).Return(pspec, nil) + svc.orm.On("GetSpec", mock.Anything, pspec.ID, mock.Anything).Return(pspec, nil) }, specID: spec.ID, wantErr: "must be an approved job proposal spec", @@ -1404,8 +1458,8 @@ func Test_Service_CancelSpec(t *testing.T) { { name: "job proposal does not exist", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(nil, errors.New("Not Found")) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(nil, errors.New("Not Found")) }, specID: spec.ID, wantErr: "orm: job proposal: Not Found", @@ -1413,8 +1467,8 @@ func Test_Service_CancelSpec(t *testing.T) { { name: "rpc client not connected", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(nil, errors.New("Not Connected")) }, specID: spec.ID, @@ -1424,9 +1478,11 @@ func Test_Service_CancelSpec(t *testing.T) { name: "cancel spec orm fails", before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) - svc.orm.On("CancelSpec", spec.ID, mock.Anything).Return(errors.New("failure")) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) + svc.orm.On("CancelSpec", mock.Anything, spec.ID).Return(errors.New("failure")) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, specID: spec.ID, wantErr: "failure", @@ -1435,11 +1491,13 @@ func Test_Service_CancelSpec(t *testing.T) { name: "find by external uuid orm fails", before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) - svc.orm.On("CancelSpec", spec.ID, mock.Anything).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, errors.New("failure")) + svc.orm.On("CancelSpec", mock.Anything, spec.ID).Return(nil) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, errors.New("failure")) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, specID: spec.ID, wantErr: "FindJobByExternalJobID failed: failure", @@ -1448,12 +1506,14 @@ func Test_Service_CancelSpec(t *testing.T) { name: "delete job fails", before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) - svc.orm.On("CancelSpec", spec.ID, mock.Anything).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(j, nil) - svc.spawner.On("DeleteJob", j.ID, mock.Anything).Return(errors.New("failure")) + svc.orm.On("CancelSpec", mock.Anything, spec.ID).Return(nil) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(j, nil) + svc.spawner.On("DeleteJob", mock.Anything, mock.Anything, j.ID).Return(errors.New("failure")) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, specID: spec.ID, wantErr: "DeleteJob failed: failure", @@ -1462,12 +1522,12 @@ func Test_Service_CancelSpec(t *testing.T) { name: "cancelled job rpc call fails", before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) - svc.orm.On("CancelSpec", spec.ID, mock.Anything).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(j, nil) - svc.spawner.On("DeleteJob", j.ID, mock.Anything).Return(nil) + svc.orm.On("CancelSpec", mock.Anything, spec.ID).Return(nil) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(j, nil) + svc.spawner.On("DeleteJob", mock.Anything, mock.Anything, j.ID).Return(nil) svc.fmsClient.On("CancelledJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -1476,6 +1536,8 @@ func Test_Service_CancelSpec(t *testing.T) { Version: int64(spec.Version), }, ).Return(nil, errors.New("failure")) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, specID: spec.ID, wantErr: "failure", @@ -1509,6 +1571,7 @@ func Test_Service_CancelSpec(t *testing.T) { func Test_Service_GetSpec(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( id = int64(1) @@ -1516,10 +1579,10 @@ func Test_Service_GetSpec(t *testing.T) { ) svc := setupTestService(t) - svc.orm.On("GetSpec", id). + svc.orm.On("GetSpec", mock.Anything, id). Return(&spec, nil) - actual, err := svc.GetSpec(id) + actual, err := svc.GetSpec(ctx, id) require.NoError(t, err) assert.Equal(t, &spec, actual) @@ -1527,6 +1590,7 @@ func Test_Service_GetSpec(t *testing.T) { func Test_Service_ListSpecsByJobProposalIDs(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) var ( id = int64(1) @@ -1536,10 +1600,10 @@ func Test_Service_ListSpecsByJobProposalIDs(t *testing.T) { ) svc := setupTestService(t) - svc.orm.On("ListSpecsByJobProposalIDs", []int64{jpID}). + svc.orm.On("ListSpecsByJobProposalIDs", mock.Anything, []int64{jpID}). Return(specs, nil) - actual, err := svc.ListSpecsByJobProposalIDs([]int64{jpID}) + actual, err := svc.ListSpecsByJobProposalIDs(ctx, []int64{jpID}) require.NoError(t, err) assert.Equal(t, specs, actual) @@ -1624,26 +1688,27 @@ answer1 [type=median index=0]; httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.EXPECT().GetSpec(spec.ID, mock.Anything).Return(spec, nil) - svc.orm.EXPECT().GetJobProposal(jp.ID, mock.Anything).Return(jp, nil) + svc.orm.EXPECT().GetSpec(mock.Anything, spec.ID).Return(spec, nil) + svc.orm.EXPECT().GetJobProposal(mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindJobIDByAddress", address, evmChainID, mock.Anything).Return(int32(0), sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindJobIDByAddress", mock.Anything, address, evmChainID, mock.Anything).Return(int32(0), sql.ErrNoRows) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, mock.IsType(uuid.UUID{}), - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -1652,7 +1717,9 @@ answer1 [type=median index=0]; Version: int64(spec.Version), }, ).Return(&proto.ApprovedJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: false, @@ -1662,27 +1729,28 @@ answer1 [type=median index=0]; httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", cancelledSpec.ID, mock.Anything).Return(cancelledSpec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) - svc.orm.On("GetLatestSpec", cancelledSpec.JobProposalID).Return(cancelledSpec, nil) + svc.orm.On("GetSpec", mock.Anything, cancelledSpec.ID).Return(cancelledSpec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) + svc.orm.On("GetLatestSpec", mock.Anything, cancelledSpec.JobProposalID).Return(cancelledSpec, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindJobIDByAddress", address, evmChainID, mock.Anything).Return(int32(0), sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindJobIDByAddress", mock.Anything, address, evmChainID, mock.Anything).Return(int32(0), sql.ErrNoRows) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, cancelledSpec.ID, externalJobID, - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -1691,7 +1759,9 @@ answer1 [type=median index=0]; Version: int64(spec.Version), }, ).Return(&proto.ApprovedJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: cancelledSpec.ID, force: false, @@ -1701,8 +1771,8 @@ answer1 [type=median index=0]; httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.EXPECT().GetSpec(spec.ID, mock.Anything).Return(spec2, nil) - svc.orm.EXPECT().GetJobProposal(jp.ID, mock.Anything).Return(jp, nil) + svc.orm.EXPECT().GetSpec(mock.Anything, spec.ID).Return(spec2, nil) + svc.orm.EXPECT().GetJobProposal(mock.Anything, jp.ID).Return(jp, nil) }, id: spec.ID, force: false, @@ -1711,8 +1781,8 @@ answer1 [type=median index=0]; { name: "failed due to proposal being revoked", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(&feeds.JobProposal{ + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(&feeds.JobProposal{ ID: 1, Status: feeds.JobProposalStatusRevoked, }, nil) @@ -1724,8 +1794,8 @@ answer1 [type=median index=0]; { name: "failed due to proposal being deleted", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(&feeds.JobProposal{ + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(&feeds.JobProposal{ ID: jp.ID, Status: feeds.JobProposalStatusDeleted, }, nil) @@ -1742,8 +1812,8 @@ answer1 [type=median index=0]; Status: feeds.SpecStatusApproved, JobProposalID: jp.ID, } - svc.orm.On("GetSpec", aspec.ID, mock.Anything).Return(aspec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, aspec.ID, mock.Anything).Return(aspec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) }, id: spec.ID, force: false, @@ -1752,8 +1822,8 @@ answer1 [type=median index=0]; { name: "rejected spec fail", before: func(svc *TestService) { - svc.orm.On("GetSpec", cancelledSpec.ID, mock.Anything).Return(rejectedSpec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, cancelledSpec.ID, mock.Anything).Return(rejectedSpec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) }, id: rejectedSpec.ID, force: false, @@ -1762,9 +1832,9 @@ answer1 [type=median index=0]; { name: "cancelled spec failed not latest spec", before: func(svc *TestService) { - svc.orm.On("GetSpec", cancelledSpec.ID, mock.Anything).Return(cancelledSpec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) - svc.orm.On("GetLatestSpec", cancelledSpec.JobProposalID).Return(&feeds.JobProposalSpec{ + svc.orm.On("GetSpec", mock.Anything, cancelledSpec.ID, mock.Anything).Return(cancelledSpec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) + svc.orm.On("GetLatestSpec", mock.Anything, cancelledSpec.JobProposalID).Return(&feeds.JobProposalSpec{ ID: 21, Status: feeds.SpecStatusPending, JobProposalID: jp.ID, @@ -1781,11 +1851,13 @@ answer1 [type=median index=0]; httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(j, nil) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(j, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: false, @@ -1796,12 +1868,14 @@ answer1 [type=median index=0]; httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindJobIDByAddress", address, evmChainID, mock.Anything).Return(j.ID, nil) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindJobIDByAddress", mock.Anything, address, evmChainID, mock.Anything).Return(j.ID, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: false, @@ -1812,27 +1886,28 @@ answer1 [type=median index=0]; httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.EXPECT().GetSpec(spec.ID, mock.Anything).Return(spec, nil) - svc.orm.EXPECT().GetJobProposal(jp.ID, mock.Anything).Return(jp, nil) + svc.orm.EXPECT().GetSpec(mock.Anything, spec.ID).Return(spec, nil) + svc.orm.EXPECT().GetJobProposal(mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(j, nil) - svc.orm.EXPECT().GetApprovedSpec(jp.ID, mock.Anything).Return(nil, sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(j, nil) + svc.orm.EXPECT().GetApprovedSpec(mock.Anything, jp.ID).Return(nil, sql.ErrNoRows) - svc.spawner.On("DeleteJob", j.ID, mock.Anything).Return(nil) + svc.spawner.On("DeleteJob", mock.Anything, mock.Anything, j.ID).Return(nil) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, externalJobID, - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -1841,7 +1916,9 @@ answer1 [type=median index=0]; Version: int64(spec.Version), }, ).Return(&proto.ApprovedJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: true, @@ -1851,28 +1928,29 @@ answer1 [type=median index=0]; httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.EXPECT().GetSpec(spec.ID, mock.Anything).Return(spec, nil) - svc.orm.EXPECT().GetJobProposal(jp.ID, mock.Anything).Return(jp, nil) + svc.orm.EXPECT().GetSpec(mock.Anything, spec.ID).Return(spec, nil) + svc.orm.EXPECT().GetJobProposal(mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindJobIDByAddress", address, evmChainID, mock.Anything).Return(j.ID, nil) - svc.orm.EXPECT().GetApprovedSpec(jp.ID, mock.Anything).Return(nil, sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindJobIDByAddress", mock.Anything, address, evmChainID, mock.Anything).Return(j.ID, nil) + svc.orm.EXPECT().GetApprovedSpec(mock.Anything, jp.ID).Return(nil, sql.ErrNoRows) - svc.spawner.On("DeleteJob", j.ID, mock.Anything).Return(nil) + svc.spawner.On("DeleteJob", mock.Anything, mock.Anything, j.ID).Return(nil) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, externalJobID, - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -1881,7 +1959,9 @@ answer1 [type=median index=0]; Version: int64(spec.Version), }, ).Return(&proto.ApprovedJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: true, @@ -1891,29 +1971,30 @@ answer1 [type=median index=0]; httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.EXPECT().GetSpec(spec.ID, mock.Anything).Return(spec, nil) - svc.orm.EXPECT().GetJobProposal(jp.ID, mock.Anything).Return(jp, nil) + svc.orm.EXPECT().GetSpec(mock.Anything, spec.ID).Return(spec, nil) + svc.orm.EXPECT().GetJobProposal(mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindJobIDByAddress", address, evmChainID, mock.Anything).Return(j.ID, nil) - svc.orm.EXPECT().GetApprovedSpec(jp.ID, mock.Anything).Return(&feeds.JobProposalSpec{ID: 100}, nil) - svc.orm.EXPECT().CancelSpec(int64(100), mock.Anything).Return(nil) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindJobIDByAddress", mock.Anything, address, evmChainID, mock.Anything).Return(j.ID, nil) + svc.orm.EXPECT().GetApprovedSpec(mock.Anything, jp.ID).Return(&feeds.JobProposalSpec{ID: 100}, nil) + svc.orm.EXPECT().CancelSpec(mock.Anything, int64(100)).Return(nil) - svc.spawner.On("DeleteJob", j.ID, mock.Anything).Return(nil) + svc.spawner.On("DeleteJob", mock.Anything, mock.Anything, j.ID).Return(nil) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, externalJobID, - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -1922,7 +2003,9 @@ answer1 [type=median index=0]; Version: int64(spec.Version), }, ).Return(&proto.ApprovedJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: true, @@ -1930,7 +2013,7 @@ answer1 [type=median index=0]; { name: "spec does not exist", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(nil, errors.New("Not Found")) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(nil, errors.New("Not Found")) }, id: spec.ID, force: false, @@ -1939,8 +2022,8 @@ answer1 [type=median index=0]; { name: "job proposal does not exist", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(nil, errors.New("Not Found")) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(nil, errors.New("Not Found")) }, id: spec.ID, wantErr: "orm: job proposal: Not Found", @@ -1950,8 +2033,8 @@ answer1 [type=median index=0]; httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(errors.New("bridges do not exist")) }, id: spec.ID, @@ -1960,8 +2043,8 @@ answer1 [type=median index=0]; { name: "rpc client not connected", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(nil, errors.New("Not Connected")) }, id: spec.ID, @@ -1973,11 +2056,13 @@ answer1 [type=median index=0]; httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.EXPECT().GetSpec(spec.ID, mock.Anything).Return(spec, nil) - svc.orm.EXPECT().GetJobProposal(jp.ID, mock.Anything).Return(jp, nil) + svc.orm.EXPECT().GetSpec(mock.Anything, spec.ID).Return(spec, nil) + svc.orm.EXPECT().GetJobProposal(mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(j, nil) - svc.orm.EXPECT().GetApprovedSpec(jp.ID, mock.Anything).Return(nil, errors.New("failure")) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(j, nil) + svc.orm.EXPECT().GetApprovedSpec(mock.Anything, jp.ID).Return(nil, errors.New("failure")) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: true, @@ -1988,12 +2073,14 @@ answer1 [type=median index=0]; httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.EXPECT().GetSpec(spec.ID, mock.Anything).Return(spec, nil) - svc.orm.EXPECT().GetJobProposal(jp.ID, mock.Anything).Return(jp, nil) + svc.orm.EXPECT().GetSpec(mock.Anything, spec.ID).Return(spec, nil) + svc.orm.EXPECT().GetJobProposal(mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindJobIDByAddress", address, evmChainID, mock.Anything).Return(j.ID, nil) - svc.orm.EXPECT().GetApprovedSpec(jp.ID, mock.Anything).Return(nil, errors.New("failure")) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindJobIDByAddress", mock.Anything, address, evmChainID, mock.Anything).Return(j.ID, nil) + svc.orm.EXPECT().GetApprovedSpec(mock.Anything, jp.ID).Return(nil, errors.New("failure")) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: true, @@ -2004,12 +2091,14 @@ answer1 [type=median index=0]; httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.EXPECT().GetSpec(spec.ID, mock.Anything).Return(spec, nil) - svc.orm.EXPECT().GetJobProposal(jp.ID, mock.Anything).Return(jp, nil) + svc.orm.EXPECT().GetSpec(mock.Anything, spec.ID).Return(spec, nil) + svc.orm.EXPECT().GetJobProposal(mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(j, nil) - svc.orm.EXPECT().GetApprovedSpec(jp.ID, mock.Anything).Return(&feeds.JobProposalSpec{ID: 100}, nil) - svc.orm.EXPECT().CancelSpec(int64(100), mock.Anything).Return(errors.New("failure")) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(j, nil) + svc.orm.EXPECT().GetApprovedSpec(mock.Anything, jp.ID).Return(&feeds.JobProposalSpec{ID: 100}, nil) + svc.orm.EXPECT().CancelSpec(mock.Anything, int64(100)).Return(errors.New("failure")) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: true, @@ -2020,13 +2109,15 @@ answer1 [type=median index=0]; httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.EXPECT().GetSpec(spec.ID, mock.Anything).Return(spec, nil) - svc.orm.EXPECT().GetJobProposal(jp.ID, mock.Anything).Return(jp, nil) + svc.orm.EXPECT().GetSpec(mock.Anything, spec.ID).Return(spec, nil) + svc.orm.EXPECT().GetJobProposal(mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindJobIDByAddress", address, evmChainID, mock.Anything).Return(j.ID, nil) - svc.orm.EXPECT().GetApprovedSpec(jp.ID, mock.Anything).Return(&feeds.JobProposalSpec{ID: 100}, nil) - svc.orm.EXPECT().CancelSpec(int64(100), mock.Anything).Return(errors.New("failure")) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindJobIDByAddress", mock.Anything, address, evmChainID, mock.Anything).Return(j.ID, nil) + svc.orm.EXPECT().GetApprovedSpec(mock.Anything, jp.ID).Return(&feeds.JobProposalSpec{ID: 100}, nil) + svc.orm.EXPECT().CancelSpec(mock.Anything, int64(100)).Return(errors.New("failure")) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: true, @@ -2036,22 +2127,25 @@ answer1 [type=median index=0]; name: "create job error", httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindJobIDByAddress", address, evmChainID, mock.Anything).Return(int32(0), sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindJobIDByAddress", mock.Anything, address, evmChainID, mock.Anything).Return(int32(0), sql.ErrNoRows) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). Return(errors.New("could not save")) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: false, @@ -2061,28 +2155,31 @@ answer1 [type=median index=0]; name: "approve spec orm error", httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindJobIDByAddress", address, evmChainID, mock.Anything).Return(int32(0), sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindJobIDByAddress", mock.Anything, address, evmChainID, mock.Anything).Return(int32(0), sql.ErrNoRows) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, externalJobID, - mock.Anything, ).Return(errors.New("failure")) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: false, @@ -2092,27 +2189,28 @@ answer1 [type=median index=0]; name: "fms call error", httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindJobIDByAddress", address, evmChainID, mock.Anything).Return(int32(0), sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindJobIDByAddress", mock.Anything, address, evmChainID, mock.Anything).Return(int32(0), sql.ErrNoRows) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, externalJobID, - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -2121,6 +2219,8 @@ answer1 [type=median index=0]; Version: int64(spec.Version), }, ).Return(nil, errors.New("failure")) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: false, @@ -2268,25 +2368,26 @@ updateInterval = "20m" httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, (*common.Hash)(nil), mock.Anything).Return(int32(0), sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, (*common.Hash)(nil)).Return(int32(0), sql.ErrNoRows) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, externalJobID, - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -2295,7 +2396,9 @@ updateInterval = "20m" Version: int64(spec.Version), }, ).Return(&proto.ApprovedJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: false, @@ -2305,27 +2408,28 @@ updateInterval = "20m" httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", cancelledSpec.ID, mock.Anything).Return(cancelledSpec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) - svc.orm.On("GetLatestSpec", cancelledSpec.JobProposalID).Return(cancelledSpec, nil) + svc.orm.On("GetSpec", mock.Anything, cancelledSpec.ID, mock.Anything).Return(cancelledSpec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) + svc.orm.On("GetLatestSpec", mock.Anything, cancelledSpec.JobProposalID).Return(cancelledSpec, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, (*common.Hash)(nil), mock.Anything).Return(int32(0), sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, (*common.Hash)(nil)).Return(int32(0), sql.ErrNoRows) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, cancelledSpec.ID, externalJobID, - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -2334,7 +2438,9 @@ updateInterval = "20m" Version: int64(spec.Version), }, ).Return(&proto.ApprovedJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: cancelledSpec.ID, force: false, @@ -2342,9 +2448,9 @@ updateInterval = "20m" { name: "cancelled spec failed not latest spec", before: func(svc *TestService) { - svc.orm.On("GetSpec", cancelledSpec.ID, mock.Anything).Return(cancelledSpec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) - svc.orm.On("GetLatestSpec", cancelledSpec.JobProposalID).Return(&feeds.JobProposalSpec{ + svc.orm.On("GetSpec", mock.Anything, cancelledSpec.ID, mock.Anything).Return(cancelledSpec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) + svc.orm.On("GetLatestSpec", mock.Anything, cancelledSpec.JobProposalID).Return(&feeds.JobProposalSpec{ ID: 21, Status: feeds.SpecStatusPending, JobProposalID: jp.ID, @@ -2359,8 +2465,8 @@ updateInterval = "20m" { name: "rejected spec failed cannot be approved", before: func(svc *TestService) { - svc.orm.On("GetSpec", cancelledSpec.ID, mock.Anything).Return(rejectedSpec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, cancelledSpec.ID, mock.Anything).Return(rejectedSpec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) }, id: rejectedSpec.ID, force: false, @@ -2371,11 +2477,13 @@ updateInterval = "20m" httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, (*common.Hash)(nil), mock.Anything).Return(j.ID, nil) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, (*common.Hash)(nil)).Return(j.ID, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: false, @@ -2386,27 +2494,28 @@ updateInterval = "20m" httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.orm.EXPECT().GetApprovedSpec(jp.ID, mock.Anything).Return(nil, sql.ErrNoRows) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, (*common.Hash)(nil), mock.Anything).Return(j.ID, nil) - svc.spawner.On("DeleteJob", j.ID, mock.Anything).Return(nil) + svc.orm.EXPECT().GetApprovedSpec(mock.Anything, jp.ID).Return(nil, sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, (*common.Hash)(nil)).Return(j.ID, nil) + svc.spawner.On("DeleteJob", mock.Anything, mock.Anything, j.ID).Return(nil) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, externalJobID, - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -2415,7 +2524,9 @@ updateInterval = "20m" Version: int64(spec.Version), }, ).Return(&proto.ApprovedJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: true, @@ -2425,33 +2536,34 @@ updateInterval = "20m" httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(&feeds.JobProposalSpec{ + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(&feeds.JobProposalSpec{ ID: 20, Status: feeds.SpecStatusPending, JobProposalID: jp.ID, Version: 1, Definition: fmt.Sprintf(defn2, externalJobID.String(), &feedID), }, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.orm.EXPECT().GetApprovedSpec(jp.ID, mock.Anything).Return(nil, sql.ErrNoRows) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, &feedID, mock.Anything).Return(j.ID, nil) - svc.spawner.On("DeleteJob", j.ID, mock.Anything).Return(nil) + svc.orm.EXPECT().GetApprovedSpec(mock.Anything, jp.ID).Return(nil, sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, &feedID).Return(j.ID, nil) + svc.spawner.On("DeleteJob", mock.Anything, mock.Anything, j.ID).Return(nil) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, externalJobID, - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -2460,7 +2572,9 @@ updateInterval = "20m" Version: int64(spec.Version), }, ).Return(&proto.ApprovedJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: true, @@ -2470,28 +2584,29 @@ updateInterval = "20m" httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.orm.EXPECT().GetApprovedSpec(jp.ID, mock.Anything).Return(&feeds.JobProposalSpec{ID: 100}, nil) - svc.orm.EXPECT().CancelSpec(int64(100), mock.Anything).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, (*common.Hash)(nil), mock.Anything).Return(j.ID, nil) - svc.spawner.On("DeleteJob", j.ID, mock.Anything).Return(nil) + svc.orm.EXPECT().GetApprovedSpec(mock.Anything, jp.ID).Return(&feeds.JobProposalSpec{ID: 100}, nil) + svc.orm.EXPECT().CancelSpec(mock.Anything, int64(100)).Return(nil) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, (*common.Hash)(nil)).Return(j.ID, nil) + svc.spawner.On("DeleteJob", mock.Anything, mock.Anything, j.ID).Return(nil) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, externalJobID, - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -2500,7 +2615,9 @@ updateInterval = "20m" Version: int64(spec.Version), }, ).Return(&proto.ApprovedJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: true, @@ -2508,7 +2625,7 @@ updateInterval = "20m" { name: "spec does not exist", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(nil, errors.New("Not Found")) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(nil, errors.New("Not Found")) }, id: spec.ID, force: false, @@ -2522,8 +2639,8 @@ updateInterval = "20m" JobProposalID: jp.ID, Status: feeds.SpecStatusApproved, } - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(aspec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(aspec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) }, id: spec.ID, force: false, @@ -2537,8 +2654,8 @@ updateInterval = "20m" JobProposalID: jp.ID, Status: feeds.SpecStatusRejected, } - svc.orm.On("GetSpec", rspec.ID, mock.Anything).Return(rspec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, rspec.ID, mock.Anything).Return(rspec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) }, id: spec.ID, force: false, @@ -2547,8 +2664,8 @@ updateInterval = "20m" { name: "job proposal does not exist", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(nil, errors.New("Not Found")) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(nil, errors.New("Not Found")) }, id: spec.ID, wantErr: "orm: job proposal: Not Found", @@ -2558,8 +2675,8 @@ updateInterval = "20m" httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(errors.New("bridges do not exist")) }, id: spec.ID, @@ -2568,8 +2685,8 @@ updateInterval = "20m" { name: "rpc client not connected", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(nil, errors.New("Not Connected")) }, id: spec.ID, @@ -2580,22 +2697,25 @@ updateInterval = "20m" name: "create job error", httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, (*common.Hash)(nil), mock.Anything).Return(int32(0), sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, (*common.Hash)(nil)).Return(int32(0), sql.ErrNoRows) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). Return(errors.New("could not save")) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: false, @@ -2605,28 +2725,31 @@ updateInterval = "20m" name: "approve spec orm error", httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, (*common.Hash)(nil), mock.Anything).Return(int32(0), sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, (*common.Hash)(nil)).Return(int32(0), sql.ErrNoRows) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, externalJobID, - mock.Anything, ).Return(errors.New("failure")) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: false, @@ -2636,27 +2759,28 @@ updateInterval = "20m" name: "fms call error", httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, (*common.Hash)(nil), mock.Anything).Return(int32(0), sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, (*common.Hash)(nil)).Return(int32(0), sql.ErrNoRows) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, externalJobID, - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -2665,6 +2789,8 @@ updateInterval = "20m" Version: int64(spec.Version), }, ).Return(nil, errors.New("failure")) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: false, @@ -2774,25 +2900,26 @@ chainID = 0 httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, (*common.Hash)(nil), mock.Anything).Return(int32(0), sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, (*common.Hash)(nil)).Return(int32(0), sql.ErrNoRows) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, externalJobID, - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -2801,7 +2928,9 @@ chainID = 0 Version: int64(spec.Version), }, ).Return(&proto.ApprovedJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: false, @@ -2811,27 +2940,28 @@ chainID = 0 httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", cancelledSpec.ID, mock.Anything).Return(cancelledSpec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) - svc.orm.On("GetLatestSpec", cancelledSpec.JobProposalID).Return(cancelledSpec, nil) + svc.orm.On("GetSpec", mock.Anything, cancelledSpec.ID, mock.Anything).Return(cancelledSpec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) + svc.orm.On("GetLatestSpec", mock.Anything, cancelledSpec.JobProposalID).Return(cancelledSpec, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, (*common.Hash)(nil), mock.Anything).Return(int32(0), sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, (*common.Hash)(nil)).Return(int32(0), sql.ErrNoRows) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, cancelledSpec.ID, externalJobID, - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -2840,7 +2970,9 @@ chainID = 0 Version: int64(spec.Version), }, ).Return(&proto.ApprovedJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: cancelledSpec.ID, force: false, @@ -2848,9 +2980,9 @@ chainID = 0 { name: "cancelled spec failed not latest spec", before: func(svc *TestService) { - svc.orm.On("GetSpec", cancelledSpec.ID, mock.Anything).Return(cancelledSpec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) - svc.orm.On("GetLatestSpec", cancelledSpec.JobProposalID).Return(&feeds.JobProposalSpec{ + svc.orm.On("GetSpec", mock.Anything, cancelledSpec.ID, mock.Anything).Return(cancelledSpec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) + svc.orm.On("GetLatestSpec", mock.Anything, cancelledSpec.JobProposalID).Return(&feeds.JobProposalSpec{ ID: 21, Status: feeds.SpecStatusPending, JobProposalID: jp.ID, @@ -2865,8 +2997,8 @@ chainID = 0 { name: "rejected spec failed cannot be approved", before: func(svc *TestService) { - svc.orm.On("GetSpec", cancelledSpec.ID, mock.Anything).Return(rejectedSpec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, cancelledSpec.ID, mock.Anything).Return(rejectedSpec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) }, id: rejectedSpec.ID, force: false, @@ -2877,11 +3009,13 @@ chainID = 0 httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, (*common.Hash)(nil), mock.Anything).Return(j.ID, nil) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, (*common.Hash)(nil)).Return(j.ID, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: false, @@ -2892,27 +3026,28 @@ chainID = 0 httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.orm.EXPECT().GetApprovedSpec(jp.ID, mock.Anything).Return(nil, sql.ErrNoRows) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, (*common.Hash)(nil), mock.Anything).Return(j.ID, nil) - svc.spawner.On("DeleteJob", j.ID, mock.Anything).Return(nil) + svc.orm.EXPECT().GetApprovedSpec(mock.Anything, jp.ID).Return(nil, sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, (*common.Hash)(nil)).Return(j.ID, nil) + svc.spawner.On("DeleteJob", mock.Anything, mock.Anything, j.ID).Return(nil) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, externalJobID, - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -2921,7 +3056,9 @@ chainID = 0 Version: int64(spec.Version), }, ).Return(&proto.ApprovedJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: true, @@ -2931,33 +3068,34 @@ chainID = 0 httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(&feeds.JobProposalSpec{ + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(&feeds.JobProposalSpec{ ID: 20, Status: feeds.SpecStatusPending, JobProposalID: jp.ID, Version: 1, Definition: fmt.Sprintf(defn2, externalJobID.String(), feedID), }, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.orm.EXPECT().GetApprovedSpec(jp.ID, mock.Anything).Return(nil, sql.ErrNoRows) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, &feedID, mock.Anything).Return(j.ID, nil) - svc.spawner.On("DeleteJob", j.ID, mock.Anything).Return(nil) + svc.orm.EXPECT().GetApprovedSpec(mock.Anything, jp.ID).Return(nil, sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, &feedID).Return(j.ID, nil) + svc.spawner.On("DeleteJob", mock.Anything, mock.Anything, j.ID).Return(nil) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, externalJobID, - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -2966,7 +3104,9 @@ chainID = 0 Version: int64(spec.Version), }, ).Return(&proto.ApprovedJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: true, @@ -2976,28 +3116,29 @@ chainID = 0 httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.orm.EXPECT().GetApprovedSpec(jp.ID, mock.Anything).Return(&feeds.JobProposalSpec{ID: 100}, nil) - svc.orm.EXPECT().CancelSpec(int64(100), mock.Anything).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, (*common.Hash)(nil), mock.Anything).Return(j.ID, nil) - svc.spawner.On("DeleteJob", j.ID, mock.Anything).Return(nil) + svc.orm.EXPECT().GetApprovedSpec(mock.Anything, jp.ID).Return(&feeds.JobProposalSpec{ID: 100}, nil) + svc.orm.EXPECT().CancelSpec(mock.Anything, int64(100)).Return(nil) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, (*common.Hash)(nil)).Return(j.ID, nil) + svc.spawner.On("DeleteJob", mock.Anything, mock.Anything, j.ID).Return(nil) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, externalJobID, - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -3006,7 +3147,9 @@ chainID = 0 Version: int64(spec.Version), }, ).Return(&proto.ApprovedJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: true, @@ -3014,7 +3157,7 @@ chainID = 0 { name: "spec does not exist", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(nil, errors.New("Not Found")) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(nil, errors.New("Not Found")) }, id: spec.ID, force: false, @@ -3028,8 +3171,8 @@ chainID = 0 JobProposalID: jp.ID, Status: feeds.SpecStatusApproved, } - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(aspec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(aspec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) }, id: spec.ID, force: false, @@ -3043,8 +3186,8 @@ chainID = 0 JobProposalID: jp.ID, Status: feeds.SpecStatusRejected, } - svc.orm.On("GetSpec", rspec.ID, mock.Anything).Return(rspec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, rspec.ID).Return(rspec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) }, id: spec.ID, force: false, @@ -3053,8 +3196,8 @@ chainID = 0 { name: "job proposal does not exist", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(nil, errors.New("Not Found")) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(nil, errors.New("Not Found")) }, id: spec.ID, wantErr: "orm: job proposal: Not Found", @@ -3064,8 +3207,8 @@ chainID = 0 httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(errors.New("bridges do not exist")) }, id: spec.ID, @@ -3074,8 +3217,8 @@ chainID = 0 { name: "rpc client not connected", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(nil, errors.New("Not Connected")) }, id: spec.ID, @@ -3086,22 +3229,25 @@ chainID = 0 name: "create job error", httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, (*common.Hash)(nil), mock.Anything).Return(int32(0), sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, (*common.Hash)(nil)).Return(int32(0), sql.ErrNoRows) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). Return(errors.New("could not save")) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: false, @@ -3111,28 +3257,31 @@ chainID = 0 name: "approve spec orm error", httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, (*common.Hash)(nil), mock.Anything).Return(int32(0), sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, (*common.Hash)(nil), mock.Anything).Return(int32(0), sql.ErrNoRows) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, externalJobID, - mock.Anything, ).Return(errors.New("failure")) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: false, @@ -3142,27 +3291,28 @@ chainID = 0 name: "fms call error", httpTimeout: commonconfig.MustNewDuration(1 * time.Minute), before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) svc.jobORM.On("AssertBridgesExist", mock.Anything, mock.IsType(pipeline.Pipeline{})).Return(nil) - svc.jobORM.On("FindJobByExternalJobID", externalJobID, mock.Anything).Return(job.Job{}, sql.ErrNoRows) - svc.jobORM.On("FindOCR2JobIDByAddress", address, (*common.Hash)(nil), mock.Anything).Return(int32(0), sql.ErrNoRows) + svc.jobORM.On("FindJobByExternalJobID", mock.Anything, externalJobID).Return(job.Job{}, sql.ErrNoRows) + svc.jobORM.On("FindOCR2JobIDByAddress", mock.Anything, address, (*common.Hash)(nil)).Return(int32(0), sql.ErrNoRows) svc.spawner. On("CreateJob", + mock.Anything, + mock.Anything, mock.MatchedBy(func(j *job.Job) bool { return j.Name.String == "LINK / ETH | version 3 | contract 0x0000000000000000000000000000000000000000" }), - mock.Anything, ). - Run(func(args mock.Arguments) { (args.Get(0).(*job.Job)).ID = 1 }). + Run(func(args mock.Arguments) { (args.Get(2).(*job.Job)).ID = 1 }). Return(nil) svc.orm.On("ApproveSpec", + mock.Anything, spec.ID, externalJobID, - mock.Anything, ).Return(nil) svc.fmsClient.On("ApprovedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -3171,6 +3321,8 @@ chainID = 0 Version: int64(spec.Version), }, ).Return(nil, errors.New("failure")) + svc.orm.On("WithDataSource", mock.Anything).Return(feeds.ORM(svc.orm)) + svc.jobORM.On("WithDataSource", mock.Anything).Return(job.ORM(svc.jobORM)) }, id: spec.ID, force: false, @@ -3227,12 +3379,12 @@ func Test_Service_RejectSpec(t *testing.T) { { name: "Success", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) svc.orm.On("RejectSpec", - spec.ID, mock.Anything, + spec.ID, ).Return(nil) svc.fmsClient.On("RejectedJob", mock.MatchedBy(func(ctx context.Context) bool { return true }), @@ -3241,20 +3393,25 @@ func Test_Service_RejectSpec(t *testing.T) { Version: int64(spec.Version), }, ).Return(&proto.RejectedJobResponse{}, nil) - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) + transactCall := svc.orm.On("Transact", mock.Anything, mock.Anything) + transactCall.Run(func(args mock.Arguments) { + fn := args[1].(func(orm feeds.ORM) error) + transactCall.ReturnArguments = mock.Arguments{fn(svc.orm)} + }) }, }, { name: "Fails to get spec", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(nil, errors.New("failure")) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(nil, errors.New("failure")) }, wantErr: "failure", }, { name: "Cannot be a rejected proposal", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(&feeds.JobProposalSpec{ + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(&feeds.JobProposalSpec{ Status: feeds.SpecStatusRejected, }, nil) }, @@ -3263,16 +3420,16 @@ func Test_Service_RejectSpec(t *testing.T) { { name: "Fails to get proposal", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(nil, errors.New("failure")) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(nil, errors.New("failure")) }, wantErr: "failure", }, { name: "FMS not connected", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(nil, errors.New("disconnected")) }, wantErr: "disconnected", @@ -3280,18 +3437,23 @@ func Test_Service_RejectSpec(t *testing.T) { { name: "Fails to update spec", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) svc.orm.On("RejectSpec", mock.Anything, mock.Anything).Return(errors.New("failure")) + transactCall := svc.orm.On("Transact", mock.Anything, mock.Anything) + transactCall.Run(func(args mock.Arguments) { + fn := args[1].(func(orm feeds.ORM) error) + transactCall.ReturnArguments = mock.Arguments{fn(svc.orm)} + }) }, wantErr: "failure", }, { name: "Fails to update spec", before: func(svc *TestService) { - svc.orm.On("GetSpec", spec.ID, mock.Anything).Return(spec, nil) - svc.orm.On("GetJobProposal", jp.ID, mock.Anything).Return(jp, nil) + svc.orm.On("GetSpec", mock.Anything, spec.ID).Return(spec, nil) + svc.orm.On("GetJobProposal", mock.Anything, jp.ID).Return(jp, nil) svc.connMgr.On("GetClient", jp.FeedsManagerID).Return(svc.fmsClient, nil) svc.orm.On("RejectSpec", mock.Anything, mock.Anything).Return(nil) svc.fmsClient. @@ -3302,6 +3464,11 @@ func Test_Service_RejectSpec(t *testing.T) { Version: int64(spec.Version), }). Return(nil, errors.New("rpc failure")) + transactCall := svc.orm.On("Transact", mock.Anything, mock.Anything) + transactCall.Run(func(args mock.Arguments) { + fn := args[1].(func(orm feeds.ORM) error) + transactCall.ReturnArguments = mock.Arguments{fn(svc.orm)} + }) }, wantErr: "rpc failure", }, @@ -3352,9 +3519,9 @@ func Test_Service_UpdateSpecDefinition(t *testing.T) { name: "success", before: func(svc *TestService) { svc.orm. - On("GetSpec", specID, mock.Anything). + On("GetSpec", mock.Anything, specID, mock.Anything). Return(spec, nil) - svc.orm.On("UpdateSpecDefinition", + svc.orm.On("UpdateSpecDefinition", mock.Anything, specID, updatedSpec, mock.Anything, @@ -3366,7 +3533,7 @@ func Test_Service_UpdateSpecDefinition(t *testing.T) { name: "does not exist", before: func(svc *TestService) { svc.orm. - On("GetSpec", specID, mock.Anything). + On("GetSpec", mock.Anything, specID, mock.Anything). Return(nil, sql.ErrNoRows) }, specID: specID, @@ -3376,7 +3543,7 @@ func Test_Service_UpdateSpecDefinition(t *testing.T) { name: "other get errors", before: func(svc *TestService) { svc.orm. - On("GetSpec", specID, mock.Anything). + On("GetSpec", mock.Anything, specID, mock.Anything). Return(nil, errors.New("other db error")) }, specID: specID, @@ -3391,7 +3558,7 @@ func Test_Service_UpdateSpecDefinition(t *testing.T) { } svc.orm. - On("GetSpec", specID, mock.Anything). + On("GetSpec", mock.Anything, specID, mock.Anything). Return(spec, nil) }, specID: specID, @@ -3443,18 +3610,18 @@ func Test_Service_StartStop(t *testing.T) { name: "success with a feeds manager connection", beforeFunc: func(svc *TestService) { svc.csaKeystore.On("GetAll").Return([]csakey.KeyV2{key}, nil) - svc.orm.On("ListManagers").Return([]feeds.FeedsManager{mgr}, nil) + svc.orm.On("ListManagers", mock.Anything).Return([]feeds.FeedsManager{mgr}, nil) svc.connMgr.On("IsConnected", mgr.ID).Return(false) svc.connMgr.On("Connect", mock.IsType(feeds.ConnectOpts{})) svc.connMgr.On("Close") - svc.orm.On("CountJobProposalsByStatus").Return(&feeds.JobProposalCounts{}, nil) + svc.orm.On("CountJobProposalsByStatus", mock.Anything).Return(&feeds.JobProposalCounts{}, nil) }, }, { name: "success with no registered managers", beforeFunc: func(svc *TestService) { svc.csaKeystore.On("GetAll").Return([]csakey.KeyV2{key}, nil) - svc.orm.On("ListManagers").Return([]feeds.FeedsManager{}, nil) + svc.orm.On("ListManagers", mock.Anything).Return([]feeds.FeedsManager{}, nil) svc.connMgr.On("Close") }, }, diff --git a/core/services/fluxmonitorv2/delegate.go b/core/services/fluxmonitorv2/delegate.go index b7b0df77cc4..1f0b72877ff 100644 --- a/core/services/fluxmonitorv2/delegate.go +++ b/core/services/fluxmonitorv2/delegate.go @@ -5,8 +5,7 @@ import ( "github.com/pkg/errors" - "github.com/jmoiron/sqlx" - + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" txmgrcommon "github.com/smartcontractkit/chainlink/v2/common/txmgr" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" "github.com/smartcontractkit/chainlink/v2/core/chains/legacyevm" @@ -25,7 +24,7 @@ type DelegateConfig interface { // Delegate represents a Flux Monitor delegate type Delegate struct { cfg DelegateConfig - db *sqlx.DB + ds sqlutil.DataSource ethKeyStore keystore.Eth jobORM job.ORM pipelineORM pipeline.ORM @@ -43,13 +42,13 @@ func NewDelegate( jobORM job.ORM, pipelineORM pipeline.ORM, pipelineRunner pipeline.Runner, - db *sqlx.DB, + ds sqlutil.DataSource, legacyChains legacyevm.LegacyChainContainer, lggr logger.Logger, ) *Delegate { return &Delegate{ cfg: cfg, - db: db, + ds: ds, ethKeyStore: ethKeyStore, jobORM: jobORM, pipelineORM: pipelineORM, @@ -86,8 +85,8 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] fm, err := NewFromJobSpec( jb, - d.db, - NewORM(d.db, d.lggr, chain.TxManager(), strategy, checker), + d.ds, + NewORM(d.ds, d.lggr, chain.TxManager(), strategy, checker), d.jobORM, d.pipelineORM, NewKeyStore(d.ethKeyStore), diff --git a/core/services/fluxmonitorv2/flux_monitor.go b/core/services/fluxmonitorv2/flux_monitor.go index 5eebb319030..dd30156e15e 100644 --- a/core/services/fluxmonitorv2/flux_monitor.go +++ b/core/services/fluxmonitorv2/flux_monitor.go @@ -13,8 +13,6 @@ import ( "github.com/pkg/errors" "github.com/shopspring/decimal" - "github.com/jmoiron/sqlx" - "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" @@ -149,7 +147,7 @@ func NewFluxMonitor( // validation. func NewFromJobSpec( jobSpec job.Job, - db *sqlx.DB, + ds sqlutil.DataSource, orm ORM, jobORM job.ORM, pipelineORM pipeline.ORM, @@ -250,7 +248,7 @@ func NewFromJobSpec( pipelineRunner, jobSpec, *jobSpec.PipelineSpec, - db, + ds, orm, jobORM, pipelineORM, @@ -760,7 +758,7 @@ func (fm *FluxMonitor) respondToNewRoundLog(log flux_aggregator_wrapper.FluxAggr result, err := results.FinalResult(newRoundLogger).SingularResult() if err != nil || result.Error != nil { newRoundLogger.Errorw("can't fetch answer", "err", err, "result", result) - fm.jobORM.TryRecordError(fm.spec.JobID, "Error polling") + fm.jobORM.TryRecordError(ctx, fm.spec.JobID, "Error polling") return } answer, err := utils.ToDecimal(result.Value) @@ -865,7 +863,7 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker roundState, err := fm.roundState(0) if err != nil { l.Errorw("unable to determine eligibility to submit from FluxAggregator contract", "err", err) - fm.jobORM.TryRecordError(fm.spec.JobID, + fm.jobORM.TryRecordError(ctx, fm.spec.JobID, "Unable to call roundState method on provided contract. Check contract address.", ) @@ -885,7 +883,7 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker roundStateNew, err2 := fm.roundState(roundState.RoundId) if err2 != nil { l.Errorw("unable to determine eligibility to submit from FluxAggregator contract", "err", err2) - fm.jobORM.TryRecordError(fm.spec.JobID, + fm.jobORM.TryRecordError(ctx, fm.spec.JobID, "Unable to call roundState method on provided contract. Check contract address.", ) @@ -962,13 +960,13 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker run, results, err := fm.runner.ExecuteRun(ctx, fm.spec, vars, fm.logger) if err != nil { l.Errorw("can't fetch answer", "err", err) - fm.jobORM.TryRecordError(fm.spec.JobID, "Error polling") + fm.jobORM.TryRecordError(ctx, fm.spec.JobID, "Error polling") return } result, err := results.FinalResult(l).SingularResult() if err != nil || result.Error != nil { l.Errorw("can't fetch answer", "err", err, "result", result) - fm.jobORM.TryRecordError(fm.spec.JobID, "Error polling") + fm.jobORM.TryRecordError(ctx, fm.spec.JobID, "Error polling") return } answer, err := utils.ToDecimal(result.Value) @@ -1041,7 +1039,7 @@ func (fm *FluxMonitor) isValidSubmission(ctx context.Context, l logger.Logger, a "max", fm.submissionChecker.Max, "answer", answer, ) - fm.jobORM.TryRecordError(fm.spec.JobID, "Answer is outside acceptable range") + fm.jobORM.TryRecordError(ctx, fm.spec.JobID, "Answer is outside acceptable range") jobId := fm.spec.JobID jobName := fm.spec.JobName diff --git a/core/services/fluxmonitorv2/flux_monitor_test.go b/core/services/fluxmonitorv2/flux_monitor_test.go index 2ac424bceab..b3a5bcee6b9 100644 --- a/core/services/fluxmonitorv2/flux_monitor_test.go +++ b/core/services/fluxmonitorv2/flux_monitor_test.go @@ -515,6 +515,7 @@ func TestFluxMonitor_PollIfEligible_Creates_JobErr(t *testing.T) { tm.jobORM. On("TryRecordError", + mock.Anything, pipelineSpec.JobID, "Unable to call roundState method on provided contract. Check contract address.", ).Once() diff --git a/core/services/fluxmonitorv2/integrations_test.go b/core/services/fluxmonitorv2/integrations_test.go index 7a967300867..6b9dcb99262 100644 --- a/core/services/fluxmonitorv2/integrations_test.go +++ b/core/services/fluxmonitorv2/integrations_test.go @@ -24,9 +24,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/jmoiron/sqlx" - commonconfig "github.com/smartcontractkit/chainlink-common/pkg/config" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/bridges" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/log" @@ -43,7 +42,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/chainlink" "github.com/smartcontractkit/chainlink/v2/core/services/fluxmonitorv2" "github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/ethkey" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/store/models" "github.com/smartcontractkit/chainlink/v2/core/utils" @@ -393,32 +391,34 @@ func assertNoSubmission(t *testing.T, // assertPipelineRunCreated checks that a pipeline exists for a given round and // verifies the answer -func assertPipelineRunCreated(t *testing.T, db *sqlx.DB, roundID int64, result int64) pipeline.Run { +func assertPipelineRunCreated(t *testing.T, ds sqlutil.DataSource, roundID int64, result int64) pipeline.Run { + ctx := testutils.Context(t) // Fetch the stats to extract the run id stats := fluxmonitorv2.FluxMonitorRoundStatsV2{} - require.NoError(t, db.Get(&stats, "SELECT * FROM flux_monitor_round_stats_v2 WHERE round_id = $1", roundID)) + require.NoError(t, ds.GetContext(ctx, &stats, "SELECT * FROM flux_monitor_round_stats_v2 WHERE round_id = $1", roundID)) if stats.ID == 0 { t.Fatalf("Stats for round id: %v not found!", roundID) } require.True(t, stats.PipelineRunID.Valid) // Verify the pipeline run data run := pipeline.Run{} - require.NoError(t, db.Get(&run, `SELECT * FROM pipeline_runs WHERE id = $1`, stats.PipelineRunID.Int64), "runID %v", stats.PipelineRunID) + require.NoError(t, ds.GetContext(ctx, &run, `SELECT * FROM pipeline_runs WHERE id = $1`, stats.PipelineRunID.Int64), "runID %v", stats.PipelineRunID) assert.Equal(t, []interface{}{result}, run.Outputs.Val) return run } -func checkLogWasConsumed(t *testing.T, fa fluxAggregatorUniverse, db *sqlx.DB, pipelineSpecID int32, blockNumber uint64, cfg pg.QConfig) { +func checkLogWasConsumed(t *testing.T, fa fluxAggregatorUniverse, ds sqlutil.DataSource, pipelineSpecID int32, blockNumber uint64) { t.Helper() lggr := logger.TestLogger(t) lggr.Infof("Waiting for log on block: %v, job id: %v", blockNumber, pipelineSpecID) g := gomega.NewWithT(t) g.Eventually(func() bool { + ctx := testutils.Context(t) block := fa.backend.Blockchain().GetBlockByNumber(blockNumber) require.NotNil(t, block) - orm := log.NewORM(db, fa.evmChainID) - consumed, err := orm.WasBroadcastConsumed(testutils.Context(t), block.Hash(), 0, pipelineSpecID) + orm := log.NewORM(ds, fa.evmChainID) + consumed, err := orm.WasBroadcastConsumed(ctx, block.Hash(), 0, pipelineSpecID) require.NoError(t, err) fa.backend.Commit() return consumed @@ -559,12 +559,12 @@ func TestFluxMonitor_Deviation(t *testing.T) { initialBalance, receiptBlock, ) - assertPipelineRunCreated(t, app.GetSqlxDB(), 1, int64(100)) + assertPipelineRunCreated(t, app.GetDB(), 1, int64(100)) // Need to wait until NewRound log is consumed - otherwise there is a chance // it will arrive after the next answer is submitted, and cause // DeleteFluxMonitorRoundsBackThrough to delete previous stats - checkLogWasConsumed(t, fa, app.GetSqlxDB(), jobID, receiptBlock, app.GetConfig().Database()) + checkLogWasConsumed(t, fa, app.GetDB(), jobID, receiptBlock) lggr.Info("Updating price to 103") // Change reported price to a value outside the deviation @@ -588,12 +588,12 @@ func TestFluxMonitor_Deviation(t *testing.T) { initialBalance-fee, receiptBlock, ) - assertPipelineRunCreated(t, app.GetSqlxDB(), 2, int64(103)) + assertPipelineRunCreated(t, app.GetDB(), 2, int64(103)) // Need to wait until NewRound log is consumed - otherwise there is a chance // it will arrive after the next answer is submitted, and cause // DeleteFluxMonitorRoundsBackThrough to delete previous stats - checkLogWasConsumed(t, fa, app.GetSqlxDB(), jobID, receiptBlock, app.GetConfig().Database()) + checkLogWasConsumed(t, fa, app.GetDB(), jobID, receiptBlock) // Should not received a submission as it is inside the deviation reportPrice.Store(104) @@ -795,7 +795,7 @@ ds1 -> ds1_parse // node doesn't submit initial response, because flag is up // Wait here so the next lower flags doesn't trigger immediately - cltest.AssertPipelineRunsStays(t, j.PipelineSpec.ID, app.GetSqlxDB(), 0) + cltest.AssertPipelineRunsStays(t, j.PipelineSpec.ID, app.GetDB(), 0) // lower global kill switch flag - should trigger job run _, err = fa.flagsContract.LowerFlags(fa.sergey, []common.Address{evmutils.ZeroAddress}) @@ -910,7 +910,7 @@ ds1 -> ds1_parse jobID, err := strconv.ParseInt(j.ID, 10, 32) require.NoError(t, err) - jse := cltest.WaitForSpecErrorV2(t, app.GetSqlxDB(), int32(jobID), 1) + jse := cltest.WaitForSpecErrorV2(t, app.GetDB(), int32(jobID), 1) assert.Contains(t, jse[0].Description, "Answer is outside acceptable range") } diff --git a/core/services/fluxmonitorv2/orm_test.go b/core/services/fluxmonitorv2/orm_test.go index db00fabb4ff..9652c0b0e27 100644 --- a/core/services/fluxmonitorv2/orm_test.go +++ b/core/services/fluxmonitorv2/orm_test.go @@ -99,14 +99,14 @@ func TestORM_UpdateFluxMonitorRoundStats(t *testing.T) { // Instantiate a real job ORM because we need to create a job to satisfy // a check in pipeline.CreateRun - jobORM := job.NewORM(db, pipelineORM, bridgeORM, keyStore, lggr, cfg.Database()) + jobORM := job.NewORM(db, pipelineORM, bridgeORM, keyStore, lggr) orm := newORM(t, db, nil) address := testutils.NewAddress() var roundID uint32 = 1 jb := makeJob(t) - require.NoError(t, jobORM.CreateJob(jb)) + require.NoError(t, jobORM.CreateJob(ctx, jb)) for expectedCount := uint64(1); expectedCount < 4; expectedCount++ { f := time.Now() diff --git a/core/services/functions/listener_test.go b/core/services/functions/listener_test.go index f3754bbbc29..a1a29bf2500 100644 --- a/core/services/functions/listener_test.go +++ b/core/services/functions/listener_test.go @@ -186,7 +186,7 @@ func TestFunctionsListener_HandleOffchainRequest_Success(t *testing.T) { uni := NewFunctionsListenerUniverse(t, 0, 1_000_000) - uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Return(nil) + uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything, mock.Anything).Return(nil) uni.bridgeAccessor.On("NewExternalAdapterClient", mock.Anything).Return(uni.eaClient, nil) uni.eaClient.On("RunComputation", mock.Anything, RequestIDStr, mock.Anything, SubscriptionOwner.Hex(), SubscriptionID, mock.Anything, mock.Anything, mock.Anything).Return(ResultBytes, nil, nil, nil) uni.pluginORM.On("SetResult", mock.Anything, RequestID, ResultBytes, mock.Anything, mock.Anything).Return(nil) @@ -230,7 +230,7 @@ func TestFunctionsListener_HandleOffchainRequest_InternalError(t *testing.T) { testutils.SkipShortDB(t) t.Parallel() uni := NewFunctionsListenerUniverse(t, 0, 1_000_000) - uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Return(nil) + uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything, mock.Anything).Return(nil) uni.bridgeAccessor.On("NewExternalAdapterClient", mock.Anything).Return(uni.eaClient, nil) uni.eaClient.On("RunComputation", mock.Anything, RequestIDStr, mock.Anything, SubscriptionOwner.Hex(), SubscriptionID, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil, nil, errors.New("error")) uni.pluginORM.On("SetError", mock.Anything, RequestID, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) diff --git a/core/services/gateway/delegate.go b/core/services/gateway/delegate.go index 8cddc027803..5a30228db4c 100644 --- a/core/services/gateway/delegate.go +++ b/core/services/gateway/delegate.go @@ -5,34 +5,31 @@ import ( "encoding/json" "github.com/google/uuid" - "github.com/jmoiron/sqlx" "github.com/pelletier/go-toml" "github.com/pkg/errors" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/chains/legacyevm" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) type Delegate struct { legacyChains legacyevm.LegacyChainContainer ks keystore.Eth - db *sqlx.DB - cfg pg.QConfig + ds sqlutil.DataSource lggr logger.Logger } var _ job.Delegate = (*Delegate)(nil) -func NewDelegate(legacyChains legacyevm.LegacyChainContainer, ks keystore.Eth, db *sqlx.DB, cfg pg.QConfig, lggr logger.Logger) *Delegate { +func NewDelegate(legacyChains legacyevm.LegacyChainContainer, ks keystore.Eth, ds sqlutil.DataSource, lggr logger.Logger) *Delegate { return &Delegate{ legacyChains: legacyChains, ks: ks, - db: db, - cfg: cfg, + ds: ds, lggr: lggr, } } @@ -57,7 +54,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) (services if err2 != nil { return nil, errors.Wrap(err2, "unmarshal gateway config") } - handlerFactory := NewHandlerFactory(d.legacyChains, d.db, d.cfg, d.lggr) + handlerFactory := NewHandlerFactory(d.legacyChains, d.ds, d.lggr) gateway, err := NewGatewayFromConfig(&gatewayConfig, handlerFactory, d.lggr) if err != nil { return nil, err diff --git a/core/services/gateway/gateway_test.go b/core/services/gateway/gateway_test.go index 7a5457c788c..3218c5428a2 100644 --- a/core/services/gateway/gateway_test.go +++ b/core/services/gateway/gateway_test.go @@ -57,7 +57,7 @@ Address = "0x0001020304050607080900010203040506070809" `) lggr := logger.TestLogger(t) - _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, nil, lggr), lggr) + _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, lggr), lggr) require.NoError(t, err) } @@ -75,7 +75,7 @@ HandlerName = "dummy" `) lggr := logger.TestLogger(t) - _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, nil, lggr), lggr) + _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, lggr), lggr) require.Error(t, err) } @@ -89,7 +89,7 @@ HandlerName = "no_such_handler" `) lggr := logger.TestLogger(t) - _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, nil, lggr), lggr) + _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, lggr), lggr) require.Error(t, err) } @@ -103,7 +103,7 @@ SomeOtherField = "abcd" `) lggr := logger.TestLogger(t) - _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, nil, lggr), lggr) + _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, lggr), lggr) require.Error(t, err) } @@ -121,7 +121,7 @@ Address = "0xnot_an_address" `) lggr := logger.TestLogger(t) - _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, nil, lggr), lggr) + _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, lggr), lggr) require.Error(t, err) } @@ -129,7 +129,7 @@ func TestGateway_CleanStartAndClose(t *testing.T) { t.Parallel() lggr := logger.TestLogger(t) - gateway, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, buildConfig("")), gateway.NewHandlerFactory(nil, nil, nil, lggr), lggr) + gateway, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, buildConfig("")), gateway.NewHandlerFactory(nil, nil, lggr), lggr) require.NoError(t, err) servicetest.Run(t, gateway) } diff --git a/core/services/gateway/handler_factory.go b/core/services/gateway/handler_factory.go index 8ccae8c7c4b..ca6b98e55aa 100644 --- a/core/services/gateway/handler_factory.go +++ b/core/services/gateway/handler_factory.go @@ -4,14 +4,12 @@ import ( "encoding/json" "fmt" - "github.com/jmoiron/sqlx" - + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/chains/legacyevm" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) const ( @@ -21,21 +19,20 @@ const ( type handlerFactory struct { legacyChains legacyevm.LegacyChainContainer - db *sqlx.DB - cfg pg.QConfig + ds sqlutil.DataSource lggr logger.Logger } var _ HandlerFactory = (*handlerFactory)(nil) -func NewHandlerFactory(legacyChains legacyevm.LegacyChainContainer, db *sqlx.DB, cfg pg.QConfig, lggr logger.Logger) HandlerFactory { - return &handlerFactory{legacyChains, db, cfg, lggr} +func NewHandlerFactory(legacyChains legacyevm.LegacyChainContainer, ds sqlutil.DataSource, lggr logger.Logger) HandlerFactory { + return &handlerFactory{legacyChains, ds, lggr} } func (hf *handlerFactory) NewHandler(handlerType HandlerType, handlerConfig json.RawMessage, donConfig *config.DONConfig, don handlers.DON) (handlers.Handler, error) { switch handlerType { case FunctionsHandlerType: - return functions.NewFunctionsHandlerFromConfig(handlerConfig, donConfig, don, hf.legacyChains, hf.db, hf.cfg, hf.lggr) + return functions.NewFunctionsHandlerFromConfig(handlerConfig, donConfig, don, hf.legacyChains, hf.ds, hf.lggr) case DummyHandlerType: return handlers.NewDummyHandler(donConfig, don, hf.lggr) default: diff --git a/core/services/gateway/handlers/functions/handler.functions.go b/core/services/gateway/handlers/functions/handler.functions.go index 692534db598..2d55bb23fde 100644 --- a/core/services/gateway/handlers/functions/handler.functions.go +++ b/core/services/gateway/handlers/functions/handler.functions.go @@ -10,13 +10,13 @@ import ( "time" "github.com/ethereum/go-ethereum/common" - "github.com/jmoiron/sqlx" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "go.uber.org/multierr" "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/chains/legacyevm" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" @@ -25,7 +25,6 @@ import ( hc "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" fallow "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/allowlist" fsub "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/subscriptions" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) var ( @@ -100,7 +99,7 @@ type PendingRequest struct { var _ handlers.Handler = (*functionsHandler)(nil) -func NewFunctionsHandlerFromConfig(handlerConfig json.RawMessage, donConfig *config.DONConfig, don handlers.DON, legacyChains legacyevm.LegacyChainContainer, db *sqlx.DB, qcfg pg.QConfig, lggr logger.Logger) (handlers.Handler, error) { +func NewFunctionsHandlerFromConfig(handlerConfig json.RawMessage, donConfig *config.DONConfig, don handlers.DON, legacyChains legacyevm.LegacyChainContainer, ds sqlutil.DataSource, lggr logger.Logger) (handlers.Handler, error) { var cfg FunctionsHandlerConfig err := json.Unmarshal(handlerConfig, &cfg) if err != nil { @@ -114,7 +113,7 @@ func NewFunctionsHandlerFromConfig(handlerConfig json.RawMessage, donConfig *con return nil, err2 } - orm, err2 := fallow.NewORM(db, lggr, cfg.OnchainAllowlist.ContractAddress) + orm, err2 := fallow.NewORM(ds, lggr, cfg.OnchainAllowlist.ContractAddress) if err2 != nil { return nil, err2 } @@ -143,7 +142,7 @@ func NewFunctionsHandlerFromConfig(handlerConfig json.RawMessage, donConfig *con return nil, err2 } - orm, err2 := fsub.NewORM(db, lggr, cfg.OnchainSubscriptions.ContractAddress) + orm, err2 := fsub.NewORM(ds, lggr, cfg.OnchainSubscriptions.ContractAddress) if err2 != nil { return nil, err2 } diff --git a/core/services/gateway/handlers/functions/handler.functions_test.go b/core/services/gateway/handlers/functions/handler.functions_test.go index 2e2c1c77caf..b7abeed8a99 100644 --- a/core/services/gateway/handlers/functions/handler.functions_test.go +++ b/core/services/gateway/handlers/functions/handler.functions_test.go @@ -84,7 +84,7 @@ func sendNodeReponses(t *testing.T, handler handlers.Handler, userRequestMsg api func TestFunctionsHandler_Minimal(t *testing.T) { t.Parallel() - handler, err := functions.NewFunctionsHandlerFromConfig(json.RawMessage("{}"), &config.DONConfig{}, nil, nil, nil, nil, logger.TestLogger(t)) + handler, err := functions.NewFunctionsHandlerFromConfig(json.RawMessage("{}"), &config.DONConfig{}, nil, nil, nil, logger.TestLogger(t)) require.NoError(t, err) // empty message should always error out @@ -96,7 +96,7 @@ func TestFunctionsHandler_Minimal(t *testing.T) { func TestFunctionsHandler_CleanStartAndClose(t *testing.T) { t.Parallel() - handler, err := functions.NewFunctionsHandlerFromConfig(json.RawMessage("{}"), &config.DONConfig{}, nil, nil, nil, nil, logger.TestLogger(t)) + handler, err := functions.NewFunctionsHandlerFromConfig(json.RawMessage("{}"), &config.DONConfig{}, nil, nil, nil, logger.TestLogger(t)) require.NoError(t, err) servicetest.Run(t, handler) diff --git a/core/services/gateway/integration_tests/gateway_integration_test.go b/core/services/gateway/integration_tests/gateway_integration_test.go index 7f4a2ab58fa..38a6b6ebbca 100644 --- a/core/services/gateway/integration_tests/gateway_integration_test.go +++ b/core/services/gateway/integration_tests/gateway_integration_test.go @@ -143,7 +143,7 @@ func TestIntegration_Gateway_NoFullNodes_BasicConnectionAndMessage(t *testing.T) // Launch Gateway lggr := logger.TestLogger(t) gatewayConfig := fmt.Sprintf(gatewayConfigTemplate, nodeKeys.Address) - gateway, err := gateway.NewGatewayFromConfig(parseGatewayConfig(t, gatewayConfig), gateway.NewHandlerFactory(nil, nil, nil, lggr), lggr) + gateway, err := gateway.NewGatewayFromConfig(parseGatewayConfig(t, gatewayConfig), gateway.NewHandlerFactory(nil, nil, lggr), lggr) require.NoError(t, err) servicetest.Run(t, gateway) userPort, nodePort := gateway.GetUserPort(), gateway.GetNodePort() diff --git a/core/services/job/common.go b/core/services/job/common.go index 58b1754cbf5..055195440b0 100644 --- a/core/services/job/common.go +++ b/core/services/job/common.go @@ -3,8 +3,6 @@ package job import ( "context" "net/url" - - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) //go:generate mockery --quiet --name ServiceCtx --output ./mocks/ --case=underscore @@ -22,7 +20,6 @@ type ServiceCtx interface { type Config interface { URL() url.URL - pg.QConfig } // ServiceAdapter is a helper introduced for transitioning from Service to ServiceCtx. diff --git a/core/services/job/helpers_test.go b/core/services/job/helpers_test.go index 22e1b0bef63..7120fe4200c 100644 --- a/core/services/job/helpers_test.go +++ b/core/services/job/helpers_test.go @@ -268,8 +268,7 @@ func makeOCRJobSpecFromToml(t *testing.T, jobSpecToml string) *job.Job { return &jb } -func makeOCR2VRFJobSpec(t testing.TB, ks keystore.Master, cfg chainlink.GeneralConfig, - transmitter common.Address, chainID *big.Int, fromBlock uint64) *job.Job { +func makeOCR2VRFJobSpec(t testing.TB, ks keystore.Master, transmitter common.Address, chainID *big.Int, fromBlock uint64) *job.Job { t.Helper() ctx := testutils.Context(t) diff --git a/core/services/job/job_orm_test.go b/core/services/job/job_orm_test.go index ffbd02c512b..f07b68d9987 100644 --- a/core/services/job/job_orm_test.go +++ b/core/services/job/job_orm_test.go @@ -73,20 +73,21 @@ serverPubKey = '8fa807463ad73f9ee855cfd60ba406dcf98a2855b3dd8af613107b0f6890a707 func TestORM(t *testing.T) { t.Parallel() - ctx := testutils.Context(t) + config := configtest.NewTestGeneralConfig(t) db := pgtest.NewSqlxDB(t) keyStore := cltest.NewKeyStore(t, db) ethKeyStore := keyStore.Eth() - require.NoError(t, keyStore.OCR().Add(ctx, cltest.DefaultOCRKey)) - require.NoError(t, keyStore.P2P().Add(ctx, cltest.DefaultP2PKey)) + func() { + ctx := testutils.Context(t) + require.NoError(t, keyStore.OCR().Add(ctx, cltest.DefaultOCRKey)) + require.NoError(t, keyStore.P2P().Add(ctx, cltest.DefaultP2PKey)) + }() pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) - bridgesORM := bridges.NewORM(db) - - orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) borm := bridges.NewORM(db) + orm := NewTestORM(t, db, pipelineORM, borm, keyStore) _, bridge := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) _, bridge2 := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) @@ -94,7 +95,7 @@ func TestORM(t *testing.T) { jb := makeOCRJobSpec(t, address, bridge.Name.String(), bridge2.Name.String()) t.Run("it creates job specs", func(t *testing.T) { - err := orm.CreateJob(jb) + err := orm.CreateJob(testutils.Context(t), jb) require.NoError(t, err) var returnedSpec job.Job @@ -109,8 +110,9 @@ func TestORM(t *testing.T) { }) t.Run("it correctly mark job_pipeline_specs as primary when creating a job", func(t *testing.T) { + ctx := testutils.Context(t) jb2 := makeOCRJobSpec(t, address, bridge.Name.String(), bridge2.Name.String()) - err := orm.CreateJob(jb2) + err := orm.CreateJob(ctx, jb2) require.NoError(t, err) var pipelineSpec pipeline.Spec @@ -128,7 +130,7 @@ func TestORM(t *testing.T) { t.Run("autogenerates external job ID if missing", func(t *testing.T) { jb2 := makeOCRJobSpec(t, address, bridge.Name.String(), bridge2.Name.String()) jb2.ExternalJobID = uuid.UUID{} - err := orm.CreateJob(jb2) + err := orm.CreateJob(testutils.Context(t), jb2) require.NoError(t, err) var returnedSpec job.Job @@ -145,7 +147,7 @@ func TestORM(t *testing.T) { require.NoError(t, err) require.Len(t, dbSpecs, 3) - err = orm.DeleteJob(jb.ID) + err = orm.DeleteJob(testutils.Context(t), jb.ID) require.NoError(t, err) dbSpecs = []job.Job{} @@ -155,8 +157,9 @@ func TestORM(t *testing.T) { }) t.Run("increase job spec error occurrence", func(t *testing.T) { + ctx := testutils.Context(t) jb3 := makeOCRJobSpec(t, address, bridge.Name.String(), bridge2.Name.String()) - err := orm.CreateJob(jb3) + err := orm.CreateJob(ctx, jb3) require.NoError(t, err) var jobSpec job.Job err = db.Get(&jobSpec, "SELECT * FROM jobs") @@ -164,9 +167,9 @@ func TestORM(t *testing.T) { ocrSpecError1 := "ocr spec 1 errored" ocrSpecError2 := "ocr spec 2 errored" - require.NoError(t, orm.RecordError(jobSpec.ID, ocrSpecError1)) - require.NoError(t, orm.RecordError(jobSpec.ID, ocrSpecError1)) - require.NoError(t, orm.RecordError(jobSpec.ID, ocrSpecError2)) + require.NoError(t, orm.RecordError(ctx, jobSpec.ID, ocrSpecError1)) + require.NoError(t, orm.RecordError(ctx, jobSpec.ID, ocrSpecError1)) + require.NoError(t, orm.RecordError(ctx, jobSpec.ID, ocrSpecError2)) var specErrors []job.SpecError err = db.Select(&specErrors, "SELECT * FROM job_spec_errors") @@ -193,8 +196,9 @@ func TestORM(t *testing.T) { }) t.Run("finds job spec error by ID", func(t *testing.T) { + ctx := testutils.Context(t) jb3 := makeOCRJobSpec(t, address, bridge.Name.String(), bridge2.Name.String()) - err := orm.CreateJob(jb3) + err := orm.CreateJob(ctx, jb3) require.NoError(t, err) var jobSpec job.Job err = db.Get(&jobSpec, "SELECT * FROM jobs") @@ -207,8 +211,8 @@ func TestORM(t *testing.T) { ocrSpecError1 := "ocr spec 3 errored" ocrSpecError2 := "ocr spec 4 errored" - require.NoError(t, orm.RecordError(jobSpec.ID, ocrSpecError1)) - require.NoError(t, orm.RecordError(jobSpec.ID, ocrSpecError2)) + require.NoError(t, orm.RecordError(ctx, jobSpec.ID, ocrSpecError1)) + require.NoError(t, orm.RecordError(ctx, jobSpec.ID, ocrSpecError2)) var updatedSpecError []job.SpecError @@ -221,9 +225,9 @@ func TestORM(t *testing.T) { assert.Equal(t, ocrSpecError1, updatedSpecError[2].Description) assert.Equal(t, ocrSpecError2, updatedSpecError[3].Description) - dbSpecErr1, err := orm.FindSpecError(updatedSpecError[2].ID) + dbSpecErr1, err := orm.FindSpecError(ctx, updatedSpecError[2].ID) require.NoError(t, err) - dbSpecErr2, err := orm.FindSpecError(updatedSpecError[3].ID) + dbSpecErr2, err := orm.FindSpecError(ctx, updatedSpecError[3].ID) require.NoError(t, err) assert.Equal(t, uint(1), dbSpecErr1.Occurrences) @@ -251,11 +255,12 @@ func TestORM(t *testing.T) { drJob, err := directrequest.ValidatedDirectRequestSpec(drSpec) require.NoError(t, err) - err = orm.CreateJob(&drJob) + err = orm.CreateJob(testutils.Context(t), &drJob) require.NoError(t, err) }) t.Run("creates webhook specs along with external_initiator_webhook_specs", func(t *testing.T) { + ctx := testutils.Context(t) eiFoo := cltest.MustInsertExternalInitiator(t, borm) eiBar := cltest.MustInsertExternalInitiator(t, borm) @@ -263,22 +268,23 @@ func TestORM(t *testing.T) { {Name: eiFoo.Name, Spec: cltest.JSONFromString(t, `{}`)}, {Name: eiBar.Name, Spec: cltest.JSONFromString(t, `{"bar": 1}`)}, } - eim := webhook.NewExternalInitiatorManager(db, nil, logger.TestLogger(t), config.Database()) - jb, err := webhook.ValidatedWebhookSpec(testspecs.GenerateWebhookSpec(testspecs.WebhookSpecParams{ExternalInitiators: eiWS}).Toml(), eim) + eim := webhook.NewExternalInitiatorManager(db, nil) + jb, err := webhook.ValidatedWebhookSpec(ctx, testspecs.GenerateWebhookSpec(testspecs.WebhookSpecParams{ExternalInitiators: eiWS}).Toml(), eim) require.NoError(t, err) - err = orm.CreateJob(&jb) + err = orm.CreateJob(testutils.Context(t), &jb) require.NoError(t, err) cltest.AssertCount(t, db, "external_initiator_webhook_specs", 2) }) t.Run("it creates and deletes records for blockhash store jobs", func(t *testing.T) { + ctx := testutils.Context(t) bhsJob, err := blockhashstore.ValidatedSpec( testspecs.GenerateBlockhashStoreSpec(testspecs.BlockhashStoreSpecParams{}).Toml()) require.NoError(t, err) - err = orm.CreateJob(&bhsJob) + err = orm.CreateJob(ctx, &bhsJob) require.NoError(t, err) savedJob, err := orm.FindJob(testutils.Context(t), bhsJob.ID) require.NoError(t, err) @@ -298,18 +304,19 @@ func TestORM(t *testing.T) { require.Equal(t, bhsJob.BlockhashStoreSpec.RunTimeout, savedJob.BlockhashStoreSpec.RunTimeout) require.Equal(t, bhsJob.BlockhashStoreSpec.EVMChainID, savedJob.BlockhashStoreSpec.EVMChainID) require.Equal(t, bhsJob.BlockhashStoreSpec.FromAddresses, savedJob.BlockhashStoreSpec.FromAddresses) - err = orm.DeleteJob(bhsJob.ID) + err = orm.DeleteJob(ctx, bhsJob.ID) require.NoError(t, err) _, err = orm.FindJob(testutils.Context(t), bhsJob.ID) require.Error(t, err) }) t.Run("it creates and deletes records for blockheaderfeeder jobs", func(t *testing.T) { + ctx := testutils.Context(t) bhsJob, err := blockheaderfeeder.ValidatedSpec( testspecs.GenerateBlockHeaderFeederSpec(testspecs.BlockHeaderFeederSpecParams{}).Toml()) require.NoError(t, err) - err = orm.CreateJob(&bhsJob) + err = orm.CreateJob(ctx, &bhsJob) require.NoError(t, err) savedJob, err := orm.FindJob(testutils.Context(t), bhsJob.ID) require.NoError(t, err) @@ -329,7 +336,7 @@ func TestORM(t *testing.T) { require.Equal(t, bhsJob.BlockHeaderFeederSpec.FromAddresses, savedJob.BlockHeaderFeederSpec.FromAddresses) require.Equal(t, bhsJob.BlockHeaderFeederSpec.GetBlockhashesBatchSize, savedJob.BlockHeaderFeederSpec.GetBlockhashesBatchSize) require.Equal(t, bhsJob.BlockHeaderFeederSpec.StoreBlockhashesBatchSize, savedJob.BlockHeaderFeederSpec.StoreBlockhashesBatchSize) - err = orm.DeleteJob(bhsJob.ID) + err = orm.DeleteJob(ctx, bhsJob.ID) require.NoError(t, err) _, err = orm.FindJob(testutils.Context(t), bhsJob.ID) require.Error(t, err) @@ -349,10 +356,11 @@ func TestORM_DeleteJob_DeletesAssociatedRecords(t *testing.T) { lggr := logger.TestLogger(t) pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) - jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) korm := keeper.NewORM(db, logger.TestLogger(t)) t.Run("it deletes records for offchainreporting jobs", func(t *testing.T) { + ctx := testutils.Context(t) _, bridge := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) _, bridge2 := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) @@ -366,13 +374,13 @@ func TestORM_DeleteJob_DeletesAssociatedRecords(t *testing.T) { }).Toml()) require.NoError(t, err) - err = jobORM.CreateJob(&jb) + err = jobORM.CreateJob(ctx, &jb) require.NoError(t, err) cltest.AssertCount(t, db, "ocr_oracle_specs", 1) cltest.AssertCount(t, db, "pipeline_specs", 1) - err = jobORM.DeleteJob(jb.ID) + err = jobORM.DeleteJob(ctx, jb.ID) require.NoError(t, err) cltest.AssertCount(t, db, "ocr_oracle_specs", 0) cltest.AssertCount(t, db, "pipeline_specs", 0) @@ -380,6 +388,7 @@ func TestORM_DeleteJob_DeletesAssociatedRecords(t *testing.T) { }) t.Run("it deletes records for keeper jobs", func(t *testing.T) { + ctx := testutils.Context(t) registry, keeperJob := cltest.MustInsertKeeperRegistry(t, db, korm, keyStore.Eth(), 0, 1, 20) cltest.MustInsertUpkeepForRegistry(t, db, registry) @@ -387,7 +396,7 @@ func TestORM_DeleteJob_DeletesAssociatedRecords(t *testing.T) { cltest.AssertCount(t, db, "keeper_registries", 1) cltest.AssertCount(t, db, "upkeep_registrations", 1) - err := jobORM.DeleteJob(keeperJob.ID) + err := jobORM.DeleteJob(ctx, keeperJob.ID) require.NoError(t, err) cltest.AssertCount(t, db, "keeper_specs", 0) cltest.AssertCount(t, db, "keeper_registries", 0) @@ -396,29 +405,31 @@ func TestORM_DeleteJob_DeletesAssociatedRecords(t *testing.T) { }) t.Run("it creates and deletes records for vrf jobs", func(t *testing.T) { - key, err := keyStore.VRF().Create(testutils.Context(t)) + ctx := testutils.Context(t) + key, err := keyStore.VRF().Create(ctx) require.NoError(t, err) pk := key.PublicKey jb, err := vrfcommon.ValidatedVRFSpec(testspecs.GenerateVRFSpec(testspecs.VRFSpecParams{PublicKey: pk.String()}).Toml()) require.NoError(t, err) - err = jobORM.CreateJob(&jb) + err = jobORM.CreateJob(ctx, &jb) require.NoError(t, err) cltest.AssertCount(t, db, "vrf_specs", 1) cltest.AssertCount(t, db, "jobs", 1) - err = jobORM.DeleteJob(jb.ID) + err = jobORM.DeleteJob(ctx, jb.ID) require.NoError(t, err) cltest.AssertCount(t, db, "vrf_specs", 0) cltest.AssertCount(t, db, "jobs", 0) }) t.Run("it deletes records for webhook jobs", func(t *testing.T) { + ctx := testutils.Context(t) ei := cltest.MustInsertExternalInitiator(t, bridges.NewORM(db)) jb, webhookSpec := cltest.MustInsertWebhookSpec(t, db) _, err := db.Exec(`INSERT INTO external_initiator_webhook_specs (external_initiator_id, webhook_spec_id, spec) VALUES ($1,$2,$3)`, ei.ID, webhookSpec.ID, `{"ei": "foo", "name": "webhookSpecTwoEIs"}`) require.NoError(t, err) - err = jobORM.DeleteJob(jb.ID) + err = jobORM.DeleteJob(ctx, jb.ID) require.NoError(t, err) cltest.AssertCount(t, db, "webhook_specs", 0) cltest.AssertCount(t, db, "external_initiator_webhook_specs", 0) @@ -449,7 +460,7 @@ func TestORM_CreateJob_VRFV2(t *testing.T) { pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) - jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) fromAddresses := []string{cltest.NewEIP55Address().String(), cltest.NewEIP55Address().String()} jb, err := vrfcommon.ValidatedVRFSpec(testspecs.GenerateVRFSpec( @@ -465,7 +476,7 @@ func TestORM_CreateJob_VRFV2(t *testing.T) { Toml()) require.NoError(t, err) - require.NoError(t, jobORM.CreateJob(&jb)) + require.NoError(t, jobORM.CreateJob(ctx, &jb)) cltest.AssertCount(t, db, "vrf_specs", 1) cltest.AssertCount(t, db, "jobs", 1) var requestedConfsDelay int64 @@ -505,20 +516,20 @@ func TestORM_CreateJob_VRFV2(t *testing.T) { var vrfOwnerAddress evmtypes.EIP55Address require.NoError(t, db.Get(&vrfOwnerAddress, `SELECT vrf_owner_address FROM vrf_specs LIMIT 1`)) require.Equal(t, "0x32891BD79647DC9136Fc0a59AAB48c7825eb624c", vrfOwnerAddress.Address().String()) - require.NoError(t, jobORM.DeleteJob(jb.ID)) + require.NoError(t, jobORM.DeleteJob(ctx, jb.ID)) cltest.AssertCount(t, db, "vrf_specs", 0) cltest.AssertCount(t, db, "jobs", 0) jb, err = vrfcommon.ValidatedVRFSpec(testspecs.GenerateVRFSpec(testspecs.VRFSpecParams{RequestTimeout: 1 * time.Hour}).Toml()) require.NoError(t, err) - require.NoError(t, jobORM.CreateJob(&jb)) + require.NoError(t, jobORM.CreateJob(ctx, &jb)) cltest.AssertCount(t, db, "vrf_specs", 1) cltest.AssertCount(t, db, "jobs", 1) require.NoError(t, db.Get(&requestedConfsDelay, `SELECT requested_confs_delay FROM vrf_specs LIMIT 1`)) require.Equal(t, int64(0), requestedConfsDelay) require.NoError(t, db.Get(&requestTimeout, `SELECT request_timeout FROM vrf_specs LIMIT 1`)) require.Equal(t, 1*time.Hour, requestTimeout) - require.NoError(t, jobORM.DeleteJob(jb.ID)) + require.NoError(t, jobORM.DeleteJob(ctx, jb.ID)) cltest.AssertCount(t, db, "vrf_specs", 0) cltest.AssertCount(t, db, "jobs", 0) } @@ -533,7 +544,7 @@ func TestORM_CreateJob_VRFV2Plus(t *testing.T) { lggr := logger.TestLogger(t) pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) - jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) fromAddresses := []string{cltest.NewEIP55Address().String(), cltest.NewEIP55Address().String()} jb, err := vrfcommon.ValidatedVRFSpec(testspecs.GenerateVRFSpec( @@ -550,7 +561,7 @@ func TestORM_CreateJob_VRFV2Plus(t *testing.T) { Toml()) require.NoError(t, err) - require.NoError(t, jobORM.CreateJob(&jb)) + require.NoError(t, jobORM.CreateJob(ctx, &jb)) cltest.AssertCount(t, db, "vrf_specs", 1) cltest.AssertCount(t, db, "jobs", 1) var requestedConfsDelay int64 @@ -589,7 +600,7 @@ func TestORM_CreateJob_VRFV2Plus(t *testing.T) { require.ElementsMatch(t, fromAddresses, actual) var vrfOwnerAddress evmtypes.EIP55Address require.Error(t, db.Get(&vrfOwnerAddress, `SELECT vrf_owner_address FROM vrf_specs LIMIT 1`)) - require.NoError(t, jobORM.DeleteJob(jb.ID)) + require.NoError(t, jobORM.DeleteJob(ctx, jb.ID)) cltest.AssertCount(t, db, "vrf_specs", 0) cltest.AssertCount(t, db, "jobs", 0) @@ -599,14 +610,14 @@ func TestORM_CreateJob_VRFV2Plus(t *testing.T) { FromAddresses: fromAddresses, }).Toml()) require.NoError(t, err) - require.NoError(t, jobORM.CreateJob(&jb)) + require.NoError(t, jobORM.CreateJob(ctx, &jb)) cltest.AssertCount(t, db, "vrf_specs", 1) cltest.AssertCount(t, db, "jobs", 1) require.NoError(t, db.Get(&requestedConfsDelay, `SELECT requested_confs_delay FROM vrf_specs LIMIT 1`)) require.Equal(t, int64(0), requestedConfsDelay) require.NoError(t, db.Get(&requestTimeout, `SELECT request_timeout FROM vrf_specs LIMIT 1`)) require.Equal(t, 1*time.Hour, requestTimeout) - require.NoError(t, jobORM.DeleteJob(jb.ID)) + require.NoError(t, jobORM.DeleteJob(ctx, jb.ID)) cltest.AssertCount(t, db, "vrf_specs", 0) cltest.AssertCount(t, db, "jobs", 0) } @@ -621,12 +632,12 @@ func TestORM_CreateJob_OCRBootstrap(t *testing.T) { lggr := logger.TestLogger(t) pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) - jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) jb, err := ocrbootstrap.ValidatedBootstrapSpecToml(testspecs.GetOCRBootstrapSpec()) require.NoError(t, err) - err = jobORM.CreateJob(&jb) + err = jobORM.CreateJob(ctx, &jb) require.NoError(t, err) cltest.AssertCount(t, db, "bootstrap_specs", 1) cltest.AssertCount(t, db, "jobs", 1) @@ -634,7 +645,7 @@ func TestORM_CreateJob_OCRBootstrap(t *testing.T) { require.NoError(t, db.Get(&relay, `SELECT relay FROM bootstrap_specs LIMIT 1`)) require.Equal(t, "evm", relay) - require.NoError(t, jobORM.DeleteJob(jb.ID)) + require.NoError(t, jobORM.DeleteJob(ctx, jb.ID)) cltest.AssertCount(t, db, "bootstrap_specs", 0) cltest.AssertCount(t, db, "jobs", 0) } @@ -648,14 +659,14 @@ func TestORM_CreateJob_EVMChainID_Validation(t *testing.T) { pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) - jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) t.Run("evm chain id validation for ocr works", func(t *testing.T) { jb := job.Job{ Type: job.OffchainReporting, OCROracleSpec: &job.OCROracleSpec{}, } - assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(&jb).Error()) + assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(testutils.Context(t), &jb).Error()) }) t.Run("evm chain id validation for direct request works", func(t *testing.T) { @@ -663,7 +674,7 @@ func TestORM_CreateJob_EVMChainID_Validation(t *testing.T) { Type: job.DirectRequest, DirectRequestSpec: &job.DirectRequestSpec{}, } - assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(&jb).Error()) + assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(testutils.Context(t), &jb).Error()) }) t.Run("evm chain id validation for flux monitor works", func(t *testing.T) { @@ -671,7 +682,7 @@ func TestORM_CreateJob_EVMChainID_Validation(t *testing.T) { Type: job.FluxMonitor, FluxMonitorSpec: &job.FluxMonitorSpec{}, } - assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(&jb).Error()) + assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(testutils.Context(t), &jb).Error()) }) t.Run("evm chain id validation for keepers works", func(t *testing.T) { @@ -679,7 +690,7 @@ func TestORM_CreateJob_EVMChainID_Validation(t *testing.T) { Type: job.Keeper, KeeperSpec: &job.KeeperSpec{}, } - assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(&jb).Error()) + assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(testutils.Context(t), &jb).Error()) }) t.Run("evm chain id validation for vrf works", func(t *testing.T) { @@ -687,7 +698,7 @@ func TestORM_CreateJob_EVMChainID_Validation(t *testing.T) { Type: job.VRF, VRFSpec: &job.VRFSpec{}, } - assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(&jb).Error()) + assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(testutils.Context(t), &jb).Error()) }) t.Run("evm chain id validation for block hash store works", func(t *testing.T) { @@ -695,7 +706,7 @@ func TestORM_CreateJob_EVMChainID_Validation(t *testing.T) { Type: job.BlockhashStore, BlockhashStoreSpec: &job.BlockhashStoreSpec{}, } - assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(&jb).Error()) + assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(testutils.Context(t), &jb).Error()) }) t.Run("evm chain id validation for block header feeder works", func(t *testing.T) { @@ -703,7 +714,7 @@ func TestORM_CreateJob_EVMChainID_Validation(t *testing.T) { Type: job.BlockHeaderFeeder, BlockHeaderFeederSpec: &job.BlockHeaderFeederSpec{}, } - assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(&jb).Error()) + assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(testutils.Context(t), &jb).Error()) }) t.Run("evm chain id validation for legacy gas station server spec works", func(t *testing.T) { @@ -711,7 +722,7 @@ func TestORM_CreateJob_EVMChainID_Validation(t *testing.T) { Type: job.LegacyGasStationServer, LegacyGasStationServerSpec: &job.LegacyGasStationServerSpec{}, } - assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(&jb).Error()) + assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(testutils.Context(t), &jb).Error()) }) t.Run("evm chain id validation for legacy gas station sidecar spec works", func(t *testing.T) { @@ -719,7 +730,7 @@ func TestORM_CreateJob_EVMChainID_Validation(t *testing.T) { Type: job.LegacyGasStationSidecar, LegacyGasStationSidecarSpec: &job.LegacyGasStationSidecarSpec{}, } - assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(&jb).Error()) + assert.Equal(t, "CreateJobFailed: evm chain id must be defined", jobORM.CreateJob(testutils.Context(t), &jb).Error()) }) } @@ -744,7 +755,7 @@ func TestORM_CreateJob_OCR_DuplicatedContractAddress(t *testing.T) { pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) - jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) // defaultChainID is deprecated defaultChainID := customChainID @@ -768,7 +779,8 @@ func TestORM_CreateJob_OCR_DuplicatedContractAddress(t *testing.T) { require.NoError(t, err) t.Run("with a set chain id", func(t *testing.T) { - err = jobORM.CreateJob(&jb) // Add job with custom chain id + ctx := testutils.Context(t) + err = jobORM.CreateJob(ctx, &jb) // Add job with custom chain id require.NoError(t, err) cltest.AssertCount(t, db, "ocr_oracle_specs", 1) @@ -778,7 +790,7 @@ func TestORM_CreateJob_OCR_DuplicatedContractAddress(t *testing.T) { spec.JobID = externalJobID.UUID.String() jba, err := ocr.ValidatedOracleSpecToml(config, legacyChains, spec.Toml()) require.NoError(t, err) - err = jobORM.CreateJob(&jba) // Try to add duplicate job with default id + err = jobORM.CreateJob(ctx, &jba) // Try to add duplicate job with default id require.Error(t, err) assert.Equal(t, fmt.Sprintf("CreateJobFailed: a job with contract address %s already exists for chain ID %s", jb.OCROracleSpec.ContractAddress, defaultChainID.String()), err.Error()) @@ -787,7 +799,7 @@ func TestORM_CreateJob_OCR_DuplicatedContractAddress(t *testing.T) { jb2, err := ocr.ValidatedOracleSpecToml(config, legacyChains, spec.Toml()) require.NoError(t, err) - err = jobORM.CreateJob(&jb2) // Try to add duplicate job with custom id + err = jobORM.CreateJob(ctx, &jb2) // Try to add duplicate job with custom id require.Error(t, err) assert.Equal(t, fmt.Sprintf("CreateJobFailed: a job with contract address %s already exists for chain ID %s", jb2.OCROracleSpec.ContractAddress, customChainID), err.Error()) }) @@ -814,7 +826,7 @@ func TestORM_CreateJob_OCR2_DuplicatedContractAddress(t *testing.T) { pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) - jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) _, address := cltest.MustInsertRandomKey(t, keyStore.Eth()) @@ -831,7 +843,7 @@ func TestORM_CreateJob_OCR2_DuplicatedContractAddress(t *testing.T) { jb.OCR2OracleSpec.TransmitterID = null.StringFrom(address.String()) jb.OCR2OracleSpec.PluginConfig["juelsPerFeeCoinSource"] = juelsPerFeeCoinSource - err = jobORM.CreateJob(&jb) + err = jobORM.CreateJob(ctx, &jb) require.NoError(t, err) jb2, err := ocr2validate.ValidatedOracleSpecToml(testutils.Context(t), config.OCR2(), config.Insecure(), testspecs.GetOCR2EVMSpecMinimal(), nil) @@ -841,7 +853,7 @@ func TestORM_CreateJob_OCR2_DuplicatedContractAddress(t *testing.T) { jb2.OCR2OracleSpec.TransmitterID = null.StringFrom(address.String()) jb.OCR2OracleSpec.PluginConfig["juelsPerFeeCoinSource"] = juelsPerFeeCoinSource - err = jobORM.CreateJob(&jb2) + err = jobORM.CreateJob(ctx, &jb2) require.Error(t, err) jb3, err := ocr2validate.ValidatedOracleSpecToml(testutils.Context(t), config.OCR2(), config.Insecure(), testspecs.GetOCR2EVMSpecMinimal(), nil) @@ -851,7 +863,7 @@ func TestORM_CreateJob_OCR2_DuplicatedContractAddress(t *testing.T) { jb3.OCR2OracleSpec.RelayConfig["chainID"] = customChainID.Int64() jb.OCR2OracleSpec.PluginConfig["juelsPerFeeCoinSource"] = juelsPerFeeCoinSource - err = jobORM.CreateJob(&jb3) + err = jobORM.CreateJob(ctx, &jb3) require.Error(t, err) } @@ -876,37 +888,41 @@ func TestORM_CreateJob_OCR2_Sending_Keys_Transmitter_Keys_Validations(t *testing pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) - jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) jb, err := ocr2validate.ValidatedOracleSpecToml(testutils.Context(t), config.OCR2(), config.Insecure(), testspecs.GetOCR2EVMSpecMinimal(), nil) require.NoError(t, err) t.Run("sending keys or transmitterID must be defined", func(t *testing.T) { + ctx := testutils.Context(t) jb.OCR2OracleSpec.TransmitterID = null.String{} - assert.Equal(t, "CreateJobFailed: neither sending keys nor transmitter ID is defined", jobORM.CreateJob(&jb).Error()) + assert.Equal(t, "CreateJobFailed: neither sending keys nor transmitter ID is defined", jobORM.CreateJob(ctx, &jb).Error()) }) _, address := cltest.MustInsertRandomKey(t, keyStore.Eth()) t.Run("sending keys validation works properly", func(t *testing.T) { + ctx := testutils.Context(t) jb.OCR2OracleSpec.TransmitterID = null.String{} _, address2 := cltest.MustInsertRandomKey(t, keyStore.Eth()) jb.OCR2OracleSpec.RelayConfig["sendingKeys"] = interface{}([]any{address.String(), address2.String(), common.HexToAddress("0X0").String()}) - assert.Equal(t, "CreateJobFailed: no EVM key matching: \"0x0000000000000000000000000000000000000000\": no such sending key exists", jobORM.CreateJob(&jb).Error()) + assert.Equal(t, "CreateJobFailed: no EVM key matching: \"0x0000000000000000000000000000000000000000\": no such sending key exists", jobORM.CreateJob(ctx, &jb).Error()) jb.OCR2OracleSpec.RelayConfig["sendingKeys"] = interface{}([]any{1, 2, 3}) - assert.Equal(t, "CreateJobFailed: sending keys are of wrong type", jobORM.CreateJob(&jb).Error()) + assert.Equal(t, "CreateJobFailed: sending keys are of wrong type", jobORM.CreateJob(ctx, &jb).Error()) }) t.Run("sending keys and transmitter ID can't both be defined", func(t *testing.T) { + ctx := testutils.Context(t) jb.OCR2OracleSpec.TransmitterID = null.StringFrom(address.String()) jb.OCR2OracleSpec.RelayConfig["sendingKeys"] = interface{}([]any{address.String()}) - assert.Equal(t, "CreateJobFailed: sending keys and transmitter ID can't both be defined", jobORM.CreateJob(&jb).Error()) + assert.Equal(t, "CreateJobFailed: sending keys and transmitter ID can't both be defined", jobORM.CreateJob(ctx, &jb).Error()) }) t.Run("transmitter validation works", func(t *testing.T) { + ctx := testutils.Context(t) jb.OCR2OracleSpec.TransmitterID = null.StringFrom("transmitterID that doesn't have a match in key store") jb.OCR2OracleSpec.RelayConfig["sendingKeys"] = nil - assert.Equal(t, "CreateJobFailed: no EVM key matching: \"transmitterID that doesn't have a match in key store\": no such transmitter key exists", jobORM.CreateJob(&jb).Error()) + assert.Equal(t, "CreateJobFailed: no EVM key matching: \"transmitterID that doesn't have a match in key store\": no such transmitter key exists", jobORM.CreateJob(ctx, &jb).Error()) }) } @@ -998,7 +1014,7 @@ func Test_FindJobs(t *testing.T) { pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) - orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) _, bridge := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) _, bridge2 := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) @@ -1016,7 +1032,7 @@ func Test_FindJobs(t *testing.T) { ) require.NoError(t, err) - err = orm.CreateJob(&jb1) + err = orm.CreateJob(ctx, &jb1) require.NoError(t, err) jb2, err := directrequest.ValidatedDirectRequestSpec( @@ -1024,11 +1040,11 @@ func Test_FindJobs(t *testing.T) { ) require.NoError(t, err) - err = orm.CreateJob(&jb2) + err = orm.CreateJob(ctx, &jb2) require.NoError(t, err) t.Run("jobs are ordered by latest first", func(t *testing.T) { - jobs, count, err2 := orm.FindJobs(0, 2) + jobs, count, err2 := orm.FindJobs(testutils.Context(t), 0, 2) require.NoError(t, err2) require.Len(t, jobs, 2) assert.Equal(t, count, 2) @@ -1041,7 +1057,7 @@ func Test_FindJobs(t *testing.T) { }) t.Run("jobs respect pagination", func(t *testing.T) { - jobs, count, err2 := orm.FindJobs(0, 1) + jobs, count, err2 := orm.FindJobs(testutils.Context(t), 0, 1) require.NoError(t, err2) require.Len(t, jobs, 1) assert.Equal(t, count, 2) @@ -1080,7 +1096,7 @@ func Test_FindJob(t *testing.T) { pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) - orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) _, bridge := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) _, bridge2 := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) @@ -1148,20 +1164,20 @@ func Test_FindJob(t *testing.T) { jobOCR2WithFeedID2.Name = null.StringFrom("new name") require.NoError(t, err) - err = orm.CreateJob(&job) + err = orm.CreateJob(ctx, &job) require.NoError(t, err) - err = orm.CreateJob(&jobSameAddress) + err = orm.CreateJob(ctx, &jobSameAddress) require.NoError(t, err) - err = orm.CreateJob(&jobOCR2) + err = orm.CreateJob(ctx, &jobOCR2) require.NoError(t, err) - err = orm.CreateJob(&jobOCR2WithFeedID1) + err = orm.CreateJob(ctx, &jobOCR2WithFeedID1) require.NoError(t, err) // second ocr2 job with same contract id but different feed id - err = orm.CreateJob(&jobOCR2WithFeedID2) + err = orm.CreateJob(ctx, &jobOCR2WithFeedID2) require.NoError(t, err) t.Run("by id", func(t *testing.T) { @@ -1180,7 +1196,8 @@ func Test_FindJob(t *testing.T) { }) t.Run("by external job id", func(t *testing.T) { - jb, err2 := orm.FindJobByExternalJobID(externalJobID) + ctx := testutils.Context(t) + jb, err2 := orm.FindJobByExternalJobID(ctx, externalJobID) require.NoError(t, err2) assert.Equal(t, jb.ID, job.ID) @@ -1193,59 +1210,64 @@ func Test_FindJob(t *testing.T) { }) t.Run("by address", func(t *testing.T) { - jbID, err2 := orm.FindJobIDByAddress(job.OCROracleSpec.ContractAddress, job.OCROracleSpec.EVMChainID) + ctx := testutils.Context(t) + jbID, err2 := orm.FindJobIDByAddress(ctx, job.OCROracleSpec.ContractAddress, job.OCROracleSpec.EVMChainID) require.NoError(t, err2) assert.Equal(t, job.ID, jbID) - _, err2 = orm.FindJobIDByAddress("not-existing", big.NewI(0)) + _, err2 = orm.FindJobIDByAddress(ctx, "not-existing", big.NewI(0)) require.Error(t, err2) require.ErrorIs(t, err2, sql.ErrNoRows) }) t.Run("by address yet chain scoped", func(t *testing.T) { + ctx := testutils.Context(t) commonAddr := jobSameAddress.OCROracleSpec.ContractAddress // Find job ID for job on chain 1337 with common address. - jbID, err2 := orm.FindJobIDByAddress(commonAddr, jobSameAddress.OCROracleSpec.EVMChainID) + jbID, err2 := orm.FindJobIDByAddress(ctx, commonAddr, jobSameAddress.OCROracleSpec.EVMChainID) require.NoError(t, err2) assert.Equal(t, jobSameAddress.ID, jbID) // Find job ID for job on default evm chain with common address. - jbID, err2 = orm.FindJobIDByAddress(commonAddr, job.OCROracleSpec.EVMChainID) + jbID, err2 = orm.FindJobIDByAddress(ctx, commonAddr, job.OCROracleSpec.EVMChainID) require.NoError(t, err2) assert.Equal(t, job.ID, jbID) }) t.Run("by contract id without feed id", func(t *testing.T) { + ctx := testutils.Context(t) contractID := "0x613a38AC1659769640aaE063C651F48E0250454C" // Find job ID for ocr2 job without feedID. - jbID, err2 := orm.FindOCR2JobIDByAddress(contractID, nil) + jbID, err2 := orm.FindOCR2JobIDByAddress(ctx, contractID, nil) require.NoError(t, err2) assert.Equal(t, jobOCR2.ID, jbID) }) t.Run("by contract id with valid feed id", func(t *testing.T) { + ctx := testutils.Context(t) contractID := "0x0000000000000000000000000000000000000006" feedID := common.HexToHash(ocr2WithFeedID1) // Find job ID for ocr2 job with feed ID - jbID, err2 := orm.FindOCR2JobIDByAddress(contractID, &feedID) + jbID, err2 := orm.FindOCR2JobIDByAddress(ctx, contractID, &feedID) require.NoError(t, err2) assert.Equal(t, jobOCR2WithFeedID1.ID, jbID) }) t.Run("with duplicate contract id but different feed id", func(t *testing.T) { + ctx := testutils.Context(t) contractID := "0x0000000000000000000000000000000000000006" feedID := common.HexToHash(ocr2WithFeedID2) // Find job ID for ocr2 job with feed ID - jbID, err2 := orm.FindOCR2JobIDByAddress(contractID, &feedID) + jbID, err2 := orm.FindOCR2JobIDByAddress(ctx, contractID, &feedID) require.NoError(t, err2) assert.Equal(t, jobOCR2WithFeedID2.ID, jbID) @@ -1263,17 +1285,18 @@ func Test_FindJobsByPipelineSpecIDs(t *testing.T) { pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) - orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) jb, err := directrequest.ValidatedDirectRequestSpec(testspecs.GetDirectRequestSpec()) require.NoError(t, err) jb.DirectRequestSpec.EVMChainID = big.NewI(0) - err = orm.CreateJob(&jb) + err = orm.CreateJob(testutils.Context(t), &jb) require.NoError(t, err) t.Run("with jobs", func(t *testing.T) { - jbs, err2 := orm.FindJobsByPipelineSpecIDs([]int32{jb.PipelineSpecID}) + ctx := testutils.Context(t) + jbs, err2 := orm.FindJobsByPipelineSpecIDs(ctx, []int32{jb.PipelineSpecID}) require.NoError(t, err2) assert.Len(t, jbs, 1) @@ -1286,15 +1309,17 @@ func Test_FindJobsByPipelineSpecIDs(t *testing.T) { }) t.Run("without jobs", func(t *testing.T) { - jbs, err2 := orm.FindJobsByPipelineSpecIDs([]int32{-1}) + ctx := testutils.Context(t) + jbs, err2 := orm.FindJobsByPipelineSpecIDs(ctx, []int32{-1}) require.NoError(t, err2) assert.Len(t, jbs, 0) }) t.Run("with chainID disabled", func(t *testing.T) { - orm2 := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + ctx := testutils.Context(t) + orm2 := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) - jbs, err2 := orm2.FindJobsByPipelineSpecIDs([]int32{jb.PipelineSpecID}) + jbs, err2 := orm2.FindJobsByPipelineSpecIDs(ctx, []int32{jb.PipelineSpecID}) require.NoError(t, err2) assert.Len(t, jbs, 1) }) @@ -1314,7 +1339,7 @@ func Test_FindPipelineRuns(t *testing.T) { bridgesORM := bridges.NewORM(db) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) - orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) _, bridge := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) _, bridge2 := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) @@ -1331,20 +1356,22 @@ func Test_FindPipelineRuns(t *testing.T) { ) require.NoError(t, err) - err = orm.CreateJob(&jb) + err = orm.CreateJob(testutils.Context(t), &jb) require.NoError(t, err) t.Run("with no pipeline runs", func(t *testing.T) { - runs, count, err2 := orm.PipelineRuns(nil, 0, 10) + ctx := testutils.Context(t) + runs, count, err2 := orm.PipelineRuns(ctx, nil, 0, 10) require.NoError(t, err2) assert.Equal(t, count, 0) assert.Empty(t, runs) }) t.Run("with a pipeline run", func(t *testing.T) { + ctx := testutils.Context(t) run := mustInsertPipelineRun(t, pipelineORM, jb) - runs, count, err2 := orm.PipelineRuns(nil, 0, 10) + runs, count, err2 := orm.PipelineRuns(ctx, nil, 0, 10) require.NoError(t, err2) assert.Equal(t, count, 1) @@ -1376,7 +1403,7 @@ func Test_PipelineRunsByJobID(t *testing.T) { bridgesORM := bridges.NewORM(db) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) - orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) _, bridge := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) _, bridge2 := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) @@ -1393,20 +1420,22 @@ func Test_PipelineRunsByJobID(t *testing.T) { ) require.NoError(t, err) - err = orm.CreateJob(&jb) + err = orm.CreateJob(testutils.Context(t), &jb) require.NoError(t, err) t.Run("with no pipeline runs", func(t *testing.T) { - runs, count, err2 := orm.PipelineRuns(&jb.ID, 0, 10) + ctx := testutils.Context(t) + runs, count, err2 := orm.PipelineRuns(ctx, &jb.ID, 0, 10) require.NoError(t, err2) assert.Equal(t, count, 0) assert.Empty(t, runs) }) t.Run("with a pipeline run", func(t *testing.T) { + ctx := testutils.Context(t) run := mustInsertPipelineRun(t, pipelineORM, jb) - runs, count, err2 := orm.PipelineRuns(&jb.ID, 0, 10) + runs, count, err2 := orm.PipelineRuns(ctx, &jb.ID, 0, 10) require.NoError(t, err2) assert.Equal(t, 1, count) @@ -1437,7 +1466,7 @@ func Test_FindPipelineRunIDsByJobID(t *testing.T) { bridgesORM := bridges.NewORM(db) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) - orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) _, address := cltest.MustInsertRandomKey(t, keyStore.Eth()) @@ -1461,7 +1490,7 @@ func Test_FindPipelineRunIDsByJobID(t *testing.T) { require.NoError(t, err) - err = orm.CreateJob(&jb) + err = orm.CreateJob(testutils.Context(t), &jb) require.NoError(t, err) jobs[j] = jb } @@ -1475,15 +1504,17 @@ func Test_FindPipelineRunIDsByJobID(t *testing.T) { } t.Run("with no pipeline runs", func(t *testing.T) { - runIDs, err := orm.FindPipelineRunIDsByJobID(jb.ID, 0, 10) + ctx := testutils.Context(t) + runIDs, err := orm.FindPipelineRunIDsByJobID(ctx, jb.ID, 0, 10) require.NoError(t, err) assert.Empty(t, runIDs) }) t.Run("with a pipeline run", func(t *testing.T) { + ctx := testutils.Context(t) run := mustInsertPipelineRun(t, pipelineORM, jb) - runIDs, err := orm.FindPipelineRunIDsByJobID(jb.ID, 0, 10) + runIDs, err := orm.FindPipelineRunIDsByJobID(ctx, jb.ID, 0, 10) require.NoError(t, err) require.Len(t, runIDs, 1) @@ -1493,7 +1524,8 @@ func Test_FindPipelineRunIDsByJobID(t *testing.T) { // Internally these queries are batched by 1000, this tests case requiring concatenation // of more than 1 batch t.Run("with batch concatenation limit 10", func(t *testing.T) { - runIDs, err := orm.FindPipelineRunIDsByJobID(jobs[3].ID, 95, 10) + ctx := testutils.Context(t) + runIDs, err := orm.FindPipelineRunIDsByJobID(ctx, jobs[3].ID, 95, 10) require.NoError(t, err) require.Len(t, runIDs, 10) assert.Equal(t, int64(4*(len(jobs)-1)), runIDs[3]-runIDs[7]) @@ -1502,7 +1534,8 @@ func Test_FindPipelineRunIDsByJobID(t *testing.T) { // Internally these queries are batched by 1000, this tests case requiring concatenation // of more than 1 batch t.Run("with batch concatenation limit 100", func(t *testing.T) { - runIDs, err := orm.FindPipelineRunIDsByJobID(jobs[3].ID, 95, 100) + ctx := testutils.Context(t) + runIDs, err := orm.FindPipelineRunIDsByJobID(ctx, jobs[3].ID, 95, 100) require.NoError(t, err) require.Len(t, runIDs, 100) assert.Equal(t, int64(67*(len(jobs)-1)), runIDs[12]-runIDs[79]) @@ -1516,7 +1549,8 @@ func Test_FindPipelineRunIDsByJobID(t *testing.T) { // returns empty. This can happen if the job id being requested hasn't run in a while, // but many other jobs have run since. t.Run("with first batch empty, over limit", func(t *testing.T) { - runIDs, err := orm.FindPipelineRunIDsByJobID(jobs[3].ID, 0, 25) + ctx := testutils.Context(t) + runIDs, err := orm.FindPipelineRunIDsByJobID(ctx, jobs[3].ID, 0, 25) require.NoError(t, err) require.Len(t, runIDs, 25) assert.Equal(t, int64(16*(len(jobs)-1)), runIDs[7]-runIDs[23]) @@ -1524,7 +1558,8 @@ func Test_FindPipelineRunIDsByJobID(t *testing.T) { // Same as previous, but where there are fewer matching jobs than the limit t.Run("with first batch empty, under limit", func(t *testing.T) { - runIDs, err := orm.FindPipelineRunIDsByJobID(jobs[3].ID, 143, 190) + ctx := testutils.Context(t) + runIDs, err := orm.FindPipelineRunIDsByJobID(ctx, jobs[3].ID, 143, 190) require.NoError(t, err) require.Len(t, runIDs, 107) assert.Equal(t, int64(16*(len(jobs)-1)), runIDs[7]-runIDs[23]) @@ -1546,7 +1581,7 @@ func Test_FindPipelineRunsByIDs(t *testing.T) { bridgesORM := bridges.NewORM(db) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) - orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) _, bridge := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) _, bridge2 := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) @@ -1563,19 +1598,21 @@ func Test_FindPipelineRunsByIDs(t *testing.T) { ) require.NoError(t, err) - err = orm.CreateJob(&jb) + err = orm.CreateJob(testutils.Context(t), &jb) require.NoError(t, err) t.Run("with no pipeline runs", func(t *testing.T) { - runs, err2 := orm.FindPipelineRunsByIDs([]int64{-1}) + ctx := testutils.Context(t) + runs, err2 := orm.FindPipelineRunsByIDs(ctx, []int64{-1}) require.NoError(t, err2) assert.Empty(t, runs) }) t.Run("with a pipeline run", func(t *testing.T) { + ctx := testutils.Context(t) run := mustInsertPipelineRun(t, pipelineORM, jb) - actual, err2 := orm.FindPipelineRunsByIDs([]int64{run.ID}) + actual, err2 := orm.FindPipelineRunsByIDs(ctx, []int64{run.ID}) require.NoError(t, err2) require.Len(t, actual, 1) @@ -1603,24 +1640,26 @@ func Test_FindPipelineRunByID(t *testing.T) { pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) - orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) jb, err := directrequest.ValidatedDirectRequestSpec(testspecs.GetDirectRequestSpec()) require.NoError(t, err) - err = orm.CreateJob(&jb) + err = orm.CreateJob(testutils.Context(t), &jb) require.NoError(t, err) t.Run("with no pipeline run", func(t *testing.T) { - run, err2 := orm.FindPipelineRunByID(-1) + ctx := testutils.Context(t) + run, err2 := orm.FindPipelineRunByID(ctx, -1) assert.Equal(t, run, pipeline.Run{}) require.ErrorIs(t, err2, sql.ErrNoRows) }) t.Run("with a pipeline run", func(t *testing.T) { + ctx := testutils.Context(t) run := mustInsertPipelineRun(t, pipelineORM, jb) - actual, err2 := orm.FindPipelineRunByID(run.ID) + actual, err2 := orm.FindPipelineRunByID(ctx, run.ID) require.NoError(t, err2) actualRun := actual @@ -1647,12 +1686,12 @@ func Test_FindJobWithoutSpecErrors(t *testing.T) { pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) - orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) jb, err := directrequest.ValidatedDirectRequestSpec(testspecs.GetDirectRequestSpec()) require.NoError(t, err) - err = orm.CreateJob(&jb) + err = orm.CreateJob(ctx, &jb) require.NoError(t, err) var jobSpec job.Job err = db.Get(&jobSpec, "SELECT * FROM jobs") @@ -1660,10 +1699,10 @@ func Test_FindJobWithoutSpecErrors(t *testing.T) { ocrSpecError1 := "ocr spec 1 errored" ocrSpecError2 := "ocr spec 2 errored" - require.NoError(t, orm.RecordError(jobSpec.ID, ocrSpecError1)) - require.NoError(t, orm.RecordError(jobSpec.ID, ocrSpecError2)) + require.NoError(t, orm.RecordError(ctx, jobSpec.ID, ocrSpecError1)) + require.NoError(t, orm.RecordError(ctx, jobSpec.ID, ocrSpecError2)) - jb, err = orm.FindJobWithoutSpecErrors(jobSpec.ID) + jb, err = orm.FindJobWithoutSpecErrors(ctx, jobSpec.ID) require.NoError(t, err) jbWithErrors, err := orm.FindJobTx(testutils.Context(t), jobSpec.ID) require.NoError(t, err) @@ -1685,12 +1724,12 @@ func Test_FindSpecErrorsByJobIDs(t *testing.T) { pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) - orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) jb, err := directrequest.ValidatedDirectRequestSpec(testspecs.GetDirectRequestSpec()) require.NoError(t, err) - err = orm.CreateJob(&jb) + err = orm.CreateJob(ctx, &jb) require.NoError(t, err) var jobSpec job.Job err = db.Get(&jobSpec, "SELECT * FROM jobs") @@ -1698,10 +1737,10 @@ func Test_FindSpecErrorsByJobIDs(t *testing.T) { ocrSpecError1 := "ocr spec 1 errored" ocrSpecError2 := "ocr spec 2 errored" - require.NoError(t, orm.RecordError(jobSpec.ID, ocrSpecError1)) - require.NoError(t, orm.RecordError(jobSpec.ID, ocrSpecError2)) + require.NoError(t, orm.RecordError(ctx, jobSpec.ID, ocrSpecError1)) + require.NoError(t, orm.RecordError(ctx, jobSpec.ID, ocrSpecError2)) - specErrs, err := orm.FindSpecErrorsByJobIDs([]int32{jobSpec.ID}) + specErrs, err := orm.FindSpecErrorsByJobIDs(ctx, []int32{jobSpec.ID}) require.NoError(t, err) assert.Equal(t, len(specErrs), 2) @@ -1722,7 +1761,7 @@ func Test_CountPipelineRunsByJobID(t *testing.T) { bridgesORM := bridges.NewORM(db) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) - orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) + orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore) _, bridge := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) _, bridge2 := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) @@ -1739,19 +1778,21 @@ func Test_CountPipelineRunsByJobID(t *testing.T) { ) require.NoError(t, err) - err = orm.CreateJob(&jb) + err = orm.CreateJob(testutils.Context(t), &jb) require.NoError(t, err) t.Run("with no pipeline runs", func(t *testing.T) { - count, err2 := orm.CountPipelineRunsByJobID(jb.ID) + ctx := testutils.Context(t) + count, err2 := orm.CountPipelineRunsByJobID(ctx, jb.ID) require.NoError(t, err2) assert.Equal(t, int32(0), count) }) t.Run("with a pipeline run", func(t *testing.T) { + ctx := testutils.Context(t) mustInsertPipelineRun(t, pipelineORM, jb) - count, err2 := orm.CountPipelineRunsByJobID(jb.ID) + count, err2 := orm.CountPipelineRunsByJobID(ctx, jb.ID) require.NoError(t, err2) require.Equal(t, int32(1), count) }) diff --git a/core/services/job/job_pipeline_orm_integration_test.go b/core/services/job/job_pipeline_orm_integration_test.go index 33ee6dc306c..f8a43bca34d 100644 --- a/core/services/job/job_pipeline_orm_integration_test.go +++ b/core/services/job/job_pipeline_orm_integration_test.go @@ -135,7 +135,7 @@ func TestPipelineORM_Integration(t *testing.T) { p, err := pipeline.Parse(DotStr) require.NoError(t, err) - specID, err = orm.CreateSpec(ctx, nil, *p, models.Interval(0)) + specID, err = orm.CreateSpec(ctx, *p, models.Interval(0)) require.NoError(t, err) var pipelineSpecs []pipeline.Spec @@ -160,12 +160,12 @@ func TestPipelineORM_Integration(t *testing.T) { legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) runner := pipeline.NewRunner(orm, btORM, config.JobPipeline(), cfg.WebServer(), legacyChains, nil, nil, lggr, nil, nil) - jobORM := NewTestORM(t, db, orm, btORM, keyStore, cfg.Database()) + jobORM := NewTestORM(t, db, orm, btORM, keyStore) dbSpec := makeVoterTurnoutOCRJobSpec(t, transmitterAddress, bridge.Name.String(), bridge2.Name.String()) // Need a job in order to create a run - require.NoError(t, jobORM.CreateJob(dbSpec)) + require.NoError(t, jobORM.CreateJob(testutils.Context(t), dbSpec)) var pipelineSpecs []pipeline.Spec sql := `SELECT pipeline_specs.*, job_pipeline_specs.job_id FROM pipeline_specs JOIN job_pipeline_specs ON (pipeline_specs.id = job_pipeline_specs.pipeline_spec_id);` diff --git a/core/services/job/kv_orm.go b/core/services/job/kv_orm.go index 6108c123a62..63384efc25b 100644 --- a/core/services/job/kv_orm.go +++ b/core/services/job/kv_orm.go @@ -5,10 +5,8 @@ import ( "fmt" "time" - "github.com/jmoiron/sqlx" - + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) // KVStore is a simple KV store that can store and retrieve serializable data. @@ -21,17 +19,17 @@ type KVStore interface { type kVStore struct { jobID int32 - q pg.Q + ds sqlutil.DataSource lggr logger.SugaredLogger } var _ KVStore = (*kVStore)(nil) -func NewKVStore(jobID int32, db *sqlx.DB, cfg pg.QConfig, lggr logger.Logger) kVStore { +func NewKVStore(jobID int32, ds sqlutil.DataSource, lggr logger.Logger) kVStore { namedLogger := logger.Sugared(lggr.Named("JobORM")) return kVStore{ jobID: jobID, - q: pg.NewQ(db, namedLogger, cfg), + ds: ds, lggr: namedLogger, } } @@ -45,7 +43,7 @@ func (kv kVStore) Store(ctx context.Context, key string, val []byte) error { val_bytea = EXCLUDED.val_bytea, updated_at = $4;` - if _, err := kv.q.ExecContext(ctx, sql, kv.jobID, key, val, time.Now()); err != nil { + if _, err := kv.ds.ExecContext(ctx, sql, kv.jobID, key, val, time.Now()); err != nil { return fmt.Errorf("failed to store value: %s for key: %s for jobID: %d : %w", string(val), key, kv.jobID, err) } return nil @@ -55,7 +53,7 @@ func (kv kVStore) Store(ctx context.Context, key string, val []byte) error { func (kv kVStore) Get(ctx context.Context, key string) ([]byte, error) { var val []byte sql := "SELECT val_bytea FROM job_kv_store WHERE job_id = $1 AND key = $2" - if err := kv.q.GetContext(ctx, &val, sql, kv.jobID, key); err != nil { + if err := kv.ds.GetContext(ctx, &val, sql, kv.jobID, key); err != nil { return nil, fmt.Errorf("failed to get value by key: %s for jobID: %d : %w", key, kv.jobID, err) } diff --git a/core/services/job/kv_orm_test.go b/core/services/job/kv_orm_test.go index 156779ffb4d..0f229f09d88 100644 --- a/core/services/job/kv_orm_test.go +++ b/core/services/job/kv_orm_test.go @@ -11,6 +11,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/bridges" "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" @@ -33,13 +34,13 @@ func TestJobKVStore(t *testing.T) { bridgesORM := bridges.NewORM(db) jobID := int32(1337) - kvStore := job.NewKVStore(jobID, db, config.Database(), lggr) - jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, cltest.NewKeyStore(t, db), config.Database()) + kvStore := job.NewKVStore(jobID, db, lggr) + jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, cltest.NewKeyStore(t, db)) jb, err := directrequest.ValidatedDirectRequestSpec(testspecs.GetDirectRequestSpec()) require.NoError(t, err) jb.ID = jobID - require.NoError(t, jobORM.CreateJob(&jb)) + require.NoError(t, jobORM.CreateJob(testutils.Context(t), &jb)) var values = [][]byte{ []byte("Hello"), @@ -72,5 +73,5 @@ func TestJobKVStore(t *testing.T) { require.NoError(t, err) require.Equal(t, td2, fetchedBytes) - require.NoError(t, jobORM.DeleteJob(jobID)) + require.NoError(t, jobORM.DeleteJob(ctx, jobID)) } diff --git a/core/services/job/mocks/orm.go b/core/services/job/mocks/orm.go index b8534b9d688..ec60137de93 100644 --- a/core/services/job/mocks/orm.go +++ b/core/services/job/mocks/orm.go @@ -12,10 +12,10 @@ import ( mock "github.com/stretchr/testify/mock" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" - pipeline "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" + sqlutil "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" + types "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" uuid "github.com/google/uuid" @@ -62,9 +62,9 @@ func (_m *ORM) Close() error { return r0 } -// CountPipelineRunsByJobID provides a mock function with given fields: jobID -func (_m *ORM) CountPipelineRunsByJobID(jobID int32) (int32, error) { - ret := _m.Called(jobID) +// CountPipelineRunsByJobID provides a mock function with given fields: ctx, jobID +func (_m *ORM) CountPipelineRunsByJobID(ctx context.Context, jobID int32) (int32, error) { + ret := _m.Called(ctx, jobID) if len(ret) == 0 { panic("no return value specified for CountPipelineRunsByJobID") @@ -72,17 +72,17 @@ func (_m *ORM) CountPipelineRunsByJobID(jobID int32) (int32, error) { var r0 int32 var r1 error - if rf, ok := ret.Get(0).(func(int32) (int32, error)); ok { - return rf(jobID) + if rf, ok := ret.Get(0).(func(context.Context, int32) (int32, error)); ok { + return rf(ctx, jobID) } - if rf, ok := ret.Get(0).(func(int32) int32); ok { - r0 = rf(jobID) + if rf, ok := ret.Get(0).(func(context.Context, int32) int32); ok { + r0 = rf(ctx, jobID) } else { r0 = ret.Get(0).(int32) } - if rf, ok := ret.Get(1).(func(int32) error); ok { - r1 = rf(jobID) + if rf, ok := ret.Get(1).(func(context.Context, int32) error); ok { + r1 = rf(ctx, jobID) } else { r1 = ret.Error(1) } @@ -90,24 +90,17 @@ func (_m *ORM) CountPipelineRunsByJobID(jobID int32) (int32, error) { return r0, r1 } -// CreateJob provides a mock function with given fields: jb, qopts -func (_m *ORM) CreateJob(jb *job.Job, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, jb) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// CreateJob provides a mock function with given fields: ctx, jb +func (_m *ORM) CreateJob(ctx context.Context, jb *job.Job) error { + ret := _m.Called(ctx, jb) if len(ret) == 0 { panic("no return value specified for CreateJob") } var r0 error - if rf, ok := ret.Get(0).(func(*job.Job, ...pg.QOpt) error); ok { - r0 = rf(jb, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *job.Job) error); ok { + r0 = rf(ctx, jb) } else { r0 = ret.Error(0) } @@ -115,24 +108,37 @@ func (_m *ORM) CreateJob(jb *job.Job, qopts ...pg.QOpt) error { return r0 } -// DeleteJob provides a mock function with given fields: id, qopts -func (_m *ORM) DeleteJob(id int32, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] +// DataSource provides a mock function with given fields: +func (_m *ORM) DataSource() sqlutil.DataSource { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for DataSource") + } + + var r0 sqlutil.DataSource + if rf, ok := ret.Get(0).(func() sqlutil.DataSource); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(sqlutil.DataSource) + } } - var _ca []interface{} - _ca = append(_ca, id) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) + + return r0 +} + +// DeleteJob provides a mock function with given fields: ctx, id +func (_m *ORM) DeleteJob(ctx context.Context, id int32) error { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for DeleteJob") } var r0 error - if rf, ok := ret.Get(0).(func(int32, ...pg.QOpt) error); ok { - r0 = rf(id, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int32) error); ok { + r0 = rf(ctx, id) } else { r0 = ret.Error(0) } @@ -186,16 +192,9 @@ func (_m *ORM) FindJob(ctx context.Context, id int32) (job.Job, error) { return r0, r1 } -// FindJobByExternalJobID provides a mock function with given fields: _a0, qopts -func (_m *ORM) FindJobByExternalJobID(_a0 uuid.UUID, qopts ...pg.QOpt) (job.Job, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, _a0) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// FindJobByExternalJobID provides a mock function with given fields: ctx, _a1 +func (_m *ORM) FindJobByExternalJobID(ctx context.Context, _a1 uuid.UUID) (job.Job, error) { + ret := _m.Called(ctx, _a1) if len(ret) == 0 { panic("no return value specified for FindJobByExternalJobID") @@ -203,17 +202,17 @@ func (_m *ORM) FindJobByExternalJobID(_a0 uuid.UUID, qopts ...pg.QOpt) (job.Job, var r0 job.Job var r1 error - if rf, ok := ret.Get(0).(func(uuid.UUID, ...pg.QOpt) (job.Job, error)); ok { - return rf(_a0, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) (job.Job, error)); ok { + return rf(ctx, _a1) } - if rf, ok := ret.Get(0).(func(uuid.UUID, ...pg.QOpt) job.Job); ok { - r0 = rf(_a0, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) job.Job); ok { + r0 = rf(ctx, _a1) } else { r0 = ret.Get(0).(job.Job) } - if rf, ok := ret.Get(1).(func(uuid.UUID, ...pg.QOpt) error); ok { - r1 = rf(_a0, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, uuid.UUID) error); ok { + r1 = rf(ctx, _a1) } else { r1 = ret.Error(1) } @@ -221,16 +220,9 @@ func (_m *ORM) FindJobByExternalJobID(_a0 uuid.UUID, qopts ...pg.QOpt) (job.Job, return r0, r1 } -// FindJobIDByAddress provides a mock function with given fields: address, evmChainID, qopts -func (_m *ORM) FindJobIDByAddress(address types.EIP55Address, evmChainID *big.Big, qopts ...pg.QOpt) (int32, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, address, evmChainID) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// FindJobIDByAddress provides a mock function with given fields: ctx, address, evmChainID +func (_m *ORM) FindJobIDByAddress(ctx context.Context, address types.EIP55Address, evmChainID *big.Big) (int32, error) { + ret := _m.Called(ctx, address, evmChainID) if len(ret) == 0 { panic("no return value specified for FindJobIDByAddress") @@ -238,17 +230,17 @@ func (_m *ORM) FindJobIDByAddress(address types.EIP55Address, evmChainID *big.Bi var r0 int32 var r1 error - if rf, ok := ret.Get(0).(func(types.EIP55Address, *big.Big, ...pg.QOpt) (int32, error)); ok { - return rf(address, evmChainID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, types.EIP55Address, *big.Big) (int32, error)); ok { + return rf(ctx, address, evmChainID) } - if rf, ok := ret.Get(0).(func(types.EIP55Address, *big.Big, ...pg.QOpt) int32); ok { - r0 = rf(address, evmChainID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, types.EIP55Address, *big.Big) int32); ok { + r0 = rf(ctx, address, evmChainID) } else { r0 = ret.Get(0).(int32) } - if rf, ok := ret.Get(1).(func(types.EIP55Address, *big.Big, ...pg.QOpt) error); ok { - r1 = rf(address, evmChainID, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, types.EIP55Address, *big.Big) error); ok { + r1 = rf(ctx, address, evmChainID) } else { r1 = ret.Error(1) } @@ -256,9 +248,9 @@ func (_m *ORM) FindJobIDByAddress(address types.EIP55Address, evmChainID *big.Bi return r0, r1 } -// FindJobIDsWithBridge provides a mock function with given fields: name -func (_m *ORM) FindJobIDsWithBridge(name string) ([]int32, error) { - ret := _m.Called(name) +// FindJobIDsWithBridge provides a mock function with given fields: ctx, name +func (_m *ORM) FindJobIDsWithBridge(ctx context.Context, name string) ([]int32, error) { + ret := _m.Called(ctx, name) if len(ret) == 0 { panic("no return value specified for FindJobIDsWithBridge") @@ -266,19 +258,19 @@ func (_m *ORM) FindJobIDsWithBridge(name string) ([]int32, error) { var r0 []int32 var r1 error - if rf, ok := ret.Get(0).(func(string) ([]int32, error)); ok { - return rf(name) + if rf, ok := ret.Get(0).(func(context.Context, string) ([]int32, error)); ok { + return rf(ctx, name) } - if rf, ok := ret.Get(0).(func(string) []int32); ok { - r0 = rf(name) + if rf, ok := ret.Get(0).(func(context.Context, string) []int32); ok { + r0 = rf(ctx, name) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]int32) } } - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(name) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, name) } else { r1 = ret.Error(1) } @@ -314,9 +306,9 @@ func (_m *ORM) FindJobTx(ctx context.Context, id int32) (job.Job, error) { return r0, r1 } -// FindJobWithoutSpecErrors provides a mock function with given fields: id -func (_m *ORM) FindJobWithoutSpecErrors(id int32) (job.Job, error) { - ret := _m.Called(id) +// FindJobWithoutSpecErrors provides a mock function with given fields: ctx, id +func (_m *ORM) FindJobWithoutSpecErrors(ctx context.Context, id int32) (job.Job, error) { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for FindJobWithoutSpecErrors") @@ -324,17 +316,17 @@ func (_m *ORM) FindJobWithoutSpecErrors(id int32) (job.Job, error) { var r0 job.Job var r1 error - if rf, ok := ret.Get(0).(func(int32) (job.Job, error)); ok { - return rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int32) (job.Job, error)); ok { + return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(int32) job.Job); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int32) job.Job); ok { + r0 = rf(ctx, id) } else { r0 = ret.Get(0).(job.Job) } - if rf, ok := ret.Get(1).(func(int32) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int32) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -342,9 +334,9 @@ func (_m *ORM) FindJobWithoutSpecErrors(id int32) (job.Job, error) { return r0, r1 } -// FindJobs provides a mock function with given fields: offset, limit -func (_m *ORM) FindJobs(offset int, limit int) ([]job.Job, int, error) { - ret := _m.Called(offset, limit) +// FindJobs provides a mock function with given fields: ctx, offset, limit +func (_m *ORM) FindJobs(ctx context.Context, offset int, limit int) ([]job.Job, int, error) { + ret := _m.Called(ctx, offset, limit) if len(ret) == 0 { panic("no return value specified for FindJobs") @@ -353,25 +345,25 @@ func (_m *ORM) FindJobs(offset int, limit int) ([]job.Job, int, error) { var r0 []job.Job var r1 int var r2 error - if rf, ok := ret.Get(0).(func(int, int) ([]job.Job, int, error)); ok { - return rf(offset, limit) + if rf, ok := ret.Get(0).(func(context.Context, int, int) ([]job.Job, int, error)); ok { + return rf(ctx, offset, limit) } - if rf, ok := ret.Get(0).(func(int, int) []job.Job); ok { - r0 = rf(offset, limit) + if rf, ok := ret.Get(0).(func(context.Context, int, int) []job.Job); ok { + r0 = rf(ctx, offset, limit) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]job.Job) } } - if rf, ok := ret.Get(1).(func(int, int) int); ok { - r1 = rf(offset, limit) + if rf, ok := ret.Get(1).(func(context.Context, int, int) int); ok { + r1 = rf(ctx, offset, limit) } else { r1 = ret.Get(1).(int) } - if rf, ok := ret.Get(2).(func(int, int) error); ok { - r2 = rf(offset, limit) + if rf, ok := ret.Get(2).(func(context.Context, int, int) error); ok { + r2 = rf(ctx, offset, limit) } else { r2 = ret.Error(2) } @@ -379,9 +371,9 @@ func (_m *ORM) FindJobs(offset int, limit int) ([]job.Job, int, error) { return r0, r1, r2 } -// FindJobsByPipelineSpecIDs provides a mock function with given fields: ids -func (_m *ORM) FindJobsByPipelineSpecIDs(ids []int32) ([]job.Job, error) { - ret := _m.Called(ids) +// FindJobsByPipelineSpecIDs provides a mock function with given fields: ctx, ids +func (_m *ORM) FindJobsByPipelineSpecIDs(ctx context.Context, ids []int32) ([]job.Job, error) { + ret := _m.Called(ctx, ids) if len(ret) == 0 { panic("no return value specified for FindJobsByPipelineSpecIDs") @@ -389,19 +381,19 @@ func (_m *ORM) FindJobsByPipelineSpecIDs(ids []int32) ([]job.Job, error) { var r0 []job.Job var r1 error - if rf, ok := ret.Get(0).(func([]int32) ([]job.Job, error)); ok { - return rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int32) ([]job.Job, error)); ok { + return rf(ctx, ids) } - if rf, ok := ret.Get(0).(func([]int32) []job.Job); ok { - r0 = rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int32) []job.Job); ok { + r0 = rf(ctx, ids) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]job.Job) } } - if rf, ok := ret.Get(1).(func([]int32) error); ok { - r1 = rf(ids) + if rf, ok := ret.Get(1).(func(context.Context, []int32) error); ok { + r1 = rf(ctx, ids) } else { r1 = ret.Error(1) } @@ -409,16 +401,9 @@ func (_m *ORM) FindJobsByPipelineSpecIDs(ids []int32) ([]job.Job, error) { return r0, r1 } -// FindOCR2JobIDByAddress provides a mock function with given fields: contractID, feedID, qopts -func (_m *ORM) FindOCR2JobIDByAddress(contractID string, feedID *common.Hash, qopts ...pg.QOpt) (int32, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, contractID, feedID) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// FindOCR2JobIDByAddress provides a mock function with given fields: ctx, contractID, feedID +func (_m *ORM) FindOCR2JobIDByAddress(ctx context.Context, contractID string, feedID *common.Hash) (int32, error) { + ret := _m.Called(ctx, contractID, feedID) if len(ret) == 0 { panic("no return value specified for FindOCR2JobIDByAddress") @@ -426,17 +411,17 @@ func (_m *ORM) FindOCR2JobIDByAddress(contractID string, feedID *common.Hash, qo var r0 int32 var r1 error - if rf, ok := ret.Get(0).(func(string, *common.Hash, ...pg.QOpt) (int32, error)); ok { - return rf(contractID, feedID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, string, *common.Hash) (int32, error)); ok { + return rf(ctx, contractID, feedID) } - if rf, ok := ret.Get(0).(func(string, *common.Hash, ...pg.QOpt) int32); ok { - r0 = rf(contractID, feedID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, string, *common.Hash) int32); ok { + r0 = rf(ctx, contractID, feedID) } else { r0 = ret.Get(0).(int32) } - if rf, ok := ret.Get(1).(func(string, *common.Hash, ...pg.QOpt) error); ok { - r1 = rf(contractID, feedID, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, string, *common.Hash) error); ok { + r1 = rf(ctx, contractID, feedID) } else { r1 = ret.Error(1) } @@ -444,9 +429,9 @@ func (_m *ORM) FindOCR2JobIDByAddress(contractID string, feedID *common.Hash, qo return r0, r1 } -// FindPipelineRunByID provides a mock function with given fields: id -func (_m *ORM) FindPipelineRunByID(id int64) (pipeline.Run, error) { - ret := _m.Called(id) +// FindPipelineRunByID provides a mock function with given fields: ctx, id +func (_m *ORM) FindPipelineRunByID(ctx context.Context, id int64) (pipeline.Run, error) { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for FindPipelineRunByID") @@ -454,17 +439,17 @@ func (_m *ORM) FindPipelineRunByID(id int64) (pipeline.Run, error) { var r0 pipeline.Run var r1 error - if rf, ok := ret.Get(0).(func(int64) (pipeline.Run, error)); ok { - return rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) (pipeline.Run, error)); ok { + return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(int64) pipeline.Run); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) pipeline.Run); ok { + r0 = rf(ctx, id) } else { r0 = ret.Get(0).(pipeline.Run) } - if rf, ok := ret.Get(1).(func(int64) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -472,9 +457,9 @@ func (_m *ORM) FindPipelineRunByID(id int64) (pipeline.Run, error) { return r0, r1 } -// FindPipelineRunIDsByJobID provides a mock function with given fields: jobID, offset, limit -func (_m *ORM) FindPipelineRunIDsByJobID(jobID int32, offset int, limit int) ([]int64, error) { - ret := _m.Called(jobID, offset, limit) +// FindPipelineRunIDsByJobID provides a mock function with given fields: ctx, jobID, offset, limit +func (_m *ORM) FindPipelineRunIDsByJobID(ctx context.Context, jobID int32, offset int, limit int) ([]int64, error) { + ret := _m.Called(ctx, jobID, offset, limit) if len(ret) == 0 { panic("no return value specified for FindPipelineRunIDsByJobID") @@ -482,19 +467,19 @@ func (_m *ORM) FindPipelineRunIDsByJobID(jobID int32, offset int, limit int) ([] var r0 []int64 var r1 error - if rf, ok := ret.Get(0).(func(int32, int, int) ([]int64, error)); ok { - return rf(jobID, offset, limit) + if rf, ok := ret.Get(0).(func(context.Context, int32, int, int) ([]int64, error)); ok { + return rf(ctx, jobID, offset, limit) } - if rf, ok := ret.Get(0).(func(int32, int, int) []int64); ok { - r0 = rf(jobID, offset, limit) + if rf, ok := ret.Get(0).(func(context.Context, int32, int, int) []int64); ok { + r0 = rf(ctx, jobID, offset, limit) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]int64) } } - if rf, ok := ret.Get(1).(func(int32, int, int) error); ok { - r1 = rf(jobID, offset, limit) + if rf, ok := ret.Get(1).(func(context.Context, int32, int, int) error); ok { + r1 = rf(ctx, jobID, offset, limit) } else { r1 = ret.Error(1) } @@ -502,9 +487,9 @@ func (_m *ORM) FindPipelineRunIDsByJobID(jobID int32, offset int, limit int) ([] return r0, r1 } -// FindPipelineRunsByIDs provides a mock function with given fields: ids -func (_m *ORM) FindPipelineRunsByIDs(ids []int64) ([]pipeline.Run, error) { - ret := _m.Called(ids) +// FindPipelineRunsByIDs provides a mock function with given fields: ctx, ids +func (_m *ORM) FindPipelineRunsByIDs(ctx context.Context, ids []int64) ([]pipeline.Run, error) { + ret := _m.Called(ctx, ids) if len(ret) == 0 { panic("no return value specified for FindPipelineRunsByIDs") @@ -512,19 +497,19 @@ func (_m *ORM) FindPipelineRunsByIDs(ids []int64) ([]pipeline.Run, error) { var r0 []pipeline.Run var r1 error - if rf, ok := ret.Get(0).(func([]int64) ([]pipeline.Run, error)); ok { - return rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int64) ([]pipeline.Run, error)); ok { + return rf(ctx, ids) } - if rf, ok := ret.Get(0).(func([]int64) []pipeline.Run); ok { - r0 = rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int64) []pipeline.Run); ok { + r0 = rf(ctx, ids) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]pipeline.Run) } } - if rf, ok := ret.Get(1).(func([]int64) error); ok { - r1 = rf(ids) + if rf, ok := ret.Get(1).(func(context.Context, []int64) error); ok { + r1 = rf(ctx, ids) } else { r1 = ret.Error(1) } @@ -532,16 +517,9 @@ func (_m *ORM) FindPipelineRunsByIDs(ids []int64) ([]pipeline.Run, error) { return r0, r1 } -// FindSpecError provides a mock function with given fields: id, qopts -func (_m *ORM) FindSpecError(id int64, qopts ...pg.QOpt) (job.SpecError, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, id) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// FindSpecError provides a mock function with given fields: ctx, id +func (_m *ORM) FindSpecError(ctx context.Context, id int64) (job.SpecError, error) { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for FindSpecError") @@ -549,17 +527,17 @@ func (_m *ORM) FindSpecError(id int64, qopts ...pg.QOpt) (job.SpecError, error) var r0 job.SpecError var r1 error - if rf, ok := ret.Get(0).(func(int64, ...pg.QOpt) (job.SpecError, error)); ok { - return rf(id, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64) (job.SpecError, error)); ok { + return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(int64, ...pg.QOpt) job.SpecError); ok { - r0 = rf(id, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64) job.SpecError); ok { + r0 = rf(ctx, id) } else { r0 = ret.Get(0).(job.SpecError) } - if rf, ok := ret.Get(1).(func(int64, ...pg.QOpt) error); ok { - r1 = rf(id, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -567,16 +545,9 @@ func (_m *ORM) FindSpecError(id int64, qopts ...pg.QOpt) (job.SpecError, error) return r0, r1 } -// FindSpecErrorsByJobIDs provides a mock function with given fields: ids, qopts -func (_m *ORM) FindSpecErrorsByJobIDs(ids []int32, qopts ...pg.QOpt) ([]job.SpecError, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, ids) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// FindSpecErrorsByJobIDs provides a mock function with given fields: ctx, ids +func (_m *ORM) FindSpecErrorsByJobIDs(ctx context.Context, ids []int32) ([]job.SpecError, error) { + ret := _m.Called(ctx, ids) if len(ret) == 0 { panic("no return value specified for FindSpecErrorsByJobIDs") @@ -584,19 +555,19 @@ func (_m *ORM) FindSpecErrorsByJobIDs(ids []int32, qopts ...pg.QOpt) ([]job.Spec var r0 []job.SpecError var r1 error - if rf, ok := ret.Get(0).(func([]int32, ...pg.QOpt) ([]job.SpecError, error)); ok { - return rf(ids, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, []int32) ([]job.SpecError, error)); ok { + return rf(ctx, ids) } - if rf, ok := ret.Get(0).(func([]int32, ...pg.QOpt) []job.SpecError); ok { - r0 = rf(ids, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, []int32) []job.SpecError); ok { + r0 = rf(ctx, ids) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]job.SpecError) } } - if rf, ok := ret.Get(1).(func([]int32, ...pg.QOpt) error); ok { - r1 = rf(ids, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, []int32) error); ok { + r1 = rf(ctx, ids) } else { r1 = ret.Error(1) } @@ -604,16 +575,9 @@ func (_m *ORM) FindSpecErrorsByJobIDs(ids []int32, qopts ...pg.QOpt) ([]job.Spec return r0, r1 } -// FindTaskResultByRunIDAndTaskName provides a mock function with given fields: runID, taskName, qopts -func (_m *ORM) FindTaskResultByRunIDAndTaskName(runID int64, taskName string, qopts ...pg.QOpt) ([]byte, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, runID, taskName) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// FindTaskResultByRunIDAndTaskName provides a mock function with given fields: ctx, runID, taskName +func (_m *ORM) FindTaskResultByRunIDAndTaskName(ctx context.Context, runID int64, taskName string) ([]byte, error) { + ret := _m.Called(ctx, runID, taskName) if len(ret) == 0 { panic("no return value specified for FindTaskResultByRunIDAndTaskName") @@ -621,19 +585,19 @@ func (_m *ORM) FindTaskResultByRunIDAndTaskName(runID int64, taskName string, qo var r0 []byte var r1 error - if rf, ok := ret.Get(0).(func(int64, string, ...pg.QOpt) ([]byte, error)); ok { - return rf(runID, taskName, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64, string) ([]byte, error)); ok { + return rf(ctx, runID, taskName) } - if rf, ok := ret.Get(0).(func(int64, string, ...pg.QOpt) []byte); ok { - r0 = rf(runID, taskName, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int64, string) []byte); ok { + r0 = rf(ctx, runID, taskName) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]byte) } } - if rf, ok := ret.Get(1).(func(int64, string, ...pg.QOpt) error); ok { - r1 = rf(runID, taskName, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, int64, string) error); ok { + r1 = rf(ctx, runID, taskName) } else { r1 = ret.Error(1) } @@ -641,24 +605,17 @@ func (_m *ORM) FindTaskResultByRunIDAndTaskName(runID int64, taskName string, qo return r0, r1 } -// InsertJob provides a mock function with given fields: _a0, qopts -func (_m *ORM) InsertJob(_a0 *job.Job, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, _a0) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// InsertJob provides a mock function with given fields: ctx, _a1 +func (_m *ORM) InsertJob(ctx context.Context, _a1 *job.Job) error { + ret := _m.Called(ctx, _a1) if len(ret) == 0 { panic("no return value specified for InsertJob") } var r0 error - if rf, ok := ret.Get(0).(func(*job.Job, ...pg.QOpt) error); ok { - r0 = rf(_a0, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *job.Job) error); ok { + r0 = rf(ctx, _a1) } else { r0 = ret.Error(0) } @@ -666,24 +623,17 @@ func (_m *ORM) InsertJob(_a0 *job.Job, qopts ...pg.QOpt) error { return r0 } -// InsertWebhookSpec provides a mock function with given fields: webhookSpec, qopts -func (_m *ORM) InsertWebhookSpec(webhookSpec *job.WebhookSpec, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, webhookSpec) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// InsertWebhookSpec provides a mock function with given fields: ctx, webhookSpec +func (_m *ORM) InsertWebhookSpec(ctx context.Context, webhookSpec *job.WebhookSpec) error { + ret := _m.Called(ctx, webhookSpec) if len(ret) == 0 { panic("no return value specified for InsertWebhookSpec") } var r0 error - if rf, ok := ret.Get(0).(func(*job.WebhookSpec, ...pg.QOpt) error); ok { - r0 = rf(webhookSpec, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *job.WebhookSpec) error); ok { + r0 = rf(ctx, webhookSpec) } else { r0 = ret.Error(0) } @@ -691,9 +641,9 @@ func (_m *ORM) InsertWebhookSpec(webhookSpec *job.WebhookSpec, qopts ...pg.QOpt) return r0 } -// PipelineRuns provides a mock function with given fields: jobID, offset, size -func (_m *ORM) PipelineRuns(jobID *int32, offset int, size int) ([]pipeline.Run, int, error) { - ret := _m.Called(jobID, offset, size) +// PipelineRuns provides a mock function with given fields: ctx, jobID, offset, size +func (_m *ORM) PipelineRuns(ctx context.Context, jobID *int32, offset int, size int) ([]pipeline.Run, int, error) { + ret := _m.Called(ctx, jobID, offset, size) if len(ret) == 0 { panic("no return value specified for PipelineRuns") @@ -702,25 +652,25 @@ func (_m *ORM) PipelineRuns(jobID *int32, offset int, size int) ([]pipeline.Run, var r0 []pipeline.Run var r1 int var r2 error - if rf, ok := ret.Get(0).(func(*int32, int, int) ([]pipeline.Run, int, error)); ok { - return rf(jobID, offset, size) + if rf, ok := ret.Get(0).(func(context.Context, *int32, int, int) ([]pipeline.Run, int, error)); ok { + return rf(ctx, jobID, offset, size) } - if rf, ok := ret.Get(0).(func(*int32, int, int) []pipeline.Run); ok { - r0 = rf(jobID, offset, size) + if rf, ok := ret.Get(0).(func(context.Context, *int32, int, int) []pipeline.Run); ok { + r0 = rf(ctx, jobID, offset, size) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]pipeline.Run) } } - if rf, ok := ret.Get(1).(func(*int32, int, int) int); ok { - r1 = rf(jobID, offset, size) + if rf, ok := ret.Get(1).(func(context.Context, *int32, int, int) int); ok { + r1 = rf(ctx, jobID, offset, size) } else { r1 = ret.Get(1).(int) } - if rf, ok := ret.Get(2).(func(*int32, int, int) error); ok { - r2 = rf(jobID, offset, size) + if rf, ok := ret.Get(2).(func(context.Context, *int32, int, int) error); ok { + r2 = rf(ctx, jobID, offset, size) } else { r2 = ret.Error(2) } @@ -728,24 +678,17 @@ func (_m *ORM) PipelineRuns(jobID *int32, offset int, size int) ([]pipeline.Run, return r0, r1, r2 } -// RecordError provides a mock function with given fields: jobID, description, qopts -func (_m *ORM) RecordError(jobID int32, description string, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, jobID, description) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// RecordError provides a mock function with given fields: ctx, jobID, description +func (_m *ORM) RecordError(ctx context.Context, jobID int32, description string) error { + ret := _m.Called(ctx, jobID, description) if len(ret) == 0 { panic("no return value specified for RecordError") } var r0 error - if rf, ok := ret.Get(0).(func(int32, string, ...pg.QOpt) error); ok { - r0 = rf(jobID, description, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, int32, string) error); ok { + r0 = rf(ctx, jobID, description) } else { r0 = ret.Error(0) } @@ -753,16 +696,29 @@ func (_m *ORM) RecordError(jobID int32, description string, qopts ...pg.QOpt) er return r0 } -// TryRecordError provides a mock function with given fields: jobID, description, qopts -func (_m *ORM) TryRecordError(jobID int32, description string, qopts ...pg.QOpt) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] +// TryRecordError provides a mock function with given fields: ctx, jobID, description +func (_m *ORM) TryRecordError(ctx context.Context, jobID int32, description string) { + _m.Called(ctx, jobID, description) +} + +// WithDataSource provides a mock function with given fields: source +func (_m *ORM) WithDataSource(source sqlutil.DataSource) job.ORM { + ret := _m.Called(source) + + if len(ret) == 0 { + panic("no return value specified for WithDataSource") } - var _ca []interface{} - _ca = append(_ca, jobID, description) - _ca = append(_ca, _va...) - _m.Called(_ca...) + + var r0 job.ORM + if rf, ok := ret.Get(0).(func(sqlutil.DataSource) job.ORM); ok { + r0 = rf(source) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(job.ORM) + } + } + + return r0 } // NewORM creates a new instance of ORM. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. diff --git a/core/services/job/mocks/spawner.go b/core/services/job/mocks/spawner.go index 37e883ef3c5..7127636cdbb 100644 --- a/core/services/job/mocks/spawner.go +++ b/core/services/job/mocks/spawner.go @@ -8,7 +8,7 @@ import ( job "github.com/smartcontractkit/chainlink/v2/core/services/job" mock "github.com/stretchr/testify/mock" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" + sqlutil "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" ) // Spawner is an autogenerated mock type for the Spawner type @@ -54,24 +54,17 @@ func (_m *Spawner) Close() error { return r0 } -// CreateJob provides a mock function with given fields: jb, qopts -func (_m *Spawner) CreateJob(jb *job.Job, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, jb) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// CreateJob provides a mock function with given fields: ctx, ds, jb +func (_m *Spawner) CreateJob(ctx context.Context, ds sqlutil.DataSource, jb *job.Job) error { + ret := _m.Called(ctx, ds, jb) if len(ret) == 0 { panic("no return value specified for CreateJob") } var r0 error - if rf, ok := ret.Get(0).(func(*job.Job, ...pg.QOpt) error); ok { - r0 = rf(jb, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, sqlutil.DataSource, *job.Job) error); ok { + r0 = rf(ctx, ds, jb) } else { r0 = ret.Error(0) } @@ -79,24 +72,17 @@ func (_m *Spawner) CreateJob(jb *job.Job, qopts ...pg.QOpt) error { return r0 } -// DeleteJob provides a mock function with given fields: jobID, qopts -func (_m *Spawner) DeleteJob(jobID int32, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, jobID) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// DeleteJob provides a mock function with given fields: ctx, ds, jobID +func (_m *Spawner) DeleteJob(ctx context.Context, ds sqlutil.DataSource, jobID int32) error { + ret := _m.Called(ctx, ds, jobID) if len(ret) == 0 { panic("no return value specified for DeleteJob") } var r0 error - if rf, ok := ret.Get(0).(func(int32, ...pg.QOpt) error); ok { - r0 = rf(jobID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, sqlutil.DataSource, int32) error); ok { + r0 = rf(ctx, ds, jobID) } else { r0 = ret.Error(0) } @@ -178,24 +164,17 @@ func (_m *Spawner) Start(_a0 context.Context) error { return r0 } -// StartService provides a mock function with given fields: ctx, spec, qopts -func (_m *Spawner) StartService(ctx context.Context, spec job.Job, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, spec) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// StartService provides a mock function with given fields: ctx, spec +func (_m *Spawner) StartService(ctx context.Context, spec job.Job) error { + ret := _m.Called(ctx, spec) if len(ret) == 0 { panic("no return value specified for StartService") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, job.Job, ...pg.QOpt) error); ok { - r0 = rf(ctx, spec, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, job.Job) error); ok { + r0 = rf(ctx, spec) } else { r0 = ret.Error(0) } diff --git a/core/services/job/orm.go b/core/services/job/orm.go index f7238799634..2b2f73396dc 100644 --- a/core/services/job/orm.go +++ b/core/services/job/orm.go @@ -18,6 +18,7 @@ import ( "github.com/jmoiron/sqlx" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink/v2/core/bridges" @@ -29,7 +30,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/null" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" medianconfig "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/median/config" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/store/models" ) @@ -44,37 +44,40 @@ var ( //go:generate mockery --quiet --name ORM --output ./mocks/ --case=underscore type ORM interface { - InsertWebhookSpec(webhookSpec *WebhookSpec, qopts ...pg.QOpt) error - InsertJob(job *Job, qopts ...pg.QOpt) error - CreateJob(jb *Job, qopts ...pg.QOpt) error - FindJobs(offset, limit int) ([]Job, int, error) + InsertWebhookSpec(ctx context.Context, webhookSpec *WebhookSpec) error + InsertJob(ctx context.Context, job *Job) error + CreateJob(ctx context.Context, jb *Job) error + FindJobs(ctx context.Context, offset, limit int) ([]Job, int, error) FindJobTx(ctx context.Context, id int32) (Job, error) FindJob(ctx context.Context, id int32) (Job, error) - FindJobByExternalJobID(uuid uuid.UUID, qopts ...pg.QOpt) (Job, error) - FindJobIDByAddress(address evmtypes.EIP55Address, evmChainID *big.Big, qopts ...pg.QOpt) (int32, error) - FindOCR2JobIDByAddress(contractID string, feedID *common.Hash, qopts ...pg.QOpt) (int32, error) - FindJobIDsWithBridge(name string) ([]int32, error) - DeleteJob(id int32, qopts ...pg.QOpt) error - RecordError(jobID int32, description string, qopts ...pg.QOpt) error + FindJobByExternalJobID(ctx context.Context, uuid uuid.UUID) (Job, error) + FindJobIDByAddress(ctx context.Context, address evmtypes.EIP55Address, evmChainID *big.Big) (int32, error) + FindOCR2JobIDByAddress(ctx context.Context, contractID string, feedID *common.Hash) (int32, error) + FindJobIDsWithBridge(ctx context.Context, name string) ([]int32, error) + DeleteJob(ctx context.Context, id int32) error + RecordError(ctx context.Context, jobID int32, description string) error // TryRecordError is a helper which calls RecordError and logs the returned error if present. - TryRecordError(jobID int32, description string, qopts ...pg.QOpt) + TryRecordError(ctx context.Context, jobID int32, description string) DismissError(ctx context.Context, errorID int64) error - FindSpecError(id int64, qopts ...pg.QOpt) (SpecError, error) + FindSpecError(ctx context.Context, id int64) (SpecError, error) Close() error - PipelineRuns(jobID *int32, offset, size int) ([]pipeline.Run, int, error) + PipelineRuns(ctx context.Context, jobID *int32, offset, size int) ([]pipeline.Run, int, error) - FindPipelineRunIDsByJobID(jobID int32, offset, limit int) (ids []int64, err error) - FindPipelineRunsByIDs(ids []int64) (runs []pipeline.Run, err error) - CountPipelineRunsByJobID(jobID int32) (count int32, err error) + FindPipelineRunIDsByJobID(ctx context.Context, jobID int32, offset, limit int) (ids []int64, err error) + FindPipelineRunsByIDs(ctx context.Context, ids []int64) (runs []pipeline.Run, err error) + CountPipelineRunsByJobID(ctx context.Context, jobID int32) (count int32, err error) - FindJobsByPipelineSpecIDs(ids []int32) ([]Job, error) - FindPipelineRunByID(id int64) (pipeline.Run, error) + FindJobsByPipelineSpecIDs(ctx context.Context, ids []int32) ([]Job, error) + FindPipelineRunByID(ctx context.Context, id int64) (pipeline.Run, error) - FindSpecErrorsByJobIDs(ids []int32, qopts ...pg.QOpt) ([]SpecError, error) - FindJobWithoutSpecErrors(id int32) (jb Job, err error) + FindSpecErrorsByJobIDs(ctx context.Context, ids []int32) ([]SpecError, error) + FindJobWithoutSpecErrors(ctx context.Context, id int32) (jb Job, err error) - FindTaskResultByRunIDAndTaskName(runID int64, taskName string, qopts ...pg.QOpt) ([]byte, error) + FindTaskResultByRunIDAndTaskName(ctx context.Context, runID int64, taskName string) ([]byte, error) AssertBridgesExist(ctx context.Context, p pipeline.Pipeline) error + + DataSource() sqlutil.DataSource + WithDataSource(source sqlutil.DataSource) ORM } type ORMConfig interface { @@ -82,31 +85,56 @@ type ORMConfig interface { } type orm struct { - q pg.Q + ds sqlutil.DataSource keyStore keystore.Master pipelineORM pipeline.ORM lggr logger.SugaredLogger - cfg pg.QConfig bridgeORM bridges.ORM } var _ ORM = (*orm)(nil) -func NewORM(db *sqlx.DB, pipelineORM pipeline.ORM, bridgeORM bridges.ORM, keyStore keystore.Master, lggr logger.Logger, cfg pg.QConfig) *orm { +func NewORM(ds sqlutil.DataSource, pipelineORM pipeline.ORM, bridgeORM bridges.ORM, keyStore keystore.Master, lggr logger.Logger) *orm { namedLogger := logger.Sugared(lggr.Named("JobORM")) return &orm{ - q: pg.NewQ(db, namedLogger, cfg), + ds: ds, keyStore: keyStore, pipelineORM: pipelineORM, bridgeORM: bridgeORM, lggr: namedLogger, - cfg: cfg, } } + func (o *orm) Close() error { return nil } +func (o *orm) DataSource() sqlutil.DataSource { + return o.ds +} + +func (o *orm) WithDataSource(ds sqlutil.DataSource) ORM { return o.withDataSource(ds) } + +func (o *orm) withDataSource(ds sqlutil.DataSource) *orm { + n := &orm{ + ds: ds, + lggr: o.lggr, + keyStore: o.keyStore, + } + if o.bridgeORM != nil { + n.bridgeORM = o.bridgeORM.WithDataSource(ds) + } + if o.pipelineORM != nil { + n.pipelineORM = o.pipelineORM.WithDataSource(ds) + } + return n +} + +func (o *orm) transact(ctx context.Context, readOnly bool, fn func(*orm) error) error { + opts := &sqlutil.TxOptions{TxOptions: sql.TxOptions{ReadOnly: readOnly}} + return sqlutil.Transact(ctx, o.withDataSource, o.ds, opts, fn) +} + func (o *orm) AssertBridgesExist(ctx context.Context, p pipeline.Pipeline) error { var bridgeNames = make(map[bridges.BridgeName]struct{}) var uniqueBridges []bridges.BridgeName @@ -137,16 +165,14 @@ func (o *orm) AssertBridgesExist(ctx context.Context, p pipeline.Pipeline) error // CreateJob creates the job, and it's associated spec record. // Expects an unmarshalled job spec as the jb argument i.e. output from ValidatedXX. // Scans all persisted records back into jb -func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { - q := o.q.WithOpts(qopts...) +func (o *orm) CreateJob(ctx context.Context, jb *Job) error { p := jb.Pipeline - ctx := context.TODO() // TODO https://smartcontract-it.atlassian.net/browse/BCF-2887 if err := o.AssertBridgesExist(ctx, p); err != nil { return err } var jobID int32 - err := q.Transaction(func(tx pg.Queryer) error { + err := o.transact(ctx, false, func(tx *orm) error { // Autogenerate a job ID if not specified if jb.ExternalJobID == (uuid.UUID{}) { jb.ExternalJobID = uuid.New() @@ -157,26 +183,18 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { if jb.DirectRequestSpec.EVMChainID == nil { return errors.New("evm chain id must be defined") } - var specID int32 - sql := `INSERT INTO direct_request_specs (contract_address, min_incoming_confirmations, requesters, min_contract_payment, evm_chain_id, created_at, updated_at) - VALUES (:contract_address, :min_incoming_confirmations, :requesters, :min_contract_payment, :evm_chain_id, now(), now()) - RETURNING id;` - if err := pg.PrepareQueryRowx(tx, sql, &specID, jb.DirectRequestSpec); err != nil { - return errors.Wrap(err, "failed to create DirectRequestSpec") + specID, err := tx.insertDirectRequestSpec(ctx, jb.DirectRequestSpec) + if err != nil { + return fmt.Errorf("failed to create DirectRequestSpec for jobSpec: %w", err) } jb.DirectRequestSpecID = &specID case FluxMonitor: if jb.FluxMonitorSpec.EVMChainID == nil { return errors.New("evm chain id must be defined") } - var specID int32 - sql := `INSERT INTO flux_monitor_specs (contract_address, threshold, absolute_threshold, poll_timer_period, poll_timer_disabled, idle_timer_period, idle_timer_disabled, - drumbeat_schedule, drumbeat_random_delay, drumbeat_enabled, min_payment, evm_chain_id, created_at, updated_at) - VALUES (:contract_address, :threshold, :absolute_threshold, :poll_timer_period, :poll_timer_disabled, :idle_timer_period, :idle_timer_disabled, - :drumbeat_schedule, :drumbeat_random_delay, :drumbeat_enabled, :min_payment, :evm_chain_id, NOW(), NOW()) - RETURNING id;` - if err := pg.PrepareQueryRowx(tx, sql, &specID, jb.FluxMonitorSpec); err != nil { - return errors.Wrap(err, "failed to create FluxMonitorSpec") + specID, err := tx.insertFluxMonitorSpec(ctx, jb.FluxMonitorSpec) + if err != nil { + return fmt.Errorf("failed to create FluxMonitorSpec for jobSpec: %w", err) } jb.FluxMonitorSpecID = &specID case OffchainReporting: @@ -184,15 +202,14 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { return errors.New("evm chain id must be defined") } - var specID int32 if jb.OCROracleSpec.EncryptedOCRKeyBundleID != nil { - _, err := o.keyStore.OCR().Get(jb.OCROracleSpec.EncryptedOCRKeyBundleID.String()) + _, err := tx.keyStore.OCR().Get(jb.OCROracleSpec.EncryptedOCRKeyBundleID.String()) if err != nil { return errors.Wrapf(ErrNoSuchKeyBundle, "no key bundle with id: %x", jb.OCROracleSpec.EncryptedOCRKeyBundleID) } } if jb.OCROracleSpec.TransmitterAddress != nil { - _, err := o.keyStore.Eth().Get(q.ParentCtx, jb.OCROracleSpec.TransmitterAddress.Hex()) + _, err := tx.keyStore.Eth().Get(ctx, jb.OCROracleSpec.TransmitterAddress.Hex()) if err != nil { return errors.Wrapf(ErrNoSuchTransmitterKey, "no key matching transmitter address: %s", jb.OCROracleSpec.TransmitterAddress.Hex()) } @@ -200,7 +217,7 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { newChainID := jb.OCROracleSpec.EVMChainID existingSpec := new(OCROracleSpec) - err := tx.Get(existingSpec, `SELECT * FROM ocr_oracle_specs WHERE contract_address = $1 and (evm_chain_id = $2 or evm_chain_id IS NULL) LIMIT 1;`, + err := tx.ds.GetContext(ctx, existingSpec, `SELECT * FROM ocr_oracle_specs WHERE contract_address = $1 and (evm_chain_id = $2 or evm_chain_id IS NULL) LIMIT 1;`, jb.OCROracleSpec.ContractAddress, newChainID, ) @@ -212,23 +229,14 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { return errors.Errorf("a job with contract address %s already exists for chain ID %s", jb.OCROracleSpec.ContractAddress, newChainID) } - sql := `INSERT INTO ocr_oracle_specs (contract_address, p2pv2_bootstrappers, is_bootstrap_peer, encrypted_ocr_key_bundle_id, transmitter_address, - observation_timeout, blockchain_timeout, contract_config_tracker_subscribe_interval, contract_config_tracker_poll_interval, contract_config_confirmations, evm_chain_id, - created_at, updated_at, database_timeout, observation_grace_period, contract_transmitter_transmit_timeout) - VALUES (:contract_address, :p2pv2_bootstrappers, :is_bootstrap_peer, :encrypted_ocr_key_bundle_id, :transmitter_address, - :observation_timeout, :blockchain_timeout, :contract_config_tracker_subscribe_interval, :contract_config_tracker_poll_interval, :contract_config_confirmations, :evm_chain_id, - NOW(), NOW(), :database_timeout, :observation_grace_period, :contract_transmitter_transmit_timeout) - RETURNING id;` - err = pg.PrepareQueryRowx(tx, sql, &specID, jb.OCROracleSpec) + specID, err := tx.insertOCROracleSpec(ctx, jb.OCROracleSpec) if err != nil { - return errors.Wrap(err, "failed to create OffchainreportingOracleSpec") + return fmt.Errorf("failed to create OCROracleSpec for jobSpec: %w", err) } jb.OCROracleSpecID = &specID case OffchainReporting2: - var specID int32 - if jb.OCR2OracleSpec.OCRKeyBundleID.Valid { - _, err := o.keyStore.OCR2().Get(jb.OCR2OracleSpec.OCRKeyBundleID.String) + _, err := tx.keyStore.OCR2().Get(jb.OCR2OracleSpec.OCRKeyBundleID.String) if err != nil { return errors.Wrapf(ErrNoSuchKeyBundle, "no key bundle with id: %q", jb.OCR2OracleSpec.OCRKeyBundleID.ValueOrZero()) } @@ -239,7 +247,7 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { } // checks if they are present and if they are valid - sendingKeysDefined, err := areSendingKeysDefined(q.ParentCtx, jb, o.keyStore) + sendingKeysDefined, err := areSendingKeysDefined(ctx, jb, tx.keyStore) if err != nil { return err } @@ -249,7 +257,7 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { } if !sendingKeysDefined { - if err = ValidateKeyStoreMatch(q.ParentCtx, jb.OCR2OracleSpec, o.keyStore, jb.OCR2OracleSpec.TransmitterID.String); err != nil { + if err = ValidateKeyStoreMatch(ctx, jb.OCR2OracleSpec, tx.keyStore, jb.OCR2OracleSpec.TransmitterID.String); err != nil { return errors.Wrap(ErrNoSuchTransmitterKey, err.Error()) } } @@ -278,66 +286,36 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { if err2 != nil { return err2 } - if err2 = o.AssertBridgesExist(ctx, *feePipeline); err2 != nil { + if err2 = tx.AssertBridgesExist(ctx, *feePipeline); err2 != nil { return err2 } } - sql := `INSERT INTO ocr2_oracle_specs (contract_id, feed_id, relay, relay_config, plugin_type, plugin_config, p2pv2_bootstrappers, ocr_key_bundle_id, transmitter_id, - blockchain_timeout, contract_config_tracker_poll_interval, contract_config_confirmations, - created_at, updated_at) - VALUES (:contract_id, :feed_id, :relay, :relay_config, :plugin_type, :plugin_config, :p2pv2_bootstrappers, :ocr_key_bundle_id, :transmitter_id, - :blockchain_timeout, :contract_config_tracker_poll_interval, :contract_config_confirmations, - NOW(), NOW()) - RETURNING id;` - err = pg.PrepareQueryRowx(tx, sql, &specID, jb.OCR2OracleSpec) + specID, err := tx.insertOCR2OracleSpec(ctx, jb.OCR2OracleSpec) if err != nil { - return errors.Wrap(err, "failed to create Offchainreporting2OracleSpec") + return fmt.Errorf("failed to create OCR2OracleSpec for jobSpec: %w", err) } jb.OCR2OracleSpecID = &specID case Keeper: if jb.KeeperSpec.EVMChainID == nil { return errors.New("evm chain id must be defined") } - var specID int32 - sql := `INSERT INTO keeper_specs (contract_address, from_address, evm_chain_id, created_at, updated_at) - VALUES (:contract_address, :from_address, :evm_chain_id, NOW(), NOW()) - RETURNING id;` - if err := pg.PrepareQueryRowx(tx, sql, &specID, jb.KeeperSpec); err != nil { - return errors.Wrap(err, "failed to create KeeperSpec") + specID, err := tx.insertKeeperSpec(ctx, jb.KeeperSpec) + if err != nil { + return fmt.Errorf("failed to create KeeperSpec for jobSpec: %w", err) } jb.KeeperSpecID = &specID case Cron: - var specID int32 - sql := `INSERT INTO cron_specs (cron_schedule, created_at, updated_at) - VALUES (:cron_schedule, NOW(), NOW()) - RETURNING id;` - if err := pg.PrepareQueryRowx(tx, sql, &specID, jb.CronSpec); err != nil { - return errors.Wrap(err, "failed to create CronSpec") + specID, err := tx.insertCronSpec(ctx, jb.CronSpec) + if err != nil { + return fmt.Errorf("failed to create CronSpec for jobSpec: %w", err) } jb.CronSpecID = &specID case VRF: if jb.VRFSpec.EVMChainID == nil { return errors.New("evm chain id must be defined") } - var specID int32 - sql := `INSERT INTO vrf_specs ( - coordinator_address, public_key, min_incoming_confirmations, - evm_chain_id, from_addresses, poll_period, requested_confs_delay, - request_timeout, chunk_size, batch_coordinator_address, batch_fulfillment_enabled, - batch_fulfillment_gas_multiplier, backoff_initial_delay, backoff_max_delay, gas_lane_price, - vrf_owner_address, custom_reverts_pipeline_enabled, - created_at, updated_at) - VALUES ( - :coordinator_address, :public_key, :min_incoming_confirmations, - :evm_chain_id, :from_addresses, :poll_period, :requested_confs_delay, - :request_timeout, :chunk_size, :batch_coordinator_address, :batch_fulfillment_enabled, - :batch_fulfillment_gas_multiplier, :backoff_initial_delay, :backoff_max_delay, :gas_lane_price, - :vrf_owner_address, :custom_reverts_pipeline_enabled, - NOW(), NOW()) - RETURNING id;` - - err := pg.PrepareQueryRowx(tx, sql, &specID, toVRFSpecRow(jb.VRFSpec)) + specID, err := tx.insertVRFSpec(ctx, jb.VRFSpec) var pqErr *pgconn.PgError ok := errors.As(err, &pqErr) if err != nil && ok && pqErr.Code == "23503" { @@ -346,11 +324,11 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { } } if err != nil { - return errors.Wrap(err, "failed to create VRFSpec") + return fmt.Errorf("failed to create VRFSpec for jobSpec: %w", err) } jb.VRFSpecID = &specID case Webhook: - err := o.InsertWebhookSpec(jb.WebhookSpec, pg.WithQueryer(tx)) + err := tx.InsertWebhookSpec(ctx, jb.WebhookSpec) if err != nil { return errors.Wrap(err, "failed to create WebhookSpec") } @@ -362,11 +340,7 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { } sql := `INSERT INTO external_initiator_webhook_specs (external_initiator_id, webhook_spec_id, spec) VALUES (:external_initiator_id, :webhook_spec_id, :spec);` - query, args, err := tx.BindNamed(sql, jb.WebhookSpec.ExternalInitiatorWebhookSpecs) - if err != nil { - return errors.Wrap(err, "failed to bindquery for ExternalInitiatorWebhookSpecs") - } - if _, err = tx.Exec(query, args...); err != nil { + if _, err := tx.ds.NamedExecContext(ctx, sql, jb.WebhookSpec.ExternalInitiatorWebhookSpecs); err != nil { return errors.Wrap(err, "failed to create ExternalInitiatorWebhookSpecs") } } @@ -374,80 +348,58 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { if jb.BlockhashStoreSpec.EVMChainID == nil { return errors.New("evm chain id must be defined") } - var specID int32 - sql := `INSERT INTO blockhash_store_specs (coordinator_v1_address, coordinator_v2_address, coordinator_v2_plus_address, trusted_blockhash_store_address, trusted_blockhash_store_batch_size, wait_blocks, lookback_blocks, heartbeat_period, blockhash_store_address, poll_period, run_timeout, evm_chain_id, from_addresses, created_at, updated_at) - VALUES (:coordinator_v1_address, :coordinator_v2_address, :coordinator_v2_plus_address, :trusted_blockhash_store_address, :trusted_blockhash_store_batch_size, :wait_blocks, :lookback_blocks, :heartbeat_period, :blockhash_store_address, :poll_period, :run_timeout, :evm_chain_id, :from_addresses, NOW(), NOW()) - RETURNING id;` - if err := pg.PrepareQueryRowx(tx, sql, &specID, toBlockhashStoreSpecRow(jb.BlockhashStoreSpec)); err != nil { - return errors.Wrap(err, "failed to create BlockhashStore spec") + specID, err := tx.insertBlockhashStoreSpec(ctx, jb.BlockhashStoreSpec) + if err != nil { + return fmt.Errorf("failed to create BlockhashStoreSpec for jobSpec: %w", err) } jb.BlockhashStoreSpecID = &specID case BlockHeaderFeeder: if jb.BlockHeaderFeederSpec.EVMChainID == nil { return errors.New("evm chain id must be defined") } - var specID int32 - sql := `INSERT INTO block_header_feeder_specs (coordinator_v1_address, coordinator_v2_address, coordinator_v2_plus_address, wait_blocks, lookback_blocks, blockhash_store_address, batch_blockhash_store_address, poll_period, run_timeout, evm_chain_id, from_addresses, get_blockhashes_batch_size, store_blockhashes_batch_size, created_at, updated_at) - VALUES (:coordinator_v1_address, :coordinator_v2_address, :coordinator_v2_plus_address, :wait_blocks, :lookback_blocks, :blockhash_store_address, :batch_blockhash_store_address, :poll_period, :run_timeout, :evm_chain_id, :from_addresses, :get_blockhashes_batch_size, :store_blockhashes_batch_size, NOW(), NOW()) - RETURNING id;` - if err := pg.PrepareQueryRowx(tx, sql, &specID, toBlockHeaderFeederSpecRow(jb.BlockHeaderFeederSpec)); err != nil { - return errors.Wrap(err, "failed to create BlockHeaderFeeder spec") + specID, err := tx.insertBlockHeaderFeederSpec(ctx, jb.BlockHeaderFeederSpec) + if err != nil { + return fmt.Errorf("failed to create BlockHeaderFeederSpec for jobSpec: %w", err) } jb.BlockHeaderFeederSpecID = &specID case LegacyGasStationServer: if jb.LegacyGasStationServerSpec.EVMChainID == nil { return errors.New("evm chain id must be defined") } - var specID int32 - sql := `INSERT INTO legacy_gas_station_server_specs (forwarder_address, evm_chain_id, ccip_chain_selector, from_addresses, created_at, updated_at) - VALUES (:forwarder_address, :evm_chain_id, :ccip_chain_selector, :from_addresses, NOW(), NOW()) - RETURNING id;` - if err := pg.PrepareQueryRowx(tx, sql, &specID, toLegacyGasStationServerSpecRow(jb.LegacyGasStationServerSpec)); err != nil { - return errors.Wrap(err, "failed to create LegacyGasStationServer spec") + specID, err := tx.insertLegacyGasStationServerSpec(ctx, jb.LegacyGasStationServerSpec) + if err != nil { + return fmt.Errorf("failed to create LegacyGasStationServerSpec for jobSpec: %w", err) } jb.LegacyGasStationServerSpecID = &specID case LegacyGasStationSidecar: if jb.LegacyGasStationSidecarSpec.EVMChainID == nil { return errors.New("evm chain id must be defined") } - var specID int32 - sql := `INSERT INTO legacy_gas_station_sidecar_specs (forwarder_address, off_ramp_address, lookback_blocks, poll_period, run_timeout, evm_chain_id, ccip_chain_selector, created_at, updated_at) - VALUES (:forwarder_address, :off_ramp_address, :lookback_blocks, :poll_period, :run_timeout, :evm_chain_id, :ccip_chain_selector, NOW(), NOW()) - RETURNING id;` - if err := pg.PrepareQueryRowx(tx, sql, &specID, jb.LegacyGasStationSidecarSpec); err != nil { - return errors.Wrap(err, "failed to create LegacyGasStationSidecar spec") + specID, err := tx.insertLegacyGasStationSidecarSpec(ctx, jb.LegacyGasStationSidecarSpec) + if err != nil { + return fmt.Errorf("failed to create LegacyGasStationSidecarSpec for jobSpec: %w", err) } jb.LegacyGasStationSidecarSpecID = &specID case Bootstrap: - var specID int32 - sql := `INSERT INTO bootstrap_specs (contract_id, feed_id, relay, relay_config, monitoring_endpoint, - blockchain_timeout, contract_config_tracker_poll_interval, - contract_config_confirmations, created_at, updated_at) - VALUES (:contract_id, :feed_id, :relay, :relay_config, :monitoring_endpoint, - :blockchain_timeout, :contract_config_tracker_poll_interval, - :contract_config_confirmations, NOW(), NOW()) - RETURNING id;` - if err := pg.PrepareQueryRowx(tx, sql, &specID, jb.BootstrapSpec); err != nil { - return errors.Wrap(err, "failed to create BootstrapSpec for jobSpec") + specID, err := tx.insertBootstrapSpec(ctx, jb.BootstrapSpec) + if err != nil { + return fmt.Errorf("failed to create BootstrapSpec for jobSpec: %w", err) } jb.BootstrapSpecID = &specID case Gateway: - var specID int32 - sql := `INSERT INTO gateway_specs (gateway_config, created_at, updated_at) - VALUES (:gateway_config, NOW(), NOW()) - RETURNING id;` - if err := pg.PrepareQueryRowx(tx, sql, &specID, jb.GatewaySpec); err != nil { - return errors.Wrap(err, "failed to create GatewaySpec for jobSpec") + specID, err := tx.insertGatewaySpec(ctx, jb.GatewaySpec) + if err != nil { + return fmt.Errorf("failed to create GatewaySpec for jobSpec: %w", err) } jb.GatewaySpecID = &specID case Stream: // 'stream' type has no associated spec, nothing to do here case Workflow: - var specID int32 sql := `INSERT INTO workflow_specs (workflow, workflow_id, workflow_owner, created_at, updated_at) VALUES (:workflow, :workflow_id, :workflow_owner, NOW(), NOW()) RETURNING id;` - if err := pg.PrepareQueryRowx(tx, sql, &specID, jb.WorkflowSpec); err != nil { + specID, err := tx.prepareQuerySpecID(ctx, sql, jb.WorkflowSpec) + if err != nil { return errors.Wrap(err, "failed to create WorkflowSpec for jobSpec") } jb.WorkflowSpecID = &specID @@ -455,14 +407,14 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { o.lggr.Panicf("Unsupported jb.Type: %v", jb.Type) } - pipelineSpecID, err := o.pipelineORM.CreateSpec(ctx, tx, p, jb.MaxTaskDuration) + pipelineSpecID, err := tx.pipelineORM.CreateSpec(ctx, p, jb.MaxTaskDuration) if err != nil { return errors.Wrap(err, "failed to create pipeline spec") } jb.PipelineSpecID = pipelineSpecID - err = o.InsertJob(jb, pg.WithQueryer(tx)) + err = tx.InsertJob(ctx, jb) jobID = jb.ID return errors.Wrap(err, "failed to insert job") }) @@ -470,7 +422,122 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { return errors.Wrap(err, "CreateJobFailed") } - return o.findJob(jb, "id", jobID, qopts...) + return o.findJob(ctx, jb, "id", jobID) +} + +func (o *orm) prepareQuerySpecID(ctx context.Context, sql string, arg any) (specID int32, err error) { + var stmt *sqlx.NamedStmt + stmt, err = o.ds.PrepareNamedContext(ctx, sql) + if err != nil { + return + } + defer stmt.Close() + err = stmt.QueryRowxContext(ctx, arg).Scan(&specID) + return +} + +func (o *orm) insertDirectRequestSpec(ctx context.Context, spec *DirectRequestSpec) (specID int32, err error) { + return o.prepareQuerySpecID(ctx, `INSERT INTO direct_request_specs (contract_address, min_incoming_confirmations, requesters, min_contract_payment, evm_chain_id, created_at, updated_at) + VALUES (:contract_address, :min_incoming_confirmations, :requesters, :min_contract_payment, :evm_chain_id, now(), now()) + RETURNING id;`, spec) +} + +func (o *orm) insertFluxMonitorSpec(ctx context.Context, spec *FluxMonitorSpec) (specID int32, err error) { + return o.prepareQuerySpecID(ctx, `INSERT INTO flux_monitor_specs (contract_address, threshold, absolute_threshold, poll_timer_period, poll_timer_disabled, idle_timer_period, idle_timer_disabled, + drumbeat_schedule, drumbeat_random_delay, drumbeat_enabled, min_payment, evm_chain_id, created_at, updated_at) + VALUES (:contract_address, :threshold, :absolute_threshold, :poll_timer_period, :poll_timer_disabled, :idle_timer_period, :idle_timer_disabled, + :drumbeat_schedule, :drumbeat_random_delay, :drumbeat_enabled, :min_payment, :evm_chain_id, NOW(), NOW()) + RETURNING id;`, spec) +} + +func (o *orm) insertOCROracleSpec(ctx context.Context, spec *OCROracleSpec) (specID int32, err error) { + return o.prepareQuerySpecID(ctx, `INSERT INTO ocr_oracle_specs (contract_address, p2pv2_bootstrappers, is_bootstrap_peer, encrypted_ocr_key_bundle_id, transmitter_address, + observation_timeout, blockchain_timeout, contract_config_tracker_subscribe_interval, contract_config_tracker_poll_interval, contract_config_confirmations, evm_chain_id, + created_at, updated_at, database_timeout, observation_grace_period, contract_transmitter_transmit_timeout) + VALUES (:contract_address, :p2pv2_bootstrappers, :is_bootstrap_peer, :encrypted_ocr_key_bundle_id, :transmitter_address, + :observation_timeout, :blockchain_timeout, :contract_config_tracker_subscribe_interval, :contract_config_tracker_poll_interval, :contract_config_confirmations, :evm_chain_id, + NOW(), NOW(), :database_timeout, :observation_grace_period, :contract_transmitter_transmit_timeout) + RETURNING id;`, spec) +} + +func (o *orm) insertOCR2OracleSpec(ctx context.Context, spec *OCR2OracleSpec) (specID int32, err error) { + return o.prepareQuerySpecID(ctx, `INSERT INTO ocr2_oracle_specs (contract_id, feed_id, relay, relay_config, plugin_type, plugin_config, p2pv2_bootstrappers, ocr_key_bundle_id, transmitter_id, + blockchain_timeout, contract_config_tracker_poll_interval, contract_config_confirmations, + created_at, updated_at) + VALUES (:contract_id, :feed_id, :relay, :relay_config, :plugin_type, :plugin_config, :p2pv2_bootstrappers, :ocr_key_bundle_id, :transmitter_id, + :blockchain_timeout, :contract_config_tracker_poll_interval, :contract_config_confirmations, + NOW(), NOW()) + RETURNING id;`, spec) +} + +func (o *orm) insertKeeperSpec(ctx context.Context, spec *KeeperSpec) (specID int32, err error) { + return o.prepareQuerySpecID(ctx, `INSERT INTO keeper_specs (contract_address, from_address, evm_chain_id, created_at, updated_at) + VALUES (:contract_address, :from_address, :evm_chain_id, NOW(), NOW()) + RETURNING id;`, spec) +} + +func (o *orm) insertCronSpec(ctx context.Context, spec *CronSpec) (specID int32, err error) { + return o.prepareQuerySpecID(ctx, `INSERT INTO cron_specs (cron_schedule, created_at, updated_at) + VALUES (:cron_schedule, NOW(), NOW()) + RETURNING id;`, spec) +} + +func (o *orm) insertVRFSpec(ctx context.Context, spec *VRFSpec) (specID int32, err error) { + return o.prepareQuerySpecID(ctx, `INSERT INTO vrf_specs ( + coordinator_address, public_key, min_incoming_confirmations, + evm_chain_id, from_addresses, poll_period, requested_confs_delay, + request_timeout, chunk_size, batch_coordinator_address, batch_fulfillment_enabled, + batch_fulfillment_gas_multiplier, backoff_initial_delay, backoff_max_delay, gas_lane_price, + vrf_owner_address, custom_reverts_pipeline_enabled, + created_at, updated_at) + VALUES ( + :coordinator_address, :public_key, :min_incoming_confirmations, + :evm_chain_id, :from_addresses, :poll_period, :requested_confs_delay, + :request_timeout, :chunk_size, :batch_coordinator_address, :batch_fulfillment_enabled, + :batch_fulfillment_gas_multiplier, :backoff_initial_delay, :backoff_max_delay, :gas_lane_price, + :vrf_owner_address, :custom_reverts_pipeline_enabled, + NOW(), NOW()) + RETURNING id;`, toVRFSpecRow(spec)) +} + +func (o *orm) insertBlockhashStoreSpec(ctx context.Context, spec *BlockhashStoreSpec) (specID int32, err error) { + return o.prepareQuerySpecID(ctx, `INSERT INTO blockhash_store_specs (coordinator_v1_address, coordinator_v2_address, coordinator_v2_plus_address, trusted_blockhash_store_address, trusted_blockhash_store_batch_size, wait_blocks, lookback_blocks, heartbeat_period, blockhash_store_address, poll_period, run_timeout, evm_chain_id, from_addresses, created_at, updated_at) + VALUES (:coordinator_v1_address, :coordinator_v2_address, :coordinator_v2_plus_address, :trusted_blockhash_store_address, :trusted_blockhash_store_batch_size, :wait_blocks, :lookback_blocks, :heartbeat_period, :blockhash_store_address, :poll_period, :run_timeout, :evm_chain_id, :from_addresses, NOW(), NOW()) + RETURNING id;`, toBlockhashStoreSpecRow(spec)) +} + +func (o *orm) insertBlockHeaderFeederSpec(ctx context.Context, spec *BlockHeaderFeederSpec) (specID int32, err error) { + return o.prepareQuerySpecID(ctx, `INSERT INTO block_header_feeder_specs (coordinator_v1_address, coordinator_v2_address, coordinator_v2_plus_address, wait_blocks, lookback_blocks, blockhash_store_address, batch_blockhash_store_address, poll_period, run_timeout, evm_chain_id, from_addresses, get_blockhashes_batch_size, store_blockhashes_batch_size, created_at, updated_at) + VALUES (:coordinator_v1_address, :coordinator_v2_address, :coordinator_v2_plus_address, :wait_blocks, :lookback_blocks, :blockhash_store_address, :batch_blockhash_store_address, :poll_period, :run_timeout, :evm_chain_id, :from_addresses, :get_blockhashes_batch_size, :store_blockhashes_batch_size, NOW(), NOW()) + RETURNING id;`, toBlockHeaderFeederSpecRow(spec)) +} + +func (o *orm) insertLegacyGasStationServerSpec(ctx context.Context, spec *LegacyGasStationServerSpec) (specID int32, err error) { + return o.prepareQuerySpecID(ctx, `INSERT INTO legacy_gas_station_server_specs (forwarder_address, evm_chain_id, ccip_chain_selector, from_addresses, created_at, updated_at) + VALUES (:forwarder_address, :evm_chain_id, :ccip_chain_selector, :from_addresses, NOW(), NOW()) + RETURNING id;`, toLegacyGasStationServerSpecRow(spec)) +} + +func (o *orm) insertLegacyGasStationSidecarSpec(ctx context.Context, spec *LegacyGasStationSidecarSpec) (specID int32, err error) { + return o.prepareQuerySpecID(ctx, `INSERT INTO legacy_gas_station_sidecar_specs (forwarder_address, off_ramp_address, lookback_blocks, poll_period, run_timeout, evm_chain_id, ccip_chain_selector, created_at, updated_at) + VALUES (:forwarder_address, :off_ramp_address, :lookback_blocks, :poll_period, :run_timeout, :evm_chain_id, :ccip_chain_selector, NOW(), NOW()) + RETURNING id;`, spec) +} + +func (o *orm) insertBootstrapSpec(ctx context.Context, spec *BootstrapSpec) (specID int32, err error) { + return o.prepareQuerySpecID(ctx, `INSERT INTO bootstrap_specs (contract_id, feed_id, relay, relay_config, monitoring_endpoint, + blockchain_timeout, contract_config_tracker_poll_interval, + contract_config_confirmations, created_at, updated_at) + VALUES (:contract_id, :feed_id, :relay, :relay_config, :monitoring_endpoint, + :blockchain_timeout, :contract_config_tracker_poll_interval, + :contract_config_confirmations, NOW(), NOW()) + RETURNING id;`, spec) +} + +func (o *orm) insertGatewaySpec(ctx context.Context, spec *GatewaySpec) (specID int32, err error) { + return o.prepareQuerySpecID(ctx, `INSERT INTO gateway_specs (gateway_config, created_at, updated_at) + VALUES (:gateway_config, NOW(), NOW()) + RETURNING id;`, spec) } // ValidateKeyStoreMatch confirms that the key has a valid match in the keystore @@ -531,17 +598,18 @@ func areSendingKeysDefined(ctx context.Context, jb *Job, keystore keystore.Maste return false, nil } -func (o *orm) InsertWebhookSpec(webhookSpec *WebhookSpec, qopts ...pg.QOpt) error { - q := o.q.WithOpts(qopts...) - query := `INSERT INTO webhook_specs (created_at, updated_at) +func (o *orm) InsertWebhookSpec(ctx context.Context, webhookSpec *WebhookSpec) error { + query, args, err := o.ds.BindNamed(`INSERT INTO webhook_specs (created_at, updated_at) VALUES (NOW(), NOW()) - RETURNING *;` - return q.GetNamed(query, webhookSpec, webhookSpec) + RETURNING *;`, webhookSpec) + if err != nil { + return fmt.Errorf("error binding arg: %w", err) + } + return o.ds.GetContext(ctx, webhookSpec, query, args...) } -func (o *orm) InsertJob(job *Job, qopts ...pg.QOpt) error { - q := o.q.WithOpts(qopts...) - return q.Transaction(func(querier pg.Queryer) error { +func (o *orm) InsertJob(ctx context.Context, job *Job) error { + return o.transact(ctx, false, func(tx *orm) error { var query string // if job has id, emplace otherwise insert with a new id. @@ -562,26 +630,30 @@ func (o *orm) InsertJob(job *Job, qopts ...pg.QOpt) error { :legacy_gas_station_server_spec_id, :legacy_gas_station_sidecar_spec_id, :workflow_spec_id, :external_job_id, :gas_limit, :forwarding_allowed, NOW()) RETURNING *;` } - err := q.GetNamed(query, job, job) + query, args, err := tx.ds.BindNamed(query, job) + if err != nil { + return fmt.Errorf("error binding arg: %w", err) + } + err = tx.ds.GetContext(ctx, job, query, args...) if err != nil { return err } // Always inserts the `job_pipeline_specs` record as primary, since this is the first one for the job. sqlStmt := `INSERT INTO job_pipeline_specs (job_id, pipeline_spec_id, is_primary) VALUES ($1, $2, true)` - _, err = q.Exec(sqlStmt, job.ID, job.PipelineSpecID) + _, err = tx.ds.ExecContext(ctx, sqlStmt, job.ID, job.PipelineSpecID) return errors.Wrap(err, "failed to insert job_pipeline_specs relationship") }) } // DeleteJob removes a job -func (o *orm) DeleteJob(id int32, qopts ...pg.QOpt) error { +func (o *orm) DeleteJob(ctx context.Context, id int32) error { o.lggr.Debugw("Deleting job", "jobID", id) // Added a 1-minute timeout to this query since this can take a long time as data increases. // This was added specifically due to an issue with a database that had a million of pipeline_runs and pipeline_task_runs // and this query was taking ~40secs. - qopts = append(qopts, pg.WithLongQueryTimeout()) - q := o.q.WithOpts(qopts...) + ctx, cancel := context.WithTimeout(sqlutil.WithoutDefaultTimeout(ctx), time.Minute) + defer cancel() query := ` WITH deleted_jobs AS ( DELETE FROM jobs WHERE id = $1 RETURNING @@ -643,8 +715,7 @@ func (o *orm) DeleteJob(id int32, qopts ...pg.QOpt) error { DELETE FROM job_pipeline_specs WHERE job_id IN (SELECT id FROM deleted_jobs) RETURNING pipeline_spec_id ) DELETE FROM pipeline_specs WHERE id IN (SELECT pipeline_spec_id FROM deleted_job_pipeline_specs)` - res, cancel, err := q.ExecQIter(query, id) - defer cancel() + res, err := o.ds.ExecContext(ctx, query, id) if err != nil { return errors.Wrap(err, "DeleteJob failed to delete job") } @@ -659,14 +730,13 @@ func (o *orm) DeleteJob(id int32, qopts ...pg.QOpt) error { return nil } -func (o *orm) RecordError(jobID int32, description string, qopts ...pg.QOpt) error { - q := o.q.WithOpts(qopts...) +func (o *orm) RecordError(ctx context.Context, jobID int32, description string) error { sql := `INSERT INTO job_spec_errors (job_id, description, occurrences, created_at, updated_at) VALUES ($1, $2, 1, $3, $3) ON CONFLICT (job_id, description) DO UPDATE SET occurrences = job_spec_errors.occurrences + 1, updated_at = excluded.updated_at` - err := q.ExecQ(sql, jobID, description, time.Now()) + _, err := o.ds.ExecContext(ctx, sql, jobID, description, time.Now()) // Noop if the job has been deleted. var pqErr *pgconn.PgError ok := errors.As(err, &pqErr) @@ -677,15 +747,13 @@ func (o *orm) RecordError(jobID int32, description string, qopts ...pg.QOpt) err } return err } -func (o *orm) TryRecordError(jobID int32, description string, qopts ...pg.QOpt) { - err := o.RecordError(jobID, description, qopts...) +func (o *orm) TryRecordError(ctx context.Context, jobID int32, description string) { + err := o.RecordError(ctx, jobID, description) o.lggr.ErrorIf(err, fmt.Sprintf("Error creating SpecError %v", description)) } func (o *orm) DismissError(ctx context.Context, ID int64) error { - q := o.q.WithOpts(pg.WithParentCtx(ctx)) - res, cancel, err := q.ExecQIter("DELETE FROM job_spec_errors WHERE id = $1", ID) - defer cancel() + res, err := o.ds.ExecContext(ctx, "DELETE FROM job_spec_errors WHERE id = $1", ID) if err != nil { return errors.Wrap(err, "failed to dismiss error") } @@ -699,35 +767,35 @@ func (o *orm) DismissError(ctx context.Context, ID int64) error { return nil } -func (o *orm) FindSpecError(id int64, qopts ...pg.QOpt) (SpecError, error) { +func (o *orm) FindSpecError(ctx context.Context, id int64) (SpecError, error) { stmt := `SELECT * FROM job_spec_errors WHERE id = $1;` specErr := new(SpecError) - err := o.q.WithOpts(qopts...).Get(specErr, stmt, id) + err := o.ds.GetContext(ctx, specErr, stmt, id) return *specErr, errors.Wrap(err, "FindSpecError failed") } -func (o *orm) FindJobs(offset, limit int) (jobs []Job, count int, err error) { - err = o.q.Transaction(func(tx pg.Queryer) error { +func (o *orm) FindJobs(ctx context.Context, offset, limit int) (jobs []Job, count int, err error) { + err = o.transact(ctx, false, func(tx *orm) error { sql := `SELECT count(*) FROM jobs;` - err = tx.QueryRowx(sql).Scan(&count) + err = tx.ds.QueryRowxContext(ctx, sql).Scan(&count) if err != nil { - return err + return fmt.Errorf("failed to query jobs count: %w", err) } sql = `SELECT jobs.*, job_pipeline_specs.pipeline_spec_id as pipeline_spec_id FROM jobs JOIN job_pipeline_specs ON (jobs.id = job_pipeline_specs.job_id) ORDER BY jobs.created_at DESC, jobs.id DESC OFFSET $1 LIMIT $2;` - err = tx.Select(&jobs, sql, offset, limit) + err = tx.ds.SelectContext(ctx, &jobs, sql, offset, limit) if err != nil { - return err + return fmt.Errorf("failed to select jobs: %w", err) } - err = LoadAllJobsTypes(tx, jobs) + err = tx.loadAllJobsTypes(ctx, jobs) if err != nil { - return err + return fmt.Errorf("failed to load job types: %w", err) } return nil @@ -820,32 +888,30 @@ func LoadConfigVarsOCR(evmOcrCfg evmconfig.OCR, ocrCfg OCRConfig, os OCROracleSp } func (o *orm) FindJobTx(ctx context.Context, id int32) (Job, error) { - ctx, cancel := context.WithTimeout(ctx, o.cfg.DefaultQueryTimeout()) - defer cancel() return o.FindJob(ctx, id) } // FindJob returns job by ID, with all relations preloaded func (o *orm) FindJob(ctx context.Context, id int32) (jb Job, err error) { - err = o.findJob(&jb, "id", id, pg.WithParentCtx(ctx)) + err = o.findJob(ctx, &jb, "id", id) return } // FindJobWithoutSpecErrors returns a job by ID, without loading Spec Errors preloaded -func (o *orm) FindJobWithoutSpecErrors(id int32) (jb Job, err error) { - err = o.q.Transaction(func(tx pg.Queryer) error { +func (o *orm) FindJobWithoutSpecErrors(ctx context.Context, id int32) (jb Job, err error) { + err = o.transact(ctx, true, func(tx *orm) error { stmt := "SELECT jobs.*, job_pipeline_specs.pipeline_spec_id as pipeline_spec_id FROM jobs JOIN job_pipeline_specs ON (jobs.id = job_pipeline_specs.job_id) WHERE jobs.id = $1 LIMIT 1" - err = tx.Get(&jb, stmt, id) + err = tx.ds.GetContext(ctx, &jb, stmt, id) if err != nil { return errors.Wrap(err, "failed to load job") } - if err = LoadAllJobTypes(tx, &jb); err != nil { + if err = tx.loadAllJobTypes(ctx, &jb); err != nil { return errors.Wrap(err, "failed to load job types") } return nil - }, pg.OptReadOnlyTx()) + }) if err != nil { return jb, errors.Wrap(err, "FindJobWithoutSpecErrors failed") } @@ -854,87 +920,76 @@ func (o *orm) FindJobWithoutSpecErrors(id int32) (jb Job, err error) { } // FindSpecErrorsByJobIDs returns all jobs spec errors by jobs IDs -func (o *orm) FindSpecErrorsByJobIDs(ids []int32, qopts ...pg.QOpt) ([]SpecError, error) { +func (o *orm) FindSpecErrorsByJobIDs(ctx context.Context, ids []int32) ([]SpecError, error) { stmt := `SELECT * FROM job_spec_errors WHERE job_id = ANY($1);` var specErrs []SpecError - err := o.q.WithOpts(qopts...).Select(&specErrs, stmt, ids) + err := o.ds.SelectContext(ctx, &specErrs, stmt, ids) return specErrs, errors.Wrap(err, "FindSpecErrorsByJobIDs failed") } -func (o *orm) FindJobByExternalJobID(externalJobID uuid.UUID, qopts ...pg.QOpt) (jb Job, err error) { - err = o.findJob(&jb, "external_job_id", externalJobID, qopts...) +func (o *orm) FindJobByExternalJobID(ctx context.Context, externalJobID uuid.UUID) (jb Job, err error) { + err = o.findJob(ctx, &jb, "external_job_id", externalJobID) return } // FindJobIDByAddress - finds a job id by contract address. Currently only OCR and FM jobs are supported -func (o *orm) FindJobIDByAddress(address evmtypes.EIP55Address, evmChainID *big.Big, qopts ...pg.QOpt) (jobID int32, err error) { - q := o.q.WithOpts(qopts...) - err = q.Transaction(func(tx pg.Queryer) error { - stmt := ` +func (o *orm) FindJobIDByAddress(ctx context.Context, address evmtypes.EIP55Address, evmChainID *big.Big) (jobID int32, err error) { + stmt := ` SELECT jobs.id FROM jobs LEFT JOIN ocr_oracle_specs ocrspec on ocrspec.contract_address = $1 AND (ocrspec.evm_chain_id = $2 OR ocrspec.evm_chain_id IS NULL) AND ocrspec.id = jobs.ocr_oracle_spec_id LEFT JOIN flux_monitor_specs fmspec on fmspec.contract_address = $1 AND (fmspec.evm_chain_id = $2 OR fmspec.evm_chain_id IS NULL) AND fmspec.id = jobs.flux_monitor_spec_id WHERE ocrspec.id IS NOT NULL OR fmspec.id IS NOT NULL ` - err = tx.Get(&jobID, stmt, address, evmChainID) - + err = o.ds.GetContext(ctx, &jobID, stmt, address, evmChainID) + if err != nil { if !errors.Is(err, sql.ErrNoRows) { - if err != nil { - return errors.Wrap(err, "error searching for job by contract address") - } - return nil + err = errors.Wrap(err, "error searching for job by contract address") } + err = errors.Wrap(err, "FindJobIDByAddress failed") + return + } - return err - }) - - return jobID, errors.Wrap(err, "FindJobIDByAddress failed") + return } -func (o *orm) FindOCR2JobIDByAddress(contractID string, feedID *common.Hash, qopts ...pg.QOpt) (jobID int32, err error) { - q := o.q.WithOpts(qopts...) - err = q.Transaction(func(tx pg.Queryer) error { - // NOTE: We want to explicitly match on NULL feed_id hence usage of `IS - // NOT DISTINCT FROM` instead of `=` - stmt := ` +func (o *orm) FindOCR2JobIDByAddress(ctx context.Context, contractID string, feedID *common.Hash) (jobID int32, err error) { + // NOTE: We want to explicitly match on NULL feed_id hence usage of `IS + // NOT DISTINCT FROM` instead of `=` + stmt := ` SELECT jobs.id FROM jobs LEFT JOIN ocr2_oracle_specs ocr2spec on ocr2spec.contract_id = $1 AND ocr2spec.feed_id IS NOT DISTINCT FROM $2 AND ocr2spec.id = jobs.ocr2_oracle_spec_id LEFT JOIN bootstrap_specs bs on bs.contract_id = $1 AND bs.feed_id IS NOT DISTINCT FROM $2 AND bs.id = jobs.bootstrap_spec_id WHERE ocr2spec.id IS NOT NULL OR bs.id IS NOT NULL ` - err = tx.Get(&jobID, stmt, contractID, feedID) - + err = o.ds.GetContext(ctx, &jobID, stmt, contractID, feedID) + if err != nil { if !errors.Is(err, sql.ErrNoRows) { - if err != nil { - return errors.Wrapf(err, "error searching for job by contract id=%s and feed id=%s", contractID, feedID) - } - return nil + err = errors.Wrapf(err, "error searching for job by contract id=%s and feed id=%s", contractID, feedID) } + err = errors.Wrap(err, "FindOCR2JobIDByAddress failed") + return + } - return err - }) - - return jobID, errors.Wrap(err, "FindOCR2JobIDByAddress failed") + return } -func (o *orm) findJob(jb *Job, col string, arg interface{}, qopts ...pg.QOpt) error { - q := o.q.WithOpts(qopts...) - err := q.Transaction(func(tx pg.Queryer) error { +func (o *orm) findJob(ctx context.Context, jb *Job, col string, arg interface{}) error { + err := o.transact(ctx, false, func(tx *orm) error { sql := fmt.Sprintf(`SELECT jobs.*, job_pipeline_specs.pipeline_spec_id FROM jobs JOIN job_pipeline_specs ON (jobs.id = job_pipeline_specs.job_id) WHERE jobs.%s = $1 AND job_pipeline_specs.is_primary = true LIMIT 1`, col) - err := tx.Get(jb, sql, arg) + err := tx.ds.GetContext(ctx, jb, sql, arg) if err != nil { return errors.Wrap(err, "failed to load job") } - if err = LoadAllJobTypes(tx, jb); err != nil { + if err = tx.loadAllJobTypes(ctx, jb); err != nil { return err } - return loadJobSpecErrors(tx, jb) + return tx.loadJobSpecErrors(ctx, jb) }) if err != nil { return errors.Wrap(err, "findJob failed") @@ -942,62 +997,62 @@ func (o *orm) findJob(jb *Job, col string, arg interface{}, qopts ...pg.QOpt) er return nil } -func (o *orm) FindJobIDsWithBridge(name string) (jids []int32, err error) { - err = o.q.Transaction(func(tx pg.Queryer) error { - query := `SELECT +func (o *orm) FindJobIDsWithBridge(ctx context.Context, name string) (jids []int32, err error) { + query := `SELECT jobs.id, pipeline_specs.dot_dag_source FROM jobs JOIN job_pipeline_specs ON job_pipeline_specs.job_id = jobs.id JOIN pipeline_specs ON pipeline_specs.id = job_pipeline_specs.pipeline_spec_id WHERE pipeline_specs.dot_dag_source ILIKE '%' || $1 || '%' ORDER BY id` + var rows *sqlx.Rows + rows, err = o.ds.QueryxContext(ctx, query, name) + if err != nil { + return + } + defer rows.Close() + var ids []int32 + var sources []string + for rows.Next() { + var id int32 + var source string + if err = rows.Scan(&id, &source); err != nil { + return + } + ids = append(jids, id) + sources = append(sources, source) + } + if err = rows.Err(); err != nil { + return + } - var rows *sqlx.Rows - rows, err = tx.Queryx(query, name) + for i, id := range ids { + var p *pipeline.Pipeline + p, err = pipeline.Parse(sources[i]) if err != nil { - return err + return nil, errors.Wrapf(err, "could not parse dag for job %d", id) } - defer rows.Close() - var ids []int32 - var sources []string - for rows.Next() { - var id int32 - var source string - if err = rows.Scan(&id, &source); err != nil { - return err - } - ids = append(jids, id) - sources = append(sources, source) - } - - for i, id := range ids { - var p *pipeline.Pipeline - p, err = pipeline.Parse(sources[i]) - if err != nil { - return errors.Wrapf(err, "could not parse dag for job %d", id) - } - for _, task := range p.Tasks { - if task.Type() == pipeline.TaskTypeBridge { - if task.(*pipeline.BridgeTask).Name == name { - jids = append(jids, id) - } + for _, task := range p.Tasks { + if task.Type() == pipeline.TaskTypeBridge { + if task.(*pipeline.BridgeTask).Name == name { + jids = append(jids, id) } } } - return nil - }) - return jids, errors.Wrap(err, "FindJobIDsWithBridge failed") + } + + return } // PipelineRunsByJobsIDs returns pipeline runs for multiple jobs, not preloading data -func (o *orm) PipelineRunsByJobsIDs(ids []int32) (runs []pipeline.Run, err error) { - err = o.q.Transaction(func(tx pg.Queryer) error { +func (o *orm) PipelineRunsByJobsIDs(ctx context.Context, ids []int32) (runs []pipeline.Run, err error) { + err = o.transact(ctx, false, func(tx *orm) error { stmt := `SELECT pipeline_runs.* FROM pipeline_runs INNER JOIN job_pipeline_specs ON pipeline_runs.pipeline_spec_id = job_pipeline_specs.pipeline_spec_id WHERE jobs.id = ANY($1) ORDER BY pipeline_runs.created_at DESC, pipeline_runs.id DESC;` - if err = tx.Select(&runs, stmt, ids); err != nil { + if err = tx.ds.SelectContext(ctx, &runs, stmt, ids); err != nil { return errors.Wrap(err, "error loading runs") } - runs, err = o.loadPipelineRunsRelations(runs, tx) + runs, err = tx.loadPipelineRunsRelations(ctx, runs) return err }) @@ -1005,11 +1060,11 @@ func (o *orm) PipelineRunsByJobsIDs(ids []int32) (runs []pipeline.Run, err error return runs, errors.Wrap(err, "PipelineRunsByJobsIDs failed") } -func (o *orm) loadPipelineRunIDs(jobID *int32, offset, limit int, tx pg.Queryer) (ids []int64, err error) { +func (o *orm) loadPipelineRunIDs(ctx context.Context, jobID *int32, offset, limit int) (ids []int64, err error) { lggr := logger.Sugared(o.lggr) var res sql.NullInt64 - if err = tx.Get(&res, "SELECT MAX(id) FROM pipeline_runs"); err != nil { + if err = o.ds.GetContext(ctx, &res, "SELECT MAX(id) FROM pipeline_runs"); err != nil { err = errors.Wrap(err, "error while loading runs") return } else if !res.Valid { @@ -1037,7 +1092,7 @@ func (o *orm) loadPipelineRunIDs(jobID *int32, offset, limit int, tx pg.Queryer) for n := int64(1000); maxID > 0 && len(ids) < limit; n *= 2 { var batch []int64 minID := maxID - n - if err = tx.Select(&batch, stmt, offset, limit-len(ids), minID, maxID); err != nil { + if err = o.ds.SelectContext(ctx, &batch, stmt, offset, limit-len(ids), minID, maxID); err != nil { err = errors.Wrap(err, "error loading runs") return } @@ -1050,7 +1105,7 @@ func (o *orm) loadPipelineRunIDs(jobID *int32, offset, limit int, tx pg.Queryer) var skipped int // If no rows were returned, we need to know whether there were any ids skipped // in this batch due to the offset, and reduce it for the next batch - err = tx.Get(&skipped, + err = o.ds.GetContext(ctx, &skipped, fmt.Sprintf( `SELECT COUNT(p.id) FROM pipeline_runs AS p %s p.id >= $1 AND p.id <= $2`, filter, ), minID, maxID, @@ -1074,63 +1129,60 @@ func (o *orm) loadPipelineRunIDs(jobID *int32, offset, limit int, tx pg.Queryer) return } -func (o *orm) FindTaskResultByRunIDAndTaskName(runID int64, taskName string, qopts ...pg.QOpt) (result []byte, err error) { - q := o.q.WithOpts(qopts...) - err = q.Transaction(func(tx pg.Queryer) error { - stmt := fmt.Sprintf("SELECT * FROM pipeline_task_runs WHERE pipeline_run_id = $1 AND dot_id = '%s';", taskName) +func (o *orm) FindTaskResultByRunIDAndTaskName(ctx context.Context, runID int64, taskName string) (result []byte, err error) { + stmt := fmt.Sprintf("SELECT * FROM pipeline_task_runs WHERE pipeline_run_id = $1 AND dot_id = '%s';", taskName) - var taskRuns []pipeline.TaskRun - if errB := tx.Select(&taskRuns, stmt, runID); errB != nil { - return errB - } - if len(taskRuns) == 0 { - return fmt.Errorf("can't find task run with id: %v, taskName: %v", runID, taskName) - } - if len(taskRuns) > 1 { - o.lggr.Errorf("found multiple task runs with id: %v, taskName: %v. Using the first one.", runID, taskName) - } - taskRun := taskRuns[0] - if !taskRun.Error.IsZero() { - return errors.New(taskRun.Error.ValueOrZero()) - } - resBytes, errB := taskRun.Output.MarshalJSON() - if errB != nil { - return errB - } - result = resBytes - return nil - }) - return result, errors.Wrap(err, "failed") + var taskRuns []pipeline.TaskRun + if errB := o.ds.SelectContext(ctx, &taskRuns, stmt, runID); errB != nil { + return nil, errB + } + if len(taskRuns) == 0 { + return nil, fmt.Errorf("can't find task run with id: %v, taskName: %v", runID, taskName) + } + if len(taskRuns) > 1 { + o.lggr.Errorf("found multiple task runs with id: %v, taskName: %v. Using the first one.", runID, taskName) + } + taskRun := taskRuns[0] + if !taskRun.Error.IsZero() { + return nil, errors.New(taskRun.Error.ValueOrZero()) + } + resBytes, errB := taskRun.Output.MarshalJSON() + if errB != nil { + return + } + result = resBytes + + return } // FindPipelineRunIDsByJobID fetches the ids of pipeline runs for a job. -func (o *orm) FindPipelineRunIDsByJobID(jobID int32, offset, limit int) (ids []int64, err error) { - err = o.q.Transaction(func(tx pg.Queryer) error { - ids, err = o.loadPipelineRunIDs(&jobID, offset, limit, tx) +func (o *orm) FindPipelineRunIDsByJobID(ctx context.Context, jobID int32, offset, limit int) (ids []int64, err error) { + err = o.transact(ctx, false, func(tx *orm) error { + ids, err = tx.loadPipelineRunIDs(ctx, &jobID, offset, limit) return err }) return ids, errors.Wrap(err, "FindPipelineRunIDsByJobID failed") } -func (o *orm) loadPipelineRunsByID(ids []int64, tx pg.Queryer) (runs []pipeline.Run, err error) { +func (o *orm) loadPipelineRunsByID(ctx context.Context, ids []int64) (runs []pipeline.Run, err error) { stmt := ` SELECT pipeline_runs.* FROM pipeline_runs WHERE id = ANY($1) ORDER BY created_at DESC, id DESC ` - if err = tx.Select(&runs, stmt, ids); err != nil { + if err = o.ds.SelectContext(ctx, &runs, stmt, ids); err != nil { err = errors.Wrap(err, "error loading runs") return } - return o.loadPipelineRunsRelations(runs, tx) + return o.loadPipelineRunsRelations(ctx, runs) } // FindPipelineRunsByIDs returns pipeline runs with the ids. -func (o *orm) FindPipelineRunsByIDs(ids []int64) (runs []pipeline.Run, err error) { - err = o.q.Transaction(func(tx pg.Queryer) error { - runs, err = o.loadPipelineRunsByID(ids, tx) +func (o *orm) FindPipelineRunsByIDs(ctx context.Context, ids []int64) (runs []pipeline.Run, err error) { + err = o.transact(ctx, false, func(tx *orm) error { + runs, err = tx.loadPipelineRunsByID(ctx, ids) return err }) @@ -1138,21 +1190,21 @@ func (o *orm) FindPipelineRunsByIDs(ids []int64) (runs []pipeline.Run, err error } // FindPipelineRunByID returns pipeline run with the id. -func (o *orm) FindPipelineRunByID(id int64) (pipeline.Run, error) { +func (o *orm) FindPipelineRunByID(ctx context.Context, id int64) (pipeline.Run, error) { var run pipeline.Run - err := o.q.Transaction(func(tx pg.Queryer) error { + err := o.transact(ctx, false, func(tx *orm) error { stmt := ` SELECT pipeline_runs.* FROM pipeline_runs WHERE id = $1 ` - if err := tx.Get(&run, stmt, id); err != nil { + if err := tx.ds.GetContext(ctx, &run, stmt, id); err != nil { return errors.Wrap(err, "error loading run") } - runs, err := o.loadPipelineRunsRelations([]pipeline.Run{run}, tx) + runs, err := tx.loadPipelineRunsRelations(ctx, []pipeline.Run{run}) run = runs[0] @@ -1163,30 +1215,24 @@ WHERE id = $1 } // CountPipelineRunsByJobID returns the total number of pipeline runs for a job. -func (o *orm) CountPipelineRunsByJobID(jobID int32) (count int32, err error) { - err = o.q.Transaction(func(tx pg.Queryer) error { - stmt := "SELECT COUNT(*) FROM pipeline_runs JOIN job_pipeline_specs USING (pipeline_spec_id) WHERE job_pipeline_specs.job_id = $1" - if err = tx.Get(&count, stmt, jobID); err != nil { - return errors.Wrap(err, "error counting runs") - } - - return err - }) +func (o *orm) CountPipelineRunsByJobID(ctx context.Context, jobID int32) (count int32, err error) { + stmt := "SELECT COUNT(*) FROM pipeline_runs JOIN job_pipeline_specs USING (pipeline_spec_id) WHERE job_pipeline_specs.job_id = $1" + err = o.ds.GetContext(ctx, &count, stmt, jobID) return count, errors.Wrap(err, "CountPipelineRunsByJobID failed") } -func (o *orm) FindJobsByPipelineSpecIDs(ids []int32) ([]Job, error) { +func (o *orm) FindJobsByPipelineSpecIDs(ctx context.Context, ids []int32) ([]Job, error) { var jbs []Job - err := o.q.Transaction(func(tx pg.Queryer) error { + err := o.transact(ctx, false, func(tx *orm) error { stmt := `SELECT jobs.*, job_pipeline_specs.pipeline_spec_id FROM jobs JOIN job_pipeline_specs ON (jobs.id = job_pipeline_specs.job_id) WHERE job_pipeline_specs.pipeline_spec_id = ANY($1) ORDER BY jobs.id ASC ` - if err := tx.Select(&jbs, stmt, ids); err != nil { + if err := tx.ds.SelectContext(ctx, &jbs, stmt, ids); err != nil { return errors.Wrap(err, "error fetching jobs by pipeline spec IDs") } - err := LoadAllJobsTypes(tx, jbs) + err := tx.loadAllJobsTypes(ctx, jbs) if err != nil { return err } @@ -1199,20 +1245,20 @@ func (o *orm) FindJobsByPipelineSpecIDs(ids []int32) ([]Job, error) { // PipelineRuns returns pipeline runs for a job, with spec and taskruns loaded, latest first // If jobID is nil, returns all pipeline runs -func (o *orm) PipelineRuns(jobID *int32, offset, size int) (runs []pipeline.Run, count int, err error) { +func (o *orm) PipelineRuns(ctx context.Context, jobID *int32, offset, size int) (runs []pipeline.Run, count int, err error) { var filter string if jobID != nil { filter = fmt.Sprintf("JOIN job_pipeline_specs USING(pipeline_spec_id) WHERE job_pipeline_specs.job_id = %d", *jobID) } - err = o.q.Transaction(func(tx pg.Queryer) error { + err = o.transact(ctx, false, func(tx *orm) error { sql := fmt.Sprintf(`SELECT count(*) FROM pipeline_runs %s`, filter) - if err = tx.QueryRowx(sql).Scan(&count); err != nil { + if err = tx.ds.QueryRowxContext(ctx, sql).Scan(&count); err != nil { return errors.Wrap(err, "error counting runs") } var ids []int64 - ids, err = o.loadPipelineRunIDs(jobID, offset, size, tx) - runs, err = o.loadPipelineRunsByID(ids, tx) + ids, err = tx.loadPipelineRunIDs(ctx, jobID, offset, size) + runs, err = tx.loadPipelineRunsByID(ctx, ids) return err }) @@ -1220,7 +1266,7 @@ func (o *orm) PipelineRuns(jobID *int32, offset, size int) (runs []pipeline.Run, return runs, count, errors.Wrap(err, "PipelineRuns failed") } -func (o *orm) loadPipelineRunsRelations(runs []pipeline.Run, tx pg.Queryer) ([]pipeline.Run, error) { +func (o *orm) loadPipelineRunsRelations(ctx context.Context, runs []pipeline.Run) ([]pipeline.Run, error) { // Postload PipelineSpecs // TODO: We should pull this out into a generic preload function once go has generics specM := make(map[int32]pipeline.Spec) @@ -1235,7 +1281,7 @@ func (o *orm) loadPipelineRunsRelations(runs []pipeline.Run, tx pg.Queryer) ([]p } stmt := `SELECT pipeline_specs.*, job_pipeline_specs.job_id AS job_id FROM pipeline_specs JOIN job_pipeline_specs ON pipeline_specs.id = job_pipeline_specs.pipeline_spec_id WHERE pipeline_specs.id = ANY($1);` var specs []pipeline.Spec - if err := o.q.Select(&specs, stmt, specIDs); err != nil { + if err := o.ds.SelectContext(ctx, &specs, stmt, specIDs); err != nil { return nil, errors.Wrap(err, "error loading specs") } for _, spec := range specs { @@ -1254,7 +1300,7 @@ func (o *orm) loadPipelineRunsRelations(runs []pipeline.Run, tx pg.Queryer) ([]p } var taskRuns []pipeline.TaskRun stmt = `SELECT * FROM pipeline_task_runs WHERE pipeline_run_id = ANY($1) ORDER BY pipeline_run_id, created_at, id;` - if err := tx.Select(&taskRuns, stmt, runIDs); err != nil { + if err := o.ds.SelectContext(ctx, &taskRuns, stmt, runIDs); err != nil { return nil, errors.Wrap(err, "error loading pipeline_task_runs") } for _, taskRun := range taskRuns { @@ -1268,9 +1314,9 @@ func (o *orm) loadPipelineRunsRelations(runs []pipeline.Run, tx pg.Queryer) ([]p // NOTE: N+1 query, be careful of performance // This is not easily fixable without complicating the logic a lot, since we // only use it in the GUI it's probably acceptable -func LoadAllJobsTypes(tx pg.Queryer, jobs []Job) error { +func (o *orm) loadAllJobsTypes(ctx context.Context, jobs []Job) error { for i := range jobs { - err := LoadAllJobTypes(tx, &jobs[i]) + err := o.loadAllJobTypes(ctx, &jobs[i]) if err != nil { return err } @@ -1278,28 +1324,28 @@ func LoadAllJobsTypes(tx pg.Queryer, jobs []Job) error { return nil } -func LoadAllJobTypes(tx pg.Queryer, job *Job) error { +func (o *orm) loadAllJobTypes(ctx context.Context, job *Job) error { return multierr.Combine( - loadJobPipelineSpec(tx, job, &job.PipelineSpecID), - loadJobType(tx, job, "FluxMonitorSpec", "flux_monitor_specs", job.FluxMonitorSpecID), - loadJobType(tx, job, "DirectRequestSpec", "direct_request_specs", job.DirectRequestSpecID), - loadJobType(tx, job, "OCROracleSpec", "ocr_oracle_specs", job.OCROracleSpecID), - loadJobType(tx, job, "OCR2OracleSpec", "ocr2_oracle_specs", job.OCR2OracleSpecID), - loadJobType(tx, job, "KeeperSpec", "keeper_specs", job.KeeperSpecID), - loadJobType(tx, job, "CronSpec", "cron_specs", job.CronSpecID), - loadJobType(tx, job, "WebhookSpec", "webhook_specs", job.WebhookSpecID), - loadVRFJob(tx, job, job.VRFSpecID), - loadBlockhashStoreJob(tx, job, job.BlockhashStoreSpecID), - loadBlockHeaderFeederJob(tx, job, job.BlockHeaderFeederSpecID), - loadLegacyGasStationServerJob(tx, job, job.LegacyGasStationServerSpecID), - loadJobType(tx, job, "LegacyGasStationSidecarSpec", "legacy_gas_station_sidecar_specs", job.LegacyGasStationSidecarSpecID), - loadJobType(tx, job, "BootstrapSpec", "bootstrap_specs", job.BootstrapSpecID), - loadJobType(tx, job, "GatewaySpec", "gateway_specs", job.GatewaySpecID), - loadJobType(tx, job, "WorkflowSpec", "workflow_specs", job.WorkflowSpecID), + o.loadJobPipelineSpec(ctx, job, &job.PipelineSpecID), + o.loadJobType(ctx, job, "FluxMonitorSpec", "flux_monitor_specs", job.FluxMonitorSpecID), + o.loadJobType(ctx, job, "DirectRequestSpec", "direct_request_specs", job.DirectRequestSpecID), + o.loadJobType(ctx, job, "OCROracleSpec", "ocr_oracle_specs", job.OCROracleSpecID), + o.loadJobType(ctx, job, "OCR2OracleSpec", "ocr2_oracle_specs", job.OCR2OracleSpecID), + o.loadJobType(ctx, job, "KeeperSpec", "keeper_specs", job.KeeperSpecID), + o.loadJobType(ctx, job, "CronSpec", "cron_specs", job.CronSpecID), + o.loadJobType(ctx, job, "WebhookSpec", "webhook_specs", job.WebhookSpecID), + o.loadVRFJob(ctx, job, job.VRFSpecID), + o.loadBlockhashStoreJob(ctx, job, job.BlockhashStoreSpecID), + o.loadBlockHeaderFeederJob(ctx, job, job.BlockHeaderFeederSpecID), + o.loadLegacyGasStationServerJob(ctx, job, job.LegacyGasStationServerSpecID), + o.loadJobType(ctx, job, "LegacyGasStationSidecarSpec", "legacy_gas_station_sidecar_specs", job.LegacyGasStationSidecarSpecID), + o.loadJobType(ctx, job, "BootstrapSpec", "bootstrap_specs", job.BootstrapSpecID), + o.loadJobType(ctx, job, "GatewaySpec", "gateway_specs", job.GatewaySpecID), + o.loadJobType(ctx, job, "WorkflowSpec", "workflow_specs", job.WorkflowSpecID), ) } -func loadJobType(tx pg.Queryer, job *Job, field, table string, id *int32) error { +func (o *orm) loadJobType(ctx context.Context, job *Job, field, table string, id *int32) error { if id == nil { return nil } @@ -1312,7 +1358,7 @@ func loadJobType(tx pg.Queryer, job *Job, field, table string, id *int32) error destVal := reflect.New(t) dest := destVal.Interface() - err := tx.Get(dest, fmt.Sprintf(`SELECT * FROM %s WHERE id = $1`, table), *id) + err := o.ds.GetContext(ctx, dest, fmt.Sprintf(`SELECT * FROM %s WHERE id = $1`, table), *id) if err != nil { return errors.Wrapf(err, "failed to load job type %s with id %d", table, *id) @@ -1321,7 +1367,7 @@ func loadJobType(tx pg.Queryer, job *Job, field, table string, id *int32) error return nil } -func loadJobPipelineSpec(tx pg.Queryer, job *Job, id *int32) error { +func (o *orm) loadJobPipelineSpec(ctx context.Context, job *Job, id *int32) error { if id == nil { return nil } @@ -1329,7 +1375,8 @@ func loadJobPipelineSpec(tx pg.Queryer, job *Job, id *int32) error { if job.PipelineSpec != nil { pipelineSpecRow = job.PipelineSpec } - err := tx.Get( + err := o.ds.GetContext( + ctx, pipelineSpecRow, `SELECT pipeline_specs.*, job_pipeline_specs.job_id as job_id FROM pipeline_specs @@ -1344,13 +1391,13 @@ func loadJobPipelineSpec(tx pg.Queryer, job *Job, id *int32) error { return nil } -func loadVRFJob(tx pg.Queryer, job *Job, id *int32) error { +func (o *orm) loadVRFJob(ctx context.Context, job *Job, id *int32) error { if id == nil { return nil } var row vrfSpecRow - err := tx.Get(&row, `SELECT * FROM vrf_specs WHERE id = $1`, *id) + err := o.ds.GetContext(ctx, &row, `SELECT * FROM vrf_specs WHERE id = $1`, *id) if err != nil { return errors.Wrapf(err, `failed to load job type VRFSpec with id %d`, *id) } @@ -1383,13 +1430,13 @@ func (r vrfSpecRow) toVRFSpec() *VRFSpec { return r.VRFSpec } -func loadBlockhashStoreJob(tx pg.Queryer, job *Job, id *int32) error { +func (o *orm) loadBlockhashStoreJob(ctx context.Context, job *Job, id *int32) error { if id == nil { return nil } var row blockhashStoreSpecRow - err := tx.Get(&row, `SELECT * FROM blockhash_store_specs WHERE id = $1`, *id) + err := o.ds.GetContext(ctx, &row, `SELECT * FROM blockhash_store_specs WHERE id = $1`, *id) if err != nil { return errors.Wrapf(err, `failed to load job type BlockhashStoreSpec with id %d`, *id) } @@ -1422,13 +1469,13 @@ func (r blockhashStoreSpecRow) toBlockhashStoreSpec() *BlockhashStoreSpec { return r.BlockhashStoreSpec } -func loadBlockHeaderFeederJob(tx pg.Queryer, job *Job, id *int32) error { +func (o *orm) loadBlockHeaderFeederJob(ctx context.Context, job *Job, id *int32) error { if id == nil { return nil } var row blockHeaderFeederSpecRow - err := tx.Get(&row, `SELECT * FROM block_header_feeder_specs WHERE id = $1`, *id) + err := o.ds.GetContext(ctx, &row, `SELECT * FROM block_header_feeder_specs WHERE id = $1`, *id) if err != nil { return errors.Wrapf(err, `failed to load job type BlockHeaderFeederSpec with id %d`, *id) } @@ -1461,13 +1508,13 @@ func (r blockHeaderFeederSpecRow) toBlockHeaderFeederSpec() *BlockHeaderFeederSp return r.BlockHeaderFeederSpec } -func loadLegacyGasStationServerJob(tx pg.Queryer, job *Job, id *int32) error { +func (o *orm) loadLegacyGasStationServerJob(ctx context.Context, job *Job, id *int32) error { if id == nil { return nil } var row legacyGasStationServerSpecRow - err := tx.Get(&row, `SELECT * FROM legacy_gas_station_server_specs WHERE id = $1`, *id) + err := o.ds.GetContext(ctx, &row, `SELECT * FROM legacy_gas_station_server_specs WHERE id = $1`, *id) if err != nil { return errors.Wrapf(err, `failed to load job type LegacyGasStationServerSpec with id %d`, *id) } @@ -1500,6 +1547,6 @@ func (r legacyGasStationServerSpecRow) toLegacyGasStationServerSpec() *LegacyGas return r.LegacyGasStationServerSpec } -func loadJobSpecErrors(tx pg.Queryer, jb *Job) error { - return errors.Wrapf(tx.Select(&jb.JobSpecErrors, `SELECT * FROM job_spec_errors WHERE job_id = $1`, jb.ID), "failed to load job spec errors for job %d", jb.ID) +func (o *orm) loadJobSpecErrors(ctx context.Context, jb *Job) error { + return errors.Wrapf(o.ds.SelectContext(ctx, &jb.JobSpecErrors, `SELECT * FROM job_spec_errors WHERE job_id = $1`, jb.ID), "failed to load job spec errors for job %d", jb.ID) } diff --git a/core/services/job/orm_test.go b/core/services/job/orm_test.go index fb0e846b9d2..11f3e94f2d4 100644 --- a/core/services/job/orm_test.go +++ b/core/services/job/orm_test.go @@ -6,8 +6,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/jmoiron/sqlx" - + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/bridges" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/evmtest" @@ -16,13 +15,12 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/chainlink" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/store/models" ) -func NewTestORM(t *testing.T, db *sqlx.DB, pipelineORM pipeline.ORM, bridgeORM bridges.ORM, keyStore keystore.Master, cfg pg.QConfig) job.ORM { - o := job.NewORM(db, pipelineORM, bridgeORM, keyStore, logger.TestLogger(t), cfg) +func NewTestORM(t *testing.T, ds sqlutil.DataSource, pipelineORM pipeline.ORM, bridgeORM bridges.ORM, keyStore keystore.Master) job.ORM { + o := job.NewORM(ds, pipelineORM, bridgeORM, keyStore, logger.TestLogger(t)) t.Cleanup(func() { assert.NoError(t, o.Close()) }) return o } diff --git a/core/services/job/runner_integration_test.go b/core/services/job/runner_integration_test.go index 26a78a8624e..cdfe39dd17f 100644 --- a/core/services/job/runner_integration_test.go +++ b/core/services/job/runner_integration_test.go @@ -89,7 +89,7 @@ func TestRunner(t *testing.T) { c := clhttptest.NewTestLocalOnlyHTTPClient() runner := pipeline.NewRunner(pipelineORM, btORM, config.JobPipeline(), config.WebServer(), legacyChains, nil, nil, logger.TestLogger(t), c, c) - jobORM := NewTestORM(t, db, pipelineORM, btORM, keyStore, config.Database()) + jobORM := NewTestORM(t, db, pipelineORM, btORM, keyStore) t.Cleanup(func() { assert.NoError(t, jobORM.Close()) }) _, placeHolderAddress := cltest.MustInsertRandomKey(t, keyStore.Eth()) @@ -121,8 +121,9 @@ func TestRunner(t *testing.T) { // Need a job in order to create a run jb := MakeVoterTurnoutOCRJobSpecWithHTTPURL(t, transmitterAddress, httpURL, bridgeVT.Name.String(), bridgeER.Name.String()) - require.NoError(t, jobORM.CreateJob(jb)) + require.NoError(t, jobORM.CreateJob(testutils.Context(t), jb)) require.NotNil(t, jb.PipelineSpec) + require.NotZero(t, jb.PipelineSpec.JobID) m, err := bridges.MarshalBridgeMetaData(big.NewInt(10), big.NewInt(100)) require.NoError(t, err) @@ -169,6 +170,7 @@ func TestRunner(t *testing.T) { }) t.Run("must delete job before deleting bridge", func(t *testing.T) { + ctx := testutils.Context(t) _, bridge := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) jb := makeOCRJobSpecFromToml(t, fmt.Sprintf(` type = "offchainreporting" @@ -178,20 +180,21 @@ func TestRunner(t *testing.T) { ds1 [type=bridge name="%s"]; """ `, bridge.Name.String())) - require.NoError(t, jobORM.CreateJob(jb)) + require.NoError(t, jobORM.CreateJob(ctx, jb)) // Should not be able to delete a bridge in use. - jids, err := jobORM.FindJobIDsWithBridge(bridge.Name.String()) + jids, err := jobORM.FindJobIDsWithBridge(ctx, bridge.Name.String()) require.NoError(t, err) require.Equal(t, 1, len(jids)) // But if we delete the job, then we can. - require.NoError(t, jobORM.DeleteJob(jb.ID)) - jids, err = jobORM.FindJobIDsWithBridge(bridge.Name.String()) + require.NoError(t, jobORM.DeleteJob(ctx, jb.ID)) + jids, err = jobORM.FindJobIDsWithBridge(ctx, bridge.Name.String()) require.NoError(t, err) require.Equal(t, 0, len(jids)) }) t.Run("referencing a non-existent bridge should error", func(t *testing.T) { + ctx := testutils.Context(t) // Create a random bridge name _, b := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) @@ -223,7 +226,7 @@ func TestRunner(t *testing.T) { `, placeHolderAddress.String())) require.NoError(t, err) // Should error creating it - err = jobORM.CreateJob(&jb) + err = jobORM.CreateJob(ctx, &jb) require.Error(t, err) assert.Contains(t, err.Error(), "not all bridges exist") @@ -258,7 +261,7 @@ answer1 [type=median index=0]; `, placeHolderAddress.String(), b.Name.String()), nil) require.NoError(t, err) // Should error creating it because of the juels per fee coin non-existent bridge - err = jobORM.CreateJob(&jb2) + err = jobORM.CreateJob(ctx, &jb2) require.Error(t, err) assert.Contains(t, err.Error(), "not all bridges exist") @@ -297,11 +300,12 @@ answer1 [type=median index=0]; `, placeHolderAddress, b.Name.String(), b.Name.String(), b.Name.String()), nil) require.NoError(t, err) // Should not error with duplicate bridges - err = jobORM.CreateJob(&jb3) + err = jobORM.CreateJob(ctx, &jb3) require.NoError(t, err) }) t.Run("handles the case where the parsed value is literally null", func(t *testing.T) { + ctx := testutils.Context(t) var httpURL string resp := `{"USD": null}` { @@ -311,7 +315,7 @@ answer1 [type=median index=0]; // Need a job in order to create a run jb := makeSimpleFetchOCRJobSpecWithHTTPURL(t, transmitterAddress, httpURL, false) - err := jobORM.CreateJob(jb) + err := jobORM.CreateJob(ctx, jb) require.NoError(t, err) runID, taskResults, err := runner.ExecuteAndInsertFinishedRun(testutils.Context(t), *jb.PipelineSpec, pipeline.NewVarsFrom(nil), logger.TestLogger(t), true) @@ -357,7 +361,7 @@ answer1 [type=median index=0]; // Need a job in order to create a run jb := makeSimpleFetchOCRJobSpecWithHTTPURL(t, transmitterAddress, httpURL, false) - err := jobORM.CreateJob(jb) + err := jobORM.CreateJob(testutils.Context(t), jb) require.NoError(t, err) runID, taskResults, err := runner.ExecuteAndInsertFinishedRun(testutils.Context(t), *jb.PipelineSpec, pipeline.NewVarsFrom(nil), logger.TestLogger(t), true) @@ -402,7 +406,7 @@ answer1 [type=median index=0]; // Need a job in order to create a run jb := makeSimpleFetchOCRJobSpecWithHTTPURL(t, transmitterAddress, httpURL, true) - err := jobORM.CreateJob(jb) + err := jobORM.CreateJob(testutils.Context(t), jb) require.NoError(t, err) runID, taskResults, err := runner.ExecuteAndInsertFinishedRun(testutils.Context(t), *jb.PipelineSpec, pipeline.NewVarsFrom(nil), logger.TestLogger(t), true) @@ -451,13 +455,13 @@ answer1 [type=median index=0]; err = toml.Unmarshal([]byte(s), &jb) require.NoError(t, err) jb.MaxTaskDuration = models.Interval(cltest.MustParseDuration(t, "1s")) - err = jobORM.CreateJob(&jb) + err = jobORM.CreateJob(testutils.Context(t), &jb) require.NoError(t, err) lggr := logger.TestLogger(t) _, err = keyStore.P2P().Create(ctx) assert.NoError(t, err) - pw := ocrcommon.NewSingletonPeerWrapper(keyStore, config.P2P(), config.OCR(), config.Database(), db, lggr) + pw := ocrcommon.NewSingletonPeerWrapper(keyStore, config.P2P(), config.OCR(), db, lggr) servicetest.Run(t, pw) sd := ocr.NewDelegate( db, @@ -487,12 +491,12 @@ answer1 [type=median index=0]; require.NoError(t, err) jb.MaxTaskDuration = models.Interval(cltest.MustParseDuration(t, "1s")) - err = jobORM.CreateJob(&jb) + err = jobORM.CreateJob(testutils.Context(t), &jb) require.NoError(t, err) assert.Equal(t, jb.MaxTaskDuration, models.Interval(cltest.MustParseDuration(t, "1s"))) lggr := logger.TestLogger(t) - pw := ocrcommon.NewSingletonPeerWrapper(keyStore, config.P2P(), config.OCR(), config.Database(), db, lggr) + pw := ocrcommon.NewSingletonPeerWrapper(keyStore, config.P2P(), config.OCR(), db, lggr) servicetest.Run(t, pw) sd := ocr.NewDelegate( db, @@ -516,11 +520,11 @@ answer1 [type=median index=0]; require.NoError(t, err) err = toml.Unmarshal([]byte(s), &jb) require.NoError(t, err) - err = jobORM.CreateJob(&jb) + err = jobORM.CreateJob(testutils.Context(t), &jb) require.NoError(t, err) lggr := logger.TestLogger(t) - pw := ocrcommon.NewSingletonPeerWrapper(keyStore, config.P2P(), config.OCR(), config.Database(), db, lggr) + pw := ocrcommon.NewSingletonPeerWrapper(keyStore, config.P2P(), config.OCR(), db, lggr) servicetest.Run(t, pw) sd := ocr.NewDelegate( db, @@ -571,12 +575,12 @@ answer1 [type=median index=0]; require.NoError(t, err) jb.MaxTaskDuration = models.Interval(cltest.MustParseDuration(t, "1s")) - err = jobORM.CreateJob(&jb) + err = jobORM.CreateJob(testutils.Context(t), &jb) require.NoError(t, err) assert.Equal(t, jb.MaxTaskDuration, models.Interval(cltest.MustParseDuration(t, "1s"))) lggr := logger.TestLogger(t) - pw := ocrcommon.NewSingletonPeerWrapper(keyStore, config.P2P(), config.OCR(), config.Database(), db, lggr) + pw := ocrcommon.NewSingletonPeerWrapper(keyStore, config.P2P(), config.OCR(), db, lggr) servicetest.Run(t, pw) sd := ocr.NewDelegate( db, @@ -617,11 +621,11 @@ answer1 [type=median index=0]; jb := makeOCRJobSpecFromToml(t, spec) // Create an OCR job - err = jobORM.CreateJob(jb) + err = jobORM.CreateJob(ctx, jb) require.NoError(t, err) lggr := logger.TestLogger(t) - pw := ocrcommon.NewSingletonPeerWrapper(keyStore, config.P2P(), config.OCR(), config.Database(), db, lggr) + pw := ocrcommon.NewSingletonPeerWrapper(keyStore, config.P2P(), config.OCR(), db, lggr) servicetest.Run(t, pw) sd := ocr.NewDelegate( db, @@ -640,6 +644,7 @@ answer1 [type=median index=0]; // Return an error getting the contract code. ethClient.On("CodeAt", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("no such code")) + for _, s := range services { err = s.Start(ctx) require.NoError(t, err) @@ -659,7 +664,7 @@ answer1 [type=median index=0]; } // Ensure we can delete an errored - err = jobORM.DeleteJob(jb.ID) + err = jobORM.DeleteJob(ctx, jb.ID) require.NoError(t, err) se = []job.SpecError{} err = db.Select(&se, `SELECT * FROM job_spec_errors`) @@ -675,6 +680,7 @@ answer1 [type=median index=0]; }) t.Run("timeouts", func(t *testing.T) { + ctx := testutils.Context(t) // There are 4 timeouts: // - ObservationTimeout = how long the whole OCR time needs to run, or it fails (default 10 seconds) // - config.JobPipelineMaxTaskDuration() = node level maximum time for a pipeline task (default 10 minutes) @@ -689,7 +695,7 @@ answer1 [type=median index=0]; defer serv.Close() jb := makeMinimalHTTPOracleSpec(t, db, config, cltest.NewEIP55Address().String(), transmitterAddress.Hex(), cltest.DefaultOCRKeyBundleID, serv.URL, `timeout="1ns"`) - err := jobORM.CreateJob(jb) + err := jobORM.CreateJob(ctx, jb) require.NoError(t, err) _, taskResults, err := runner.ExecuteAndInsertFinishedRun(testutils.Context(t), *jb.PipelineSpec, pipeline.NewVarsFrom(nil), logger.TestLogger(t), true) @@ -700,7 +706,7 @@ answer1 [type=median index=0]; // No task timeout should succeed. jb = makeMinimalHTTPOracleSpec(t, db, config, cltest.NewEIP55Address().String(), transmitterAddress.Hex(), cltest.DefaultOCRKeyBundleID, serv.URL, "") jb.Name = null.NewString("a job 2", true) - err = jobORM.CreateJob(jb) + err = jobORM.CreateJob(ctx, jb) require.NoError(t, err) _, taskResults, err = runner.ExecuteAndInsertFinishedRun(testutils.Context(t), *jb.PipelineSpec, pipeline.NewVarsFrom(nil), logger.TestLogger(t), true) require.NoError(t, err) @@ -712,7 +718,7 @@ answer1 [type=median index=0]; jb = makeMinimalHTTPOracleSpec(t, db, config, cltest.NewEIP55Address().String(), transmitterAddress.Hex(), cltest.DefaultOCRKeyBundleID, serv.URL, "") jb.MaxTaskDuration = models.Interval(time.Duration(1)) jb.Name = null.NewString("a job 3", true) - err = jobORM.CreateJob(jb) + err = jobORM.CreateJob(ctx, jb) require.NoError(t, err) _, taskResults, err = runner.ExecuteAndInsertFinishedRun(testutils.Context(t), *jb.PipelineSpec, pipeline.NewVarsFrom(nil), logger.TestLogger(t), true) @@ -722,6 +728,7 @@ answer1 [type=median index=0]; }) t.Run("deleting jobs", func(t *testing.T) { + ctx := testutils.Context(t) var httpURL string { resp := `{"USD": 42.42}` @@ -731,7 +738,7 @@ answer1 [type=median index=0]; // Need a job in order to create a run jb := makeSimpleFetchOCRJobSpecWithHTTPURL(t, transmitterAddress, httpURL, false) - err := jobORM.CreateJob(jb) + err := jobORM.CreateJob(ctx, jb) require.NoError(t, err) _, taskResults, err := runner.ExecuteAndInsertFinishedRun(testutils.Context(t), *jb.PipelineSpec, pipeline.NewVarsFrom(nil), logger.TestLogger(t), true) @@ -742,7 +749,7 @@ answer1 [type=median index=0]; assert.Equal(t, "4242", results.Values[0].(decimal.Decimal).String()) // Delete the job - err = jobORM.DeleteJob(jb.ID) + err = jobORM.DeleteJob(ctx, jb.ID) require.NoError(t, err) // Create another run, it should fail @@ -849,7 +856,7 @@ func TestRunner_Success_Callback_AsyncJob(t *testing.T) { require.NoError(t, err) bridgeCalled <- struct{}{} })) - _, bridge := cltest.MustCreateBridge(t, app.GetSqlxDB(), cltest.BridgeOpts{URL: bridgeServer.URL}) + _, bridge := cltest.MustCreateBridge(t, app.GetDB(), cltest.BridgeOpts{URL: bridgeServer.URL}) bridgeName = bridge.Name.String() defer bridgeServer.Close() } @@ -879,7 +886,7 @@ func TestRunner_Success_Callback_AsyncJob(t *testing.T) { """ `, jobUUID, eiName, cltest.MustJSONMarshal(t, eiSpec), bridgeName) - _, err := webhook.ValidatedWebhookSpec(tomlSpec, app.GetExternalInitiatorManager()) + _, err := webhook.ValidatedWebhookSpec(testutils.Context(t), tomlSpec, app.GetExternalInitiatorManager()) require.NoError(t, err) job := cltest.CreateJobViaWeb(t, app, []byte(cltest.MustJSONMarshal(t, web.CreateJobRequest{TOML: tomlSpec}))) jobID = job.ID @@ -891,9 +898,9 @@ func TestRunner_Success_Callback_AsyncJob(t *testing.T) { _ = cltest.CreateJobRunViaExternalInitiatorV2(t, app, jobUUID, *eia, cltest.MustJSONMarshal(t, eiRequest)) - pipelineORM := pipeline.NewORM(app.GetSqlxDB(), logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - bridgesORM := bridges.NewORM(app.GetSqlxDB()) - jobORM := NewTestORM(t, app.GetSqlxDB(), pipelineORM, bridgesORM, app.KeyStore, cfg.Database()) + pipelineORM := pipeline.NewORM(app.GetDB(), logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + bridgesORM := bridges.NewORM(app.GetDB()) + jobORM := NewTestORM(t, app.GetDB(), pipelineORM, bridgesORM, app.KeyStore) // Trigger v2/resume select { @@ -1028,7 +1035,7 @@ func TestRunner_Error_Callback_AsyncJob(t *testing.T) { require.NoError(t, err) bridgeCalled <- struct{}{} })) - _, bridge := cltest.MustCreateBridge(t, app.GetSqlxDB(), cltest.BridgeOpts{URL: bridgeServer.URL}) + _, bridge := cltest.MustCreateBridge(t, app.GetDB(), cltest.BridgeOpts{URL: bridgeServer.URL}) bridgeName = bridge.Name.String() defer bridgeServer.Close() } @@ -1058,19 +1065,20 @@ func TestRunner_Error_Callback_AsyncJob(t *testing.T) { """ `, jobUUID, eiName, cltest.MustJSONMarshal(t, eiSpec), bridgeName) - _, err := webhook.ValidatedWebhookSpec(tomlSpec, app.GetExternalInitiatorManager()) + _, err := webhook.ValidatedWebhookSpec(testutils.Context(t), tomlSpec, app.GetExternalInitiatorManager()) require.NoError(t, err) job := cltest.CreateJobViaWeb(t, app, []byte(cltest.MustJSONMarshal(t, web.CreateJobRequest{TOML: tomlSpec}))) jobID = job.ID require.Eventually(t, func() bool { return eiNotifiedOfCreate }, 5*time.Second, 10*time.Millisecond, "expected external initiator to be notified of new job") } + t.Run("simulate request from EI -> Core node with erroring callback", func(t *testing.T) { _ = cltest.CreateJobRunViaExternalInitiatorV2(t, app, jobUUID, *eia, cltest.MustJSONMarshal(t, eiRequest)) - pipelineORM := pipeline.NewORM(app.GetSqlxDB(), logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - bridgesORM := bridges.NewORM(app.GetSqlxDB()) - jobORM := NewTestORM(t, app.GetSqlxDB(), pipelineORM, bridgesORM, app.KeyStore, cfg.Database()) + pipelineORM := pipeline.NewORM(app.GetDB(), logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + bridgesORM := bridges.NewORM(app.GetDB()) + jobORM := NewTestORM(t, app.GetDB(), pipelineORM, bridgesORM, app.KeyStore) // Trigger v2/resume select { diff --git a/core/services/job/spawner.go b/core/services/job/spawner.go index 8024424226c..6bb2cdbf76b 100644 --- a/core/services/job/spawner.go +++ b/core/services/job/spawner.go @@ -9,13 +9,11 @@ import ( pkgerrors "github.com/pkg/errors" - "github.com/jmoiron/sqlx" - "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) //go:generate mockery --quiet --name Spawner --output ./mocks/ --case=underscore @@ -29,16 +27,16 @@ type ( // CreateJob creates a new job and starts services. // All services must start without errors for the job to be active. - CreateJob(jb *Job, qopts ...pg.QOpt) (err error) + CreateJob(ctx context.Context, ds sqlutil.DataSource, jb *Job) (err error) // DeleteJob deletes a job and stops any active services. - DeleteJob(jobID int32, qopts ...pg.QOpt) error + DeleteJob(ctx context.Context, ds sqlutil.DataSource, jobID int32) error // ActiveJobs returns a map of jobs with active services (started without error). ActiveJobs() map[int32]Job // StartService starts services for the given job spec. // NOTE: Prefer to use CreateJob, this is only publicly exposed for use in tests // to start a job that was previously manually inserted into DB - StartService(ctx context.Context, spec Job, qopts ...pg.QOpt) error + StartService(ctx context.Context, spec Job) error } Checker interface { @@ -54,7 +52,6 @@ type ( jobTypeDelegates map[Type]Delegate activeJobs map[int32]activeJob activeJobsMu sync.RWMutex - q pg.Q lggr logger.Logger chStop services.StopChan @@ -90,14 +87,13 @@ type ( var _ Spawner = (*spawner)(nil) -func NewSpawner(orm ORM, config Config, checker Checker, jobTypeDelegates map[Type]Delegate, db *sqlx.DB, lggr logger.Logger, lbDependentAwaiters []utils.DependentAwaiter) *spawner { +func NewSpawner(orm ORM, config Config, checker Checker, jobTypeDelegates map[Type]Delegate, lggr logger.Logger, lbDependentAwaiters []utils.DependentAwaiter) *spawner { namedLogger := lggr.Named("JobSpawner") s := &spawner{ orm: orm, config: config, checker: checker, jobTypeDelegates: jobTypeDelegates, - q: pg.NewQ(db, namedLogger, config), lggr: namedLogger, activeJobs: make(map[int32]activeJob), chStop: make(services.StopChan), @@ -134,7 +130,7 @@ func (js *spawner) HealthReport() map[string]error { func (js *spawner) startAllServices(ctx context.Context) { // TODO: rename to find AllJobs - specs, _, err := js.orm.FindJobs(0, math.MaxUint32) + specs, _, err := js.orm.FindJobs(ctx, 0, math.MaxUint32) if err != nil { werr := fmt.Errorf("couldn't fetch unclaimed jobs: %v", err) js.lggr.Critical(werr.Error()) @@ -191,7 +187,7 @@ func (js *spawner) stopService(jobID int32) { delete(js.activeJobs, jobID) } -func (js *spawner) StartService(ctx context.Context, jb Job, qopts ...pg.QOpt) error { +func (js *spawner) StartService(ctx context.Context, jb Job) error { lggr := js.lggr.With("jobID", jb.ID) js.activeJobsMu.Lock() defer js.activeJobsMu.Unlock() @@ -220,7 +216,7 @@ func (js *spawner) StartService(ctx context.Context, jb Job, qopts ...pg.QOpt) e lggr.Errorw("Error creating services for job", "err", err) cctx, cancel := js.chStop.NewCtx() defer cancel() - js.orm.TryRecordError(jb.ID, err.Error(), pg.WithParentCtx(cctx)) + js.orm.TryRecordError(cctx, jb.ID, err.Error()) js.activeJobs[jb.ID] = aj return pkgerrors.Wrapf(err, "failed to create services for job: %d", jb.ID) } @@ -249,7 +245,11 @@ func (js *spawner) StartService(ctx context.Context, jb Job, qopts ...pg.QOpt) e } // Should not get called before Start() -func (js *spawner) CreateJob(jb *Job, qopts ...pg.QOpt) (err error) { +func (js *spawner) CreateJob(ctx context.Context, ds sqlutil.DataSource, jb *Job) (err error) { + orm := js.orm + if ds != nil { + orm = orm.WithDataSource(ds) + } delegate, exists := js.jobTypeDelegates[jb.Type] if !exists { js.lggr.Errorf("job type '%s' has not been registered with the job.Spawner", jb.Type) @@ -257,15 +257,7 @@ func (js *spawner) CreateJob(jb *Job, qopts ...pg.QOpt) (err error) { return } - q := js.q.WithOpts(qopts...) - pctx, cancel := js.chStop.Ctx(q.ParentCtx) - defer cancel() - q.ParentCtx = pctx - - ctx, cancel := q.Context() - defer cancel() - - err = js.orm.CreateJob(jb, pg.WithQueryer(q.Queryer), pg.WithParentCtx(ctx)) + err = orm.CreateJob(ctx, jb) if err != nil { js.lggr.Errorw("Error creating job", "type", jb.Type, "err", err) return @@ -273,7 +265,7 @@ func (js *spawner) CreateJob(jb *Job, qopts ...pg.QOpt) (err error) { js.lggr.Infow("Created job", "type", jb.Type, "jobID", jb.ID) delegate.BeforeJobCreated(*jb) - err = js.StartService(pctx, *jb, pg.WithQueryer(q.Queryer)) + err = js.StartService(ctx, *jb) if err != nil { js.lggr.Errorw("Error starting job services", "type", jb.Type, "jobID", jb.ID, "err", err) } else { @@ -286,7 +278,10 @@ func (js *spawner) CreateJob(jb *Job, qopts ...pg.QOpt) (err error) { } // Should not get called before Start() -func (js *spawner) DeleteJob(jobID int32, qopts ...pg.QOpt) error { +func (js *spawner) DeleteJob(ctx context.Context, ds sqlutil.DataSource, jobID int32) error { + if ds == nil { + ds = js.orm.DataSource() + } if jobID == 0 { return pkgerrors.New("will not delete job with 0 ID") } @@ -302,15 +297,8 @@ func (js *spawner) DeleteJob(jobID int32, qopts ...pg.QOpt) error { aj, exists = js.activeJobs[jobID] }() - q := js.q.WithOpts(qopts...) - pctx, cancel := js.chStop.Ctx(q.ParentCtx) - defer cancel() - q.ParentCtx = pctx - ctx, cancel := q.Context() - defer cancel() - if !exists { // inactive, so look up the spec and delegate - jb, err := js.orm.FindJob(ctx, jobID) + jb, err := js.orm.WithDataSource(ds).FindJob(ctx, jobID) if err != nil { return pkgerrors.Wrapf(err, "job %d not found", jobID) } @@ -330,8 +318,8 @@ func (js *spawner) DeleteJob(jobID int32, qopts ...pg.QOpt) error { aj.delegate.BeforeJobDeleted(aj.spec) lggr.Debugw("Callback: BeforeDeleteJob done") - err := q.Transaction(func(tx pg.Queryer) error { - err := js.orm.DeleteJob(jobID, pg.WithQueryer(tx)) + err := sqlutil.Transact(ctx, js.orm.WithDataSource, ds, nil, func(tx ORM) error { + err := tx.DeleteJob(ctx, jobID) if err != nil { js.lggr.Errorw("Error deleting job", "jobID", jobID, "err", err) return err diff --git a/core/services/job/spawner_test.go b/core/services/job/spawner_test.go index ae8ffaa2161..7b4ab138e7c 100644 --- a/core/services/job/spawner_test.go +++ b/core/services/job/spawner_test.go @@ -101,10 +101,10 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) t.Run("should respect its dependents", func(t *testing.T) { lggr := logger.TestLogger(t) - orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) + orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore) a := utils.NewDependentAwaiter() a.AddDependents(1) - spawner := job.NewSpawner(orm, config.Database(), noopChecker{}, map[job.Type]job.Delegate{}, db, lggr, []utils.DependentAwaiter{a}) + spawner := job.NewSpawner(orm, config.Database(), noopChecker{}, map[job.Type]job.Delegate{}, lggr, []utils.DependentAwaiter{a}) // Starting the spawner should signal to the dependents result := make(chan bool) go func() { @@ -124,7 +124,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { jobB := makeOCRJobSpec(t, address, bridge.Name.String(), bridge2.Name.String()) lggr := logger.TestLogger(t) - orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) + orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore) eventuallyA := cltest.NewAwaiter() serviceA1 := mocks.NewServiceCtx(t) @@ -146,9 +146,10 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { spawner := job.NewSpawner(orm, config.Database(), noopChecker{}, map[job.Type]job.Delegate{ jobA.Type: delegateA, jobB.Type: delegateB, - }, db, lggr, nil) - require.NoError(t, spawner.Start(testutils.Context(t))) - err := spawner.CreateJob(jobA) + }, lggr, nil) + ctx := testutils.Context(t) + require.NoError(t, spawner.Start(ctx)) + err := spawner.CreateJob(ctx, nil, jobA) require.NoError(t, err) jobSpecIDA := jobA.ID delegateA.jobID = jobSpecIDA @@ -156,7 +157,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { eventuallyA.AwaitOrFail(t, 20*time.Second) - err = spawner.CreateJob(jobB) + err = spawner.CreateJob(ctx, nil, jobB) require.NoError(t, err) jobSpecIDB := jobB.ID delegateB.jobID = jobSpecIDB @@ -166,12 +167,12 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { serviceA1.On("Close").Return(nil).Once() serviceA2.On("Close").Return(nil).Once() - err = spawner.DeleteJob(jobSpecIDA) + err = spawner.DeleteJob(ctx, nil, jobSpecIDA) require.NoError(t, err) serviceB1.On("Close").Return(nil).Once() serviceB2.On("Close").Return(nil).Once() - err = spawner.DeleteJob(jobSpecIDB) + err = spawner.DeleteJob(ctx, nil, jobSpecIDB) require.NoError(t, err) require.NoError(t, spawner.Close()) @@ -189,19 +190,20 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { serviceA2.On("Start", mock.Anything).Return(nil).Once().Run(func(mock.Arguments) { eventually.ItHappened() }) lggr := logger.TestLogger(t) - orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) + orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore) mailMon := servicetest.Run(t, mailboxtest.NewMonitor(t)) d := ocr.NewDelegate(nil, orm, nil, nil, nil, monitoringEndpoint, legacyChains, logger.TestLogger(t), config, mailMon) delegateA := &delegate{jobA.Type, []job.ServiceCtx{serviceA1, serviceA2}, 0, nil, d} spawner := job.NewSpawner(orm, config.Database(), noopChecker{}, map[job.Type]job.Delegate{ jobA.Type: delegateA, - }, db, lggr, nil) + }, lggr, nil) - err := orm.CreateJob(jobA) + ctx := testutils.Context(t) + err := orm.CreateJob(ctx, jobA) require.NoError(t, err) delegateA.jobID = jobA.ID - require.NoError(t, spawner.Start(testutils.Context(t))) + require.NoError(t, spawner.Start(ctx)) eventually.AwaitOrFail(t) @@ -223,20 +225,21 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { serviceA2.On("Start", mock.Anything).Return(nil).Once().Run(func(mock.Arguments) { eventuallyStart.ItHappened() }) lggr := logger.TestLogger(t) - orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) + orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore) mailMon := servicetest.Run(t, mailboxtest.NewMonitor(t)) d := ocr.NewDelegate(nil, orm, nil, nil, nil, monitoringEndpoint, legacyChains, logger.TestLogger(t), config, mailMon) delegateA := &delegate{jobA.Type, []job.ServiceCtx{serviceA1, serviceA2}, 0, nil, d} spawner := job.NewSpawner(orm, config.Database(), noopChecker{}, map[job.Type]job.Delegate{ jobA.Type: delegateA, - }, db, lggr, nil) + }, lggr, nil) - err := orm.CreateJob(jobA) + ctx := testutils.Context(t) + err := orm.CreateJob(ctx, jobA) require.NoError(t, err) jobSpecIDA := jobA.ID delegateA.jobID = jobSpecIDA - require.NoError(t, spawner.Start(testutils.Context(t))) + require.NoError(t, spawner.Start(ctx)) defer func() { assert.NoError(t, spawner.Close()) }() eventuallyStart.AwaitOrFail(t) @@ -252,7 +255,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { serviceA1.On("Close").Return(nil).Once() serviceA2.On("Close").Return(nil).Once().Run(func(mock.Arguments) { eventuallyClose.ItHappened() }) - err = spawner.DeleteJob(jobSpecIDA) + err = spawner.DeleteJob(ctx, nil, jobSpecIDA) require.NoError(t, err) eventuallyClose.AwaitOrFail(t) @@ -287,9 +290,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { chain := evmtest.MustGetDefaultChain(t, legacyChains) evmRelayer, err := evmrelayer.NewRelayer(lggr, chain, evmrelayer.RelayerOpts{ - DB: db, DS: db, - QConfig: testopts.GeneralConfig.Database(), CSAETHKeystore: keyStore, }) assert.NoError(t, err) @@ -299,23 +300,24 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { r: evmRelayer, } - jobOCR2VRF := makeOCR2VRFJobSpec(t, keyStore, config, address, chain.ID(), 2) + jobOCR2VRF := makeOCR2VRFJobSpec(t, keyStore, address, chain.ID(), 2) - orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) + orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore) mailMon := servicetest.Run(t, mailboxtest.NewMonitor(t)) processConfig := plugins.NewRegistrarConfig(loop.GRPCOpts{}, func(name string) (*plugins.RegisteredLoop, error) { return nil, nil }, func(loopId string) {}) - ocr2DelegateConfig := ocr2.NewDelegateConfig(config.OCR2(), config.Mercury(), config.Threshold(), config.Insecure(), config.JobPipeline(), config.Database(), processConfig) + ocr2DelegateConfig := ocr2.NewDelegateConfig(config.OCR2(), config.Mercury(), config.Threshold(), config.Insecure(), config.JobPipeline(), processConfig) - d := ocr2.NewDelegate(nil, nil, orm, nil, nil, nil, nil, nil, monitoringEndpoint, legacyChains, lggr, ocr2DelegateConfig, + d := ocr2.NewDelegate(nil, orm, nil, nil, nil, nil, nil, monitoringEndpoint, legacyChains, lggr, ocr2DelegateConfig, keyStore.OCR2(), keyStore.DKGSign(), keyStore.DKGEncrypt(), ethKeyStore, testRelayGetter, mailMon, capabilities.NewRegistry(lggr)) delegateOCR2 := &delegate{jobOCR2VRF.Type, []job.ServiceCtx{}, 0, nil, d} spawner := job.NewSpawner(orm, config.Database(), noopChecker{}, map[job.Type]job.Delegate{ jobOCR2VRF.Type: delegateOCR2, - }, db, lggr, nil) + }, lggr, nil) - err = spawner.CreateJob(jobOCR2VRF) + ctx := testutils.Context(t) + err = spawner.CreateJob(ctx, nil, jobOCR2VRF) require.NoError(t, err) jobSpecID := jobOCR2VRF.ID delegateOCR2.jobID = jobOCR2VRF.ID @@ -324,7 +326,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { lggr.Debugf("Got here, with args %v", args) }) - err = spawner.DeleteJob(jobSpecID) + err = spawner.DeleteJob(ctx, nil, jobSpecID) require.NoError(t, err) lp.AssertNumberOfCalls(t, "UnregisterFilter", 3) diff --git a/core/services/keeper/delegate.go b/core/services/keeper/delegate.go index 184a61e1e1a..71a0c5c43a9 100644 --- a/core/services/keeper/delegate.go +++ b/core/services/keeper/delegate.go @@ -5,8 +5,7 @@ import ( "github.com/pkg/errors" - "github.com/jmoiron/sqlx" - + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" "github.com/smartcontractkit/chainlink/v2/core/chains/legacyevm" "github.com/smartcontractkit/chainlink/v2/core/config" @@ -25,7 +24,7 @@ type DelegateConfig interface { type Delegate struct { cfg DelegateConfig logger logger.Logger - db *sqlx.DB + ds sqlutil.DataSource jrm job.ORM pr pipeline.Runner legacyChains legacyevm.LegacyChainContainer @@ -35,7 +34,7 @@ type Delegate struct { // NewDelegate is the constructor of Delegate func NewDelegate( cfg DelegateConfig, - db *sqlx.DB, + ds sqlutil.DataSource, jrm job.ORM, pr pipeline.Runner, logger logger.Logger, @@ -45,7 +44,7 @@ func NewDelegate( return &Delegate{ cfg: cfg, logger: logger, - db: db, + ds: ds, jrm: jrm, pr: pr, legacyChains: legacyChains, @@ -58,10 +57,10 @@ func (d *Delegate) JobType() job.Type { return job.Keeper } -func (d *Delegate) BeforeJobCreated(spec job.Job) {} -func (d *Delegate) AfterJobCreated(spec job.Job) {} -func (d *Delegate) BeforeJobDeleted(spec job.Job) {} -func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } +func (d *Delegate) BeforeJobCreated(spec job.Job) {} +func (d *Delegate) AfterJobCreated(spec job.Job) {} +func (d *Delegate) BeforeJobDeleted(spec job.Job) {} +func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job) error { return nil } // ServicesForSpec satisfies the job.Delegate interface. func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) (services []job.ServiceCtx, err error) { @@ -73,7 +72,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) (services return nil, err } registryAddress := spec.KeeperSpec.ContractAddress - orm := NewORM(d.db, d.logger) + orm := NewORM(d.ds, d.logger) svcLogger := d.logger.With( "jobID", spec.ID, "registryAddress", registryAddress.Hex(), diff --git a/core/services/keeper/helpers_test.go b/core/services/keeper/helpers_test.go index 3fb9d7760a4..a10c3a48aff 100644 --- a/core/services/keeper/helpers_test.go +++ b/core/services/keeper/helpers_test.go @@ -12,6 +12,10 @@ func (rs *RegistrySynchronizer) ExportedFullSync(ctx context.Context) { rs.fullSync(ctx) } +func (rs *RegistrySynchronizer) ExportedProcessLogs(ctx context.Context) { + rs.processLogs(ctx) +} + func (rw *RegistryWrapper) GetUpkeepIdFromRawRegistrationLog(rawLog types.Log) (*big.Int, error) { switch rw.Version { case RegistryVersion_1_0, RegistryVersion_1_1: diff --git a/core/services/keeper/integration_test.go b/core/services/keeper/integration_test.go index 08699d3d835..9e4cf5f9041 100644 --- a/core/services/keeper/integration_test.go +++ b/core/services/keeper/integration_test.go @@ -284,7 +284,7 @@ func TestKeeperEthIntegration(t *testing.T) { require.NoError(t, err) backend.Commit() - cltest.WaitForCount(t, app.GetSqlxDB(), "upkeep_registrations", 0) + cltest.WaitForCount(t, app.GetDB(), "upkeep_registrations", 0) // add new upkeep (same target contract) registrationTx, err = registryWrapper.RegisterUpkeep(steve, upkeepAddr, 2_500_000, carrol.From, []byte{}) @@ -308,8 +308,8 @@ func TestKeeperEthIntegration(t *testing.T) { require.NoError(t, err) var registry keeper.Registry - require.NoError(t, app.GetSqlxDB().Get(®istry, `SELECT * FROM keeper_registries`)) - cltest.AssertRecordEventually(t, app.GetSqlxDB(), ®istry, fmt.Sprintf("SELECT * FROM keeper_registries WHERE id = %d", registry.ID), func() bool { + require.NoError(t, app.GetDB().GetContext(ctx, ®istry, `SELECT * FROM keeper_registries`)) + cltest.AssertRecordEventually(t, app.GetDB(), ®istry, fmt.Sprintf("SELECT * FROM keeper_registries WHERE id = %d", registry.ID), func() bool { return registry.KeeperIndex == -1 }) runs, err := app.PipelineORM().GetAllRuns(ctx) @@ -435,7 +435,7 @@ func TestKeeperForwarderEthIntegration(t *testing.T) { SchemaVersion: 1, ForwardingAllowed: true, } - err = app.JobORM().CreateJob(&jb) + err = app.JobORM().CreateJob(testutils.Context(t), &jb) require.NoError(t, err) registry := keeper.Registry{ diff --git a/core/services/keeper/registry_synchronizer_helper_test.go b/core/services/keeper/registry_synchronizer_helper_test.go index 2726c9a754d..73a3cb88166 100644 --- a/core/services/keeper/registry_synchronizer_helper_test.go +++ b/core/services/keeper/registry_synchronizer_helper_test.go @@ -46,7 +46,7 @@ func setupRegistrySync(t *testing.T, version keeper.RegistryVersion) ( j := cltest.MustInsertKeeperJob(t, db, korm, cltest.NewEIP55Address(), cltest.NewEIP55Address()) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, Client: ethClient, LogBroadcaster: lbMock, GeneralConfig: cfg, KeyStore: keyStore.Eth()}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) - jpv2 := cltest.NewJobPipelineV2(t, cfg.WebServer(), cfg.JobPipeline(), cfg.Database(), legacyChains, db, keyStore, nil, nil) + jpv2 := cltest.NewJobPipelineV2(t, cfg.WebServer(), cfg.JobPipeline(), legacyChains, db, keyStore, nil, nil) contractAddress := j.KeeperSpec.ContractAddress.Address() switch version { diff --git a/core/services/keeper/registry_synchronizer_process_logs.go b/core/services/keeper/registry_synchronizer_process_logs.go index a1bdcd8db0b..5aaddb6f1e4 100644 --- a/core/services/keeper/registry_synchronizer_process_logs.go +++ b/core/services/keeper/registry_synchronizer_process_logs.go @@ -82,6 +82,9 @@ func (rs *RegistrySynchronizer) processLogs(ctx context.Context) { } if err != nil { + if ctx.Err() != nil { + return + } rs.logger.Error(err) } diff --git a/core/services/keeper/registry_synchronizer_sync.go b/core/services/keeper/registry_synchronizer_sync.go index 6c0e12d844b..6615d376e2b 100644 --- a/core/services/keeper/registry_synchronizer_sync.go +++ b/core/services/keeper/registry_synchronizer_sync.go @@ -30,7 +30,7 @@ func (rs *RegistrySynchronizer) fullSync(ctx context.Context) { } func (rs *RegistrySynchronizer) syncRegistry(ctx context.Context) (Registry, error) { - registry, err := rs.newRegistryFromChain() + registry, err := rs.newRegistryFromChain(ctx) if err != nil { return Registry{}, errors.Wrap(err, "failed to get new registry from chain") } @@ -138,13 +138,13 @@ func (rs *RegistrySynchronizer) syncUpkeep(ctx context.Context, getter upkeepGet } // newRegistryFromChain returns a Registry struct with fields synched from those on chain -func (rs *RegistrySynchronizer) newRegistryFromChain() (Registry, error) { +func (rs *RegistrySynchronizer) newRegistryFromChain(ctx context.Context) (Registry, error) { fromAddress := rs.effectiveKeeperAddress contractAddress := rs.job.KeeperSpec.ContractAddress registryConfig, err := rs.registryWrapper.GetConfig(nil) if err != nil { - rs.jrm.TryRecordError(rs.job.ID, err.Error()) + rs.jrm.TryRecordError(ctx, rs.job.ID, err.Error()) return Registry{}, errors.Wrap(err, "failed to get contract config") } diff --git a/core/services/keeper/upkeep_executer_test.go b/core/services/keeper/upkeep_executer_test.go index ec23331f904..cd02fc27d11 100644 --- a/core/services/keeper/upkeep_executer_test.go +++ b/core/services/keeper/upkeep_executer_test.go @@ -83,7 +83,7 @@ func setup(t *testing.T, estimator gas.EvmFeeEstimator, overrideFn func(c *chain txm := txmmocks.NewMockEvmTxManager(t) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{TxManager: txm, DB: db, Client: ethClient, KeyStore: keyStore.Eth(), GeneralConfig: cfg, GasEstimator: estimator}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) - jpv2 := cltest.NewJobPipelineV2(t, cfg.WebServer(), cfg.JobPipeline(), cfg.Database(), legacyChains, db, keyStore, nil, nil) + jpv2 := cltest.NewJobPipelineV2(t, cfg.WebServer(), cfg.JobPipeline(), legacyChains, db, keyStore, nil, nil) ch := evmtest.MustGetDefaultChain(t, legacyChains) orm := keeper.NewORM(db, logger.TestLogger(t)) registry, jb := cltest.MustInsertKeeperRegistry(t, db, orm, keyStore.Eth(), 0, 1, 20) diff --git a/core/services/keystore/eth_test.go b/core/services/keystore/eth_test.go index 07a41599f7d..4a9c8a952ff 100644 --- a/core/services/keystore/eth_test.go +++ b/core/services/keystore/eth_test.go @@ -13,8 +13,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" commonutils "github.com/smartcontractkit/chainlink-common/pkg/utils" evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils" @@ -29,9 +27,7 @@ import ( func Test_EthKeyStore(t *testing.T) { t.Parallel() - db := sqlutil.WrapDataSource(pgtest.NewSqlxDB(t), logger.Test(t), sqlutil.TimeoutHook(func() time.Duration { - return 5 * time.Minute - }), sqlutil.MonitorHook(func() bool { return true })) + db := pgtest.NewSqlxDB(t) keyStore := keystore.ExposedNewMaster(t, db) err := keyStore.Unlock(testutils.Context(t), cltest.Password) diff --git a/core/services/llo/delegate.go b/core/services/llo/delegate.go index 2bab0ab12a2..6cdad6290fc 100644 --- a/core/services/llo/delegate.go +++ b/core/services/llo/delegate.go @@ -13,13 +13,13 @@ import ( "gopkg.in/guregu/null.v4" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" "github.com/smartcontractkit/chainlink-data-streams/llo" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/llo/evm" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/streams" ) @@ -43,11 +43,11 @@ type delegate struct { } type DelegateConfig struct { - Logger logger.Logger - Queryer pg.Queryer - Runner streams.Runner - Registry Registry - JobName null.String + Logger logger.Logger + DataSource sqlutil.DataSource + Runner streams.Runner + Registry Registry + JobName null.String // LLO ChannelDefinitionCache llotypes.ChannelDefinitionCache @@ -67,8 +67,8 @@ type DelegateConfig struct { } func NewDelegate(cfg DelegateConfig) (job.ServiceCtx, error) { - if cfg.Queryer == nil { - return nil, errors.New("Queryer must not be nil") + if cfg.DataSource == nil { + return nil, errors.New("DataSource must not be nil") } if cfg.Runner == nil { return nil, errors.New("Runner must not be nil") diff --git a/core/services/ocr/contract_tracker.go b/core/services/ocr/contract_tracker.go index 5746f97cd38..f7c7a9940b5 100644 --- a/core/services/ocr/contract_tracker.go +++ b/core/services/ocr/contract_tracker.go @@ -90,8 +90,9 @@ type ( } OCRContractTrackerDB interface { - SaveLatestRoundRequested(ctx context.Context, tx sqlutil.DataSource, rr offchainaggregator.OffchainAggregatorRoundRequested) error + SaveLatestRoundRequested(ctx context.Context, rr offchainaggregator.OffchainAggregatorRoundRequested) error LoadLatestRoundRequested(ctx context.Context) (rr offchainaggregator.OffchainAggregatorRoundRequested, err error) + WithDataSource(sqlutil.DataSource) OCRContractTrackerDB } ) @@ -293,7 +294,7 @@ func (t *OCRContractTracker) HandleLog(ctx context.Context, lb log.Broadcast) { } if IsLaterThan(raw, t.latestRoundRequested.Raw) { err = sqlutil.TransactDataSource(ctx, t.ds, nil, func(tx sqlutil.DataSource) error { - if err = t.ocrDB.SaveLatestRoundRequested(ctx, tx, *rr); err != nil { + if err = t.ocrDB.WithDataSource(tx).SaveLatestRoundRequested(ctx, *rr); err != nil { return err } return t.logBroadcaster.MarkConsumed(ctx, tx, lb) diff --git a/core/services/ocr/contract_tracker_test.go b/core/services/ocr/contract_tracker_test.go index 6f8b05c6436..d8fd8a01a11 100644 --- a/core/services/ocr/contract_tracker_test.go +++ b/core/services/ocr/contract_tracker_test.go @@ -224,9 +224,10 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) - uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { + uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { return rr.Epoch == 1 && rr.Round == 1 })).Return(nil) + uni.db.On("WithDataSource", mock.Anything).Return(uni.db) uni.tracker.HandleLog(testutils.Context(t), logBroadcast) @@ -244,7 +245,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) - uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { + uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { return rr.Epoch == 1 && rr.Round == 9 })).Return(nil) @@ -273,7 +274,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) - uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { + uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { return rr.Epoch == 2 && rr.Round == 1 })).Return(nil) @@ -295,7 +296,8 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin logBroadcast.On("String").Return("").Maybe() uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("something exploded")) + uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything).Return(errors.New("something exploded")) + uni.db.On("WithDataSource", mock.Anything).Return(uni.db) uni.tracker.HandleLog(testutils.Context(t), logBroadcast) diff --git a/core/services/ocr/database.go b/core/services/ocr/database.go index 95993de9d5c..b5f890565f1 100644 --- a/core/services/ocr/database.go +++ b/core/services/ocr/database.go @@ -38,6 +38,9 @@ func NewDB(ds sqlutil.DataSource, oracleSpecID int32, lggr logger.Logger) *db { lggr: logger.Sugared(lggr), } } +func (d *db) WithDataSource(ds sqlutil.DataSource) OCRContractTrackerDB { + return NewDB(ds, d.oracleSpecID, d.lggr) +} func (d *db) ReadState(ctx context.Context, cd ocrtypes.ConfigDigest) (ps *ocrtypes.PersistentState, err error) { stmt := ` @@ -293,12 +296,12 @@ WHERE ocr_oracle_spec_id = $1 AND time < $2 return } -func (d *db) SaveLatestRoundRequested(ctx context.Context, tx sqlutil.DataSource, rr offchainaggregator.OffchainAggregatorRoundRequested) error { +func (d *db) SaveLatestRoundRequested(ctx context.Context, rr offchainaggregator.OffchainAggregatorRoundRequested) error { rawLog, err := json.Marshal(rr.Raw) if err != nil { return errors.Wrap(err, "could not marshal log as JSON") } - _, err = tx.ExecContext(ctx, ` + _, err = d.ds.ExecContext(ctx, ` INSERT INTO ocr_latest_round_requested (ocr_oracle_spec_id, requester, config_digest, epoch, round, raw) VALUES ($1,$2,$3,$4,$5,$6) ON CONFLICT (ocr_oracle_spec_id) DO UPDATE SET requester = EXCLUDED.requester, diff --git a/core/services/ocr/database_test.go b/core/services/ocr/database_test.go index a2559ca2a87..12f9309450c 100644 --- a/core/services/ocr/database_test.go +++ b/core/services/ocr/database_test.go @@ -407,7 +407,7 @@ func Test_DB_LatestRoundRequested(t *testing.T) { t.Run("saves latest round requested", func(t *testing.T) { ctx := testutils.Context(t) - err := odb.SaveLatestRoundRequested(ctx, sqlDB, rr) + err := odb.SaveLatestRoundRequested(ctx, rr) require.NoError(t, err) rawLog.Index = 42 @@ -421,7 +421,7 @@ func Test_DB_LatestRoundRequested(t *testing.T) { Raw: rawLog, } - err = odb.SaveLatestRoundRequested(ctx, sqlDB, rr) + err = odb.SaveLatestRoundRequested(ctx, rr) require.NoError(t, err) }) diff --git a/core/services/ocr/delegate.go b/core/services/ocr/delegate.go index 88561bd1c3a..690e9ad7c71 100644 --- a/core/services/ocr/delegate.go +++ b/core/services/ocr/delegate.go @@ -10,13 +10,12 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" - "github.com/jmoiron/sqlx" - "github.com/smartcontractkit/libocr/gethwrappers/offchainaggregator" ocr "github.com/smartcontractkit/libocr/offchainreporting" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting/types" commonlogger "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" txmgrcommon "github.com/smartcontractkit/chainlink/v2/common/txmgr" @@ -35,7 +34,7 @@ import ( ) type Delegate struct { - db *sqlx.DB + ds sqlutil.DataSource jobORM job.ORM keyStore keystore.Master pipelineRunner pipeline.Runner @@ -52,7 +51,7 @@ var _ job.Delegate = (*Delegate)(nil) const ConfigOverriderPollInterval = 30 * time.Second func NewDelegate( - db *sqlx.DB, + ds sqlutil.DataSource, jobORM job.ORM, keyStore keystore.Master, pipelineRunner pipeline.Runner, @@ -64,7 +63,7 @@ func NewDelegate( mailMon *mailbox.Monitor, ) *Delegate { return &Delegate{ - db: db, + ds: ds, jobORM: jobORM, keyStore: keyStore, pipelineRunner: pipelineRunner, @@ -120,7 +119,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] return nil, errors.Wrap(err, "could not instantiate NewOffchainAggregatorCaller") } - ocrDB := NewDB(d.db, concreteSpec.ID, lggr) + ocrDB := NewDB(d.ds, concreteSpec.ID, lggr) tracker := NewOCRContractTracker( contract, @@ -130,7 +129,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] chain.LogBroadcaster(), jb.ID, lggr, - d.db, + d.ds, ocrDB, chain.Config().EVM(), chain.HeadBroadcaster(), @@ -157,7 +156,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] } ocrLogger := commonlogger.NewOCRWrapper(lggr, d.cfg.OCR().TraceLogging(), func(msg string) { - d.jobORM.TryRecordError(jb.ID, msg) + d.jobORM.TryRecordError(ctx, jb.ID, msg) }) lc := toLocalConfig(chain.Config().EVM(), chain.Config().EVM().OCR(), d.cfg.Insecure(), *concreteSpec, d.cfg.OCR()) diff --git a/core/services/ocr/helpers_internal_test.go b/core/services/ocr/helpers_internal_test.go index c6a3d1ac401..a8c656f636c 100644 --- a/core/services/ocr/helpers_internal_test.go +++ b/core/services/ocr/helpers_internal_test.go @@ -3,8 +3,7 @@ package ocr import ( "testing" - "github.com/jmoiron/sqlx" - + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/logger" ) @@ -12,6 +11,6 @@ func (c *ConfigOverriderImpl) ExportedUpdateFlagsStatus() error { return c.updateFlagsStatus() } -func NewTestDB(t *testing.T, sqldb *sqlx.DB, oracleSpecID int32) *db { - return NewDB(sqldb, oracleSpecID, logger.TestLogger(t)) +func NewTestDB(t *testing.T, ds sqlutil.DataSource, oracleSpecID int32) *db { + return NewDB(ds, oracleSpecID, logger.TestLogger(t)) } diff --git a/core/services/ocr/mocks/ocr_contract_tracker_db.go b/core/services/ocr/mocks/ocr_contract_tracker_db.go index ed47d87cd1e..d3dcce2641b 100644 --- a/core/services/ocr/mocks/ocr_contract_tracker_db.go +++ b/core/services/ocr/mocks/ocr_contract_tracker_db.go @@ -5,6 +5,7 @@ package mocks import ( context "context" + ocr "github.com/smartcontractkit/chainlink/v2/core/services/ocr" mock "github.com/stretchr/testify/mock" offchainaggregator "github.com/smartcontractkit/libocr/gethwrappers/offchainaggregator" @@ -45,17 +46,17 @@ func (_m *OCRContractTrackerDB) LoadLatestRoundRequested(ctx context.Context) (o return r0, r1 } -// SaveLatestRoundRequested provides a mock function with given fields: ctx, tx, rr -func (_m *OCRContractTrackerDB) SaveLatestRoundRequested(ctx context.Context, tx sqlutil.DataSource, rr offchainaggregator.OffchainAggregatorRoundRequested) error { - ret := _m.Called(ctx, tx, rr) +// SaveLatestRoundRequested provides a mock function with given fields: ctx, rr +func (_m *OCRContractTrackerDB) SaveLatestRoundRequested(ctx context.Context, rr offchainaggregator.OffchainAggregatorRoundRequested) error { + ret := _m.Called(ctx, rr) if len(ret) == 0 { panic("no return value specified for SaveLatestRoundRequested") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, sqlutil.DataSource, offchainaggregator.OffchainAggregatorRoundRequested) error); ok { - r0 = rf(ctx, tx, rr) + if rf, ok := ret.Get(0).(func(context.Context, offchainaggregator.OffchainAggregatorRoundRequested) error); ok { + r0 = rf(ctx, rr) } else { r0 = ret.Error(0) } @@ -63,6 +64,26 @@ func (_m *OCRContractTrackerDB) SaveLatestRoundRequested(ctx context.Context, tx return r0 } +// WithDataSource provides a mock function with given fields: _a0 +func (_m *OCRContractTrackerDB) WithDataSource(_a0 sqlutil.DataSource) ocr.OCRContractTrackerDB { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for WithDataSource") + } + + var r0 ocr.OCRContractTrackerDB + if rf, ok := ret.Get(0).(func(sqlutil.DataSource) ocr.OCRContractTrackerDB); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(ocr.OCRContractTrackerDB) + } + } + + return r0 +} + // NewOCRContractTrackerDB creates a new instance of OCRContractTrackerDB. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewOCRContractTrackerDB(t interface { diff --git a/core/services/ocr2/database.go b/core/services/ocr2/database.go index 1d449047578..83ee3240a46 100644 --- a/core/services/ocr2/database.go +++ b/core/services/ocr2/database.go @@ -6,18 +6,17 @@ import ( "encoding/binary" "time" - "github.com/jmoiron/sqlx" "github.com/lib/pq" "github.com/pkg/errors" ocrcommon "github.com/smartcontractkit/libocr/commontypes" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) type db struct { - q pg.Q + ds sqlutil.DataSource oracleSpecID int32 pluginID int32 lggr logger.SugaredLogger @@ -28,11 +27,9 @@ var ( ) // NewDB returns a new DB scoped to this oracleSpecID -func NewDB(sqlxDB *sqlx.DB, oracleSpecID int32, pluginID int32, lggr logger.Logger, cfg pg.QConfig) *db { - namedLogger := lggr.Named("OCR2.DB") - +func NewDB(ds sqlutil.DataSource, oracleSpecID int32, pluginID int32, lggr logger.Logger) *db { return &db{ - q: pg.NewQ(sqlxDB, namedLogger, cfg), + ds: ds, oracleSpecID: oracleSpecID, pluginID: pluginID, lggr: logger.Sugared(lggr), @@ -51,7 +48,7 @@ func (d *db) ReadState(ctx context.Context, cd ocrtypes.ConfigDigest) (ps *ocrty var tmp []int64 var highestSentEpochTmp int64 - err = d.q.QueryRowxContext(ctx, stmt, d.oracleSpecID, cd).Scan(&ps.Epoch, &highestSentEpochTmp, pq.Array(&tmp)) + err = d.ds.QueryRowxContext(ctx, stmt, d.oracleSpecID, cd).Scan(&ps.Epoch, &highestSentEpochTmp, pq.Array(&tmp)) if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -98,7 +95,9 @@ func (d *db) WriteState(ctx context.Context, cd ocrtypes.ConfigDigest, state ocr NOW() )` - _, err := d.q.WithOpts(pg.WithLongQueryTimeout()).ExecContext( + ctx, cancel := context.WithTimeout(sqlutil.WithoutDefaultTimeout(ctx), time.Minute) + defer cancel() + _, err := d.ds.ExecContext( ctx, stmt, d.oracleSpecID, cd, state.Epoch, state.HighestSentEpoch, pq.Array(&highestReceivedEpoch), ) @@ -126,7 +125,7 @@ func (d *db) ReadConfig(ctx context.Context) (c *ocrtypes.ContractConfig, err er signers := [][]byte{} transmitters := [][]byte{} - err = d.q.QueryRowx(stmt, d.oracleSpecID, d.pluginID).Scan( + err = d.ds.QueryRowxContext(ctx, stmt, d.oracleSpecID, d.pluginID).Scan( &digest, &c.ConfigCount, (*pq.ByteaArray)(&signers), @@ -191,7 +190,7 @@ func (d *db) WriteConfig(ctx context.Context, c ocrtypes.ContractConfig) error { offchain_config = EXCLUDED.offchain_config, updated_at = NOW() ` - _, err := d.q.ExecContext(ctx, stmt, + _, err := d.ds.ExecContext(ctx, stmt, d.oracleSpecID, d.pluginID, c.ConfigDigest, @@ -252,7 +251,7 @@ func (d *db) StorePendingTransmission(ctx context.Context, t ocrtypes.ReportTime updated_at = NOW() ` - _, err := d.q.ExecContext(ctx, stmt, + _, err := d.ds.ExecContext(ctx, stmt, d.oracleSpecID, digest, t.Epoch, @@ -279,7 +278,7 @@ func (d *db) PendingTransmissionsWithConfigDigest(ctx context.Context, cd ocrtyp FROM ocr2_pending_transmissions WHERE ocr2_oracle_spec_id = $1 AND config_digest = $2 ` - rows, err := d.q.QueryxContext(ctx, stmt, d.oracleSpecID, cd) //nolint sqlclosecheck false positive + rows, err := d.ds.QueryxContext(ctx, stmt, d.oracleSpecID, cd) //nolint sqlclosecheck false positive if err != nil { return nil, errors.Wrap(err, "PendingTransmissionsWithConfigDigest failed to query rows") } @@ -325,7 +324,9 @@ func (d *db) PendingTransmissionsWithConfigDigest(ctx context.Context, cd ocrtyp } func (d *db) DeletePendingTransmission(ctx context.Context, t ocrtypes.ReportTimestamp) (err error) { - _, err = d.q.WithOpts(pg.WithLongQueryTimeout()).ExecContext(ctx, ` + ctx, cancel := context.WithTimeout(sqlutil.WithoutDefaultTimeout(ctx), time.Minute) + defer cancel() + _, err = d.ds.ExecContext(ctx, ` DELETE FROM ocr2_pending_transmissions WHERE ocr2_oracle_spec_id = $1 AND config_digest = $2 AND epoch = $3 AND round = $4 `, d.oracleSpecID, t.ConfigDigest, t.Epoch, t.Round) @@ -336,7 +337,9 @@ WHERE ocr2_oracle_spec_id = $1 AND config_digest = $2 AND epoch = $3 AND round } func (d *db) DeletePendingTransmissionsOlderThan(ctx context.Context, t time.Time) (err error) { - _, err = d.q.WithOpts(pg.WithLongQueryTimeout()).ExecContext(ctx, ` + ctx, cancel := context.WithTimeout(sqlutil.WithoutDefaultTimeout(ctx), time.Minute) + defer cancel() + _, err = d.ds.ExecContext(ctx, ` DELETE FROM ocr2_pending_transmissions WHERE ocr2_oracle_spec_id = $1 AND time < $2 `, d.oracleSpecID, t) @@ -347,7 +350,7 @@ WHERE ocr2_oracle_spec_id = $1 AND time < $2 } func (d *db) ReadProtocolState(ctx context.Context, configDigest ocrtypes.ConfigDigest, key string) (value []byte, err error) { - err = d.q.GetContext(ctx, &value, ` + err = d.ds.GetContext(ctx, &value, ` SELECT value FROM ocr_protocol_states WHERE config_digest = $1 AND key = $2; `, configDigest, key) @@ -363,9 +366,9 @@ WHERE config_digest = $1 AND key = $2; func (d *db) WriteProtocolState(ctx context.Context, configDigest ocrtypes.ConfigDigest, key string, value []byte) (err error) { if value == nil { - _, err = d.q.ExecContext(ctx, `DELETE FROM ocr_protocol_states WHERE config_digest = $1 AND key = $2;`, configDigest, key) + _, err = d.ds.ExecContext(ctx, `DELETE FROM ocr_protocol_states WHERE config_digest = $1 AND key = $2;`, configDigest, key) } else { - _, err = d.q.ExecContext(ctx, ` + _, err = d.ds.ExecContext(ctx, ` INSERT INTO ocr_protocol_states (config_digest, key, value) VALUES ($1, $2, $3) ON CONFLICT (config_digest, key) DO UPDATE SET value = $3;`, configDigest, key, value) } diff --git a/core/services/ocr2/database_test.go b/core/services/ocr2/database_test.go index 6e4f8f5dd66..3e78249d087 100644 --- a/core/services/ocr2/database_test.go +++ b/core/services/ocr2/database_test.go @@ -16,7 +16,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" - "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" @@ -60,14 +59,13 @@ func Test_DB_ReadWriteState(t *testing.T) { sqlDB := setupDB(t) configDigest := testhelpers.MakeConfigDigest(t) - cfg := configtest.NewTestGeneralConfig(t) ethKeyStore := cltest.NewKeyStore(t, sqlDB).Eth() key, _ := cltest.MustInsertRandomKey(t, ethKeyStore) spec := MustInsertOCROracleSpec(t, sqlDB, key.EIP55Address) lggr := logger.TestLogger(t) t.Run("reads and writes state", func(t *testing.T) { - db := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr, cfg.Database()) + db := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr) state := ocrtypes.PersistentState{ Epoch: 1, HighestSentEpoch: 2, @@ -84,7 +82,7 @@ func Test_DB_ReadWriteState(t *testing.T) { }) t.Run("updates state", func(t *testing.T) { - db := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr, cfg.Database()) + db := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr) newState := ocrtypes.PersistentState{ Epoch: 2, HighestSentEpoch: 3, @@ -101,7 +99,7 @@ func Test_DB_ReadWriteState(t *testing.T) { }) t.Run("does not return result for wrong spec", func(t *testing.T) { - db := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr, cfg.Database()) + db := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr) state := ocrtypes.PersistentState{ Epoch: 3, HighestSentEpoch: 4, @@ -112,7 +110,7 @@ func Test_DB_ReadWriteState(t *testing.T) { require.NoError(t, err) // odb with different spec - db = ocr2.NewDB(sqlDB, -1, defaultPluginID, lggr, cfg.Database()) + db = ocr2.NewDB(sqlDB, -1, defaultPluginID, lggr) readState, err := db.ReadState(testutils.Context(t), configDigest) require.NoError(t, err) @@ -121,7 +119,7 @@ func Test_DB_ReadWriteState(t *testing.T) { }) t.Run("does not return result for wrong config digest", func(t *testing.T) { - db := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr, cfg.Database()) + db := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr) state := ocrtypes.PersistentState{ Epoch: 4, HighestSentEpoch: 5, @@ -151,14 +149,13 @@ func Test_DB_ReadWriteConfig(t *testing.T) { OffchainConfigVersion: 111, OffchainConfig: []byte{0x03, 0x04}, } - cfg := configtest.NewTestGeneralConfig(t) ethKeyStore := cltest.NewKeyStore(t, sqlDB).Eth() key, _ := cltest.MustInsertRandomKey(t, ethKeyStore) spec := MustInsertOCROracleSpec(t, sqlDB, key.EIP55Address) lggr := logger.TestLogger(t) t.Run("reads and writes config", func(t *testing.T) { - db := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr, cfg.Database()) + db := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr) err := db.WriteConfig(testutils.Context(t), config) require.NoError(t, err) @@ -170,7 +167,7 @@ func Test_DB_ReadWriteConfig(t *testing.T) { }) t.Run("updates config", func(t *testing.T) { - db := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr, cfg.Database()) + db := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr) newConfig := ocrtypes.ContractConfig{ ConfigDigest: testhelpers.MakeConfigDigest(t), @@ -188,12 +185,12 @@ func Test_DB_ReadWriteConfig(t *testing.T) { }) t.Run("does not return result for wrong spec", func(t *testing.T) { - db := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr, cfg.Database()) + db := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr) err := db.WriteConfig(testutils.Context(t), config) require.NoError(t, err) - db = ocr2.NewDB(sqlDB, -1, defaultPluginID, lggr, cfg.Database()) + db = ocr2.NewDB(sqlDB, -1, defaultPluginID, lggr) readConfig, err := db.ReadConfig(testutils.Context(t)) require.NoError(t, err) @@ -203,8 +200,8 @@ func Test_DB_ReadWriteConfig(t *testing.T) { t.Run("reads and writes config for multiple plugins", func(t *testing.T) { otherPluginID := int32(2) - db1 := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr, cfg.Database()) - db2 := ocr2.NewDB(sqlDB, spec.ID, otherPluginID, lggr, cfg.Database()) + db1 := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr) + db2 := ocr2.NewDB(sqlDB, spec.ID, otherPluginID, lggr) otherConfig := ocrtypes.ContractConfig{ ConfigDigest: testhelpers.MakeConfigDigest(t), @@ -238,15 +235,14 @@ func assertPendingTransmissionEqual(t *testing.T, pt1, pt2 ocrtypes.PendingTrans func Test_DB_PendingTransmissions(t *testing.T) { sqlDB := setupDB(t) - cfg := configtest.NewTestGeneralConfig(t) ethKeyStore := cltest.NewKeyStore(t, sqlDB).Eth() key, _ := cltest.MustInsertRandomKey(t, ethKeyStore) lggr := logger.TestLogger(t) spec := MustInsertOCROracleSpec(t, sqlDB, key.EIP55Address) spec2 := MustInsertOCROracleSpec(t, sqlDB, key.EIP55Address) - db := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr, cfg.Database()) - db2 := ocr2.NewDB(sqlDB, spec2.ID, defaultPluginID, lggr, cfg.Database()) + db := ocr2.NewDB(sqlDB, spec.ID, defaultPluginID, lggr) + db2 := ocr2.NewDB(sqlDB, spec2.ID, defaultPluginID, lggr) configDigest := testhelpers.MakeConfigDigest(t) k := ocrtypes.ReportTimestamp{ @@ -436,7 +432,7 @@ func Test_DB_PendingTransmissions(t *testing.T) { require.Len(t, m, 1) // Didn't affect other oracleSpecIDs - db = ocr2.NewDB(sqlDB, spec2.ID, defaultPluginID, lggr, cfg.Database()) + db = ocr2.NewDB(sqlDB, spec2.ID, defaultPluginID, lggr) m, err = db.PendingTransmissionsWithConfigDigest(testutils.Context(t), configDigest) require.NoError(t, err) require.Len(t, m, 1) @@ -446,10 +442,8 @@ func Test_DB_PendingTransmissions(t *testing.T) { func Test_DB_ReadWriteProtocolState(t *testing.T) { sqlDB := setupDB(t) - cfg := configtest.NewTestGeneralConfig(t) - lggr := logger.TestLogger(t) - db := ocr2.NewDB(sqlDB, 0, defaultPluginID, lggr, cfg.Database()) + db := ocr2.NewDB(sqlDB, 0, defaultPluginID, lggr) cd1 := testhelpers.MakeConfigDigest(t) cd2 := testhelpers.MakeConfigDigest(t) ctx := testutils.Context(t) diff --git a/core/services/ocr2/delegate.go b/core/services/ocr2/delegate.go index 7a2ef532fd2..6c76b0ff9c9 100644 --- a/core/services/ocr2/delegate.go +++ b/core/services/ocr2/delegate.go @@ -9,7 +9,6 @@ import ( "time" "github.com/ethereum/go-ethereum/common" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "google.golang.org/grpc" @@ -70,7 +69,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/promwrapper" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/validate" "github.com/smartcontractkit/chainlink/v2/core/services/ocrcommon" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/services/relay" evmrelay "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm" @@ -111,7 +109,6 @@ type RelayGetter interface { Get(id types.RelayID) (loop.Relayer, error) } type Delegate struct { - db *sqlx.DB // legacy: prefer to use ds instead ds sqlutil.DataSource jobORM job.ORM bridgeORM bridges.ORM @@ -138,7 +135,6 @@ type DelegateConfig interface { plugins.RegistrarConfig OCR2() ocr2Config JobPipeline() jobPipelineConfig - Database() pg.QConfig Insecure() insecureConfig Mercury() coreconfig.Mercury Threshold() coreconfig.Threshold @@ -149,7 +145,6 @@ type delegateConfig struct { plugins.RegistrarConfig ocr2 ocr2Config jobPipeline jobPipelineConfig - database pg.QConfig insecure insecureConfig mercury mercuryConfig threshold thresholdConfig @@ -159,10 +154,6 @@ func (d *delegateConfig) JobPipeline() jobPipelineConfig { return d.jobPipeline } -func (d *delegateConfig) Database() pg.QConfig { - return d.database -} - func (d *delegateConfig) Insecure() insecureConfig { return d.insecure } @@ -212,12 +203,11 @@ type thresholdConfig interface { ThresholdKeyShare() string } -func NewDelegateConfig(ocr2Cfg ocr2Config, m coreconfig.Mercury, t coreconfig.Threshold, i insecureConfig, jp jobPipelineConfig, qconf pg.QConfig, pluginProcessCfg plugins.RegistrarConfig) DelegateConfig { +func NewDelegateConfig(ocr2Cfg ocr2Config, m coreconfig.Mercury, t coreconfig.Threshold, i insecureConfig, jp jobPipelineConfig, pluginProcessCfg plugins.RegistrarConfig) DelegateConfig { return &delegateConfig{ ocr2: ocr2Cfg, RegistrarConfig: pluginProcessCfg, jobPipeline: jp, - database: qconf, insecure: i, mercury: m, threshold: t, @@ -227,7 +217,6 @@ func NewDelegateConfig(ocr2Cfg ocr2Config, m coreconfig.Mercury, t coreconfig.Th var _ job.Delegate = (*Delegate)(nil) func NewDelegate( - db *sqlx.DB, ds sqlutil.DataSource, jobORM job.ORM, bridgeORM bridges.ORM, @@ -248,7 +237,6 @@ func NewDelegate( capabilitiesRegistry core.CapabilitiesRegistry, ) *Delegate { return &Delegate{ - db: db, ds: ds, jobORM: jobORM, bridgeORM: bridgeORM, @@ -382,7 +370,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi } lggr := logger.Sugared(d.lggr.Named(jb.ExternalJobID.String()).With(lggrCtx.Args()...)) - kvStore := job.NewKVStore(jb.ID, d.db, d.cfg.Database(), lggr) + kvStore := job.NewKVStore(jb.ID, d.ds, lggr) rid, err := spec.RelayID() if err != nil { @@ -404,7 +392,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi spec.RelayConfig["effectiveTransmitterID"] = effectiveTransmitterID spec.RelayConfig.ApplyDefaultsOCR2(d.cfg.OCR2()) - ocrDB := NewDB(d.db, spec.ID, 0, lggr, d.cfg.Database()) + ocrDB := NewDB(d.ds, spec.ID, 0, lggr) if d.peerWrapper == nil { return nil, errors.New("cannot setup OCR2 job service, libp2p peer was missing") } else if !d.peerWrapper.IsStarted() { @@ -412,7 +400,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi } ocrLogger := commonlogger.NewOCRWrapper(lggr, d.cfg.OCR2().TraceLogging(), func(msg string) { - lggr.ErrorIf(d.jobORM.RecordError(jb.ID, msg), "unable to record error") + lggr.ErrorIf(d.jobORM.RecordError(ctx, jb.ID, msg), "unable to record error") }) lc, err := validate.ToLocalConfig(d.cfg.OCR2(), d.cfg.Insecure(), *spec) @@ -475,8 +463,8 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi thresholdPluginId s4PluginId ) - thresholdPluginDB := NewDB(d.db, spec.ID, thresholdPluginId, lggr, d.cfg.Database()) - s4PluginDB := NewDB(d.db, spec.ID, s4PluginId, lggr, d.cfg.Database()) + thresholdPluginDB := NewDB(d.ds, spec.ID, thresholdPluginId, lggr) + s4PluginDB := NewDB(d.ds, spec.ID, s4PluginId, lggr) return d.newServicesOCR2Functions(ctx, lggr, jb, bootstrapPeers, kb, ocrDB, thresholdPluginDB, s4PluginDB, lc, ocrLogger) case types.GenericPlugin: @@ -904,10 +892,10 @@ func (d *Delegate) newServicesLLO( kr := llo.NewOnchainKeyring(lggr, kbm) cfg := llo.DelegateConfig{ - Logger: lggr, - Queryer: pg.NewQ(d.db, lggr, d.cfg.Database()), - Runner: d.pipelineRunner, - Registry: d.streamRegistry, + Logger: lggr, + DataSource: d.ds, + Runner: d.pipelineRunner, + Registry: d.streamRegistry, JobName: jb.Name, @@ -1010,7 +998,7 @@ func (d *Delegate) newServicesDKG( if err2 != nil { return nil, fmt.Errorf("DKG services: failed to get chain %s: %w", rid.ChainID, err2) } - ocr2vrfRelayer := evmrelay.NewOCR2VRFRelayer(d.db, chain, lggr.Named("OCR2VRFRelayer"), d.ethKs) + ocr2vrfRelayer := evmrelay.NewOCR2VRFRelayer(chain, lggr.Named("OCR2VRFRelayer"), d.ethKs) dkgProvider, err2 := ocr2vrfRelayer.NewDKGProvider( types.RelayArgs{ ExternalJobID: jb.ExternalJobID, @@ -1041,20 +1029,7 @@ func (d *Delegate) newServicesDKG( OnchainKeyring: kb, MetricsRegisterer: prometheus.WrapRegistererWith(map[string]string{"job_name": jb.Name.ValueOrZero()}, prometheus.DefaultRegisterer), } - return dkg.NewDKGServices( - jb, - dkgProvider, - lggr, - ocrLogger, - d.dkgSignKs, - d.dkgEncryptKs, - chain.Client(), - oracleArgsNoPlugin, - d.db, - d.cfg.Database(), - chain.ID(), - spec.Relay, - ) + return dkg.NewDKGServices(jb, dkgProvider, lggr, ocrLogger, d.dkgSignKs, d.dkgEncryptKs, chain.Client(), oracleArgsNoPlugin, d.ds, chain.ID(), spec.Relay) } func (d *Delegate) newServicesOCR2VRF( @@ -1094,7 +1069,7 @@ func (d *Delegate) newServicesOCR2VRF( return nil, errors.Wrap(err2, "validate ocr2vrf plugin config") } - ocr2vrfRelayer := evmrelay.NewOCR2VRFRelayer(d.db, chain, lggr.Named("OCR2VRFRelayer"), d.ethKs) + ocr2vrfRelayer := evmrelay.NewOCR2VRFRelayer(chain, lggr.Named("OCR2VRFRelayer"), d.ethKs) transmitterID := spec.TransmitterID.String vrfProvider, err2 := ocr2vrfRelayer.NewOCR2VRFProvider( @@ -1179,11 +1154,11 @@ func (d *Delegate) newServicesOCR2VRF( ) vrfLogger := commonlogger.NewOCRWrapper(l.With( "vrfContractID", spec.ContractID), d.cfg.OCR2().TraceLogging(), func(msg string) { - lggr.ErrorIf(d.jobORM.RecordError(jb.ID, msg), "unable to record error") + lggr.ErrorIf(d.jobORM.RecordError(ctx, jb.ID, msg), "unable to record error") }) dkgLogger := commonlogger.NewOCRWrapper(l.With( "dkgContractID", cfg.DKGContractAddress), d.cfg.OCR2().TraceLogging(), func(msg string) { - lggr.ErrorIf(d.jobORM.RecordError(jb.ID, msg), "unable to record error") + lggr.ErrorIf(d.jobORM.RecordError(ctx, jb.ID, msg), "unable to record error") }) dkgReportingPluginFactoryDecorator := func(wrapped ocrtypes.ReportingPluginFactory) ocrtypes.ReportingPluginFactory { return promwrapper.NewPromFactory(wrapped, "DKG", string(types.NetworkEVM), chain.ID()) @@ -1222,7 +1197,7 @@ func (d *Delegate) newServicesOCR2VRF( KeyID: keyID, DKGReportingPluginFactoryDecorator: dkgReportingPluginFactoryDecorator, VRFReportingPluginFactoryDecorator: vrfReportingPluginFactoryDecorator, - DKGSharePersistence: persistence.NewShareDB(d.db, lggr.Named("DKGShareDB"), d.cfg.Database(), chain.ID(), spec.Relay), + DKGSharePersistence: persistence.NewShareDB(d.ds, lggr.Named("DKGShareDB"), chain.ID(), spec.Relay), }) if err2 != nil { return nil, errors.Wrap(err2, "new ocr2vrf") @@ -1449,7 +1424,7 @@ func (d *Delegate) newServicesOCR2Keepers20( return nil, fmt.Errorf("keepers2.0 services: failed to get chain (%s): %w", rid.ChainID, err2) } - keeperProvider, rgstry, encoder, logProvider, err2 := ocr2keeper.EVMDependencies20(ctx, jb, d.db, lggr, chain, d.ethKs, d.cfg.Database()) + keeperProvider, rgstry, encoder, logProvider, err2 := ocr2keeper.EVMDependencies20(ctx, jb, d.ds, lggr, chain, d.ethKs) if err2 != nil { return nil, errors.Wrap(err2, "could not build dependencies for ocr2 keepers") } @@ -1699,11 +1674,11 @@ func (d *Delegate) newServicesOCR2Functions( // errorLog implements [loop.ErrorLog] type errorLog struct { jobID int32 - recordError func(jobID int32, description string, qopts ...pg.QOpt) error + recordError func(ctx context.Context, jobID int32, description string) error } func (l *errorLog) SaveError(ctx context.Context, msg string) error { - return l.recordError(l.jobID, msg) + return l.recordError(ctx, l.jobID, msg) } type logWriter struct { diff --git a/core/services/ocr2/plugins/dkg/persistence/db.go b/core/services/ocr2/plugins/dkg/persistence/db.go index b8ecfbaceb4..07dad494ed7 100644 --- a/core/services/ocr2/plugins/dkg/persistence/db.go +++ b/core/services/ocr2/plugins/dkg/persistence/db.go @@ -16,11 +16,11 @@ import ( ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" ocr2vrftypes "github.com/smartcontractkit/chainlink-vrf/types" "github.com/smartcontractkit/chainlink-vrf/types/hash" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) var ( @@ -52,16 +52,16 @@ var ( ) type shareDB struct { - q pg.Q + ds sqlutil.DataSource lggr logger.Logger chainID *big.Int chainType string } // NewShareDB creates a new DKG share database. -func NewShareDB(db *sqlx.DB, lggr logger.Logger, cfg pg.QConfig, chainID *big.Int, chainType string) ocr2vrftypes.DKGSharePersistence { +func NewShareDB(ds sqlutil.DataSource, lggr logger.Logger, chainID *big.Int, chainType string) ocr2vrftypes.DKGSharePersistence { return &shareDB{ - q: pg.NewQ(db, lggr, cfg), + ds: ds, lggr: lggr, chainID: chainID, chainType: chainType, @@ -134,13 +134,13 @@ func (s *shareDB) WriteShareRecords( // Always upsert because we want the number of rows in the table to match // the number of members of the committee. - query := ` + _, err := s.ds.NamedExecContext(ctx, ` INSERT INTO dkg_shares (config_digest, key_id, dealer, marshaled_share_record, record_hash) VALUES (:config_digest, :key_id, :dealer, :marshaled_share_record, :record_hash) ON CONFLICT ON CONSTRAINT dkg_shares_pkey DO UPDATE SET marshaled_share_record = EXCLUDED.marshaled_share_record, record_hash = EXCLUDED.record_hash -` - return s.q.ExecQNamed(query, named[:]) +`, named[:]) + return err } // ReadShareRecords retrieves any share records in the database that correspond @@ -152,6 +152,7 @@ func (s *shareDB) ReadShareRecords( retrievedShares []ocr2vrftypes.PersistentShareSetRecord, err error, ) { + ctx := context.Background() //TODO https://smartcontract-it.atlassian.net/browse/BCF-2887 lggr := s.lggr.With( "configDigest", hexutil.Encode(cfgDgst[:]), "keyID", hexutil.Encode(keyID[:])) @@ -177,9 +178,9 @@ WHERE config_digest = :config_digest if err != nil { return nil, errors.Wrap(err, "sqlx Named") } - query = s.q.Rebind(query) + query = s.ds.Rebind(query) var dkgShares []dkgShare - err = s.q.Select(&dkgShares, query, args...) + err = s.ds.SelectContext(ctx, &dkgShares, query, args...) if errors.Is(err, sql.ErrNoRows) { return nil, nil } diff --git a/core/services/ocr2/plugins/dkg/persistence/db_test.go b/core/services/ocr2/plugins/dkg/persistence/db_test.go index a3949ea70dc..dbc400a3468 100644 --- a/core/services/ocr2/plugins/dkg/persistence/db_test.go +++ b/core/services/ocr2/plugins/dkg/persistence/db_test.go @@ -22,7 +22,7 @@ import ( func setup(t testing.TB) (ocr2vrftypes.DKGSharePersistence, *sqlx.DB) { db := pgtest.NewSqlxDB(t) lggr := logger.TestLogger(t) - return NewShareDB(db, lggr, pgtest.NewQConfig(true), big.NewInt(1337), types.NetworkEVM), db + return NewShareDB(db, lggr, big.NewInt(1337), types.NetworkEVM), db } func TestShareDB_WriteShareRecords(t *testing.T) { diff --git a/core/services/ocr2/plugins/dkg/plugin.go b/core/services/ocr2/plugins/dkg/plugin.go index 65216ac1115..0364c5d9ab0 100644 --- a/core/services/ocr2/plugins/dkg/plugin.go +++ b/core/services/ocr2/plugins/dkg/plugin.go @@ -6,12 +6,12 @@ import ( "fmt" "math/big" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "github.com/smartcontractkit/libocr/commontypes" libocr2 "github.com/smartcontractkit/libocr/offchainreporting2plus" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-vrf/altbn_128" "github.com/smartcontractkit/chainlink-vrf/dkg" @@ -21,7 +21,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/keystore" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/dkg/config" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/dkg/persistence" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" evmrelay "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm" ) @@ -34,8 +33,7 @@ func NewDKGServices( dkgEncryptKs keystore.DKGEncrypt, ethClient evmclient.Client, oracleArgsNoPlugin libocr2.OCR2OracleArgs, - db *sqlx.DB, - qConfig pg.QConfig, + ds sqlutil.DataSource, chainID *big.Int, network string, ) ([]job.ServiceCtx, error) { @@ -68,7 +66,7 @@ func NewDKGServices( if err != nil { return nil, errors.Wrap(err, "decode key ID") } - shareDB := persistence.NewShareDB(db, lggr.Named("DKGShareDB"), qConfig, chainID, network) + shareDB := persistence.NewShareDB(ds, lggr.Named("DKGShareDB"), chainID, network) oracleArgsNoPlugin.ReportingPluginFactory = dkg.NewReportingPluginFactory( encryptKey.KyberScalar(), signKey.KyberScalar(), diff --git a/core/services/ocr2/plugins/generic/pipeline_runner_adapter_test.go b/core/services/ocr2/plugins/generic/pipeline_runner_adapter_test.go index 0d6bc02e198..a33ab2a1bd2 100644 --- a/core/services/ocr2/plugins/generic/pipeline_runner_adapter_test.go +++ b/core/services/ocr2/plugins/generic/pipeline_runner_adapter_test.go @@ -47,7 +47,7 @@ func TestAdapter_Integration(t *testing.T) { keystore := keystore.NewInMemory(db, utils.FastScryptParams, logger) pipelineORM := pipeline.NewORM(db, logger, cfg.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) - jobORM := job.NewORM(db, pipelineORM, bridgesORM, keystore, logger, cfg.Database()) + jobORM := job.NewORM(db, pipelineORM, bridgesORM, keystore, logger) pr := pipeline.NewRunner( pipelineORM, bridgesORM, @@ -75,7 +75,7 @@ func TestAdapter_Integration(t *testing.T) { jb.Name = null.StringFrom("Job 1") jb.OCR2OracleSpec.TransmitterID = null.StringFrom(address.String()) jb.OCR2OracleSpec.PluginConfig["juelsPerFeeCoinSource"] = juelsPerFeeCoinSource - err = jobORM.CreateJob(&jb) + err = jobORM.CreateJob(ctx, &jb) require.NoError(t, err) pra := generic.NewPipelineRunnerAdapter(logger, jb, pr) results, err := pra.ExecuteRun(testutils.Context(t), spec, core.Vars{Vars: map[string]interface{}{"val": 1}}, core.Options{}) diff --git a/core/services/ocr2/plugins/ocr2keeper/util.go b/core/services/ocr2/plugins/ocr2keeper/util.go index 35bd62eeed8..339d8a89dfb 100644 --- a/core/services/ocr2/plugins/ocr2keeper/util.go +++ b/core/services/ocr2/plugins/ocr2keeper/util.go @@ -4,12 +4,11 @@ import ( "context" "fmt" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" evmrelay "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm" - "github.com/jmoiron/sqlx" - ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" ocr2keepers20 "github.com/smartcontractkit/chainlink-automation/pkg/v2" @@ -25,7 +24,6 @@ import ( evmregistry20 "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ocr2keeper/evmregistry/v20" evmregistry21 "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21" evmregistry21transmit "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/transmit" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) type Encoder20 interface { @@ -44,9 +42,9 @@ var ( ErrNoChainFromSpec = fmt.Errorf("could not create chain from spec") ) -func EVMProvider(db *sqlx.DB, chain legacyevm.Chain, lggr logger.Logger, spec job.Job, ethKeystore keystore.Eth, dbCfg pg.QConfig) (evmrelay.OCR2KeeperProvider, error) { +func EVMProvider(ds sqlutil.DataSource, chain legacyevm.Chain, lggr logger.Logger, spec job.Job, ethKeystore keystore.Eth) (evmrelay.OCR2KeeperProvider, error) { oSpec := spec.OCR2OracleSpec - ocr2keeperRelayer := evmrelay.NewOCR2KeeperRelayer(db, chain, lggr.Named("OCR2KeeperRelayer"), ethKeystore, dbCfg) + ocr2keeperRelayer := evmrelay.NewOCR2KeeperRelayer(ds, chain, lggr.Named("OCR2KeeperRelayer"), ethKeystore) keeperProvider, err := ocr2keeperRelayer.NewOCR2KeeperProvider( types.RelayArgs{ @@ -70,11 +68,10 @@ func EVMProvider(db *sqlx.DB, chain legacyevm.Chain, lggr logger.Logger, spec jo func EVMDependencies20( ctx context.Context, spec job.Job, - db *sqlx.DB, + ds sqlutil.DataSource, lggr logger.Logger, chain legacyevm.Chain, ethKeystore keystore.Eth, - dbCfg pg.QConfig, ) (evmrelay.OCR2KeeperProvider, *evmregistry20.EvmRegistry, Encoder20, *evmregistry20.LogProvider, error) { var err error @@ -82,7 +79,7 @@ func EVMDependencies20( var registry *evmregistry20.EvmRegistry // the provider will be returned as a dependency - if keeperProvider, err = EVMProvider(db, chain, lggr, spec, ethKeystore, dbCfg); err != nil { + if keeperProvider, err = EVMProvider(ds, chain, lggr, spec, ethKeystore); err != nil { return nil, nil, nil, nil, err } diff --git a/core/services/ocrbootstrap/database.go b/core/services/ocrbootstrap/database.go index 86ade39a05d..ef63b75dd39 100644 --- a/core/services/ocrbootstrap/database.go +++ b/core/services/ocrbootstrap/database.go @@ -8,11 +8,12 @@ import ( "github.com/pkg/errors" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/logger" ) type db struct { - *sql.DB + ds sqlutil.DataSource oracleSpecID int32 lggr logger.Logger } @@ -20,12 +21,12 @@ type db struct { var _ ocrtypes.ConfigDatabase = &db{} // NewDB returns a new DB scoped to this oracleSpecID -func NewDB(sqldb *sql.DB, bootstrapSpecID int32, lggr logger.Logger) *db { - return &db{sqldb, bootstrapSpecID, lggr} +func NewDB(ds sqlutil.DataSource, bootstrapSpecID int32, lggr logger.Logger) *db { + return &db{ds, bootstrapSpecID, lggr} } func (d *db) ReadConfig(ctx context.Context) (c *ocrtypes.ContractConfig, err error) { - q := d.QueryRowContext(ctx, ` + q := d.ds.QueryRowxContext(ctx, ` SELECT config_digest, config_count, @@ -82,7 +83,7 @@ func (d *db) WriteConfig(ctx context.Context, c ocrtypes.ContractConfig) error { for _, s := range c.Signers { signers = append(signers, []byte(s)) } - _, err := d.ExecContext(ctx, ` + _, err := d.ds.ExecContext(ctx, ` INSERT INTO bootstrap_contract_configs ( bootstrap_spec_id, config_digest, diff --git a/core/services/ocrbootstrap/database_test.go b/core/services/ocrbootstrap/database_test.go index e00e318c69c..eaad863c88b 100644 --- a/core/services/ocrbootstrap/database_test.go +++ b/core/services/ocrbootstrap/database_test.go @@ -52,7 +52,7 @@ func Test_DB_ReadWriteConfig(t *testing.T) { lggr := logger.TestLogger(t) t.Run("reads and writes config", func(t *testing.T) { - db := ocrbootstrap.NewDB(sqlDB.DB, spec.ID, lggr) + db := ocrbootstrap.NewDB(sqlDB, spec.ID, lggr) err := db.WriteConfig(testutils.Context(t), config) require.NoError(t, err) @@ -64,7 +64,7 @@ func Test_DB_ReadWriteConfig(t *testing.T) { }) t.Run("updates config", func(t *testing.T) { - db := ocrbootstrap.NewDB(sqlDB.DB, spec.ID, lggr) + db := ocrbootstrap.NewDB(sqlDB, spec.ID, lggr) newConfig := ocrtypes.ContractConfig{ ConfigDigest: testhelpers.MakeConfigDigest(t), @@ -82,12 +82,12 @@ func Test_DB_ReadWriteConfig(t *testing.T) { }) t.Run("does not return result for wrong spec", func(t *testing.T) { - db := ocrbootstrap.NewDB(sqlDB.DB, spec.ID, lggr) + db := ocrbootstrap.NewDB(sqlDB, spec.ID, lggr) err := db.WriteConfig(testutils.Context(t), config) require.NoError(t, err) - db = ocrbootstrap.NewDB(sqlDB.DB, -1, lggr) + db = ocrbootstrap.NewDB(sqlDB, -1, lggr) readConfig, err := db.ReadConfig(testutils.Context(t)) require.NoError(t, err) diff --git a/core/services/ocrbootstrap/delegate.go b/core/services/ocrbootstrap/delegate.go index 84f473088bb..0bb7a0ca2ba 100644 --- a/core/services/ocrbootstrap/delegate.go +++ b/core/services/ocrbootstrap/delegate.go @@ -7,12 +7,11 @@ import ( "github.com/pkg/errors" - "github.com/jmoiron/sqlx" - ocr "github.com/smartcontractkit/libocr/offchainreporting2plus" commonlogger "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/loop" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink/v2/core/logger" @@ -27,7 +26,7 @@ type RelayGetter interface { // Delegate creates Bootstrap jobs type Delegate struct { - db *sqlx.DB + ds sqlutil.DataSource jobORM job.ORM peerWrapper *ocrcommon.SingletonPeerWrapper ocr2Cfg validate.OCR2Config @@ -48,7 +47,7 @@ type relayConfig struct { // NewDelegateBootstrap creates a new Delegate func NewDelegateBootstrap( - db *sqlx.DB, + ds sqlutil.DataSource, jobORM job.ORM, peerWrapper *ocrcommon.SingletonPeerWrapper, lggr logger.Logger, @@ -57,7 +56,7 @@ func NewDelegateBootstrap( relayers RelayGetter, ) *Delegate { return &Delegate{ - db: db, + ds: ds, jobORM: jobORM, peerWrapper: peerWrapper, lggr: logger.Sugared(lggr), @@ -166,10 +165,10 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] bootstrapNodeArgs := ocr.BootstrapperArgs{ BootstrapperFactory: d.peerWrapper.Peer2, ContractConfigTracker: configProvider.ContractConfigTracker(), - Database: NewDB(d.db.DB, spec.ID, lggr), + Database: NewDB(d.ds, spec.ID, lggr), LocalConfig: lc, Logger: commonlogger.NewOCRWrapper(lggr.Named("OCRBootstrap"), d.ocr2Cfg.TraceLogging(), func(msg string) { - logger.Sugared(lggr).ErrorIf(d.jobORM.RecordError(jb.ID, msg), "unable to record error") + logger.Sugared(lggr).ErrorIf(d.jobORM.RecordError(ctx, jb.ID, msg), "unable to record error") }), OffchainConfigDigester: configProvider.OffchainConfigDigester(), } diff --git a/core/services/ocrcommon/discoverer_database.go b/core/services/ocrcommon/discoverer_database.go index 9413b11ad07..ea75f9e6d21 100644 --- a/core/services/ocrcommon/discoverer_database.go +++ b/core/services/ocrcommon/discoverer_database.go @@ -2,25 +2,26 @@ package ocrcommon import ( "context" - "database/sql" "github.com/lib/pq" "github.com/pkg/errors" "go.uber.org/multierr" ocrnetworking "github.com/smartcontractkit/libocr/networking/types" + + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" ) var _ ocrnetworking.DiscovererDatabase = &DiscovererDatabase{} type DiscovererDatabase struct { - db *sql.DB + ds sqlutil.DataSource peerID string } -func NewDiscovererDatabase(db *sql.DB, peerID string) *DiscovererDatabase { +func NewDiscovererDatabase(ds sqlutil.DataSource, peerID string) *DiscovererDatabase { return &DiscovererDatabase{ - db, + ds, peerID, } } @@ -28,7 +29,7 @@ func NewDiscovererDatabase(db *sql.DB, peerID string) *DiscovererDatabase { // StoreAnnouncement has key-value-store semantics and stores a peerID (key) and an associated serialized // announcement (value). func (d *DiscovererDatabase) StoreAnnouncement(ctx context.Context, peerID string, ann []byte) error { - _, err := d.db.ExecContext(ctx, ` + _, err := d.ds.ExecContext(ctx, ` INSERT INTO ocr_discoverer_announcements (local_peer_id, remote_peer_id, ann, created_at, updated_at) VALUES ($1,$2,$3,NOW(),NOW()) ON CONFLICT (local_peer_id, remote_peer_id) DO UPDATE SET ann = EXCLUDED.ann, @@ -40,7 +41,7 @@ updated_at = EXCLUDED.updated_at // ReadAnnouncements returns one serialized announcement (if available) for each of the peerIDs in the form of a map // keyed by each announcement's corresponding peer ID. func (d *DiscovererDatabase) ReadAnnouncements(ctx context.Context, peerIDs []string) (results map[string][]byte, err error) { - rows, err := d.db.QueryContext(ctx, ` + rows, err := d.ds.QueryContext(ctx, ` SELECT remote_peer_id, ann FROM ocr_discoverer_announcements WHERE remote_peer_id = ANY($1) AND local_peer_id = $2`, pq.Array(peerIDs), d.peerID) if err != nil { return nil, errors.Wrap(err, "DiscovererDatabase failed to ReadAnnouncements") diff --git a/core/services/ocrcommon/discoverer_database_test.go b/core/services/ocrcommon/discoverer_database_test.go index b7a79e92bce..23d5ad661a4 100644 --- a/core/services/ocrcommon/discoverer_database_test.go +++ b/core/services/ocrcommon/discoverer_database_test.go @@ -16,7 +16,7 @@ import ( ) func Test_DiscovererDatabase(t *testing.T) { - db := pgtest.NewSqlDB(t) + db := pgtest.NewSqlxDB(t) localPeerID1 := mustRandomP2PPeerID(t) localPeerID2 := mustRandomP2PPeerID(t) diff --git a/core/services/ocrcommon/peer_wrapper.go b/core/services/ocrcommon/peer_wrapper.go index 02bdd9cee7d..97c429f9a5f 100644 --- a/core/services/ocrcommon/peer_wrapper.go +++ b/core/services/ocrcommon/peer_wrapper.go @@ -4,7 +4,6 @@ import ( "context" "io" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" @@ -14,12 +13,12 @@ import ( commonlogger "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/config" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" "github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/p2pkey" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) type PeerWrapperOCRConfig interface { @@ -43,8 +42,7 @@ type ( keyStore keystore.Master p2pCfg config.P2P ocrCfg PeerWrapperOCRConfig - dbConfig pg.QConfig - db *sqlx.DB + ds sqlutil.DataSource lggr logger.Logger PeerID p2pkey.PeerID @@ -69,13 +67,12 @@ func ValidatePeerWrapperConfig(config config.P2P) error { // NewSingletonPeerWrapper creates a new peer based on the p2p keys in the keystore // It currently only supports one peerID/key // It should be fairly easy to modify it to support multiple peerIDs/keys using e.g. a map -func NewSingletonPeerWrapper(keyStore keystore.Master, p2pCfg config.P2P, ocrCfg PeerWrapperOCRConfig, dbConfig pg.QConfig, db *sqlx.DB, lggr logger.Logger) *SingletonPeerWrapper { +func NewSingletonPeerWrapper(keyStore keystore.Master, p2pCfg config.P2P, ocrCfg PeerWrapperOCRConfig, ds sqlutil.DataSource, lggr logger.Logger) *SingletonPeerWrapper { return &SingletonPeerWrapper{ keyStore: keyStore, p2pCfg: p2pCfg, ocrCfg: ocrCfg, - dbConfig: dbConfig, - db: db, + ds: ds, lggr: lggr.Named("SingletonPeerWrapper"), } } @@ -120,7 +117,7 @@ func (p *SingletonPeerWrapper) peerConfig() (ocrnetworking.PeerConfig, error) { } p.PeerID = key.PeerID() - discovererDB := NewDiscovererDatabase(p.db.DB, p.PeerID.Raw()) + discovererDB := NewDiscovererDatabase(p.ds, p.PeerID.Raw()) config := p.p2pCfg peerConfig := ocrnetworking.PeerConfig{ diff --git a/core/services/ocrcommon/peer_wrapper_test.go b/core/services/ocrcommon/peer_wrapper_test.go index e87f211fd21..a47ed19ec56 100644 --- a/core/services/ocrcommon/peer_wrapper_test.go +++ b/core/services/ocrcommon/peer_wrapper_test.go @@ -34,7 +34,7 @@ func Test_SingletonPeerWrapper_Start(t *testing.T) { c.P2P.V2.Enabled = ptr(true) }) keyStore := cltest.NewKeyStore(t, db) - pw := ocrcommon.NewSingletonPeerWrapper(keyStore, cfg.P2P(), cfg.OCR(), cfg.Database(), db, logger.TestLogger(t)) + pw := ocrcommon.NewSingletonPeerWrapper(keyStore, cfg.P2P(), cfg.OCR(), db, logger.TestLogger(t)) require.Contains(t, pw.Start(testutils.Context(t)).Error(), "No P2P keys found in keystore. Peer wrapper will not be fully initialized") }) @@ -49,7 +49,7 @@ func Test_SingletonPeerWrapper_Start(t *testing.T) { c.P2P.V2.ListenAddresses = &[]string{fmt.Sprintf("127.0.0.1:%d", freeport.GetOne(t))} c.P2P.PeerID = ptr(k.PeerID()) }) - pw := ocrcommon.NewSingletonPeerWrapper(keyStore, cfg.P2P(), cfg.OCR(), cfg.Database(), db, logger.TestLogger(t)) + pw := ocrcommon.NewSingletonPeerWrapper(keyStore, cfg.P2P(), cfg.OCR(), db, logger.TestLogger(t)) servicetest.Run(t, pw) require.Equal(t, k.PeerID(), pw.PeerID) @@ -66,7 +66,7 @@ func Test_SingletonPeerWrapper_Start(t *testing.T) { _, err := keyStore.P2P().Create(ctx) require.NoError(t, err) - pw := ocrcommon.NewSingletonPeerWrapper(keyStore, cfg.P2P(), cfg.OCR(), cfg.Database(), db, logger.TestLogger(t)) + pw := ocrcommon.NewSingletonPeerWrapper(keyStore, cfg.P2P(), cfg.OCR(), db, logger.TestLogger(t)) require.Contains(t, pw.Start(testutils.Context(t)).Error(), "unable to find P2P key with id") }) @@ -83,7 +83,7 @@ func Test_SingletonPeerWrapper_Start(t *testing.T) { c.P2P.PeerID = ptr(k2.PeerID()) }) - pw := ocrcommon.NewSingletonPeerWrapper(keyStore, cfg.P2P(), cfg.OCR(), cfg.Database(), db, logger.TestLogger(t)) + pw := ocrcommon.NewSingletonPeerWrapper(keyStore, cfg.P2P(), cfg.OCR(), db, logger.TestLogger(t)) servicetest.Run(t, pw) require.Equal(t, k2.PeerID(), pw.PeerID) @@ -101,7 +101,7 @@ func Test_SingletonPeerWrapper_Start(t *testing.T) { _, err := keyStore.P2P().Create(ctx) require.NoError(t, err) - pw := ocrcommon.NewSingletonPeerWrapper(keyStore, cfg.P2P(), cfg.OCR(), cfg.Database(), db, logger.TestLogger(t)) + pw := ocrcommon.NewSingletonPeerWrapper(keyStore, cfg.P2P(), cfg.OCR(), db, logger.TestLogger(t)) require.Contains(t, pw.Start(testutils.Context(t)).Error(), "unable to find P2P key with id") }) @@ -131,7 +131,7 @@ func Test_SingletonPeerWrapper_Close(t *testing.T) { }) - pw := ocrcommon.NewSingletonPeerWrapper(keyStore, cfg.P2P(), cfg.OCR(), cfg.Database(), db, logger.TestLogger(t)) + pw := ocrcommon.NewSingletonPeerWrapper(keyStore, cfg.P2P(), cfg.OCR(), db, logger.TestLogger(t)) require.NoError(t, pw.Start(testutils.Context(t))) require.True(t, pw.IsStarted(), "Should have started successfully") @@ -139,7 +139,7 @@ func Test_SingletonPeerWrapper_Close(t *testing.T) { /* If peer is still stuck in listenLoop, we will get a bind error trying to start on the same port */ require.False(t, pw.IsStarted()) - pw = ocrcommon.NewSingletonPeerWrapper(keyStore, cfg.P2P(), cfg.OCR(), cfg.Database(), db, logger.TestLogger(t)) + pw = ocrcommon.NewSingletonPeerWrapper(keyStore, cfg.P2P(), cfg.OCR(), db, logger.TestLogger(t)) require.NoError(t, pw.Start(testutils.Context(t)), "Should have shut down gracefully, and be able to re-use same port") require.True(t, pw.IsStarted(), "Should have started successfully") require.NoError(t, pw.Close()) diff --git a/core/services/pg/connection.go b/core/services/pg/connection.go index 79d74c6e610..e8b6f3af429 100644 --- a/core/services/pg/connection.go +++ b/core/services/pg/connection.go @@ -1,6 +1,7 @@ package pg import ( + "database/sql" "fmt" "log" "os" @@ -18,6 +19,10 @@ import ( "github.com/XSAM/otelsql" ) +// NOTE: This is the default level in Postgres anyway, we just make it +// explicit here +const defaultIsolation = sql.LevelReadCommitted + var MinRequiredPGVersion = 110000 func init() { diff --git a/core/services/pg/lease_lock.go b/core/services/pg/lease_lock.go index 58ec2781245..885115d26c7 100644 --- a/core/services/pg/lease_lock.go +++ b/core/services/pg/lease_lock.go @@ -12,6 +12,7 @@ import ( "github.com/pkg/errors" "go.uber.org/multierr" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink/v2/core/logger" ) @@ -248,10 +249,12 @@ func (l *leaseLock) getLease(ctx context.Context, isInitial bool) (gotLease bool // NOTE: Uses database time for all calculations since it's conceivable // that node local times might be skewed compared to each other - err = sqlxTransactionQ(ctx, l.conn, l.logger, func(tx Queryer) error { + err = sqlutil.TransactConn(ctx, func(ds sqlutil.DataSource) sqlutil.DataSource { + return ds + }, l.conn, nil, func(tx sqlutil.DataSource) error { if isInitial { for _, query := range initialSQL { - if _, err = tx.Exec(query); err != nil { + if _, err = tx.ExecContext(ctx, query); err != nil { return errors.Wrap(err, "failed to create initial lease_lock table") } } @@ -259,7 +262,7 @@ func (l *leaseLock) getLease(ctx context.Context, isInitial bool) (gotLease bool // Upsert the lease_lock, only overwriting an existing one if the existing one has expired var res sql.Result - res, err = tx.Exec(` + res, err = tx.ExecContext(ctx, ` INSERT INTO lease_lock (client_id, expires_at) VALUES ($1, NOW()+$2::interval) ON CONFLICT ((client_id IS NOT NULL)) DO UPDATE SET client_id = EXCLUDED.client_id, expires_at = EXCLUDED.expires_at diff --git a/core/services/pg/q.go b/core/services/pg/q.go deleted file mode 100644 index 433023ddbc9..00000000000 --- a/core/services/pg/q.go +++ /dev/null @@ -1,384 +0,0 @@ -package pg - -import ( - "context" - "database/sql" - "fmt" - "strconv" - "strings" - "sync" - "time" - - "github.com/ethereum/go-ethereum/common" - "github.com/jmoiron/sqlx" - "github.com/lib/pq" - "github.com/pkg/errors" - - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" -) - -// QOpt is deprecated. Use [sqlutil.DataSource] with [sqlutil.QueryHook]s instead. -// -// QOpt pattern for ORM methods aims to clarify usage and remove some common footguns, notably: -// -// 1. It should be easy and obvious how to pass a parent context or a transaction into an ORM method -// 2. Simple queries should not be cluttered -// 3. It should have compile-time safety and be explicit -// 4. It should enforce some sort of context deadline on all queries by default -// 5. It should optimise for clarity and readability -// 6. It should mandate using sqlx everywhere, gorm is forbidden in new code -// 7. It should make using sqlx a little more convenient by wrapping certain methods -// 8. It allows easier mocking of DB calls (Queryer is an interface) -// -// The two main concepts introduced are: -// -// A `Q` struct that wraps a `sqlx.DB` or `sqlx.Tx` and implements the `pg.Queryer` interface. -// -// This struct is initialised with `QOpts` which define how the queryer should behave. `QOpts` can define a parent context, an open transaction or other options to configure the Queryer. -// -// A sample ORM method looks like this: -// -// func (o *orm) GetFoo(id int64, qopts ...pg.QOpt) (Foo, error) { -// q := pg.NewQ(q, qopts...) -// return q.Exec(...) -// } -// -// Now you can call it like so: -// -// orm.GetFoo(1) // will automatically have default query timeout context set -// orm.GetFoo(1, pg.WithParentCtx(ctx)) // will wrap the supplied parent context with the default query context -// orm.GetFoo(1, pg.WithQueryer(tx)) // allows to pass in a running transaction or anything else that implements Queryer -// orm.GetFoo(q, pg.WithQueryer(tx), pg.WithParentCtx(ctx)) // options can be combined -type QOpt func(*Q) - -// WithQueryer sets the queryer -func WithQueryer(queryer Queryer) QOpt { - return func(q *Q) { - if q.Queryer != nil { - panic("queryer already set") - } - q.Queryer = queryer - } -} - -// WithParentCtx sets or overwrites the parent ctx -func WithParentCtx(ctx context.Context) QOpt { - return func(q *Q) { - q.ParentCtx = ctx - } -} - -// If the parent has a timeout, just use that instead of DefaultTimeout -func WithParentCtxInheritTimeout(ctx context.Context) QOpt { - return func(q *Q) { - q.ParentCtx = ctx - deadline, ok := q.ParentCtx.Deadline() - if ok { - q.QueryTimeout = time.Until(deadline) - } - } -} - -// WithLongQueryTimeout prevents the usage of the `DefaultQueryTimeout` duration and uses `OneMinuteQueryTimeout` instead -// Some queries need to take longer when operating over big chunks of data, like deleting jobs, but we need to keep some upper bound timeout -func WithLongQueryTimeout() QOpt { - return func(q *Q) { - q.QueryTimeout = longQueryTimeout - } -} - -var _ Queryer = Q{} - -type QConfig interface { - LogSQL() bool - DefaultQueryTimeout() time.Duration -} - -// Q wraps an underlying queryer (either a *sqlx.DB or a *sqlx.Tx) -// -// It is designed to make handling *sqlx.Tx or *sqlx.DB a little bit safer by -// preventing footguns such as having no deadline on contexts. -// -// It also handles nesting transactions. -// -// It automatically adds the default context deadline to all non-context -// queries (if you _really_ want to issue a query without a context, use the -// underlying Queryer) -// -// This is not the prettiest construct but without macros its about the best we -// can do. -// Deprecated: Use a `sqlutil.DataSource` with `sqlutil.QueryHook`s instead -type Q struct { - Queryer - ParentCtx context.Context - db *sqlx.DB - logger logger.SugaredLogger - config QConfig - QueryTimeout time.Duration -} - -func NewQ(db *sqlx.DB, lggr logger.Logger, config QConfig, qopts ...QOpt) (q Q) { - for _, opt := range qopts { - opt(&q) - } - - q.db = db - // skip two levels since we use internal helpers and also want to point up the stack to the caller of the Q method. - q.logger = logger.Sugared(logger.Helper(lggr, 2)) - q.config = config - - if q.Queryer == nil { - q.Queryer = db - } - if q.ParentCtx == nil { - q.ParentCtx = context.Background() - } - if q.QueryTimeout <= 0 { - q.QueryTimeout = q.config.DefaultQueryTimeout() - } - return -} - -func (q Q) originalLogger() logger.Logger { - return logger.Helper(q.logger, -2) -} - -func PrepareQueryRowx(q Queryer, sql string, dest interface{}, arg interface{}) error { - stmt, err := q.PrepareNamed(sql) - if err != nil { - return errors.Wrap(err, "error preparing named statement") - } - defer stmt.Close() - return errors.Wrap(stmt.QueryRowx(arg).Scan(dest), "error querying row") -} - -func (q Q) WithOpts(qopts ...QOpt) Q { - return NewQ(q.db, q.originalLogger(), q.config, qopts...) -} - -func (q Q) Context() (context.Context, context.CancelFunc) { - return context.WithTimeout(q.ParentCtx, q.QueryTimeout) -} - -func (q Q) Transaction(fc func(q Queryer) error, txOpts ...TxOption) error { - ctx, cancel := q.Context() - defer cancel() - return SqlxTransaction(ctx, q.Queryer, q.originalLogger(), fc, txOpts...) -} - -// CAUTION: A subtle problem lurks here, because the following code is buggy: -// -// ctx, cancel := context.WithCancel(context.Background()) -// rows, err := db.QueryContext(ctx, "SELECT foo") -// cancel() // canceling here "poisons" the scan below -// for rows.Next() { -// rows.Scan(...) -// } -// -// We must cancel the context only after we have completely finished using the -// returned rows or result from the query/exec -// -// For this reasons, the following functions return a context.CancelFunc and it -// is up to the caller to ensure that cancel is called after it has finished -// -// Generally speaking, it makes more sense to use Get/Select in most cases, -// which avoids this problem -func (q Q) ExecQIter(query string, args ...interface{}) (sql.Result, context.CancelFunc, error) { - ctx, cancel := q.Context() - - ql := q.newQueryLogger(query, args) - ql.logSqlQuery() - defer ql.postSqlLog(ctx, time.Now()) - - res, err := q.Queryer.ExecContext(ctx, query, args...) - return res, cancel, ql.withLogError(err) -} -func (q Q) ExecQWithRowsAffected(query string, args ...interface{}) (int64, error) { - res, cancel, err := q.ExecQIter(query, args...) - defer cancel() - if err != nil { - return 0, err - } - - rowsDeleted, err := res.RowsAffected() - return rowsDeleted, err -} -func (q Q) ExecQ(query string, args ...interface{}) error { - ctx, cancel := q.Context() - defer cancel() - - ql := q.newQueryLogger(query, args) - ql.logSqlQuery() - defer ql.postSqlLog(ctx, time.Now()) - - _, err := q.Queryer.ExecContext(ctx, query, args...) - return ql.withLogError(err) -} -func (q Q) ExecQNamed(query string, arg interface{}) (err error) { - query, args, err := q.BindNamed(query, arg) - if err != nil { - return errors.Wrap(err, "error binding arg") - } - ctx, cancel := q.Context() - defer cancel() - - ql := q.newQueryLogger(query, args) - ql.logSqlQuery() - defer ql.postSqlLog(ctx, time.Now()) - - _, err = q.Queryer.ExecContext(ctx, query, args...) - return ql.withLogError(err) -} - -// Select and Get are safe to wrap the context cancellation because the rows -// are entirely consumed within the call -func (q Q) Select(dest interface{}, query string, args ...interface{}) error { - ctx, cancel := q.Context() - defer cancel() - - ql := q.newQueryLogger(query, args) - ql.logSqlQuery() - defer ql.postSqlLog(ctx, time.Now()) - - return ql.withLogError(q.Queryer.SelectContext(ctx, dest, query, args...)) -} - -func (q Q) SelectNamed(dest interface{}, query string, arg interface{}) error { - query, args, err := q.BindNamed(query, arg) - if err != nil { - return errors.Wrap(err, "error binding arg") - } - return q.Select(dest, query, args...) -} - -func (q Q) Get(dest interface{}, query string, args ...interface{}) error { - ctx, cancel := q.Context() - defer cancel() - - ql := q.newQueryLogger(query, args) - ql.logSqlQuery() - defer ql.postSqlLog(ctx, time.Now()) - - return ql.withLogError(q.Queryer.GetContext(ctx, dest, query, args...)) -} - -func (q Q) GetNamed(sql string, dest interface{}, arg interface{}) error { - query, args, err := q.BindNamed(sql, arg) - if err != nil { - return errors.Wrap(err, "error binding arg") - } - ctx, cancel := q.Context() - defer cancel() - - ql := q.newQueryLogger(query, args) - ql.logSqlQuery() - defer ql.postSqlLog(ctx, time.Now()) - - return ql.withLogError(errors.Wrap(q.GetContext(ctx, dest, query, args...), "error in get query")) -} - -func (q Q) newQueryLogger(query string, args []interface{}) *queryLogger { - return &queryLogger{Q: q, query: query, args: args, str: sync.OnceValue(func() string { - return sprintQ(query, args) - })} -} - -// sprintQ formats the query with the given args and returns the resulting string. -func sprintQ(query string, args []interface{}) string { - if args == nil { - return query - } - var pairs []string - for i, arg := range args { - // We print by type so one can directly take the logged query string and execute it manually in pg. - // Annoyingly it seems as though the logger itself will add an extra \, so you still have to remove that. - switch v := arg.(type) { - case []byte: - pairs = append(pairs, fmt.Sprintf("$%d", i+1), fmt.Sprintf("'\\x%x'", v)) - case common.Address: - pairs = append(pairs, fmt.Sprintf("$%d", i+1), fmt.Sprintf("'\\x%x'", v.Bytes())) - case common.Hash: - pairs = append(pairs, fmt.Sprintf("$%d", i+1), fmt.Sprintf("'\\x%x'", v.Bytes())) - case pq.ByteaArray: - pairs = append(pairs, fmt.Sprintf("$%d", i+1)) - if v == nil { - pairs = append(pairs, "NULL") - continue - } - if len(v) == 0 { - pairs = append(pairs, "ARRAY[]") - continue - } - var s strings.Builder - fmt.Fprintf(&s, "ARRAY['\\x%x'", v[0]) - for j := 1; j < len(v); j++ { - fmt.Fprintf(&s, ",'\\x%x'", v[j]) - } - pairs = append(pairs, fmt.Sprintf("%s]", s.String())) - case string: - pairs = append(pairs, fmt.Sprintf("$%d", i+1), fmt.Sprintf("'%s'", v)) - default: - pairs = append(pairs, fmt.Sprintf("$%d", i+1), fmt.Sprintf("%v", v)) - } - } - replacer := strings.NewReplacer(pairs...) - queryWithVals := replacer.Replace(query) - return strings.ReplaceAll(strings.ReplaceAll(queryWithVals, "\n", " "), "\t", " ") -} - -// queryLogger extends Q with logging helpers for a particular query w/ args. -type queryLogger struct { - Q - - query string - args []interface{} - - str func() string -} - -func (q *queryLogger) String() string { - return q.str() -} - -func (q *queryLogger) logSqlQuery() { - if q.config != nil && q.config.LogSQL() { - q.logger.Debugw("SQL QUERY", "sql", q) - } -} - -func (q *queryLogger) withLogError(err error) error { - if err != nil && !errors.Is(err, sql.ErrNoRows) && q.config != nil && q.config.LogSQL() { - q.logger.Errorw("SQL ERROR", "err", err, "sql", q) - } - return err -} - -// postSqlLog logs about context cancellation and timing after a query returns. -// Queries which use their full timeout log critical level. More than 50% log error, and 10% warn. -func (q *queryLogger) postSqlLog(ctx context.Context, begin time.Time) { - elapsed := time.Since(begin) - if ctx.Err() != nil { - q.logger.Debugw("SQL CONTEXT CANCELLED", "ms", elapsed.Milliseconds(), "err", ctx.Err(), "sql", q) - } - - timeout := q.QueryTimeout - if timeout <= 0 { - timeout = DefaultQueryTimeout - } - - pct := float64(elapsed) / float64(timeout) - pct *= 100 - - kvs := []any{"ms", elapsed.Milliseconds(), "timeout", timeout.Milliseconds(), "percent", strconv.FormatFloat(pct, 'f', 1, 64), "sql", q} - - if elapsed >= timeout { - q.logger.Criticalw("SLOW SQL QUERY", kvs...) - } else if errThreshold := timeout / 5; errThreshold > 0 && elapsed > errThreshold { - q.logger.Errorw("SLOW SQL QUERY", kvs...) - } else if warnThreshold := timeout / 10; warnThreshold > 0 && elapsed > warnThreshold { - q.logger.Warnw("SLOW SQL QUERY", kvs...) - } - - sqlutil.PromSQLQueryTime.Observe(pct) -} diff --git a/core/services/pg/q_test.go b/core/services/pg/q_test.go deleted file mode 100644 index 81a883789df..00000000000 --- a/core/services/pg/q_test.go +++ /dev/null @@ -1,85 +0,0 @@ -package pg - -import ( - "testing" - - "github.com/google/uuid" - "github.com/jmoiron/sqlx" - "github.com/lib/pq" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" - "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/store/dialects" -) - -func Test_sprintQ(t *testing.T) { - for _, tt := range []struct { - name string - query string - args []interface{} - exp string - }{ - {"none", - "SELECT * FROM table;", - nil, - "SELECT * FROM table;"}, - {"one", - "SELECT $1 FROM table;", - []interface{}{"foo"}, - "SELECT 'foo' FROM table;"}, - {"two", - "SELECT $1 FROM table WHERE bar = $2;", - []interface{}{"foo", 1}, - "SELECT 'foo' FROM table WHERE bar = 1;"}, - {"limit", - "SELECT $1 FROM table LIMIT $2;", - []interface{}{"foo", Limit(10)}, - "SELECT 'foo' FROM table LIMIT 10;"}, - {"limit-all", - "SELECT $1 FROM table LIMIT $2;", - []interface{}{"foo", Limit(-1)}, - "SELECT 'foo' FROM table LIMIT NULL;"}, - {"bytea", - "SELECT $1 FROM table WHERE b = $2;", - []interface{}{"foo", []byte{0x0a}}, - "SELECT 'foo' FROM table WHERE b = '\\x0a';"}, - {"bytea[]", - "SELECT $1 FROM table WHERE b = $2;", - []interface{}{"foo", pq.ByteaArray([][]byte{{0xa}, {0xb}})}, - "SELECT 'foo' FROM table WHERE b = ARRAY['\\x0a','\\x0b'];"}, - } { - t.Run(tt.name, func(t *testing.T) { - got := sprintQ(tt.query, tt.args) - t.Log(tt.query, tt.args) - t.Log(got) - require.Equal(t, tt.exp, got) - }) - } -} - -func Test_ExecQWithRowsAffected(t *testing.T) { - testutils.SkipShortDB(t) - db, err := sqlx.Open(string(dialects.TransactionWrappedPostgres), uuid.New().String()) - require.NoError(t, err) - q := NewQ(db, logger.NullLogger, NewQConfig(false)) - - require.NoError(t, q.ExecQ("CREATE TABLE testtable (a TEXT, b TEXT)")) - - rows, err := q.ExecQWithRowsAffected("INSERT INTO testtable (a, b) VALUES ($1, $2)", "foo", "bar") - require.NoError(t, err) - assert.Equal(t, int64(1), rows) - - rows, err = q.ExecQWithRowsAffected("INSERT INTO testtable (a, b) VALUES ($1, $1), ($2, $2), ($1, $2)", "foo", "bar") - require.NoError(t, err) - assert.Equal(t, int64(3), rows) - - rows, err = q.ExecQWithRowsAffected("delete from testtable") - require.NoError(t, err) - assert.Equal(t, int64(4), rows) - - rows, err = q.ExecQWithRowsAffected("delete from testtable") - require.NoError(t, err) - assert.Equal(t, int64(0), rows) -} diff --git a/core/services/pg/sqlx.go b/core/services/pg/sqlx.go index 76eae792cbf..9c99142e5c9 100644 --- a/core/services/pg/sqlx.go +++ b/core/services/pg/sqlx.go @@ -4,12 +4,8 @@ import ( "context" "database/sql" - "github.com/pkg/errors" - mapper "github.com/scylladb/go-reflectx" - "github.com/jmoiron/sqlx" - - "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/scylladb/go-reflectx" ) // Queryer is deprecated. Use sqlutil.DataSource instead @@ -32,22 +28,6 @@ type Queryer interface { func WrapDbWithSqlx(rdb *sql.DB) *sqlx.DB { db := sqlx.NewDb(rdb, "postgres") - db.MapperFunc(mapper.CamelToSnakeASCII) + db.MapperFunc(reflectx.CamelToSnakeASCII) return db } - -func SqlxTransaction(ctx context.Context, q Queryer, lggr logger.Logger, fc func(q Queryer) error, txOpts ...TxOption) (err error) { - switch db := q.(type) { - case *sqlx.Tx: - // nested transaction: just use the outer transaction - err = fc(db) - case *sqlx.DB: - err = sqlxTransactionQ(ctx, db, lggr, fc, txOpts...) - case Q: - err = sqlxTransactionQ(ctx, db.db, lggr, fc, txOpts...) - default: - err = errors.Errorf("invalid db type: %T", q) - } - - return -} diff --git a/core/services/pg/transaction.go b/core/services/pg/transaction.go deleted file mode 100644 index d60270b4fe8..00000000000 --- a/core/services/pg/transaction.go +++ /dev/null @@ -1,95 +0,0 @@ -package pg - -import ( - "context" - "database/sql" - "fmt" - "time" - - "github.com/getsentry/sentry-go" - "github.com/jmoiron/sqlx" - "github.com/pkg/errors" - "go.uber.org/multierr" - - "github.com/smartcontractkit/chainlink-common/pkg/logger" - corelogger "github.com/smartcontractkit/chainlink/v2/core/logger" -) - -// NOTE: This is the default level in Postgres anyway, we just make it -// explicit here -const defaultIsolation = sql.LevelReadCommitted - -// TxOption is a functional option for SQL transactions. -type TxOption func(*sql.TxOptions) - -func OptReadOnlyTx() TxOption { - return func(opts *sql.TxOptions) { - opts.ReadOnly = true - } -} - -func SqlTransaction(ctx context.Context, rdb *sql.DB, lggr logger.Logger, fn func(tx *sqlx.Tx) error, opts ...TxOption) (err error) { - db := WrapDbWithSqlx(rdb) - wrapFn := func(q Queryer) error { - tx, ok := q.(*sqlx.Tx) - if !ok { - panic(fmt.Sprintf("expected q to be %T but got %T", tx, q)) - } - return fn(tx) - } - return sqlxTransactionQ(ctx, db, lggr, wrapFn, opts...) -} - -// txBeginner can be a db or a conn, anything that implements BeginTxx -type txBeginner interface { - BeginTxx(context.Context, *sql.TxOptions) (*sqlx.Tx, error) -} - -func sqlxTransactionQ(ctx context.Context, db txBeginner, lggr logger.Logger, fn func(q Queryer) error, opts ...TxOption) (err error) { - var txOpts sql.TxOptions - for _, o := range opts { - o(&txOpts) - } - - var tx *sqlx.Tx - tx, err = db.BeginTxx(ctx, &txOpts) - if err != nil { - return errors.Wrap(err, "failed to begin transaction") - } - - defer func() { - if p := recover(); p != nil { - sentry.CurrentHub().Recover(p) - sentry.Flush(corelogger.SentryFlushDeadline) - - // A panic occurred, rollback and repanic - lggr.Errorf("Panic in transaction, rolling back: %s", p) - done := make(chan struct{}) - go func() { - if rerr := tx.Rollback(); rerr != nil { - lggr.Errorf("Failed to rollback on panic: %s", rerr) - } - close(done) - }() - select { - case <-done: - panic(p) - case <-time.After(10 * time.Second): - panic(fmt.Sprintf("panic in transaction; aborting rollback that took longer than 10s: %s", p)) - } - } else if err != nil { - lggr.Warnf("Error in transaction, rolling back: %s", err) - // An error occurred, rollback and return error - if rerr := tx.Rollback(); rerr != nil { - err = multierr.Combine(err, errors.WithStack(rerr)) - } - } else { - // All good! Time to commit. - err = errors.WithStack(tx.Commit()) - } - }() - - err = fn(tx) - - return -} diff --git a/core/services/pg/utils.go b/core/services/pg/utils.go deleted file mode 100644 index eb53c261296..00000000000 --- a/core/services/pg/utils.go +++ /dev/null @@ -1,50 +0,0 @@ -package pg - -import ( - "database/sql/driver" - "strconv" - "time" -) - -const ( - // DefaultQueryTimeout is a reasonable upper bound for how long a SQL query should take. - // The configured value should be used instead of this if possible. - DefaultQueryTimeout = 10 * time.Second - // longQueryTimeout is a bigger upper bound for how long a SQL query should take - longQueryTimeout = 1 * time.Minute -) - -var _ driver.Valuer = Limit(-1) - -// Limit is a helper driver.Valuer for LIMIT queries which uses nil/NULL for negative values. -type Limit int - -func (l Limit) String() string { - if l < 0 { - return "NULL" - } - return strconv.Itoa(int(l)) -} - -func (l Limit) Value() (driver.Value, error) { - if l < 0 { - return nil, nil - } - return l, nil -} - -var _ QConfig = &qConfig{} - -// qConfig implements pg.QCOnfig -type qConfig struct { - logSQL bool - defaultQueryTimeout time.Duration -} - -func NewQConfig(logSQL bool) QConfig { - return &qConfig{logSQL, DefaultQueryTimeout} -} - -func (p *qConfig) LogSQL() bool { return p.logSQL } - -func (p *qConfig) DefaultQueryTimeout() time.Duration { return p.defaultQueryTimeout } diff --git a/core/services/pipeline/helpers_test.go b/core/services/pipeline/helpers_test.go index 0bbdef7a7f2..97d81f56f74 100644 --- a/core/services/pipeline/helpers_test.go +++ b/core/services/pipeline/helpers_test.go @@ -5,7 +5,6 @@ import ( "github.com/google/uuid" - "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/bridges" "github.com/smartcontractkit/chainlink/v2/core/chains/legacyevm" ) @@ -65,4 +64,4 @@ func (t *ETHTxTask) HelperSetDependencies(legacyChains legacyevm.LegacyChainCont t.jobType = jobType } -func (o *orm) Prune(ds sqlutil.DataSource, pipelineSpecID int32) { o.prune(ds, pipelineSpecID) } +func (o *orm) Prune(pipelineSpecID int32) { o.prune(o.ds, pipelineSpecID) } diff --git a/core/services/pipeline/mocks/orm.go b/core/services/pipeline/mocks/orm.go index 2fa6d8681e8..9bd5ddfbdea 100644 --- a/core/services/pipeline/mocks/orm.go +++ b/core/services/pipeline/mocks/orm.go @@ -58,9 +58,9 @@ func (_m *ORM) CreateRun(ctx context.Context, run *pipeline.Run) error { return r0 } -// CreateSpec provides a mock function with given fields: ctx, ds, _a2, maxTaskTimeout -func (_m *ORM) CreateSpec(ctx context.Context, ds pipeline.CreateDataSource, _a2 pipeline.Pipeline, maxTaskTimeout models.Interval) (int32, error) { - ret := _m.Called(ctx, ds, _a2, maxTaskTimeout) +// CreateSpec provides a mock function with given fields: ctx, _a1, maxTaskTimeout +func (_m *ORM) CreateSpec(ctx context.Context, _a1 pipeline.Pipeline, maxTaskTimeout models.Interval) (int32, error) { + ret := _m.Called(ctx, _a1, maxTaskTimeout) if len(ret) == 0 { panic("no return value specified for CreateSpec") @@ -68,17 +68,17 @@ func (_m *ORM) CreateSpec(ctx context.Context, ds pipeline.CreateDataSource, _a2 var r0 int32 var r1 error - if rf, ok := ret.Get(0).(func(context.Context, pipeline.CreateDataSource, pipeline.Pipeline, models.Interval) (int32, error)); ok { - return rf(ctx, ds, _a2, maxTaskTimeout) + if rf, ok := ret.Get(0).(func(context.Context, pipeline.Pipeline, models.Interval) (int32, error)); ok { + return rf(ctx, _a1, maxTaskTimeout) } - if rf, ok := ret.Get(0).(func(context.Context, pipeline.CreateDataSource, pipeline.Pipeline, models.Interval) int32); ok { - r0 = rf(ctx, ds, _a2, maxTaskTimeout) + if rf, ok := ret.Get(0).(func(context.Context, pipeline.Pipeline, models.Interval) int32); ok { + r0 = rf(ctx, _a1, maxTaskTimeout) } else { r0 = ret.Get(0).(int32) } - if rf, ok := ret.Get(1).(func(context.Context, pipeline.CreateDataSource, pipeline.Pipeline, models.Interval) error); ok { - r1 = rf(ctx, ds, _a2, maxTaskTimeout) + if rf, ok := ret.Get(1).(func(context.Context, pipeline.Pipeline, models.Interval) error); ok { + r1 = rf(ctx, _a1, maxTaskTimeout) } else { r1 = ret.Error(1) } diff --git a/core/services/pipeline/orm.go b/core/services/pipeline/orm.go index 3bebfb8cbad..0a96a7e08d5 100644 --- a/core/services/pipeline/orm.go +++ b/core/services/pipeline/orm.go @@ -11,8 +11,6 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" - "github.com/jmoiron/sqlx" - "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" @@ -81,8 +79,7 @@ type CreateDataSource interface { type ORM interface { services.Service - // ds is optional and to be removed after completing https://smartcontract-it.atlassian.net/browse/BCF-2978 - CreateSpec(ctx context.Context, ds CreateDataSource, pipeline Pipeline, maxTaskTimeout models.Interval) (int32, error) + CreateSpec(ctx context.Context, pipeline Pipeline, maxTaskTimeout models.Interval) (int32, error) CreateRun(ctx context.Context, run *Run) (err error) InsertRun(ctx context.Context, run *Run) error DeleteRun(ctx context.Context, id int64) error @@ -163,6 +160,9 @@ func (o *orm) Transact(ctx context.Context, fn func(ORM) error) error { return sqlutil.Transact(ctx, func(tx sqlutil.DataSource) ORM { return o.withDataSource(tx) }, o.ds, nil, func(tx ORM) error { + if err := tx.Start(ctx); err != nil { + return fmt.Errorf("failed to start tx orm: %w", err) + } defer func() { if err := tx.Close(); err != nil { o.lggr.Warnw("Error closing temporary transactional ORM", "err", err) @@ -191,14 +191,11 @@ func (o *orm) transact(ctx context.Context, fn func(*orm) error) error { return sqlutil.Transact(ctx, o.withDataSource, o.ds, nil, fn) } -func (o *orm) CreateSpec(ctx context.Context, ds CreateDataSource, pipeline Pipeline, maxTaskDuration models.Interval) (id int32, err error) { +func (o *orm) CreateSpec(ctx context.Context, pipeline Pipeline, maxTaskDuration models.Interval) (id int32, err error) { sql := `INSERT INTO pipeline_specs (dot_dag_source, max_task_duration, created_at) VALUES ($1, $2, NOW()) RETURNING id;` - if ds == nil { - ds = o.ds - } - err = ds.GetContext(ctx, &id, sql, pipeline.Source, maxTaskDuration) + err = o.ds.GetContext(ctx, &id, sql, pipeline.Source, maxTaskDuration) return id, errors.WithStack(err) } @@ -254,13 +251,13 @@ func (o *orm) StoreRun(ctx context.Context, run *Run) (restart bool, err error) // Lock the current run. This prevents races with /v2/resume sql := `SELECT id FROM pipeline_runs WHERE id = $1 FOR UPDATE;` if _, err = tx.ds.ExecContext(ctx, sql, run.ID); err != nil { - return errors.Wrap(err, "StoreRun") + return fmt.Errorf("failed to select pipeline run %d: %w", run.ID, err) } taskRuns := []TaskRun{} // Reload task runs, we want to check for any changes while the run was ongoing if err = tx.ds.SelectContext(ctx, &taskRuns, `SELECT * FROM pipeline_task_runs WHERE pipeline_run_id = $1`, run.ID); err != nil { - return errors.Wrap(err, "StoreRun") + return fmt.Errorf("failed to select piepline task run %d: %w", run.ID, err) } // Construct a temporary run so we can use r.ByDotID @@ -287,17 +284,17 @@ func (o *orm) StoreRun(ctx context.Context, run *Run) (restart bool, err error) // Suspend the run run.State = RunStatusSuspended if _, err = tx.ds.NamedExecContext(ctx, `UPDATE pipeline_runs SET state = :state WHERE id = :id`, run); err != nil { - return errors.Wrap(err, "StoreRun") + return fmt.Errorf("failed to update pipeline run %d to %s: %w", run.ID, run.State, err) } } else { defer o.prune(tx.ds, run.PruningKey) // Simply finish the run, no need to do any sort of locking if run.Outputs.Val == nil || len(run.FatalErrors)+len(run.AllErrors) == 0 { - return errors.Errorf("run must have both Outputs and Errors, got Outputs: %#v, FatalErrors: %#v, AllErrors: %#v", run.Outputs.Val, run.FatalErrors, run.AllErrors) + return fmt.Errorf("run must have both Outputs and Errors, got Outputs: %#v, FatalErrors: %#v, AllErrors: %#v", run.Outputs.Val, run.FatalErrors, run.AllErrors) } sql := `UPDATE pipeline_runs SET state = :state, finished_at = :finished_at, all_errors= :all_errors, fatal_errors= :fatal_errors, outputs = :outputs WHERE id = :id` if _, err = tx.ds.NamedExecContext(ctx, sql, run); err != nil { - return errors.Wrap(err, "StoreRun") + return fmt.Errorf("failed to update pipeline run %d: %w", run.ID, err) } } @@ -309,18 +306,15 @@ func (o *orm) StoreRun(ctx context.Context, run *Run) (restart bool, err error) RETURNING *; ` - // NOTE: can't use Select() to auto scan because we're using NamedQuery, - // sqlx.Named + Select is possible but it's about the same amount of code - var rows *sqlx.Rows - rows, err = sqlx.NamedQueryContext(ctx, tx.ds, sql, run.PipelineTaskRuns) + taskRuns := []TaskRun{} + query, args, err := tx.ds.BindNamed(sql, run.PipelineTaskRuns) if err != nil { - return errors.Wrap(err, "StoreRun") + return fmt.Errorf("failed to prepare named query: %w", err) } - taskRuns := []TaskRun{} - if err = sqlx.StructScan(rows, &taskRuns); err != nil { - return errors.Wrap(err, "StoreRun") + err = tx.ds.SelectContext(ctx, &taskRuns, query, args...) + if err != nil { + return fmt.Errorf("failed to insert pipeline task runs: %w", err) } - // replace with new task run data run.PipelineTaskRuns = taskRuns return nil }) @@ -383,19 +377,18 @@ VALUES (:pipeline_spec_id, :pruning_key, :meta, :all_errors, :fatal_errors, :inputs, :outputs, :created_at, :finished_at, :state) RETURNING id ` - rows, errQ := sqlx.NamedQueryContext(ctx, tx.ds, pipelineRunsQuery, runs) - if errQ != nil { - return errors.Wrap(errQ, "inserting finished pipeline runs") - } - defer rows.Close() var runIDs []int64 - for rows.Next() { + err := sqlutil.NamedQueryContext(ctx, tx.ds, pipelineRunsQuery, runs, func(row sqlutil.RowScanner) error { var runID int64 - if errS := rows.Scan(&runID); errS != nil { + if errS := row.Scan(&runID); errS != nil { return errors.Wrap(errS, "scanning pipeline runs id row") } runIDs = append(runIDs, runID) + return nil + }) + if err != nil { + return errors.Wrap(err, "inserting finished pipeline runs") } pruningKeysm := make(map[int32]struct{}) @@ -717,13 +710,13 @@ const syncLimit = 1000 // // Note this does not guarantee the pipeline_runs table is kept to exactly the // max length, rather that it doesn't excessively larger than it. -func (o *orm) prune(ds sqlutil.DataSource, jobID int32) { +func (o *orm) prune(tx sqlutil.DataSource, jobID int32) { if jobID == 0 { o.lggr.Panic("expected a non-zero job ID") } // For small maxSuccessfulRuns its fast enough to prune every time if o.maxSuccessfulRuns < syncLimit { - o.execPrune(o.ctx, ds, jobID) + o.withDataSource(tx).execPrune(o.ctx, jobID) return } // for large maxSuccessfulRuns we do it async on a sampled basis @@ -736,11 +729,11 @@ func (o *orm) prune(ds sqlutil.DataSource, jobID int32) { go func() { o.lggr.Debugw("Pruning runs", "jobID", jobID, "count", val, "every", every, "maxSuccessfulRuns", o.maxSuccessfulRuns) defer o.wg.Done() - // Must not use ds here since it's async and the transaction - // could be stale ctx, cancel := context.WithTimeout(sqlutil.WithoutDefaultTimeout(o.ctx), time.Minute) defer cancel() - o.execPrune(ctx, o.ds, jobID) + + // Must not use tx here since it could be stale by the time we execute async. + o.execPrune(ctx, jobID) }() }) if !ok { @@ -750,8 +743,8 @@ func (o *orm) prune(ds sqlutil.DataSource, jobID int32) { } } -func (o *orm) execPrune(ctx context.Context, ds sqlutil.DataSource, jobID int32) { - res, err := ds.ExecContext(o.ctx, `DELETE FROM pipeline_runs WHERE pruning_key = $1 AND state = $2 AND id NOT IN ( +func (o *orm) execPrune(ctx context.Context, jobID int32) { + res, err := o.ds.ExecContext(o.ctx, `DELETE FROM pipeline_runs WHERE pruning_key = $1 AND state = $2 AND id NOT IN ( SELECT id FROM pipeline_runs WHERE pruning_key = $1 AND state = $2 ORDER BY id DESC @@ -769,7 +762,7 @@ LIMIT $3 if rowsAffected == 0 { // check the spec still exists and garbage collect if necessary var exists bool - if err := ds.GetContext(ctx, &exists, `SELECT EXISTS(SELECT ps.* FROM pipeline_specs ps JOIN job_pipeline_specs jps ON (ps.id=jps.pipeline_spec_id) WHERE jps.job_id = $1)`, jobID); err != nil { + if err := o.ds.GetContext(ctx, &exists, `SELECT EXISTS(SELECT ps.* FROM pipeline_specs ps JOIN job_pipeline_specs jps ON (ps.id=jps.pipeline_spec_id) WHERE jps.job_id = $1)`, jobID); err != nil { o.lggr.Errorw("Failed check existence of pipeline_spec while pruning runs", "err", err, "jobID", jobID) return } diff --git a/core/services/pipeline/orm_test.go b/core/services/pipeline/orm_test.go index bba928534ba..6ff32e15cc7 100644 --- a/core/services/pipeline/orm_test.go +++ b/core/services/pipeline/orm_test.go @@ -1,6 +1,7 @@ package pipeline_test import ( + "context" "testing" "time" @@ -12,6 +13,7 @@ import ( "github.com/jmoiron/sqlx" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/hex" "github.com/smartcontractkit/chainlink-common/pkg/utils/jsonserializable" @@ -25,41 +27,34 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/chainlink" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/store/models" ) -type ormconfig struct { - pg.QConfig -} - -func (ormconfig) JobPipelineMaxSuccessfulRuns() uint64 { return 123456 } - type testOnlyORM interface { pipeline.ORM - AddJobPipelineSpecWithoutConstraints(jobID, pipelineSpecID int32) error + AddJobPipelineSpecWithoutConstraints(ctx context.Context, jobID, pipelineSpecID int32) error } type testORM struct { pipeline.ORM - db *sqlx.DB + ds sqlutil.DataSource } -func (torm *testORM) AddJobPipelineSpecWithoutConstraints(jobID, pipelineSpecID int32) error { - _, err := torm.db.Exec(`SET CONSTRAINTS fk_job_pipeline_spec_job DEFERRED`) +func (torm *testORM) AddJobPipelineSpecWithoutConstraints(ctx context.Context, jobID, pipelineSpecID int32) error { + _, err := torm.ds.ExecContext(ctx, `SET CONSTRAINTS fk_job_pipeline_spec_job DEFERRED`) if err != nil { return err } - _, err = torm.db.Exec(`INSERT INTO job_pipeline_specs (job_id,pipeline_spec_id, is_primary) VALUES ($1, $2, false)`, jobID, pipelineSpecID) + _, err = torm.ds.ExecContext(ctx, `INSERT INTO job_pipeline_specs (job_id,pipeline_spec_id, is_primary) VALUES ($1, $2, false)`, jobID, pipelineSpecID) if err != nil { return err } return nil } -func newTestORM(orm pipeline.ORM, db *sqlx.DB) testOnlyORM { - return &testORM{ORM: orm, db: db} +func newTestORM(orm pipeline.ORM, ds sqlutil.DataSource) testOnlyORM { + return &testORM{ORM: orm, ds: ds} } func setupORM(t *testing.T, heavy bool) (db *sqlx.DB, orm pipeline.ORM, jorm job.ORM) { @@ -70,14 +65,12 @@ func setupORM(t *testing.T, heavy bool) (db *sqlx.DB, orm pipeline.ORM, jorm job } else { db = pgtest.NewSqlxDB(t) } - cfg := ormconfig{pgtest.NewQConfig(true)} - orm = pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipelineMaxSuccessfulRuns()) - config := configtest.NewTestGeneralConfig(t) + orm = pipeline.NewORM(db, logger.TestLogger(t), 123456) lggr := logger.TestLogger(t) keyStore := cltest.NewKeyStore(t, db) bridgeORM := bridges.NewORM(db) - jorm = job.NewORM(db, orm, bridgeORM, keyStore, lggr, config.Database()) + jorm = job.NewORM(db, orm, bridgeORM, keyStore, lggr) return } @@ -103,7 +96,7 @@ func Test_PipelineORM_CreateSpec(t *testing.T) { Source: source, } - id, err := orm.CreateSpec(ctx, nil, p, maxTaskDuration) + id, err := orm.CreateSpec(ctx, p, maxTaskDuration) require.NoError(t, err) actual := pipeline.Spec{} @@ -171,7 +164,7 @@ answer2 [type=bridge name=election_winner index=1]; DotDagSource: s, }, } - err := jobORM.CreateJob(&jb) + err := jobORM.CreateJob(ctx, &jb) require.NoError(t, err) run := &pipeline.Run{ @@ -274,7 +267,7 @@ answer2 [type=bridge name=election_winner index=1]; DotDagSource: s, }, } - err := jorm.CreateJob(&jb) + err := jorm.CreateJob(ctx, &jb) require.NoError(t, err) spec := pipeline.Spec{ DotDagSource: s, @@ -665,7 +658,7 @@ func Test_GetUnfinishedRuns_Keepers(t *testing.T) { porm := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgeORM := bridges.NewORM(db) - jorm := job.NewORM(db, porm, bridgeORM, keyStore, lggr, config.Database()) + jorm := job.NewORM(db, porm, bridgeORM, keyStore, lggr) defer func() { assert.NoError(t, jorm.Close()) }() timestamp := time.Now() @@ -689,7 +682,7 @@ func Test_GetUnfinishedRuns_Keepers(t *testing.T) { MaxTaskDuration: models.Interval(1 * time.Minute), } - err := jorm.CreateJob(&keeperJob) + err := jorm.CreateJob(ctx, &keeperJob) require.NoError(t, err) require.Equal(t, job.Keeper, keeperJob.Type) @@ -768,7 +761,7 @@ func Test_GetUnfinishedRuns_DirectRequest(t *testing.T) { porm := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgeORM := bridges.NewORM(db) - jorm := job.NewORM(db, porm, bridgeORM, keyStore, lggr, config.Database()) + jorm := job.NewORM(db, porm, bridgeORM, keyStore, lggr) defer func() { assert.NoError(t, jorm.Close()) }() timestamp := time.Now() @@ -791,7 +784,7 @@ func Test_GetUnfinishedRuns_DirectRequest(t *testing.T) { MaxTaskDuration: models.Interval(1 * time.Minute), } - err := jorm.CreateJob(&drJob) + err := jorm.CreateJob(ctx, &drJob) require.NoError(t, err) require.Equal(t, job.DirectRequest, drJob.Type) @@ -865,13 +858,13 @@ func Test_Prune(t *testing.T) { ps1 := cltest.MustInsertPipelineSpec(t, db) // We need a job_pipeline_specs entry to test the pruning mechanism - err := torm.AddJobPipelineSpecWithoutConstraints(ps1.ID, ps1.ID) + err := torm.AddJobPipelineSpecWithoutConstraints(testutils.Context(t), ps1.ID, ps1.ID) require.NoError(t, err) jobID := ps1.ID t.Run("when there are no runs to prune, does nothing", func(t *testing.T) { - porm.Prune(db, jobID) + porm.Prune(jobID) // no error logs; it did nothing assert.Empty(t, observed.All()) @@ -907,7 +900,7 @@ func Test_Prune(t *testing.T) { cltest.MustInsertPipelineRunWithStatus(t, db, ps2.ID, pipeline.RunStatusSuspended, jobID2) } - porm.Prune(db, jobID2) + porm.Prune(jobID2) cnt := pgtest.MustCount(t, db, "SELECT count(*) FROM pipeline_runs WHERE pipeline_spec_id = $1 AND state = $2", ps1.ID, pipeline.RunStatusCompleted) assert.Equal(t, cnt, 20) diff --git a/core/services/pipeline/runner.go b/core/services/pipeline/runner.go index 862d2f49178..2de27b3d008 100644 --- a/core/services/pipeline/runner.go +++ b/core/services/pipeline/runner.go @@ -661,7 +661,7 @@ func (r *runner) Run(ctx context.Context, run *Run, l logger.Logger, saveSuccess } if err = r.orm.InsertFinishedRun(ctx, run, saveSuccessfulTaskRuns); err != nil { - return false, pkgerrors.Wrapf(err, "error storing run for spec ID %v", run.PipelineSpec.ID) + return false, pkgerrors.Wrapf(err, "error inserting finished run for spec ID %v", run.PipelineSpec.ID) } } diff --git a/core/services/pipeline/runner_test.go b/core/services/pipeline/runner_test.go index e086d5297ef..44d7acadd27 100644 --- a/core/services/pipeline/runner_test.go +++ b/core/services/pipeline/runner_test.go @@ -33,7 +33,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/chainlink" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline/mocks" "github.com/smartcontractkit/chainlink/v2/core/utils" @@ -42,14 +41,10 @@ import ( ) func newRunner(t testing.TB, db *sqlx.DB, bridgeORM bridges.ORM, cfg chainlink.GeneralConfig) (pipeline.Runner, *mocks.ORM) { - lggr := logger.TestLogger(t) ethKeyStore := cltest.NewKeyStore(t, db).Eth() relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: cfg, KeyStore: ethKeyStore}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) orm := mocks.NewORM(t) - q := pg.NewQ(db, lggr, cfg.Database()) - - orm.On("GetQ").Return(q).Maybe() c := clhttptest.NewTestLocalOnlyHTTPClient() r := pipeline.NewRunner(orm, bridgeORM, cfg.JobPipeline(), cfg.WebServer(), legacyChains, ethKeyStore, nil, logger.TestLogger(t), c, c) return r, orm @@ -250,8 +245,7 @@ func Test_PipelineRunner_ExecuteTaskRunsWithVars(t *testing.T) { "times": "1000000000000000000", }, }, - }, - cfg.Database()) + }) defer ds1.Close() btORM.On("FindBridge", mock.Anything, bridge.Name).Return(bridge, nil).Once() @@ -269,7 +263,7 @@ func Test_PipelineRunner_ExecuteTaskRunsWithVars(t *testing.T) { defer ds4.Close() // 3. Setup final bridge task - submit, submitBt := makeBridge(t, db, expectedRequestSubmit, map[string]interface{}{"ok": true}, cfg.Database()) + submit, submitBt := makeBridge(t, db, expectedRequestSubmit, map[string]interface{}{"ok": true}) defer submit.Close() btORM.On("FindBridge", mock.Anything, submitBt.Name).Return(submitBt, nil).Once() @@ -419,10 +413,7 @@ func Test_PipelineRunner_HandleFaults(t *testing.T) { // but a sufficient number of them still complete within the desired time frame // and so we can still obtain a median. db := pgtest.NewSqlxDB(t) - orm := mocks.NewORM(t) - q := pg.NewQ(db, logger.TestLogger(t), configtest.NewTestGeneralConfig(t).Database()) - orm.On("GetQ").Return(q).Maybe() m1 := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { time.Sleep(100 * time.Millisecond) res.WriteHeader(http.StatusOK) @@ -472,8 +463,7 @@ func Test_PipelineRunner_HandleFaultsPersistRun(t *testing.T) { db := pgtest.NewSqlxDB(t) orm := mocks.NewORM(t) btORM := bridgesMocks.NewORM(t) - q := pg.NewQ(db, logger.TestLogger(t), configtest.NewTestGeneralConfig(t).Database()) - orm.On("GetQ").Return(q).Maybe() + orm.On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything). Run(func(args mock.Arguments) { args.Get(1).(*pipeline.Run).ID = 1 @@ -513,8 +503,7 @@ func Test_PipelineRunner_ExecuteAndInsertFinishedRun_SavingTheSpec(t *testing.T) db := pgtest.NewSqlxDB(t) orm := mocks.NewORM(t) btORM := bridgesMocks.NewORM(t) - q := pg.NewQ(db, logger.TestLogger(t), configtest.NewTestGeneralConfig(t).Database()) - orm.On("GetQ").Return(q).Maybe() + orm.On("InsertFinishedRunWithSpec", mock.Anything, mock.Anything, mock.Anything). Run(func(args mock.Arguments) { args.Get(1).(*pipeline.Run).ID = 1 diff --git a/core/services/pipeline/task.bridge_test.go b/core/services/pipeline/task.bridge_test.go index 029c6c78ca8..e95aef4984c 100644 --- a/core/services/pipeline/task.bridge_test.go +++ b/core/services/pipeline/task.bridge_test.go @@ -29,7 +29,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/chainlink" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline/internal/eautils" "github.com/smartcontractkit/chainlink/v2/core/store/models" @@ -217,7 +216,7 @@ func TestBridgeTask_Happy(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) + specID, err := trORM.CreateSpec(testutils.Context(t), pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -259,7 +258,7 @@ func TestBridgeTask_HandlesIntermittentFailure(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) + specID, err := trORM.CreateSpec(testutils.Context(t), pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) result, runInfo := task.Run(testutils.Context(t), logger.TestLogger(t), @@ -300,18 +299,19 @@ func TestBridgeTask_HandlesIntermittentFailure(t *testing.T) { func TestBridgeTask_DoesNotReturnStaleResults(t *testing.T) { t.Parallel() - db := pgtest.NewSqlxDB(t) + ctx := testutils.Context(t) cfg := configtest.NewGeneralConfig(t, func(c *chainlink.Config, s *chainlink.Secrets) { c.WebServer.BridgeCacheTTL = commonconfig.MustNewDuration(30 * time.Second) }) - queryer := pg.NewQ(db, logger.TestLogger(t), cfg.Database()) + s1 := httptest.NewServer(fakeIntermittentlyFailingPriceResponder(t, utils.MustUnmarshalToMap(btcUSDPairing), decimal.NewFromInt(9700), "", nil)) defer s1.Close() feedURL, err := url.ParseRequestURI(s1.URL) require.NoError(t, err) + db := pgtest.NewSqlxDB(t) orm := bridges.NewORM(db) _, bridge := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{URL: feedURL.String()}) @@ -322,12 +322,12 @@ func TestBridgeTask_DoesNotReturnStaleResults(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) + specID, err := trORM.CreateSpec(testutils.Context(t), pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) // Insert entry 1m in the past, stale value, should not be used in case of EA failure. - err = queryer.ExecQ(`INSERT INTO bridge_last_value(dot_id, spec_id, value, finished_at) + _, err = db.ExecContext(ctx, `INSERT INTO bridge_last_value(dot_id, spec_id, value, finished_at) VALUES($1, $2, $3, $4) ON CONFLICT ON CONSTRAINT bridge_last_value_pkey DO UPDATE SET value = $3, finished_at = $4;`, task.DotID(), specID, big.NewInt(9700).Bytes(), time.Now().Add(-1*time.Minute)) require.NoError(t, err) @@ -348,7 +348,7 @@ func TestBridgeTask_DoesNotReturnStaleResults(t *testing.T) { require.Nil(t, result2.Value) // Insert entry 10s in the past, under 30 seconds and should be used in case of failure. - err = queryer.ExecQ(`INSERT INTO bridge_last_value(dot_id, spec_id, value, finished_at) + _, err = db.ExecContext(ctx, `INSERT INTO bridge_last_value(dot_id, spec_id, value, finished_at) VALUES($1, $2, $3, $4) ON CONFLICT ON CONSTRAINT bridge_last_value_pkey DO UPDATE SET value = $3, finished_at = $4;`, task.DotID(), specID, big.NewInt(9700).Bytes(), time.Now().Add(-10*time.Second)) require.NoError(t, err) @@ -398,7 +398,7 @@ func TestBridgeTask_DoesNotReturnStaleResults(t *testing.T) { task2.HelperSetDependencies(cfg2.JobPipeline(), cfg2.WebServer(), orm, specID, uuid.UUID{}, c) // Insert entry 32m in the past, under cacheTTL of 35m but more than stalenessCap of 30m. - err = queryer.ExecQ(`INSERT INTO bridge_last_value(dot_id, spec_id, value, finished_at) + _, err = db.ExecContext(ctx, `INSERT INTO bridge_last_value(dot_id, spec_id, value, finished_at) VALUES($1, $2, $3, $4) ON CONFLICT ON CONSTRAINT bridge_last_value_pkey DO UPDATE SET value = $3, finished_at = $4;`, task2.DotID(), specID, big.NewInt(9700).Bytes(), time.Now().Add(-32*time.Minute)) require.NoError(t, err) @@ -420,7 +420,7 @@ func TestBridgeTask_DoesNotReturnStaleResults(t *testing.T) { require.Nil(t, result2.Value) // Insert entry 25m in the past, under stalenessCap - err = queryer.ExecQ(`INSERT INTO bridge_last_value(dot_id, spec_id, value, finished_at) + _, err = db.ExecContext(ctx, `INSERT INTO bridge_last_value(dot_id, spec_id, value, finished_at) VALUES($1, $2, $3, $4) ON CONFLICT ON CONSTRAINT bridge_last_value_pkey DO UPDATE SET value = $3, finished_at = $4;`, task2.DotID(), specID, big.NewInt(9700).Bytes(), time.Now().Add(-25*time.Minute)) require.NoError(t, err) @@ -482,7 +482,7 @@ func TestBridgeTask_AsyncJobPendingState(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) + specID, err := trORM.CreateSpec(testutils.Context(t), pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, id, c) @@ -660,7 +660,7 @@ func TestBridgeTask_Variables(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) + specID, err := trORM.CreateSpec(testutils.Context(t), pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -729,7 +729,7 @@ func TestBridgeTask_Meta(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) + specID, err := trORM.CreateSpec(testutils.Context(t), pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -783,7 +783,7 @@ func TestBridgeTask_IncludeInputAtKey(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) + specID, err := trORM.CreateSpec(testutils.Context(t), pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -839,7 +839,7 @@ func TestBridgeTask_ErrorMessage(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) + specID, err := trORM.CreateSpec(testutils.Context(t), pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -878,7 +878,7 @@ func TestBridgeTask_OnlyErrorMessage(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) + specID, err := trORM.CreateSpec(testutils.Context(t), pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -903,7 +903,7 @@ func TestBridgeTask_ErrorIfBridgeMissing(t *testing.T) { c := clhttptest.NewTestLocalOnlyHTTPClient() orm := bridges.NewORM(db) trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) + specID, err := trORM.CreateSpec(testutils.Context(t), pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -993,7 +993,7 @@ func TestBridgeTask_Headers(t *testing.T) { c := clhttptest.NewTestLocalOnlyHTTPClient() trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) + specID, err := trORM.CreateSpec(testutils.Context(t), pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -1015,7 +1015,7 @@ func TestBridgeTask_Headers(t *testing.T) { c := clhttptest.NewTestLocalOnlyHTTPClient() trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) + specID, err := trORM.CreateSpec(testutils.Context(t), pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -1037,7 +1037,7 @@ func TestBridgeTask_Headers(t *testing.T) { c := clhttptest.NewTestLocalOnlyHTTPClient() trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) + specID, err := trORM.CreateSpec(testutils.Context(t), pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -1052,6 +1052,7 @@ func TestBridgeTask_Headers(t *testing.T) { func TestBridgeTask_AdapterResponseStatusFailure(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) cfg := configtest.NewGeneralConfig(t, func(c *chainlink.Config, s *chainlink.Secrets) { @@ -1062,7 +1063,6 @@ func TestBridgeTask_AdapterResponseStatusFailure(t *testing.T) { Data: adapterResponseData{Result: &decimal.Zero}, } - queryer := pg.NewQ(db, logger.TestLogger(t), cfg.Database()) s1 := httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := json.NewEncoder(w).Encode(testAdapterResponse) @@ -1083,12 +1083,12 @@ func TestBridgeTask_AdapterResponseStatusFailure(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) + specID, err := trORM.CreateSpec(ctx, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) // Insert entry 1m in the past, stale value, should not be used in case of EA failure. - err = queryer.ExecQ(`INSERT INTO bridge_last_value(dot_id, spec_id, value, finished_at) + _, err = db.ExecContext(ctx, `INSERT INTO bridge_last_value(dot_id, spec_id, value, finished_at) VALUES($1, $2, $3, $4) ON CONFLICT ON CONSTRAINT bridge_last_value_pkey DO UPDATE SET value = $3, finished_at = $4;`, task.DotID(), specID, big.NewInt(9700).Bytes(), time.Now()) require.NoError(t, err) @@ -1105,7 +1105,7 @@ func TestBridgeTask_AdapterResponseStatusFailure(t *testing.T) { // expect all external adapter response status failures to be served from the cache testAdapterResponse.SetStatusCode(http.StatusBadRequest) - result, runInfo := task.Run(testutils.Context(t), logger.TestLogger(t), vars, nil) + result, runInfo := task.Run(ctx, logger.TestLogger(t), vars, nil) require.NoError(t, result.Error) require.NotNil(t, result.Value) @@ -1114,7 +1114,7 @@ func TestBridgeTask_AdapterResponseStatusFailure(t *testing.T) { testAdapterResponse.SetStatusCode(http.StatusOK) testAdapterResponse.SetProviderStatusCode(http.StatusBadRequest) - result, runInfo = task.Run(testutils.Context(t), logger.TestLogger(t), vars, nil) + result, runInfo = task.Run(ctx, logger.TestLogger(t), vars, nil) require.NoError(t, result.Error) require.NotNil(t, result.Value) @@ -1124,7 +1124,7 @@ func TestBridgeTask_AdapterResponseStatusFailure(t *testing.T) { testAdapterResponse.SetStatusCode(http.StatusOK) testAdapterResponse.SetProviderStatusCode(http.StatusOK) testAdapterResponse.SetError("some error") - result, runInfo = task.Run(testutils.Context(t), logger.TestLogger(t), vars, nil) + result, runInfo = task.Run(ctx, logger.TestLogger(t), vars, nil) require.NoError(t, result.Error) require.NotNil(t, result.Value) @@ -1132,7 +1132,7 @@ func TestBridgeTask_AdapterResponseStatusFailure(t *testing.T) { require.False(t, runInfo.IsPending) testAdapterResponse.SetStatusCode(http.StatusInternalServerError) - result, runInfo = task.Run(testutils.Context(t), logger.TestLogger(t), vars, nil) + result, runInfo = task.Run(ctx, logger.TestLogger(t), vars, nil) require.NoError(t, result.Error) require.NotNil(t, result.Value) diff --git a/core/services/pipeline/task.http_test.go b/core/services/pipeline/task.http_test.go index 6264d1e591b..4098ce50d2a 100644 --- a/core/services/pipeline/task.http_test.go +++ b/core/services/pipeline/task.http_test.go @@ -177,7 +177,7 @@ func TestHTTPTask_Variables(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) + specID, err := trORM.CreateSpec(testutils.Context(t), pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) diff --git a/core/services/pipeline/test_helpers_test.go b/core/services/pipeline/test_helpers_test.go index fc87942e073..d890afc33df 100644 --- a/core/services/pipeline/test_helpers_test.go +++ b/core/services/pipeline/test_helpers_test.go @@ -13,7 +13,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/bridges" "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/jmoiron/sqlx" @@ -39,7 +38,7 @@ func fakeExternalAdapter(t *testing.T, expectedRequest, response interface{}) ht }) } -func makeBridge(t *testing.T, db *sqlx.DB, expectedRequest, response interface{}, cfg pg.QConfig) (*httptest.Server, bridges.BridgeType) { +func makeBridge(t *testing.T, db *sqlx.DB, expectedRequest, response interface{}) (*httptest.Server, bridges.BridgeType) { t.Helper() server := httptest.NewServer(fakeExternalAdapter(t, expectedRequest, response)) diff --git a/core/services/promreporter/prom_reporter.go b/core/services/promreporter/prom_reporter.go index a302a6fa220..92e674aac44 100644 --- a/core/services/promreporter/prom_reporter.go +++ b/core/services/promreporter/prom_reporter.go @@ -2,12 +2,12 @@ package promreporter import ( "context" - "database/sql" "fmt" "math/big" "sync" "time" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" txmgrcommon "github.com/smartcontractkit/chainlink/v2/common/txmgr" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" @@ -28,7 +28,7 @@ import ( type ( promReporter struct { services.StateMachine - db *sql.DB + ds sqlutil.DataSource chains legacyevm.LegacyChainContainer lggr logger.Logger backend PrometheusBackend @@ -92,7 +92,7 @@ func (defaultBackend) SetPipelineTaskRunsQueued(n int) { promPipelineRunsQueued.Set(float64(n)) } -func NewPromReporter(db *sql.DB, chainContainer legacyevm.LegacyChainContainer, lggr logger.Logger, opts ...interface{}) *promReporter { +func NewPromReporter(ds sqlutil.DataSource, chainContainer legacyevm.LegacyChainContainer, lggr logger.Logger, opts ...interface{}) *promReporter { var backend PrometheusBackend = defaultBackend{} period := 15 * time.Second for _, opt := range opts { @@ -106,7 +106,7 @@ func NewPromReporter(db *sql.DB, chainContainer legacyevm.LegacyChainContainer, chStop := make(chan struct{}) return &promReporter{ - db: db, + ds: ds, chains: chainContainer, lggr: lggr.Named("PromReporter"), backend: backend, @@ -242,7 +242,7 @@ func (pr *promReporter) reportMaxUnconfirmedBlocks(ctx context.Context, head *ev } func (pr *promReporter) reportPipelineRunStats(ctx context.Context) (err error) { - rows, err := pr.db.QueryContext(ctx, ` + rows, err := pr.ds.QueryContext(ctx, ` SELECT pipeline_run_id FROM pipeline_task_runs WHERE finished_at IS NULL `) if err != nil { diff --git a/core/services/promreporter/prom_reporter_test.go b/core/services/promreporter/prom_reporter_test.go index bb09b86df95..f17b4aafed2 100644 --- a/core/services/promreporter/prom_reporter_test.go +++ b/core/services/promreporter/prom_reporter_test.go @@ -48,7 +48,6 @@ func newLegacyChainContainer(t *testing.T, db *sqlx.DB) legacyevm.LegacyChainCon lp := logpoller.NewLogPoller(logpoller.NewORM(testutils.FixtureChainID, db, lggr), ethClient, lggr, lpOpts) txm, err := txmgr.NewTxm( - db, db, evmConfig, evmConfig.GasEstimator(), @@ -72,7 +71,7 @@ func Test_PromReporter_OnNewLongestChain(t *testing.T) { db := pgtest.NewSqlxDB(t) backend := mocks.NewPrometheusBackend(t) - reporter := promreporter.NewPromReporter(db.DB, newLegacyChainContainer(t, db), logger.TestLogger(t), backend, 10*time.Millisecond) + reporter := promreporter.NewPromReporter(db, newLegacyChainContainer(t, db), logger.TestLogger(t), backend, 10*time.Millisecond) var subscribeCalls atomic.Int32 @@ -114,7 +113,7 @@ func Test_PromReporter_OnNewLongestChain(t *testing.T) { subscribeCalls.Add(1) }). Return() - reporter := promreporter.NewPromReporter(db.DB, newLegacyChainContainer(t, db), logger.TestLogger(t), backend, 10*time.Millisecond) + reporter := promreporter.NewPromReporter(db, newLegacyChainContainer(t, db), logger.TestLogger(t), backend, 10*time.Millisecond) servicetest.Run(t, reporter) etx := cltest.MustInsertUnconfirmedEthTxWithBroadcastLegacyAttempt(t, txStore, 0, fromAddress) @@ -133,7 +132,7 @@ func Test_PromReporter_OnNewLongestChain(t *testing.T) { pgtest.MustExec(t, db, `SET CONSTRAINTS pipeline_task_runs_pipeline_run_id_fkey DEFERRED`) backend := mocks.NewPrometheusBackend(t) - reporter := promreporter.NewPromReporter(db.DB, newLegacyChainContainer(t, db), logger.TestLogger(t), backend, 10*time.Millisecond) + reporter := promreporter.NewPromReporter(db, newLegacyChainContainer(t, db), logger.TestLogger(t), backend, 10*time.Millisecond) cltest.MustInsertUnfinishedPipelineTaskRun(t, db, 1) cltest.MustInsertUnfinishedPipelineTaskRun(t, db, 1) diff --git a/core/services/relay/evm/evm.go b/core/services/relay/evm/evm.go index e9aaa7e0a8e..2f30d40e28d 100644 --- a/core/services/relay/evm/evm.go +++ b/core/services/relay/evm/evm.go @@ -11,7 +11,6 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common" "github.com/google/uuid" - "github.com/jmoiron/sqlx" pkgerrors "github.com/pkg/errors" "golang.org/x/exp/maps" @@ -35,7 +34,6 @@ import ( lloconfig "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/llo/config" mercuryconfig "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/mercury/config" "github.com/smartcontractkit/chainlink/v2/core/services/ocrcommon" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/functions" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury" mercuryutils "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury/utils" @@ -71,13 +69,11 @@ func init() { var _ commontypes.Relayer = &Relayer{} //nolint:staticcheck type Relayer struct { - db *sqlx.DB // legacy: prefer to use ds instead ds sqlutil.DataSource chain legacyevm.Chain lggr logger.Logger ks CSAETHKeystore mercuryPool wsrpc.Pool - pgCfg pg.QConfig chainReader commontypes.ChainReader codec commontypes.Codec @@ -95,24 +91,16 @@ type CSAETHKeystore interface { } type RelayerOpts struct { - *sqlx.DB // legacy: prefer to use ds instead - DS sqlutil.DataSource - pg.QConfig + DS sqlutil.DataSource CSAETHKeystore MercuryPool wsrpc.Pool } func (c RelayerOpts) Validate() error { var err error - if c.DB == nil { - err = errors.Join(err, errors.New("nil DB")) - } if c.DS == nil { err = errors.Join(err, errors.New("nil DataSource")) } - if c.QConfig == nil { - err = errors.Join(err, errors.New("nil QConfig")) - } if c.CSAETHKeystore == nil { err = errors.Join(err, errors.New("nil Keystore")) } @@ -134,13 +122,11 @@ func NewRelayer(lggr logger.Logger, chain legacyevm.Chain, opts RelayerOpts) (*R lloORM := llo.NewORM(opts.DS, chain.ID()) cdcFactory := llo.NewChannelDefinitionCacheFactory(lggr, lloORM, chain.LogPoller()) return &Relayer{ - db: opts.DB, ds: opts.DS, chain: chain, lggr: lggr, ks: opts.CSAETHKeystore, mercuryPool: opts.MercuryPool, - pgCfg: opts.QConfig, cdcFactory: cdcFactory, lloORM: lloORM, mercuryORM: mercuryORM, @@ -637,7 +623,7 @@ func (r *Relayer) NewMedianProvider(rargs commontypes.RelayArgs, pargs commontyp func (r *Relayer) NewAutomationProvider(rargs commontypes.RelayArgs, pargs commontypes.PluginArgs) (commontypes.AutomationProvider, error) { lggr := r.lggr.Named("AutomationProvider").Named(rargs.ExternalJobID.String()) - ocr2keeperRelayer := NewOCR2KeeperRelayer(r.db, r.chain, lggr.Named("OCR2KeeperRelayer"), r.ks.Eth(), r.pgCfg) + ocr2keeperRelayer := NewOCR2KeeperRelayer(r.ds, r.chain, lggr.Named("OCR2KeeperRelayer"), r.ks.Eth()) return ocr2keeperRelayer.NewOCR2KeeperProvider(rargs, pargs) } diff --git a/core/services/relay/evm/evm_test.go b/core/services/relay/evm/evm_test.go index d53fe910bc3..ab60ff2a128 100644 --- a/core/services/relay/evm/evm_test.go +++ b/core/services/relay/evm/evm_test.go @@ -5,20 +5,13 @@ import ( "github.com/stretchr/testify/assert" - "github.com/jmoiron/sqlx" - "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" - "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm" ) func TestRelayerOpts_Validate(t *testing.T) { - cfg := configtest.NewTestGeneralConfig(t) type fields struct { - DB *sqlx.DB DS sqlutil.DataSource - QConfig pg.QConfig CSAETHKeystore evm.CSAETHKeystore } tests := []struct { @@ -29,33 +22,25 @@ func TestRelayerOpts_Validate(t *testing.T) { { name: "all invalid", fields: fields{ - DB: nil, DS: nil, - QConfig: nil, CSAETHKeystore: nil, }, - wantErrContains: `nil DB -nil DataSource -nil QConfig + wantErrContains: `nil DataSource nil Keystore`, }, { - name: "missing db, ds, keystore", + name: "missing ds, keystore", fields: fields{ - DB: nil, - QConfig: cfg.Database(), + DS: nil, }, - wantErrContains: `nil DB -nil DataSource + wantErrContains: `nil DataSource nil Keystore`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := evm.RelayerOpts{ - DB: tt.fields.DB, DS: tt.fields.DS, - QConfig: tt.fields.QConfig, CSAETHKeystore: tt.fields.CSAETHKeystore, } err := c.Validate() diff --git a/core/services/relay/evm/mercury/wsrpc/pb/mercury_wsrpc.pb.go b/core/services/relay/evm/mercury/wsrpc/pb/mercury_wsrpc.pb.go index 0c31a1d7ac9..4d05db4380f 100644 --- a/core/services/relay/evm/mercury/wsrpc/pb/mercury_wsrpc.pb.go +++ b/core/services/relay/evm/mercury/wsrpc/pb/mercury_wsrpc.pb.go @@ -11,7 +11,6 @@ import ( ) // MercuryClient is the client API for Mercury service. -// type MercuryClient interface { Transmit(ctx context.Context, in *TransmitRequest) (*TransmitResponse, error) LatestReport(ctx context.Context, in *LatestReportRequest) (*LatestReportResponse, error) diff --git a/core/services/relay/evm/ocr2keeper.go b/core/services/relay/evm/ocr2keeper.go index 0dd971123c6..78f4b43b43f 100644 --- a/core/services/relay/evm/ocr2keeper.go +++ b/core/services/relay/evm/ocr2keeper.go @@ -27,7 +27,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/logprovider" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/transmit" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/upkeepstate" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/types" ) @@ -69,17 +68,15 @@ type ocr2keeperRelayer struct { chain legacyevm.Chain lggr logger.Logger ethKeystore keystore.Eth - dbCfg pg.QConfig } // NewOCR2KeeperRelayer is the constructor of ocr2keeperRelayer -func NewOCR2KeeperRelayer(ds sqlutil.DataSource, chain legacyevm.Chain, lggr logger.Logger, ethKeystore keystore.Eth, dbCfg pg.QConfig) OCR2KeeperRelayer { +func NewOCR2KeeperRelayer(ds sqlutil.DataSource, chain legacyevm.Chain, lggr logger.Logger, ethKeystore keystore.Eth) OCR2KeeperRelayer { return &ocr2keeperRelayer{ ds: ds, chain: chain, lggr: lggr, ethKeystore: ethKeystore, - dbCfg: dbCfg, } } diff --git a/core/services/relay/evm/ocr2vrf.go b/core/services/relay/evm/ocr2vrf.go index 07edd1c5ac6..a108151be47 100644 --- a/core/services/relay/evm/ocr2vrf.go +++ b/core/services/relay/evm/ocr2vrf.go @@ -6,8 +6,6 @@ import ( "fmt" "github.com/ethereum/go-ethereum/common" - "github.com/jmoiron/sqlx" - "github.com/smartcontractkit/libocr/offchainreporting2plus/chains/evmutil" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" @@ -44,15 +42,13 @@ var ( // Relayer with added DKG and OCR2VRF provider functions. type ocr2vrfRelayer struct { - db *sqlx.DB chain legacyevm.Chain lggr logger.Logger ethKeystore keystore.Eth } -func NewOCR2VRFRelayer(db *sqlx.DB, chain legacyevm.Chain, lggr logger.Logger, ethKeystore keystore.Eth) OCR2VRFRelayer { +func NewOCR2VRFRelayer(chain legacyevm.Chain, lggr logger.Logger, ethKeystore keystore.Eth) OCR2VRFRelayer { return &ocr2vrfRelayer{ - db: db, chain: chain, lggr: lggr, ethKeystore: ethKeystore, diff --git a/core/services/versioning/orm.go b/core/services/versioning/orm.go index 5a2472eee8e..5f6e3e60222 100644 --- a/core/services/versioning/orm.go +++ b/core/services/versioning/orm.go @@ -23,16 +23,14 @@ type ORM interface { } type orm struct { - ds sqlutil.DataSource - lggr logger.Logger - timeout time.Duration + ds sqlutil.DataSource + lggr logger.Logger } -func NewORM(ds sqlutil.DataSource, lggr logger.Logger, timeout time.Duration) *orm { +func NewORM(ds sqlutil.DataSource, lggr logger.Logger) *orm { return &orm{ - ds: ds, - lggr: lggr.Named("VersioningORM"), - timeout: timeout, + ds: ds, + lggr: lggr.Named("VersioningORM"), } } @@ -47,8 +45,6 @@ func (o *orm) UpsertNodeVersion(ctx context.Context, version NodeVersion) error return errors.Wrapf(err, "%q is not valid semver", version.Version) } - ctx, cancel := context.WithTimeout(ctx, o.timeout) - defer cancel() return sqlutil.TransactDataSource(ctx, o.ds, nil, func(tx sqlutil.DataSource) error { if _, _, err := CheckVersion(ctx, tx, logger.NullLogger, version.Version); err != nil { return err diff --git a/core/services/versioning/orm_test.go b/core/services/versioning/orm_test.go index f655c9c47fe..3504c2bc772 100644 --- a/core/services/versioning/orm_test.go +++ b/core/services/versioning/orm_test.go @@ -10,14 +10,13 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/static" ) func TestORM_NodeVersion_UpsertNodeVersion(t *testing.T) { ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) - orm := NewORM(db, logger.TestLogger(t), pg.DefaultQueryTimeout) + orm := NewORM(db, logger.TestLogger(t)) err := orm.UpsertNodeVersion(ctx, NewNodeVersion("9.9.8")) require.NoError(t, err) @@ -66,7 +65,7 @@ func Test_Version_CheckVersion(t *testing.T) { lggr := logger.TestLogger(t) - orm := NewORM(db, lggr, pg.DefaultQueryTimeout) + orm := NewORM(db, lggr) err := orm.UpsertNodeVersion(ctx, NewNodeVersion("9.9.8")) require.NoError(t, err) @@ -101,7 +100,7 @@ func Test_Version_CheckVersion(t *testing.T) { func TestORM_NodeVersion_FindLatestNodeVersion(t *testing.T) { ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) - orm := NewORM(db, logger.TestLogger(t), pg.DefaultQueryTimeout) + orm := NewORM(db, logger.TestLogger(t)) // Not Found _, err := orm.FindLatestNodeVersion(ctx) diff --git a/core/services/vrf/delegate_test.go b/core/services/vrf/delegate_test.go index a3962977257..889b19d0e04 100644 --- a/core/services/vrf/delegate_test.go +++ b/core/services/vrf/delegate_test.go @@ -83,10 +83,10 @@ func buildVrfUni(t *testing.T, db *sqlx.DB, cfg chainlink.GeneralConfig) vrfUniv btORM := bridges.NewORM(db) ks := keystore.NewInMemory(db, utils.FastScryptParams, lggr) _, dbConfig, evmConfig := txmgr.MakeTestConfigs(t) - txm, err := txmgr.NewTxm(db, db, evmConfig, evmConfig.GasEstimator(), evmConfig.Transactions(), nil, dbConfig, dbConfig.Listener(), ec, logger.TestLogger(t), nil, ks.Eth(), nil) + txm, err := txmgr.NewTxm(db, evmConfig, evmConfig.GasEstimator(), evmConfig.Transactions(), nil, dbConfig, dbConfig.Listener(), ec, logger.TestLogger(t), nil, ks.Eth(), nil) orm := headtracker.NewORM(*testutils.FixtureChainID, db) require.NoError(t, orm.IdempotentInsertHead(testutils.Context(t), cltest.Head(51))) - jrm := job.NewORM(db, prm, btORM, ks, lggr, cfg.Database()) + jrm := job.NewORM(db, prm, btORM, ks, lggr) t.Cleanup(func() { assert.NoError(t, jrm.Close()) }) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{LogBroadcaster: lb, KeyStore: ks.Eth(), Client: ec, DB: db, GeneralConfig: cfg, TxManager: txm}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) @@ -165,7 +165,8 @@ func setup(t *testing.T) (vrfUniverse, *v1.Listener, job.Job) { vs := testspecs.GenerateVRFSpec(testspecs.VRFSpecParams{PublicKey: vuni.vrfkey.PublicKey.String(), EVMChainID: testutils.FixtureChainID.String()}) jb, err := vrfcommon.ValidatedVRFSpec(vs.Toml()) require.NoError(t, err) - err = vuni.jrm.CreateJob(&jb) + ctx := testutils.Context(t) + err = vuni.jrm.CreateJob(ctx, &jb) require.NoError(t, err) vl, err := vd.ServicesForSpec(testutils.Context(t), jb) require.NoError(t, err) @@ -701,7 +702,8 @@ func Test_VRFV2PlusServiceFailsWhenVRFOwnerProvided(t *testing.T) { toml := "vrfOwnerAddress=\"0xF62fEFb54a0af9D32CDF0Db21C52710844c7eddb\"\n" + vs.Toml() jb, err := vrfcommon.ValidatedVRFSpec(toml) require.NoError(t, err) - err = vuni.jrm.CreateJob(&jb) + ctx := testutils.Context(t) + err = vuni.jrm.CreateJob(ctx, &jb) require.NoError(t, err) _, err = vd.ServicesForSpec(testutils.Context(t), jb) require.Error(t, err) diff --git a/core/services/vrf/v1/integration_test.go b/core/services/vrf/v1/integration_test.go index c28ad9ce3d0..74006639c6e 100644 --- a/core/services/vrf/v1/integration_test.go +++ b/core/services/vrf/v1/integration_test.go @@ -55,10 +55,10 @@ func TestIntegration_VRF_JPV2(t *testing.T) { cu := vrftesthelpers.NewVRFCoordinatorUniverse(t, key1, key2) incomingConfs := 2 app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, cu.Backend, key1, key2) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) jb, vrfKey := createVRFJobRegisterKey(t, cu, app, incomingConfs) - require.NoError(t, app.JobSpawner().CreateJob(&jb)) + require.NoError(t, app.JobSpawner().CreateJob(ctx, nil, &jb)) _, err := cu.ConsumerContract.TestRequestRandomness(cu.Carol, vrfKey.PublicKey.MustHash(), big.NewInt(100)) @@ -93,12 +93,12 @@ func TestIntegration_VRF_JPV2(t *testing.T) { // stop jobs as to not cause a race condition in geth simulated backend // between job creating new tx and fulfillment logs polling below - require.NoError(t, app.JobSpawner().DeleteJob(jb.ID)) + require.NoError(t, app.JobSpawner().DeleteJob(ctx, nil, jb.ID)) // Ensure the eth transaction gets confirmed on chain. gomega.NewWithT(t).Eventually(func() bool { - orm := txmgr.NewTxStore(app.GetSqlxDB(), app.GetLogger()) - uc, err2 := orm.CountUnconfirmedTransactions(testutils.Context(t), key1.Address, testutils.SimulatedChainID) + orm := txmgr.NewTxStore(app.GetDB(), app.GetLogger()) + uc, err2 := orm.CountUnconfirmedTransactions(ctx, key1.Address, testutils.SimulatedChainID) require.NoError(t, err2) return uc == 0 }, testutils.WaitTimeout(t), 100*time.Millisecond).Should(gomega.BeTrue()) @@ -116,11 +116,11 @@ func TestIntegration_VRF_JPV2(t *testing.T) { }, testutils.WaitTimeout(t), 500*time.Millisecond).Should(gomega.BeTrue()) // Check that each sending address sent one transaction - n1, err := cu.Backend.PendingNonceAt(testutils.Context(t), key1.Address) + n1, err := cu.Backend.PendingNonceAt(ctx, key1.Address) require.NoError(t, err) require.EqualValues(t, 1, n1) - n2, err := cu.Backend.PendingNonceAt(testutils.Context(t), key2.Address) + n2, err := cu.Backend.PendingNonceAt(ctx, key2.Address) require.NoError(t, err) require.EqualValues(t, 1, n2) }) @@ -142,7 +142,7 @@ func TestIntegration_VRF_WithBHS(t *testing.T) { cu := vrftesthelpers.NewVRFCoordinatorUniverse(t, key) incomingConfs := 2 app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, cu.Backend, key) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF Job but do not start it yet jb, vrfKey := createVRFJobRegisterKey(t, cu, app, incomingConfs) @@ -155,7 +155,7 @@ func TestIntegration_VRF_WithBHS(t *testing.T) { // Ensure log poller is ready and has all logs. require.NoError(t, app.GetRelayers().LegacyEVMChains().Slice()[0].LogPoller().Ready()) - require.NoError(t, app.GetRelayers().LegacyEVMChains().Slice()[0].LogPoller().Replay(testutils.Context(t), 1)) + require.NoError(t, app.GetRelayers().LegacyEVMChains().Slice()[0].LogPoller().Replay(ctx, 1)) // Create a VRF request _, err := cu.ConsumerContract.TestRequestRandomness(cu.Carol, @@ -194,7 +194,7 @@ func TestIntegration_VRF_WithBHS(t *testing.T) { } // Start the VRF Job and wait until it's processed - require.NoError(t, app.JobSpawner().CreateJob(&jb)) + require.NoError(t, app.JobSpawner().CreateJob(ctx, nil, &jb)) var runs []pipeline.Run gomega.NewWithT(t).Eventually(func() bool { @@ -209,13 +209,13 @@ func TestIntegration_VRF_WithBHS(t *testing.T) { // stop jobs as to not cause a race condition in geth simulated backend // between job creating new tx and fulfillment logs polling below - require.NoError(t, app.JobSpawner().DeleteJob(jb.ID)) - require.NoError(t, app.JobSpawner().DeleteJob(bhsJob.ID)) + require.NoError(t, app.JobSpawner().DeleteJob(ctx, nil, jb.ID)) + require.NoError(t, app.JobSpawner().DeleteJob(ctx, nil, bhsJob.ID)) // Ensure the eth transaction gets confirmed on chain. gomega.NewWithT(t).Eventually(func() bool { - orm := txmgr.NewTxStore(app.GetSqlxDB(), app.GetLogger()) - uc, err2 := orm.CountUnconfirmedTransactions(testutils.Context(t), key.Address, testutils.SimulatedChainID) + orm := txmgr.NewTxStore(app.GetDB(), app.GetLogger()) + uc, err2 := orm.CountUnconfirmedTransactions(ctx, key.Address, testutils.SimulatedChainID) require.NoError(t, err2) return uc == 0 }, 5*time.Second, 100*time.Millisecond).Should(gomega.BeTrue()) diff --git a/core/services/vrf/v1/listener_v1.go b/core/services/vrf/v1/listener_v1.go index ddf5779deb0..c8029403084 100644 --- a/core/services/vrf/v1/listener_v1.go +++ b/core/services/vrf/v1/listener_v1.go @@ -28,7 +28,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/recovery" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/services/vrf/vrfcommon" "github.com/smartcontractkit/chainlink/v2/core/utils" @@ -57,7 +56,6 @@ type Listener struct { Coordinator *solidity_vrf_coordinator_interface.VRFCoordinator PipelineRunner pipeline.Runner Job job.Job - Q pg.Q GethKs vrfcommon.GethKeyStore MailMon *mailbox.Monitor ReqLogs *mailbox.Mailbox[log.Broadcast] @@ -285,6 +283,8 @@ func (lsn *Listener) RunHeadListener(unsubscribe func()) { } func (lsn *Listener) RunLogListener(unsubscribes []func(), minConfs uint32) { + ctx, cancel := lsn.ChStop.NewCtx() + defer cancel() lsn.L.Infow("Listening for run requests", "gasLimit", lsn.FeeCfg.LimitDefault(), "minConfs", minConfs) @@ -304,8 +304,6 @@ func (lsn *Listener) RunLogListener(unsubscribes []func(), minConfs uint32) { break } recovery.WrapRecover(lsn.L, func() { - ctx, cancel := lsn.ChStop.NewCtx() - defer cancel() lsn.handleLog(ctx, lb, minConfs) }) } @@ -488,8 +486,7 @@ func (lsn *Listener) ProcessRequest(ctx context.Context, req request) bool { // The VRF pipeline has no async tasks, so we don't need to check for `incomplete` if _, err = lsn.PipelineRunner.Run(ctx, run, lggr, true, func(tx sqlutil.DataSource) error { // Always mark consumed regardless of whether the proof failed or not. - //TODO restore tx https://smartcontract-it.atlassian.net/browse/BCF-2978 - if err = lsn.Chain.LogBroadcaster().MarkConsumed(ctx, nil, req.lb); err != nil { + if err = lsn.Chain.LogBroadcaster().MarkConsumed(ctx, tx, req.lb); err != nil { lggr.Errorw("Failed mark consumed", "err", err) } return nil diff --git a/core/services/vrf/v2/integration_helpers_test.go b/core/services/vrf/v2/integration_helpers_test.go index 2e0554fca96..d61779c5714 100644 --- a/core/services/vrf/v2/integration_helpers_test.go +++ b/core/services/vrf/v2/integration_helpers_test.go @@ -1207,7 +1207,6 @@ func testSingleConsumerBigGasCallbackSandwich( ownerKey ethkey.KeyV2, uni coordinatorV2UniverseCommon, batchCoordinatorAddress common.Address, - batchEnabled bool, vrfVersion vrfcommon.Version, nativePayment bool, ) { @@ -1324,7 +1323,6 @@ func testSingleConsumerMultipleGasLanes( ownerKey ethkey.KeyV2, uni coordinatorV2UniverseCommon, batchCoordinatorAddress common.Address, - batchEnabled bool, vrfVersion vrfcommon.Version, nativePayment bool, ) { @@ -1694,7 +1692,7 @@ func testMaliciousConsumer( }).Toml() jb, err := vrfcommon.ValidatedVRFSpec(s) require.NoError(t, err) - err = app.JobSpawner().CreateJob(&jb) + err = app.JobSpawner().CreateJob(ctx, nil, &jb) require.NoError(t, err) time.Sleep(1 * time.Second) @@ -1861,7 +1859,7 @@ func testReplayOldRequestsOnStartUp( jb, err := vrfcommon.ValidatedVRFSpec(spec) require.NoError(t, err) t.Log(jb.VRFSpec.PublicKey.MustHash(), vrfKey.PublicKey.MustHash()) - err = app.JobSpawner().CreateJob(&jb) + err = app.JobSpawner().CreateJob(ctx, nil, &jb) require.NoError(t, err) // Wait until all jobs are active and listening for logs diff --git a/core/services/vrf/v2/integration_v2_plus_test.go b/core/services/vrf/v2/integration_v2_plus_test.go index b885473e488..53baaa0eda3 100644 --- a/core/services/vrf/v2/integration_v2_plus_test.go +++ b/core/services/vrf/v2/integration_v2_plus_test.go @@ -646,29 +646,13 @@ func TestVRFV2PlusIntegration_SingleConsumer_NeedsTopUp(t *testing.T) { func TestVRFV2PlusIntegration_SingleConsumer_BigGasCallback_Sandwich(t *testing.T) { ownerKey := cltest.MustGenerateRandomKey(t) uni := newVRFCoordinatorV2PlusUniverse(t, ownerKey, 1, false) - testSingleConsumerBigGasCallbackSandwich( - t, - ownerKey, - uni.coordinatorV2UniverseCommon, - uni.batchCoordinatorContractAddress, - false, - vrfcommon.V2Plus, - false, - ) + testSingleConsumerBigGasCallbackSandwich(t, ownerKey, uni.coordinatorV2UniverseCommon, uni.batchCoordinatorContractAddress, vrfcommon.V2Plus, false) } func TestVRFV2PlusIntegration_SingleConsumer_MultipleGasLanes(t *testing.T) { ownerKey := cltest.MustGenerateRandomKey(t) uni := newVRFCoordinatorV2PlusUniverse(t, ownerKey, 1, false) - testSingleConsumerMultipleGasLanes( - t, - ownerKey, - uni.coordinatorV2UniverseCommon, - uni.batchCoordinatorContractAddress, - false, - vrfcommon.V2Plus, - false, - ) + testSingleConsumerMultipleGasLanes(t, ownerKey, uni.coordinatorV2UniverseCommon, uni.batchCoordinatorContractAddress, vrfcommon.V2Plus, false) } func TestVRFV2PlusIntegration_SingleConsumer_AlwaysRevertingCallback_StillFulfilled(t *testing.T) { @@ -1178,7 +1162,7 @@ func TestVRFV2PlusIntegration_Migration(t *testing.T) { // Fund gas lane. sendEth(t, ownerKey, uni.backend, key1.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job using key1 and key2 on the same gas lane. jbs := createVRFJobs( @@ -1234,7 +1218,7 @@ func TestVRFV2PlusIntegration_Migration(t *testing.T) { require.NoError(t, err) linkContractBalance, err := uni.linkContract.BalanceOf(nil, uni.migrationTestCoordinatorAddress) require.NoError(t, err) - balance, err := uni.backend.BalanceAt(testutils.Context(t), uni.migrationTestCoordinatorAddress, nil) + balance, err := uni.backend.BalanceAt(ctx, uni.migrationTestCoordinatorAddress, nil) require.NoError(t, err) require.Equal(t, subV1.Balance(), totalLinkBalance) diff --git a/core/services/vrf/v2/integration_v2_reverted_txns_test.go b/core/services/vrf/v2/integration_v2_reverted_txns_test.go index dfee450b6a2..25e3afcf751 100644 --- a/core/services/vrf/v2/integration_v2_reverted_txns_test.go +++ b/core/services/vrf/v2/integration_v2_reverted_txns_test.go @@ -448,7 +448,7 @@ func createVRFJobsNew( jb, err := vrfcommon.ValidatedVRFSpec(s) t.Log(jb.VRFSpec.PublicKey.MustHash(), vrfkey.PublicKey.MustHash()) require.NoError(t, err) - err = app.JobSpawner().CreateJob(&jb) + err = app.JobSpawner().CreateJob(ctx, nil, &jb) require.NoError(t, err) registerProvingKeyHelper(t, uni.coordinatorV2UniverseCommon, coordinator, vrfkey, ptr(gasLanePrices[i].ToInt().Uint64())) jobs = append(jobs, jb) diff --git a/core/services/vrf/v2/integration_v2_test.go b/core/services/vrf/v2/integration_v2_test.go index 543ec943527..e8d4fd255f7 100644 --- a/core/services/vrf/v2/integration_v2_test.go +++ b/core/services/vrf/v2/integration_v2_test.go @@ -574,7 +574,7 @@ func createVRFJobs( jb, err := vrfcommon.ValidatedVRFSpec(spec) require.NoError(t, err) t.Log(jb.VRFSpec.PublicKey.MustHash(), vrfkey.PublicKey.MustHash()) - err = app.JobSpawner().CreateJob(&jb) + err = app.JobSpawner().CreateJob(ctx, nil, &jb) require.NoError(t, err) registerProvingKeyHelper(t, uni, coordinator, vrfkey, ptr(gasLanePrices[i].ToInt().Uint64())) jobs = append(jobs, jb) @@ -1427,29 +1427,13 @@ func TestVRFV2Integration_SingleConsumer_NeedsTopUp(t *testing.T) { func TestVRFV2Integration_SingleConsumer_BigGasCallback_Sandwich(t *testing.T) { ownerKey := cltest.MustGenerateRandomKey(t) uni := newVRFCoordinatorV2Universe(t, ownerKey, 1) - testSingleConsumerBigGasCallbackSandwich( - t, - ownerKey, - uni.coordinatorV2UniverseCommon, - uni.batchCoordinatorContractAddress, - false, - vrfcommon.V2, - false, - ) + testSingleConsumerBigGasCallbackSandwich(t, ownerKey, uni.coordinatorV2UniverseCommon, uni.batchCoordinatorContractAddress, vrfcommon.V2, false) } func TestVRFV2Integration_SingleConsumer_MultipleGasLanes(t *testing.T) { ownerKey := cltest.MustGenerateRandomKey(t) uni := newVRFCoordinatorV2Universe(t, ownerKey, 1) - testSingleConsumerMultipleGasLanes( - t, - ownerKey, - uni.coordinatorV2UniverseCommon, - uni.batchCoordinatorContractAddress, - false, - vrfcommon.V2, - false, - ) + testSingleConsumerMultipleGasLanes(t, ownerKey, uni.coordinatorV2UniverseCommon, uni.batchCoordinatorContractAddress, vrfcommon.V2, false) } func TestVRFV2Integration_SingleConsumer_AlwaysRevertingCallback_StillFulfilled(t *testing.T) { diff --git a/core/services/vrf/vrftesthelpers/helpers.go b/core/services/vrf/vrftesthelpers/helpers.go index 77d3f33a653..33ad8470731 100644 --- a/core/services/vrf/vrftesthelpers/helpers.go +++ b/core/services/vrf/vrftesthelpers/helpers.go @@ -74,7 +74,8 @@ func CreateAndStartBHSJob( jb, err := blockhashstore.ValidatedSpec(s.Toml()) require.NoError(t, err) - require.NoError(t, app.JobSpawner().CreateJob(&jb)) + ctx := testutils.Context(t) + require.NoError(t, app.JobSpawner().CreateJob(ctx, nil, &jb)) gomega.NewWithT(t).Eventually(func() bool { jbs := app.JobSpawner().ActiveJobs() for _, jb := range jbs { @@ -115,7 +116,8 @@ func CreateAndStartBlockHeaderFeederJob( jb, err := blockheaderfeeder.ValidatedSpec(s.Toml()) require.NoError(t, err) - require.NoError(t, app.JobSpawner().CreateJob(&jb)) + ctx := testutils.Context(t) + require.NoError(t, app.JobSpawner().CreateJob(ctx, nil, &jb)) gomega.NewWithT(t).Eventually(func() bool { jbs := app.JobSpawner().ActiveJobs() for _, jb := range jbs { diff --git a/core/services/webhook/authorizer.go b/core/services/webhook/authorizer.go index c745c5d0b09..f5409cfbacc 100644 --- a/core/services/webhook/authorizer.go +++ b/core/services/webhook/authorizer.go @@ -24,29 +24,29 @@ var ( _ Authorizer = &neverAuthorizer{} ) -func NewAuthorizer(db sqlutil.DataSource, user *sessions.User, ei *bridges.ExternalInitiator) Authorizer { +func NewAuthorizer(ds sqlutil.DataSource, user *sessions.User, ei *bridges.ExternalInitiator) Authorizer { if user != nil { return &alwaysAuthorizer{} } else if ei != nil { - return NewEIAuthorizer(db, *ei) + return NewEIAuthorizer(ds, *ei) } return &neverAuthorizer{} } type eiAuthorizer struct { - db sqlutil.DataSource + ds sqlutil.DataSource ei bridges.ExternalInitiator } -func NewEIAuthorizer(db sqlutil.DataSource, ei bridges.ExternalInitiator) *eiAuthorizer { - return &eiAuthorizer{db, ei} +func NewEIAuthorizer(ds sqlutil.DataSource, ei bridges.ExternalInitiator) *eiAuthorizer { + return &eiAuthorizer{ds, ei} } func (ea *eiAuthorizer) CanRun(ctx context.Context, config AuthorizerConfig, jobUUID uuid.UUID) (can bool, err error) { if !config.ExternalInitiatorsEnabled() { return false, nil } - row := ea.db.QueryRowxContext(ctx, ` + row := ea.ds.QueryRowxContext(ctx, ` SELECT EXISTS ( SELECT 1 FROM external_initiator_webhook_specs JOIN jobs ON external_initiator_webhook_specs.webhook_spec_id = jobs.webhook_spec_id diff --git a/core/services/webhook/authorizer_test.go b/core/services/webhook/authorizer_test.go index 82af7c6fcce..202791c26c2 100644 --- a/core/services/webhook/authorizer_test.go +++ b/core/services/webhook/authorizer_test.go @@ -8,9 +8,8 @@ import ( "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink/v2/core/bridges" - "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" - "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/services/webhook" "github.com/smartcontractkit/chainlink/v2/core/sessions" diff --git a/core/services/webhook/external_initiator_manager.go b/core/services/webhook/external_initiator_manager.go index 0c035abde7a..2f9a176906d 100644 --- a/core/services/webhook/external_initiator_manager.go +++ b/core/services/webhook/external_initiator_manager.go @@ -11,12 +11,9 @@ import ( "github.com/lib/pq" "github.com/pkg/errors" - "github.com/jmoiron/sqlx" - + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/bridges" - "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/static" "github.com/smartcontractkit/chainlink/v2/core/store/models" ) @@ -27,7 +24,7 @@ import ( type ExternalInitiatorManager interface { Notify(ctx context.Context, webhookSpecID int32) error DeleteJob(ctx context.Context, webhookSpecID int32) error - FindExternalInitiatorByName(name string) (bridges.ExternalInitiator, error) + FindExternalInitiatorByName(ctx context.Context, name string) (bridges.ExternalInitiator, error) } //go:generate mockery --quiet --name HTTPClient --output ./mocks/ --case=underscore @@ -36,25 +33,24 @@ type HTTPClient interface { } type externalInitiatorManager struct { - q pg.Q + ds sqlutil.DataSource httpclient HTTPClient } var _ ExternalInitiatorManager = (*externalInitiatorManager)(nil) // NewExternalInitiatorManager returns the concrete externalInitiatorManager -func NewExternalInitiatorManager(db *sqlx.DB, httpclient HTTPClient, lggr logger.Logger, cfg pg.QConfig) *externalInitiatorManager { - namedLogger := lggr.Named("ExternalInitiatorManager") +func NewExternalInitiatorManager(ds sqlutil.DataSource, httpclient HTTPClient) *externalInitiatorManager { return &externalInitiatorManager{ - q: pg.NewQ(db, namedLogger, cfg), + ds: ds, httpclient: httpclient, } } // Notify sends a POST notification to the External Initiator // responsible for initiating the Job Spec. -func (m externalInitiatorManager) Notify(ctx context.Context, webhookSpecID int32) error { - eiWebhookSpecs, jobID, err := m.Load(webhookSpecID) +func (m *externalInitiatorManager) Notify(ctx context.Context, webhookSpecID int32) error { + eiWebhookSpecs, jobID, err := m.Load(ctx, webhookSpecID) if err != nil { return err } @@ -90,19 +86,21 @@ func (m externalInitiatorManager) Notify(ctx context.Context, webhookSpecID int3 return nil } -func (m externalInitiatorManager) Load(webhookSpecID int32) (eiWebhookSpecs []job.ExternalInitiatorWebhookSpec, jobID uuid.UUID, err error) { - err = m.q.Transaction(func(tx pg.Queryer) error { - if err = tx.Get(&jobID, "SELECT external_job_id FROM jobs WHERE webhook_spec_id = $1", webhookSpecID); err != nil { +func (m *externalInitiatorManager) Load(ctx context.Context, webhookSpecID int32) (eiWebhookSpecs []job.ExternalInitiatorWebhookSpec, jobID uuid.UUID, err error) { + err = sqlutil.Transact(ctx, func(ds sqlutil.DataSource) *externalInitiatorManager { + return NewExternalInitiatorManager(ds, m.httpclient) + }, m.ds, nil, func(tx *externalInitiatorManager) error { + if err = tx.ds.GetContext(ctx, &jobID, "SELECT external_job_id FROM jobs WHERE webhook_spec_id = $1", webhookSpecID); err != nil { if err = errors.Wrapf(err, "failed to load job ID from job for webhook spec with ID %d", webhookSpecID); err != nil { return err } } - if err = tx.Select(&eiWebhookSpecs, "SELECT * FROM external_initiator_webhook_specs WHERE external_initiator_webhook_specs.webhook_spec_id = $1", webhookSpecID); err != nil { + if err = tx.ds.SelectContext(ctx, &eiWebhookSpecs, "SELECT * FROM external_initiator_webhook_specs WHERE external_initiator_webhook_specs.webhook_spec_id = $1", webhookSpecID); err != nil { if err = errors.Wrapf(err, "failed to load external_initiator_webhook_specs for webhook_spec_id %d", webhookSpecID); err != nil { return err } } - if err = m.eagerLoadExternalInitiator(tx, eiWebhookSpecs); err != nil { + if err = tx.eagerLoadExternalInitiator(ctx, eiWebhookSpecs); err != nil { if err = errors.Wrapf(err, "failed to preload ExternalInitiator for webhook_spec_id %d", webhookSpecID); err != nil { return err } @@ -113,7 +111,7 @@ func (m externalInitiatorManager) Load(webhookSpecID int32) (eiWebhookSpecs []jo return } -func (m externalInitiatorManager) eagerLoadExternalInitiator(q pg.Queryer, txs []job.ExternalInitiatorWebhookSpec) error { +func (m *externalInitiatorManager) eagerLoadExternalInitiator(ctx context.Context, txs []job.ExternalInitiatorWebhookSpec) error { var ids []int64 for _, tx := range txs { ids = append(ids, tx.ExternalInitiatorID) @@ -122,7 +120,7 @@ func (m externalInitiatorManager) eagerLoadExternalInitiator(q pg.Queryer, txs [ return nil } var externalInitiators []bridges.ExternalInitiator - if err := sqlx.Select(q, &externalInitiators, `SELECT * FROM external_initiators WHERE external_initiators.id = ANY($1);`, pq.Array(ids)); err != nil { + if err := m.ds.SelectContext(ctx, &externalInitiators, `SELECT * FROM external_initiators WHERE external_initiators.id = ANY($1);`, pq.Array(ids)); err != nil { return err } @@ -137,8 +135,8 @@ func (m externalInitiatorManager) eagerLoadExternalInitiator(q pg.Queryer, txs [ return nil } -func (m externalInitiatorManager) DeleteJob(ctx context.Context, webhookSpecID int32) error { - eiWebhookSpecs, jobID, err := m.Load(webhookSpecID) +func (m *externalInitiatorManager) DeleteJob(ctx context.Context, webhookSpecID int32) error { + eiWebhookSpecs, jobID, err := m.Load(ctx, webhookSpecID) if err != nil { return err } @@ -166,9 +164,9 @@ func (m externalInitiatorManager) DeleteJob(ctx context.Context, webhookSpecID i return nil } -func (m externalInitiatorManager) FindExternalInitiatorByName(name string) (bridges.ExternalInitiator, error) { +func (m *externalInitiatorManager) FindExternalInitiatorByName(ctx context.Context, name string) (bridges.ExternalInitiator, error) { var exi bridges.ExternalInitiator - err := m.q.Get(&exi, "SELECT * FROM external_initiators WHERE lower(external_initiators.name) = lower($1)", name) + err := m.ds.GetContext(ctx, &exi, "SELECT * FROM external_initiators WHERE lower(external_initiators.name) = lower($1)", name) return exi, err } @@ -211,6 +209,6 @@ var _ ExternalInitiatorManager = (*NullExternalInitiatorManager)(nil) func (NullExternalInitiatorManager) Notify(context.Context, int32) error { return nil } func (NullExternalInitiatorManager) DeleteJob(context.Context, int32) error { return nil } -func (NullExternalInitiatorManager) FindExternalInitiatorByName(name string) (bridges.ExternalInitiator, error) { +func (NullExternalInitiatorManager) FindExternalInitiatorByName(ctx context.Context, name string) (bridges.ExternalInitiator, error) { return bridges.ExternalInitiator{}, nil } diff --git a/core/services/webhook/external_initiator_manager_test.go b/core/services/webhook/external_initiator_manager_test.go index 22ab50513cf..a2402b4114d 100644 --- a/core/services/webhook/external_initiator_manager_test.go +++ b/core/services/webhook/external_initiator_manager_test.go @@ -17,15 +17,14 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" - "github.com/smartcontractkit/chainlink/v2/core/logger" _ "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/webhook" webhookmocks "github.com/smartcontractkit/chainlink/v2/core/services/webhook/mocks" ) func Test_ExternalInitiatorManager_Load(t *testing.T) { + ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) - cfg := pgtest.NewQConfig(true) borm := bridges.NewORM(db) eiFoo := cltest.MustInsertExternalInitiator(t, borm) @@ -39,21 +38,21 @@ func Test_ExternalInitiatorManager_Load(t *testing.T) { pgtest.MustExec(t, db, `INSERT INTO external_initiator_webhook_specs (external_initiator_id, webhook_spec_id, spec) VALUES ($1,$2,$3)`, eiBar.ID, webhookSpecTwoEIs.ID, `{"ei": "bar", "name": "webhookSpecTwoEIs"}`) pgtest.MustExec(t, db, `INSERT INTO external_initiator_webhook_specs (external_initiator_id, webhook_spec_id, spec) VALUES ($1,$2,$3)`, eiFoo.ID, webhookSpecOneEI.ID, `{"ei": "foo", "name": "webhookSpecOneEI"}`) - eim := webhook.NewExternalInitiatorManager(db, nil, logger.TestLogger(t), cfg) + eim := webhook.NewExternalInitiatorManager(db, nil) - eiWebhookSpecs, jobID, err := eim.Load(webhookSpecNoEIs.ID) + eiWebhookSpecs, jobID, err := eim.Load(ctx, webhookSpecNoEIs.ID) require.NoError(t, err) assert.Len(t, eiWebhookSpecs, 0) assert.Equal(t, jb3.ExternalJobID, jobID) - eiWebhookSpecs, jobID, err = eim.Load(webhookSpecOneEI.ID) + eiWebhookSpecs, jobID, err = eim.Load(ctx, webhookSpecOneEI.ID) require.NoError(t, err) assert.Len(t, eiWebhookSpecs, 1) assert.Equal(t, `{"ei": "foo", "name": "webhookSpecOneEI"}`, eiWebhookSpecs[0].Spec.Raw) assert.Equal(t, eiFoo.ID, eiWebhookSpecs[0].ExternalInitiator.ID) assert.Equal(t, jb1.ExternalJobID, jobID) - eiWebhookSpecs, jobID, err = eim.Load(webhookSpecTwoEIs.ID) + eiWebhookSpecs, jobID, err = eim.Load(ctx, webhookSpecTwoEIs.ID) require.NoError(t, err) assert.Len(t, eiWebhookSpecs, 2) assert.Equal(t, jb2.ExternalJobID, jobID) @@ -62,7 +61,6 @@ func Test_ExternalInitiatorManager_Load(t *testing.T) { func Test_ExternalInitiatorManager_Notify(t *testing.T) { ctx := tests.Context(t) db := pgtest.NewSqlxDB(t) - cfg := pgtest.NewQConfig(true) borm := bridges.NewORM(db) eiWithURL := cltest.MustInsertExternalInitiatorWithOpts(t, borm, cltest.ExternalInitiatorOpts{ @@ -79,7 +77,7 @@ func Test_ExternalInitiatorManager_Notify(t *testing.T) { pgtest.MustExec(t, db, `INSERT INTO external_initiator_webhook_specs (external_initiator_id, webhook_spec_id, spec) VALUES ($1,$2,$3)`, eiNoURL.ID, webhookSpecTwoEIs.ID, `{"ei": "bar", "name": "webhookSpecTwoEIs"}`) client := webhookmocks.NewHTTPClient(t) - eim := webhook.NewExternalInitiatorManager(db, client, logger.TestLogger(t), cfg) + eim := webhook.NewExternalInitiatorManager(db, client) // Does nothing with no EI require.NoError(t, eim.Notify(ctx, webhookSpecNoEIs.ID)) @@ -102,7 +100,6 @@ func Test_ExternalInitiatorManager_Notify(t *testing.T) { func Test_ExternalInitiatorManager_DeleteJob(t *testing.T) { ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) - cfg := pgtest.NewQConfig(true) borm := bridges.NewORM(db) eiWithURL := cltest.MustInsertExternalInitiatorWithOpts(t, borm, cltest.ExternalInitiatorOpts{ @@ -119,7 +116,7 @@ func Test_ExternalInitiatorManager_DeleteJob(t *testing.T) { pgtest.MustExec(t, db, `INSERT INTO external_initiator_webhook_specs (external_initiator_id, webhook_spec_id, spec) VALUES ($1,$2,$3)`, eiNoURL.ID, webhookSpecTwoEIs.ID, `{"ei": "bar", "name": "webhookSpecTwoEIs"}`) client := webhookmocks.NewHTTPClient(t) - eim := webhook.NewExternalInitiatorManager(db, client, logger.TestLogger(t), cfg) + eim := webhook.NewExternalInitiatorManager(db, client) // Does nothing with no EI require.NoError(t, eim.DeleteJob(ctx, webhookSpecNoEIs.ID)) diff --git a/core/services/webhook/mocks/external_initiator_manager.go b/core/services/webhook/mocks/external_initiator_manager.go index 9711ae686ea..7a3ee29f62f 100644 --- a/core/services/webhook/mocks/external_initiator_manager.go +++ b/core/services/webhook/mocks/external_initiator_manager.go @@ -33,9 +33,9 @@ func (_m *ExternalInitiatorManager) DeleteJob(ctx context.Context, webhookSpecID return r0 } -// FindExternalInitiatorByName provides a mock function with given fields: name -func (_m *ExternalInitiatorManager) FindExternalInitiatorByName(name string) (bridges.ExternalInitiator, error) { - ret := _m.Called(name) +// FindExternalInitiatorByName provides a mock function with given fields: ctx, name +func (_m *ExternalInitiatorManager) FindExternalInitiatorByName(ctx context.Context, name string) (bridges.ExternalInitiator, error) { + ret := _m.Called(ctx, name) if len(ret) == 0 { panic("no return value specified for FindExternalInitiatorByName") @@ -43,17 +43,17 @@ func (_m *ExternalInitiatorManager) FindExternalInitiatorByName(name string) (br var r0 bridges.ExternalInitiator var r1 error - if rf, ok := ret.Get(0).(func(string) (bridges.ExternalInitiator, error)); ok { - return rf(name) + if rf, ok := ret.Get(0).(func(context.Context, string) (bridges.ExternalInitiator, error)); ok { + return rf(ctx, name) } - if rf, ok := ret.Get(0).(func(string) bridges.ExternalInitiator); ok { - r0 = rf(name) + if rf, ok := ret.Get(0).(func(context.Context, string) bridges.ExternalInitiator); ok { + r0 = rf(ctx, name) } else { r0 = ret.Get(0).(bridges.ExternalInitiator) } - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(name) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, name) } else { r1 = ret.Error(1) } diff --git a/core/services/webhook/validate.go b/core/services/webhook/validate.go index 6066a863ff7..a6d6aa72e83 100644 --- a/core/services/webhook/validate.go +++ b/core/services/webhook/validate.go @@ -1,6 +1,8 @@ package webhook import ( + "context" + "github.com/pelletier/go-toml" "github.com/pkg/errors" "go.uber.org/multierr" @@ -18,7 +20,7 @@ type TOMLWebhookSpec struct { ExternalInitiators []TOMLWebhookSpecExternalInitiator `toml:"externalInitiators"` } -func ValidatedWebhookSpec(tomlString string, externalInitiatorManager ExternalInitiatorManager) (jb job.Job, err error) { +func ValidatedWebhookSpec(ctx context.Context, tomlString string, externalInitiatorManager ExternalInitiatorManager) (jb job.Job, err error) { var tree *toml.Tree tree, err = toml.Load(tomlString) if err != nil { @@ -40,7 +42,7 @@ func ValidatedWebhookSpec(tomlString string, externalInitiatorManager ExternalIn var externalInitiatorWebhookSpecs []job.ExternalInitiatorWebhookSpec for _, eiSpec := range tomlSpec.ExternalInitiators { - ei, findErr := externalInitiatorManager.FindExternalInitiatorByName(eiSpec.Name) + ei, findErr := externalInitiatorManager.FindExternalInitiatorByName(ctx, eiSpec.Name) if findErr != nil { err = multierr.Combine(err, errors.Wrapf(findErr, "unable to find external initiator named %s", eiSpec.Name)) continue diff --git a/core/services/webhook/validate_test.go b/core/services/webhook/validate_test.go index f6993f6aefe..8d398fc0b68 100644 --- a/core/services/webhook/validate_test.go +++ b/core/services/webhook/validate_test.go @@ -6,9 +6,11 @@ import ( "github.com/manyminds/api2go/jsonapi" "github.com/pkg/errors" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink/v2/core/bridges" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/webhook" webhookmocks "github.com/smartcontractkit/chainlink/v2/core/services/webhook/mocks" @@ -96,8 +98,8 @@ func TestValidatedWebJobSpec(t *testing.T) { """ `, mock: func(t *testing.T, eim *webhookmocks.ExternalInitiatorManager) { - eim.On("FindExternalInitiatorByName", "foo").Return(bridges.ExternalInitiator{ID: 42}, nil).Once() - eim.On("FindExternalInitiatorByName", "bar").Return(bridges.ExternalInitiator{ID: 43}, nil).Once() + eim.On("FindExternalInitiatorByName", mock.Anything, "foo").Return(bridges.ExternalInitiator{ID: 42}, nil).Once() + eim.On("FindExternalInitiatorByName", mock.Anything, "bar").Return(bridges.ExternalInitiator{ID: 43}, nil).Once() }, assertion: func(t *testing.T, s job.Job, err error) { require.NoError(t, err) @@ -134,9 +136,9 @@ func TestValidatedWebJobSpec(t *testing.T) { """ `, mock: func(t *testing.T, eim *webhookmocks.ExternalInitiatorManager) { - eim.On("FindExternalInitiatorByName", "foo").Return(bridges.ExternalInitiator{ID: 42}, nil).Once() - eim.On("FindExternalInitiatorByName", "bar").Return(bridges.ExternalInitiator{}, errors.New("something exploded")).Once() - eim.On("FindExternalInitiatorByName", "baz").Return(bridges.ExternalInitiator{}, errors.New("something exploded")).Once() + eim.On("FindExternalInitiatorByName", mock.Anything, "foo").Return(bridges.ExternalInitiator{ID: 42}, nil).Once() + eim.On("FindExternalInitiatorByName", mock.Anything, "bar").Return(bridges.ExternalInitiator{}, errors.New("something exploded")).Once() + eim.On("FindExternalInitiatorByName", mock.Anything, "baz").Return(bridges.ExternalInitiator{}, errors.New("something exploded")).Once() }, assertion: func(t *testing.T, s job.Job, err error) { require.EqualError(t, err, "unable to find external initiator named bar: something exploded; unable to find external initiator named baz: something exploded") @@ -147,11 +149,12 @@ func TestValidatedWebJobSpec(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) eim := new(webhookmocks.ExternalInitiatorManager) if tc.mock != nil { tc.mock(t, eim) } - s, err := webhook.ValidatedWebhookSpec(tc.toml, eim) + s, err := webhook.ValidatedWebhookSpec(ctx, tc.toml, eim) tc.assertion(t, s, err) }) } diff --git a/core/sessions/ldapauth/sync.go b/core/sessions/ldapauth/sync.go index 74c606a9684..a6e0366e21d 100644 --- a/core/sessions/ldapauth/sync.go +++ b/core/sessions/ldapauth/sync.go @@ -1,23 +1,23 @@ package ldapauth import ( + "context" "errors" "fmt" "time" "github.com/go-ldap/ldap/v3" - "github.com/jmoiron/sqlx" "github.com/lib/pq" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink/v2/core/config" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/sessions" ) type LDAPServerStateSyncer struct { - q pg.Q + ds sqlutil.DataSource ldapClient LDAPClient config config.LDAP lggr logger.Logger @@ -26,14 +26,13 @@ type LDAPServerStateSyncer struct { // NewLDAPServerStateSync creates a reaper that cleans stale sessions from the store. func NewLDAPServerStateSync( - db *sqlx.DB, - pgCfg pg.QConfig, + ds sqlutil.DataSource, config config.LDAP, lggr logger.Logger, ) *utils.SleeperTask { namedLogger := lggr.Named("LDAPServerStateSync") serverSync := LDAPServerStateSyncer{ - q: pg.NewQ(db, namedLogger, pgCfg), + ds: ds, ldapClient: newLDAPClient(config), config: config, lggr: namedLogger, @@ -65,14 +64,15 @@ func (ldSync *LDAPServerStateSyncer) StartWorkOnTimer() { } func (ldSync *LDAPServerStateSyncer) Work() { + ctx := context.Background() // TODO https://smartcontract-it.atlassian.net/browse/BCF-2887 // Purge expired ldap_sessions and ldap_user_api_tokens recordCreationStaleThreshold := ldSync.config.SessionTimeout().Before(time.Now()) - err := ldSync.deleteStaleSessions(recordCreationStaleThreshold) + err := ldSync.deleteStaleSessions(ctx, recordCreationStaleThreshold) if err != nil { ldSync.lggr.Error("unable to expire local LDAP sessions: ", err) } recordCreationStaleThreshold = ldSync.config.UserAPITokenDuration().Before(time.Now()) - err = ldSync.deleteStaleAPITokens(recordCreationStaleThreshold) + err = ldSync.deleteStaleAPITokens(ctx, recordCreationStaleThreshold) if err != nil { ldSync.lggr.Error("unable to expire user API tokens: ", err) } @@ -160,18 +160,18 @@ func (ldSync *LDAPServerStateSyncer) Work() { // upstreamUserStateMap is now the most up to date source of truth // Now sync database sessions and roles with new data - err = ldSync.q.Transaction(func(tx pg.Queryer) error { + err = sqlutil.TransactDataSource(ctx, ldSync.ds, nil, func(tx sqlutil.DataSource) error { // First, purge users present in the local ldap_sessions table but not in the upstream server type LDAPSession struct { UserEmail string UserRole sessions.UserRole } var existingSessions []LDAPSession - if err = tx.Select(&existingSessions, "SELECT user_email, user_role FROM ldap_sessions WHERE localauth_user = false"); err != nil { + if err = tx.SelectContext(ctx, &existingSessions, "SELECT user_email, user_role FROM ldap_sessions WHERE localauth_user = false"); err != nil { return fmt.Errorf("unable to query ldap_sessions table: %w", err) } var existingAPITokens []LDAPSession - if err = tx.Select(&existingAPITokens, "SELECT user_email, user_role FROM ldap_user_api_tokens WHERE localauth_user = false"); err != nil { + if err = tx.SelectContext(ctx, &existingAPITokens, "SELECT user_email, user_role FROM ldap_user_api_tokens WHERE localauth_user = false"); err != nil { return fmt.Errorf("unable to query ldap_user_api_tokens table: %w", err) } @@ -202,7 +202,7 @@ func (ldSync *LDAPServerStateSyncer) Work() { // Remove any active sessions this user may have if len(emailsToPurge) > 0 { - _, err = ldSync.q.Exec("DELETE FROM ldap_sessions WHERE user_email = ANY($1)", pq.Array(emailsToPurge)) + _, err = tx.ExecContext(ctx, "DELETE FROM ldap_sessions WHERE user_email = ANY($1)", pq.Array(emailsToPurge)) if err != nil { return err } @@ -210,7 +210,7 @@ func (ldSync *LDAPServerStateSyncer) Work() { // Remove any active API tokens this user may have if len(apiTokenEmailsToPurge) > 0 { - _, err = ldSync.q.Exec("DELETE FROM ldap_user_api_tokens WHERE user_email = ANY($1)", pq.Array(apiTokenEmailsToPurge)) + _, err = tx.ExecContext(ctx, "DELETE FROM ldap_user_api_tokens WHERE user_email = ANY($1)", pq.Array(apiTokenEmailsToPurge)) if err != nil { return err } @@ -235,14 +235,14 @@ func (ldSync *LDAPServerStateSyncer) Work() { if len(emailValues) != 0 { // Set new role state for all rows in single Exec query := fmt.Sprintf("UPDATE ldap_sessions SET user_role = CASE %s ELSE user_role END", queryWhenClause) - _, err = ldSync.q.Exec(query, emailValues...) + _, err = tx.ExecContext(ctx, query, emailValues...) if err != nil { return err } // Update role of API tokens as well query = fmt.Sprintf("UPDATE ldap_user_api_tokens SET user_role = CASE %s ELSE user_role END", queryWhenClause) - _, err = ldSync.q.Exec(query, emailValues...) + _, err = tx.ExecContext(ctx, query, emailValues...) if err != nil { return err } @@ -258,14 +258,14 @@ func (ldSync *LDAPServerStateSyncer) Work() { } // deleteStaleSessions deletes all ldap_sessions before the passed time. -func (ldSync *LDAPServerStateSyncer) deleteStaleSessions(before time.Time) error { - _, err := ldSync.q.Exec("DELETE FROM ldap_sessions WHERE created_at < $1", before) +func (ldSync *LDAPServerStateSyncer) deleteStaleSessions(ctx context.Context, before time.Time) error { + _, err := ldSync.ds.ExecContext(ctx, "DELETE FROM ldap_sessions WHERE created_at < $1", before) return err } // deleteStaleAPITokens deletes all ldap_user_api_tokens before the passed time. -func (ldSync *LDAPServerStateSyncer) deleteStaleAPITokens(before time.Time) error { - _, err := ldSync.q.Exec("DELETE FROM ldap_user_api_tokens WHERE created_at < $1", before) +func (ldSync *LDAPServerStateSyncer) deleteStaleAPITokens(ctx context.Context, before time.Time) error { + _, err := ldSync.ds.ExecContext(ctx, "DELETE FROM ldap_user_api_tokens WHERE created_at < $1", before) return err } diff --git a/core/sessions/localauth/reaper.go b/core/sessions/localauth/reaper.go index eef884367aa..7b91e4ce2c0 100644 --- a/core/sessions/localauth/reaper.go +++ b/core/sessions/localauth/reaper.go @@ -1,16 +1,17 @@ package localauth import ( - "database/sql" + "context" "time" commonconfig "github.com/smartcontractkit/chainlink-common/pkg/config" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink/v2/core/logger" ) type sessionReaper struct { - db *sql.DB + ds sqlutil.DataSource config SessionReaperConfig lggr logger.Logger } @@ -21,9 +22,9 @@ type SessionReaperConfig interface { } // NewSessionReaper creates a reaper that cleans stale sessions from the store. -func NewSessionReaper(db *sql.DB, config SessionReaperConfig, lggr logger.Logger) *utils.SleeperTask { +func NewSessionReaper(ds sqlutil.DataSource, config SessionReaperConfig, lggr logger.Logger) *utils.SleeperTask { return utils.NewSleeperTask(&sessionReaper{ - db, + ds, config, lggr.Named("SessionReaper"), }) @@ -34,16 +35,17 @@ func (sr *sessionReaper) Name() string { } func (sr *sessionReaper) Work() { + ctx := context.Background() //TODO https://smartcontract-it.atlassian.net/browse/BCF-2887 recordCreationStaleThreshold := sr.config.SessionReaperExpiration().Before( sr.config.SessionTimeout().Before(time.Now())) - err := sr.deleteStaleSessions(recordCreationStaleThreshold) + err := sr.deleteStaleSessions(ctx, recordCreationStaleThreshold) if err != nil { sr.lggr.Error("unable to reap stale sessions: ", err) } } // DeleteStaleSessions deletes all sessions before the passed time. -func (sr *sessionReaper) deleteStaleSessions(before time.Time) error { - _, err := sr.db.Exec("DELETE FROM sessions WHERE last_used < $1", before) +func (sr *sessionReaper) deleteStaleSessions(ctx context.Context, before time.Time) error { + _, err := sr.ds.ExecContext(ctx, "DELETE FROM sessions WHERE last_used < $1", before) return err } diff --git a/core/sessions/localauth/reaper_test.go b/core/sessions/localauth/reaper_test.go index 47413c5fc62..806c9448682 100644 --- a/core/sessions/localauth/reaper_test.go +++ b/core/sessions/localauth/reaper_test.go @@ -37,7 +37,7 @@ func TestSessionReaper_ReapSessions(t *testing.T) { lggr := logger.TestLogger(t) orm := localauth.NewORM(db, config.SessionTimeout().Duration(), lggr, audit.NoopLogger) - r := localauth.NewSessionReaper(db.DB, config, lggr) + r := localauth.NewSessionReaper(db, config, lggr) t.Cleanup(func() { assert.NoError(t, r.Stop()) }) diff --git a/core/store/migrate/migrate.go b/core/store/migrate/migrate.go index aff3229e92a..c8f70e87383 100644 --- a/core/store/migrate/migrate.go +++ b/core/store/migrate/migrate.go @@ -9,13 +9,12 @@ import ( "strconv" "strings" - "github.com/jmoiron/sqlx" pkgerrors "github.com/pkg/errors" "github.com/pressly/goose/v3" "gopkg.in/guregu/null.v4" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/config/env" - "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/chainlink" "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/store/migrate/migrations" // Invoke init() functions within migrations pkg. @@ -36,7 +35,7 @@ func init() { } // Ensure we migrated from v1 migrations to goose_migrations -func ensureMigrated(ctx context.Context, db *sql.DB, lggr logger.Logger) error { +func ensureMigrated(ctx context.Context, db *sql.DB) error { sqlxDB := pg.WrapDbWithSqlx(db) var names []string err := sqlxDB.SelectContext(ctx, &names, `SELECT id FROM migrations`) @@ -44,12 +43,10 @@ func ensureMigrated(ctx context.Context, db *sql.DB, lggr logger.Logger) error { // already migrated return nil } - err = pg.SqlTransaction(ctx, db, lggr, func(tx *sqlx.Tx) error { - // ensure that no legacy job specs are present: we _must_ bail out early if - // so because otherwise we run the risk of dropping working jobs if the - // user has not read the release notes - return migrations.CheckNoLegacyJobs(tx.Tx) - }) + // ensure that no legacy job specs are present: we _must_ bail out early if + // so because otherwise we run the risk of dropping working jobs if the + // user has not read the release notes + err = migrations.CheckNoLegacyJobs(ctx, db) if err != nil { return err } @@ -66,14 +63,14 @@ func ensureMigrated(ctx context.Context, db *sql.DB, lggr logger.Logger) error { } // ensure a goose migrations table exists with it's initial v0 - if _, err = goose.GetDBVersion(db); err != nil { + if _, err = goose.GetDBVersionContext(ctx, db); err != nil { return err } // insert records for existing migrations //nolint sql := fmt.Sprintf(`INSERT INTO %s (version_id, is_applied) VALUES ($1, true);`, goose.TableName()) - return pg.SqlTransaction(ctx, db, lggr, func(tx *sqlx.Tx) error { + return sqlutil.TransactDataSource(ctx, sqlxDB, nil, func(tx sqlutil.DataSource) error { for _, name := range names { var id int64 // the first migration doesn't follow the naming convention @@ -92,18 +89,18 @@ func ensureMigrated(ctx context.Context, db *sql.DB, lggr logger.Logger) error { } } - if _, err = db.Exec(sql, id); err != nil { + if _, err = tx.ExecContext(ctx, sql, id); err != nil { return err } } - _, err = db.Exec("DROP TABLE migrations;") + _, err = tx.ExecContext(ctx, "DROP TABLE migrations;") return err }) } -func Migrate(ctx context.Context, db *sql.DB, lggr logger.Logger) error { - if err := ensureMigrated(ctx, db, lggr); err != nil { +func Migrate(ctx context.Context, db *sql.DB) error { + if err := ensureMigrated(ctx, db); err != nil { return err } // WithAllowMissing is necessary when upgrading from 0.10.14 since it @@ -111,8 +108,8 @@ func Migrate(ctx context.Context, db *sql.DB, lggr logger.Logger) error { return goose.Up(db, MIGRATIONS_DIR, goose.WithAllowMissing()) } -func Rollback(ctx context.Context, db *sql.DB, lggr logger.Logger, version null.Int) error { - if err := ensureMigrated(ctx, db, lggr); err != nil { +func Rollback(ctx context.Context, db *sql.DB, version null.Int) error { + if err := ensureMigrated(ctx, db); err != nil { return err } if version.Valid { @@ -121,15 +118,15 @@ func Rollback(ctx context.Context, db *sql.DB, lggr logger.Logger, version null. return goose.Down(db, MIGRATIONS_DIR) } -func Current(ctx context.Context, db *sql.DB, lggr logger.Logger) (int64, error) { - if err := ensureMigrated(ctx, db, lggr); err != nil { +func Current(ctx context.Context, db *sql.DB) (int64, error) { + if err := ensureMigrated(ctx, db); err != nil { return -1, err } return goose.EnsureDBVersion(db) } -func Status(ctx context.Context, db *sql.DB, lggr logger.Logger) error { - if err := ensureMigrated(ctx, db, lggr); err != nil { +func Status(ctx context.Context, db *sql.DB) error { + if err := ensureMigrated(ctx, db); err != nil { return err } return goose.Status(db, MIGRATIONS_DIR) diff --git a/core/store/migrate/migrate_test.go b/core/store/migrate/migrate_test.go index 3c0c2dc2158..8169368eb1f 100644 --- a/core/store/migrate/migrate_test.go +++ b/core/store/migrate/migrate_test.go @@ -21,11 +21,9 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/cltest/heavyweight" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" - "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/chainlink" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/store/migrate" "github.com/smartcontractkit/chainlink/v2/core/store/models" @@ -79,13 +77,13 @@ func TestMigrate_0100_BootstrapConfigs(t *testing.T) { pipelineORM := pipeline.NewORM(db, lggr, cfg.JobPipeline().MaxSuccessfulRuns()) ctx := testutils.Context(t) - pipelineID, err := pipelineORM.CreateSpec(ctx, nil, pipeline.Pipeline{}, 0) + pipelineID, err := pipelineORM.CreateSpec(ctx, pipeline.Pipeline{}, 0) require.NoError(t, err) - pipelineID2, err := pipelineORM.CreateSpec(ctx, nil, pipeline.Pipeline{}, 0) + pipelineID2, err := pipelineORM.CreateSpec(ctx, pipeline.Pipeline{}, 0) require.NoError(t, err) - nonBootstrapPipelineID, err := pipelineORM.CreateSpec(ctx, nil, pipeline.Pipeline{}, 0) + nonBootstrapPipelineID, err := pipelineORM.CreateSpec(ctx, pipeline.Pipeline{}, 0) require.NoError(t, err) - newFormatBoostrapPipelineID2, err := pipelineORM.CreateSpec(ctx, nil, pipeline.Pipeline{}, 0) + newFormatBoostrapPipelineID2, err := pipelineORM.CreateSpec(ctx, pipeline.Pipeline{}, 0) require.NoError(t, err) // OCR2 struct at migration v0099 @@ -392,25 +390,24 @@ func TestMigrate_101_GenericOCR2(t *testing.T) { func TestMigrate(t *testing.T) { ctx := testutils.Context(t) - lggr := logger.TestLogger(t) _, db := heavyweight.FullTestDBEmptyV2(t, nil) err := goose.UpTo(db.DB, migrationDir, 100) require.NoError(t, err) - err = migrate.Status(ctx, db.DB, lggr) + err = migrate.Status(ctx, db.DB) require.NoError(t, err) - ver, err := migrate.Current(ctx, db.DB, lggr) + ver, err := migrate.Current(ctx, db.DB) require.NoError(t, err) require.Equal(t, int64(100), ver) - err = migrate.Migrate(ctx, db.DB, lggr) + err = migrate.Migrate(ctx, db.DB) require.NoError(t, err) - err = migrate.Rollback(ctx, db.DB, lggr, null.IntFrom(99)) + err = migrate.Rollback(ctx, db.DB, null.IntFrom(99)) require.NoError(t, err) - ver, err = migrate.Current(ctx, db.DB, lggr) + ver, err = migrate.Current(ctx, db.DB) require.NoError(t, err) require.Equal(t, int64(99), ver) } @@ -543,6 +540,7 @@ func TestNoTriggers(t *testing.T) { } func BenchmarkBackfillingRecordsWithMigration202(b *testing.B) { + ctx := testutils.Context(b) previousMigration := int64(201) backfillMigration := int64(202) chainCount := 2 @@ -555,7 +553,6 @@ func BenchmarkBackfillingRecordsWithMigration202(b *testing.B) { err := goose.UpTo(db.DB, migrationDir, previousMigration) require.NoError(b, err) - q := pg.NewQ(db, logger.NullLogger, pgtest.NewQConfig(true)) for j := 0; j < chainCount; j++ { // Insert 100_000 block to database, can't do all at once, so batching by 10k var blocks []logpoller.LogPollerBlock @@ -574,7 +571,7 @@ func BenchmarkBackfillingRecordsWithMigration202(b *testing.B) { end = maxLogsSize } - err = q.ExecQNamed(` + _, err = db.NamedExecContext(ctx, ` INSERT INTO evm.log_poller_blocks (evm_chain_id, block_hash, block_number, finalized_block_number, block_timestamp, created_at) VALUES @@ -600,7 +597,7 @@ func BenchmarkBackfillingRecordsWithMigration202(b *testing.B) { err = goose.DownTo(db.DB, migrationDir, previousMigration) require.NoError(b, err) - err = q.ExecQ(` + _, err = db.ExecContext(ctx, ` UPDATE evm.log_poller_blocks SET finalized_block_number = 0`) require.NoError(b, err) diff --git a/core/store/migrate/migrations/0054_remove_legacy_pipeline.go b/core/store/migrate/migrations/0054_remove_legacy_pipeline.go index 924d32308b8..b5ddcccd89b 100644 --- a/core/store/migrate/migrations/0054_remove_legacy_pipeline.go +++ b/core/store/migrate/migrations/0054_remove_legacy_pipeline.go @@ -32,9 +32,13 @@ func init() { goose.AddMigrationContext(Up54, Down54) } +type queryer interface { + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} + // nolint func Up54(ctx context.Context, tx *sql.Tx) error { - if err := CheckNoLegacyJobs(tx); err != nil { + if err := CheckNoLegacyJobs(ctx, tx); err != nil { return err } if _, err := tx.ExecContext(ctx, up54); err != nil { @@ -49,9 +53,9 @@ func Down54(ctx context.Context, tx *sql.Tx) error { } // CheckNoLegacyJobs ensures that there are no legacy job specs -func CheckNoLegacyJobs(tx *sql.Tx) error { +func CheckNoLegacyJobs(ctx context.Context, ds queryer) error { var count int - if err := tx.QueryRow(`SELECT COUNT(*) FROM job_specs WHERE deleted_at IS NULL`).Scan(&count); err != nil { + if err := ds.QueryRowContext(ctx, `SELECT COUNT(*) FROM job_specs WHERE deleted_at IS NULL`).Scan(&count); err != nil { return err } if count > 0 { diff --git a/core/web/bridge_types_controller.go b/core/web/bridge_types_controller.go index 1c9fd7365f7..c787ca2fb54 100644 --- a/core/web/bridge_types_controller.go +++ b/core/web/bridge_types_controller.go @@ -213,7 +213,7 @@ func (btc *BridgeTypesController) Destroy(c *gin.Context) { jsonAPIError(c, http.StatusInternalServerError, fmt.Errorf("error searching for bridge: %+v", err)) return } - jobsUsingBridge, err := btc.App.JobORM().FindJobIDsWithBridge(name) + jobsUsingBridge, err := btc.App.JobORM().FindJobIDsWithBridge(ctx, name) if err != nil { jsonAPIError(c, http.StatusInternalServerError, fmt.Errorf("error searching for associated v2 jobs: %+v", err)) return diff --git a/core/web/eth_keys_controller_test.go b/core/web/eth_keys_controller_test.go index 09a3eb5fa1e..34cde6f6a64 100644 --- a/core/web/eth_keys_controller_test.go +++ b/core/web/eth_keys_controller_test.go @@ -425,8 +425,7 @@ func TestETHKeysController_ChainSuccess_ResetWithAbandon(t *testing.T) { }) assert.NoError(t, err) - db := app.GetSqlxDB() - txStore := txmgr.NewTxStore(db, logger.TestLogger(t)) + txStore := txmgr.NewTxStore(app.GetDB(), logger.TestLogger(t)) txes, err := txStore.FindTxesByFromAddressAndState(testutils.Context(t), addr, "fatal_error") require.NoError(t, err) diff --git a/core/web/evm_transactions_controller_test.go b/core/web/evm_transactions_controller_test.go index a4dd21c9f03..3eb667bc6f8 100644 --- a/core/web/evm_transactions_controller_test.go +++ b/core/web/evm_transactions_controller_test.go @@ -25,8 +25,8 @@ func TestTransactionsController_Index_Success(t *testing.T) { ctx := testutils.Context(t) require.NoError(t, app.Start(ctx)) - db := app.GetSqlxDB() - txStore := cltest.NewTestTxStore(t, app.GetSqlxDB()) + db := app.GetDB() + txStore := cltest.NewTestTxStore(t, app.GetDB()) ethKeyStore := cltest.NewKeyStore(t, db).Eth() client := app.NewHTTPClient(nil) _, from := cltest.MustInsertRandomKey(t, ethKeyStore) @@ -84,7 +84,7 @@ func TestTransactionsController_Show_Success(t *testing.T) { ctx := testutils.Context(t) require.NoError(t, app.Start(ctx)) - txStore := cltest.NewTestTxStore(t, app.GetSqlxDB()) + txStore := cltest.NewTestTxStore(t, app.GetDB()) client := app.NewHTTPClient(nil) _, from := cltest.MustInsertRandomKey(t, app.KeyStore.Eth()) @@ -118,7 +118,7 @@ func TestTransactionsController_Show_NotFound(t *testing.T) { ctx := testutils.Context(t) require.NoError(t, app.Start(ctx)) - txStore := cltest.NewTestTxStore(t, app.GetSqlxDB()) + txStore := cltest.NewTestTxStore(t, app.GetDB()) client := app.NewHTTPClient(nil) _, from := cltest.MustInsertRandomKey(t, app.KeyStore.Eth()) tx := cltest.MustInsertUnconfirmedEthTxWithBroadcastLegacyAttempt(t, txStore, 1, from) diff --git a/core/web/evm_transfer_controller_test.go b/core/web/evm_transfer_controller_test.go index bfac6752f51..fff30a7df31 100644 --- a/core/web/evm_transfer_controller_test.go +++ b/core/web/evm_transfer_controller_test.go @@ -10,9 +10,8 @@ import ( "testing" "time" - "github.com/jmoiron/sqlx" - commonconfig "github.com/smartcontractkit/chainlink-common/pkg/config" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" @@ -72,7 +71,7 @@ func TestTransfersController_CreateSuccess_From(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Len(t, errors.Errors, 0) - validateTxCount(t, app.GetSqlxDB(), 1) + validateTxCount(t, app.GetDB(), 1) } func TestTransfersController_CreateSuccess_From_WEI(t *testing.T) { @@ -113,7 +112,7 @@ func TestTransfersController_CreateSuccess_From_WEI(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Len(t, errors.Errors, 0) - validateTxCount(t, app.GetSqlxDB(), 1) + validateTxCount(t, app.GetDB(), 1) } func TestTransfersController_CreateSuccess_From_BalanceMonitorDisabled(t *testing.T) { @@ -159,7 +158,7 @@ func TestTransfersController_CreateSuccess_From_BalanceMonitorDisabled(t *testin assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Len(t, errors.Errors, 0) - validateTxCount(t, app.GetSqlxDB(), 1) + validateTxCount(t, app.GetDB(), 1) } func TestTransfersController_TransferZeroAddressError(t *testing.T) { @@ -327,7 +326,7 @@ func TestTransfersController_CreateSuccess_eip1559(t *testing.T) { err = web.ParseJSONAPIResponse(cltest.ParseResponseBody(t, resp), &resource) assert.NoError(t, err) - validateTxCount(t, app.GetSqlxDB(), 1) + validateTxCount(t, app.GetDB(), 1) // check returned data assert.NotEmpty(t, resource.Hash) @@ -398,8 +397,8 @@ func TestTransfersController_FindTxAttempt(t *testing.T) { }) } -func validateTxCount(t *testing.T, db *sqlx.DB, count int) { - txStore := txmgr.NewTxStore(db, logger.TestLogger(t)) +func validateTxCount(t *testing.T, ds sqlutil.DataSource, count int) { + txStore := txmgr.NewTxStore(ds, logger.TestLogger(t)) txes, err := txStore.GetAllTxes(testutils.Context(t)) require.NoError(t, err) diff --git a/core/web/evm_tx_attempts_controller_test.go b/core/web/evm_tx_attempts_controller_test.go index a92c8293a3f..f277f1f37bf 100644 --- a/core/web/evm_tx_attempts_controller_test.go +++ b/core/web/evm_tx_attempts_controller_test.go @@ -20,7 +20,7 @@ func TestTxAttemptsController_Index_Success(t *testing.T) { app := cltest.NewApplicationWithKey(t) require.NoError(t, app.Start(testutils.Context(t))) - txStore := cltest.NewTestTxStore(t, app.GetSqlxDB()) + txStore := cltest.NewTestTxStore(t, app.GetDB()) client := app.NewHTTPClient(nil) _, from := cltest.MustInsertRandomKey(t, app.KeyStore.Eth()) diff --git a/core/web/external_initiators_controller_test.go b/core/web/external_initiators_controller_test.go index 3c1425bcfdd..a79909c5864 100644 --- a/core/web/external_initiators_controller_test.go +++ b/core/web/external_initiators_controller_test.go @@ -75,7 +75,7 @@ func TestExternalInitiatorsController_Index(t *testing.T) { client := app.NewHTTPClient(nil) - db := app.GetSqlxDB() + db := app.GetDB() borm := bridges.NewORM(db) eiFoo := cltest.MustInsertExternalInitiatorWithOpts(t, borm, cltest.ExternalInitiatorOpts{ diff --git a/core/web/jobs_controller.go b/core/web/jobs_controller.go index 0808422cca7..5ca00476007 100644 --- a/core/web/jobs_controller.go +++ b/core/web/jobs_controller.go @@ -26,7 +26,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/ocr" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/validate" "github.com/smartcontractkit/chainlink/v2/core/services/ocrbootstrap" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/streams" "github.com/smartcontractkit/chainlink/v2/core/services/vrf/vrfcommon" "github.com/smartcontractkit/chainlink/v2/core/services/webhook" @@ -48,7 +47,7 @@ func (jc *JobsController) Index(c *gin.Context, size, page, offset int) { size = 1000 } - jobs, count, err := jc.App.JobORM().FindJobs(offset, size) + jobs, count, err := jc.App.JobORM().FindJobs(c.Request.Context(), offset, size) if err != nil { jsonAPIError(c, http.StatusInternalServerError, err) return @@ -66,14 +65,15 @@ func (jc *JobsController) Index(c *gin.Context, size, page, offset int) { // Example: // "GET /jobs/:ID" func (jc *JobsController) Show(c *gin.Context) { + ctx := c.Request.Context() var err error jobSpec := job.Job{} if externalJobID, pErr := uuid.Parse(c.Param("ID")); pErr == nil { // Find a job by external job ID - jobSpec, err = jc.App.JobORM().FindJobByExternalJobID(externalJobID, pg.WithParentCtx(c.Request.Context())) + jobSpec, err = jc.App.JobORM().FindJobByExternalJobID(ctx, externalJobID) } else if pErr = jobSpec.SetID(c.Param("ID")); pErr == nil { // Find a job by job ID - jobSpec, err = jc.App.JobORM().FindJobTx(c, jobSpec.ID) + jobSpec, err = jc.App.JobORM().FindJobTx(ctx, jobSpec.ID) } else { jsonAPIError(c, http.StatusUnprocessableEntity, pErr) return @@ -242,7 +242,7 @@ func (jc *JobsController) validateJobSpec(ctx context.Context, tomlString string case job.VRF: jb, err = vrfcommon.ValidatedVRFSpec(tomlString) case job.Webhook: - jb, err = webhook.ValidatedWebhookSpec(tomlString, jc.App.GetExternalInitiatorManager()) + jb, err = webhook.ValidatedWebhookSpec(ctx, tomlString, jc.App.GetExternalInitiatorManager()) case job.BlockhashStore: jb, err = blockhashstore.ValidatedSpec(tomlString) case job.BlockHeaderFeeder: diff --git a/core/web/jobs_controller_test.go b/core/web/jobs_controller_test.go index 9c7b529b6bf..bade8fe293b 100644 --- a/core/web/jobs_controller_test.go +++ b/core/web/jobs_controller_test.go @@ -24,8 +24,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/jmoiron/sqlx" - + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils" evmclimocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client/mocks" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" @@ -37,7 +36,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/p2pkey" "github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/vrfkey" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/testdata/testspecs" "github.com/smartcontractkit/chainlink/v2/core/utils/tomlutils" "github.com/smartcontractkit/chainlink/v2/core/web" @@ -129,7 +127,7 @@ func TestJobController_Create_DirectRequest_Fast(t *testing.T) { }(i) } wg.Wait() - cltest.AssertCount(t, app.GetSqlxDB(), "direct_request_specs", int64(n)) + cltest.AssertCount(t, app.GetDB(), "direct_request_specs", int64(n)) } func mustInt32FromString(t *testing.T, s string) int32 { @@ -141,7 +139,7 @@ func mustInt32FromString(t *testing.T, s string) int32 { func TestJobController_Create_HappyPath(t *testing.T) { ctx := testutils.Context(t) app, client := setupJobsControllerTests(t) - b1, b2 := setupBridges(t, app.GetSqlxDB(), app.GetConfig().Database()) + b1, b2 := setupBridges(t, app.GetDB()) require.NoError(t, app.KeyStore.OCR().Add(ctx, cltest.DefaultOCRKey)) var pks []vrfkey.KeyV2 var k []p2pkey.KeyV2 @@ -220,7 +218,8 @@ func TestJobController_Create_HappyPath(t *testing.T) { // services failed to start require.Contains(t, errs.Errors[0].Detail, "no contract code at given address") // but the job should still exist - jb, err := jorm.FindJobByExternalJobID(uuid.MustParse(nameAndExternalJobID)) + ctx := testutils.Context(t) + jb, err := jorm.FindJobByExternalJobID(ctx, uuid.MustParse(nameAndExternalJobID)) require.NoError(t, err) require.NotNil(t, jb.KeeperSpec) @@ -333,7 +332,8 @@ func TestJobController_Create_HappyPath(t *testing.T) { // services failed to start require.Contains(t, errs.Errors[0].Detail, "no contract code at given address") // but the job should still exist - jb, err := jorm.FindJobByExternalJobID(uuid.MustParse(nameAndExternalJobID)) + ctx := testutils.Context(t) + jb, err := jorm.FindJobByExternalJobID(ctx, uuid.MustParse(nameAndExternalJobID)) require.NoError(t, err) require.NotNil(t, jb.FluxMonitorSpec) @@ -481,8 +481,8 @@ func TestJobsController_Create_WebhookSpec(t *testing.T) { app := cltest.NewApplicationEVMDisabled(t) require.NoError(t, app.Start(testutils.Context(t))) - _, fetchBridge := cltest.MustCreateBridge(t, app.GetSqlxDB(), cltest.BridgeOpts{}) - _, submitBridge := cltest.MustCreateBridge(t, app.GetSqlxDB(), cltest.BridgeOpts{}) + _, fetchBridge := cltest.MustCreateBridge(t, app.GetDB(), cltest.BridgeOpts{}) + _, submitBridge := cltest.MustCreateBridge(t, app.GetDB(), cltest.BridgeOpts{}) client := app.NewHTTPClient(nil) @@ -620,10 +620,10 @@ func TestJobsController_Update_HappyPath(t *testing.T) { app := cltest.NewApplicationWithConfigAndKey(t, cfg, cltest.DefaultP2PKey) require.NoError(t, app.KeyStore.OCR().Add(ctx, cltest.DefaultOCRKey)) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) - _, bridge := cltest.MustCreateBridge(t, app.GetSqlxDB(), cltest.BridgeOpts{}) - _, bridge2 := cltest.MustCreateBridge(t, app.GetSqlxDB(), cltest.BridgeOpts{}) + _, bridge := cltest.MustCreateBridge(t, app.GetDB(), cltest.BridgeOpts{}) + _, bridge2 := cltest.MustCreateBridge(t, app.GetDB(), cltest.BridgeOpts{}) client := app.NewHTTPClient(nil) @@ -639,16 +639,16 @@ func TestJobsController_Update_HappyPath(t *testing.T) { // BCF-2095 // disable fkey checks until the end of the test transaction require.NoError(t, utils.JustError( - app.GetSqlxDB().Exec(`SET CONSTRAINTS job_spec_errors_v2_job_id_fkey DEFERRED`))) + app.GetDB().ExecContext(ctx, `SET CONSTRAINTS job_spec_errors_v2_job_id_fkey DEFERRED`))) var ocrSpec job.OCROracleSpec err = toml.Unmarshal([]byte(ocrspec.Toml()), &ocrSpec) require.NoError(t, err) jb.OCROracleSpec = &ocrSpec jb.OCROracleSpec.TransmitterAddress = &app.Keys[0].EIP55Address - err = app.AddJobV2(testutils.Context(t), &jb) + err = app.AddJobV2(ctx, &jb) require.NoError(t, err) - dbJb, err := app.JobORM().FindJob(testutils.Context(t), jb.ID) + dbJb, err := app.JobORM().FindJob(ctx, jb.ID) require.NoError(t, err) require.Equal(t, dbJb.Name.String, ocrspec.Name) @@ -666,7 +666,7 @@ func TestJobsController_Update_HappyPath(t *testing.T) { response, cleanup := client.Put("/v2/jobs/"+fmt.Sprintf("%v", jb.ID), bytes.NewReader(body)) t.Cleanup(cleanup) - dbJb, err = app.JobORM().FindJob(testutils.Context(t), jb.ID) + dbJb, err = app.JobORM().FindJob(ctx, jb.ID) require.NoError(t, err) require.Equal(t, dbJb.Name.String, updatedSpec.Name) @@ -686,8 +686,8 @@ func TestJobsController_Update_NonExistentID(t *testing.T) { require.NoError(t, app.KeyStore.OCR().Add(ctx, cltest.DefaultOCRKey)) require.NoError(t, app.Start(ctx)) - _, bridge := cltest.MustCreateBridge(t, app.GetSqlxDB(), cltest.BridgeOpts{}) - _, bridge2 := cltest.MustCreateBridge(t, app.GetSqlxDB(), cltest.BridgeOpts{}) + _, bridge := cltest.MustCreateBridge(t, app.GetDB(), cltest.BridgeOpts{}) + _, bridge2 := cltest.MustCreateBridge(t, app.GetDB(), cltest.BridgeOpts{}) client := app.NewHTTPClient(nil) @@ -754,9 +754,9 @@ func runDirectRequestJobSpecAssertions(t *testing.T, ereJobSpecFromFile job.Job, assert.Contains(t, ereJobSpecFromServer.DirectRequestSpec.UpdatedAt.String(), "20") } -func setupBridges(t *testing.T, db *sqlx.DB, cfg pg.QConfig) (b1, b2 string) { - _, bridge := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) - _, bridge2 := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) +func setupBridges(t *testing.T, ds sqlutil.DataSource) (b1, b2 string) { + _, bridge := cltest.MustCreateBridge(t, ds, cltest.BridgeOpts{}) + _, bridge2 := cltest.MustCreateBridge(t, ds, cltest.BridgeOpts{}) return bridge.Name.String(), bridge2.Name.String() } @@ -800,8 +800,8 @@ func setupJobSpecsControllerTestsWithJobs(t *testing.T) (*cltest.TestApplication require.NoError(t, app.KeyStore.OCR().Add(ctx, cltest.DefaultOCRKey)) require.NoError(t, app.Start(ctx)) - _, bridge := cltest.MustCreateBridge(t, app.GetSqlxDB(), cltest.BridgeOpts{}) - _, bridge2 := cltest.MustCreateBridge(t, app.GetSqlxDB(), cltest.BridgeOpts{}) + _, bridge := cltest.MustCreateBridge(t, app.GetDB(), cltest.BridgeOpts{}) + _, bridge2 := cltest.MustCreateBridge(t, app.GetDB(), cltest.BridgeOpts{}) client := app.NewHTTPClient(nil) diff --git a/core/web/loader/feeds_manager.go b/core/web/loader/feeds_manager.go index fb894d38b6e..a29d510a09d 100644 --- a/core/web/loader/feeds_manager.go +++ b/core/web/loader/feeds_manager.go @@ -14,7 +14,7 @@ type feedsBatcher struct { app chainlink.Application } -func (b *feedsBatcher) loadByIDs(_ context.Context, keys dataloader.Keys) []*dataloader.Result { +func (b *feedsBatcher) loadByIDs(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { // Create a map for remembering the order of keys passed in keyOrder := make(map[string]int, len(keys)) // Collect the keys to search for @@ -28,7 +28,7 @@ func (b *feedsBatcher) loadByIDs(_ context.Context, keys dataloader.Keys) []*dat } // Fetch the feeds managers - managers, err := b.app.GetFeedsService().ListManagersByIDs(managersIDs) + managers, err := b.app.GetFeedsService().ListManagersByIDs(ctx, managersIDs) if err != nil { return []*dataloader.Result{{Data: nil, Error: err}} } diff --git a/core/web/loader/feeds_manager_chain_config.go b/core/web/loader/feeds_manager_chain_config.go index 89d35919fd1..661edc33219 100644 --- a/core/web/loader/feeds_manager_chain_config.go +++ b/core/web/loader/feeds_manager_chain_config.go @@ -14,10 +14,10 @@ type feedsManagerChainConfigBatcher struct { app chainlink.Application } -func (b *feedsManagerChainConfigBatcher) loadByManagerIDs(_ context.Context, keys dataloader.Keys) []*dataloader.Result { +func (b *feedsManagerChainConfigBatcher) loadByManagerIDs(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { ids, keyOrder := keyOrderInt64(keys) - cfgs, err := b.app.GetFeedsService().ListChainConfigsByManagerIDs(ids) + cfgs, err := b.app.GetFeedsService().ListChainConfigsByManagerIDs(ctx, ids) if err != nil { return []*dataloader.Result{{Data: nil, Error: err}} } diff --git a/core/web/loader/job.go b/core/web/loader/job.go index 11eb2c76814..a37c2809625 100644 --- a/core/web/loader/job.go +++ b/core/web/loader/job.go @@ -16,7 +16,7 @@ type jobBatcher struct { app chainlink.Application } -func (b *jobBatcher) loadByExternalJobIDs(_ context.Context, keys dataloader.Keys) []*dataloader.Result { +func (b *jobBatcher) loadByExternalJobIDs(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { // Create a map for remembering the order of keys passed in keyOrder := make(map[string]int, len(keys)) // Collect the keys to search for @@ -33,7 +33,7 @@ func (b *jobBatcher) loadByExternalJobIDs(_ context.Context, keys dataloader.Key // Fetch the jobs var jobs []job.Job for _, id := range jobIDs { - job, err := b.app.JobORM().FindJobByExternalJobID(id) + job, err := b.app.JobORM().FindJobByExternalJobID(ctx, id) if err != nil { return []*dataloader.Result{{Data: nil, Error: err}} @@ -63,7 +63,7 @@ func (b *jobBatcher) loadByExternalJobIDs(_ context.Context, keys dataloader.Key return results } -func (b *jobBatcher) loadByPipelineSpecIDs(_ context.Context, keys dataloader.Keys) []*dataloader.Result { +func (b *jobBatcher) loadByPipelineSpecIDs(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { // Create a map for remembering the order of keys passed in keyOrder := make(map[string]int, len(keys)) // Collect the keys to search for @@ -77,7 +77,7 @@ func (b *jobBatcher) loadByPipelineSpecIDs(_ context.Context, keys dataloader.Ke } // Fetch the jobs - jobs, err := b.app.JobORM().FindJobsByPipelineSpecIDs(plSpecIDs) + jobs, err := b.app.JobORM().FindJobsByPipelineSpecIDs(ctx, plSpecIDs) if err != nil { return []*dataloader.Result{{Data: nil, Error: err}} } diff --git a/core/web/loader/job_proposal.go b/core/web/loader/job_proposal.go index ac24e620f47..17fb6bc67c2 100644 --- a/core/web/loader/job_proposal.go +++ b/core/web/loader/job_proposal.go @@ -14,7 +14,7 @@ type jobProposalBatcher struct { app chainlink.Application } -func (b *jobProposalBatcher) loadByManagersIDs(_ context.Context, keys dataloader.Keys) []*dataloader.Result { +func (b *jobProposalBatcher) loadByManagersIDs(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { // Create a map for remembering the order of keys passed in keyOrder := make(map[string]int, len(keys)) // Collect the keys to search for @@ -28,7 +28,7 @@ func (b *jobProposalBatcher) loadByManagersIDs(_ context.Context, keys dataloade keyOrder[key.String()] = ix } - jps, err := b.app.GetFeedsService().ListJobProposalsByManagersIDs(mgrsIDs) + jps, err := b.app.GetFeedsService().ListJobProposalsByManagersIDs(ctx, mgrsIDs) if err != nil { return []*dataloader.Result{{Data: nil, Error: err}} } diff --git a/core/web/loader/job_proposal_spec.go b/core/web/loader/job_proposal_spec.go index bb6720903dc..bff112268f8 100644 --- a/core/web/loader/job_proposal_spec.go +++ b/core/web/loader/job_proposal_spec.go @@ -14,10 +14,10 @@ type jobProposalSpecBatcher struct { app chainlink.Application } -func (b *jobProposalSpecBatcher) loadByJobProposalsIDs(_ context.Context, keys dataloader.Keys) []*dataloader.Result { +func (b *jobProposalSpecBatcher) loadByJobProposalsIDs(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { ids, keyOrder := keyOrderInt64(keys) - specs, err := b.app.GetFeedsService().ListSpecsByJobProposalIDs(ids) + specs, err := b.app.GetFeedsService().ListSpecsByJobProposalIDs(ctx, ids) if err != nil { return []*dataloader.Result{{Data: nil, Error: err}} } diff --git a/core/web/loader/job_run.go b/core/web/loader/job_run.go index 691b27a3511..4367eec2c78 100644 --- a/core/web/loader/job_run.go +++ b/core/web/loader/job_run.go @@ -14,7 +14,7 @@ type jobRunBatcher struct { app chainlink.Application } -func (b *jobRunBatcher) loadByIDs(_ context.Context, keys dataloader.Keys) []*dataloader.Result { +func (b *jobRunBatcher) loadByIDs(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { // Create a map for remembering the order of keys passed in keyOrder := make(map[string]int, len(keys)) // Collect the keys to search for @@ -29,7 +29,7 @@ func (b *jobRunBatcher) loadByIDs(_ context.Context, keys dataloader.Keys) []*da } // Fetch the runs - runs, err := b.app.JobORM().FindPipelineRunsByIDs(runIDs) + runs, err := b.app.JobORM().FindPipelineRunsByIDs(ctx, runIDs) if err != nil { return []*dataloader.Result{{Data: nil, Error: err}} } diff --git a/core/web/loader/job_spec_errors.go b/core/web/loader/job_spec_errors.go index 5ef23154d2e..5d558c52ec5 100644 --- a/core/web/loader/job_spec_errors.go +++ b/core/web/loader/job_spec_errors.go @@ -7,7 +7,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/chainlink" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/utils/stringutils" ) @@ -29,7 +28,7 @@ func (b *jobSpecErrorsBatcher) loadByJobIDs(ctx context.Context, keys dataloader keyOrder[key.String()] = ix } - specErrors, err := b.app.JobORM().FindSpecErrorsByJobIDs(jobIDs, pg.WithParentCtx(ctx)) + specErrors, err := b.app.JobORM().FindSpecErrorsByJobIDs(ctx, jobIDs) if err != nil { return []*dataloader.Result{{Data: nil, Error: err}} } diff --git a/core/web/loader/loader_test.go b/core/web/loader/loader_test.go index d12a10a9e52..5e22c9afef6 100644 --- a/core/web/loader/loader_test.go +++ b/core/web/loader/loader_test.go @@ -125,7 +125,7 @@ func TestLoader_FeedsManagers(t *testing.T) { Name: "manager 3", } - fsvc.On("ListManagersByIDs", []int64{3, 1, 2, 5}).Return([]feeds.FeedsManager{ + fsvc.On("ListManagersByIDs", mock.Anything, []int64{3, 1, 2, 5}).Return([]feeds.FeedsManager{ mgr1, mgr2, mgr3, }, nil) app.On("GetFeedsService").Return(fsvc) @@ -167,7 +167,7 @@ func TestLoader_JobProposals(t *testing.T) { Status: feeds.JobProposalStatusRejected, } - fsvc.On("ListJobProposalsByManagersIDs", []int64{3, 1, 2}).Return([]feeds.JobProposal{ + fsvc.On("ListJobProposalsByManagersIDs", mock.Anything, []int64{3, 1, 2}).Return([]feeds.JobProposal{ jp1, jp3, jp2, }, nil) app.On("GetFeedsService").Return(fsvc) @@ -194,7 +194,7 @@ func TestLoader_JobRuns(t *testing.T) { run2 := pipeline.Run{ID: int64(2)} run3 := pipeline.Run{ID: int64(3)} - jobsORM.On("FindPipelineRunsByIDs", []int64{3, 1, 2}).Return([]pipeline.Run{ + jobsORM.On("FindPipelineRunsByIDs", mock.Anything, []int64{3, 1, 2}).Return([]pipeline.Run{ run3, run1, run2, }, nil) app.On("JobORM").Return(jobsORM) @@ -224,7 +224,7 @@ func TestLoader_JobsByPipelineSpecIDs(t *testing.T) { job2 := job.Job{ID: int32(3), PipelineSpecID: int32(2)} job3 := job.Job{ID: int32(4), PipelineSpecID: int32(3)} - jobsORM.On("FindJobsByPipelineSpecIDs", []int32{3, 1, 2}).Return([]job.Job{ + jobsORM.On("FindJobsByPipelineSpecIDs", mock.Anything, []int32{3, 1, 2}).Return([]job.Job{ job1, job2, job3, }, nil) app.On("JobORM").Return(jobsORM) @@ -247,7 +247,7 @@ func TestLoader_JobsByPipelineSpecIDs(t *testing.T) { app := coremocks.NewApplication(t) ctx := InjectDataloader(testutils.Context(t), app) - jobsORM.On("FindJobsByPipelineSpecIDs", []int32{3, 1, 2}).Return([]job.Job{}, sql.ErrNoRows) + jobsORM.On("FindJobsByPipelineSpecIDs", mock.Anything, []int32{3, 1, 2}).Return([]job.Job{}, sql.ErrNoRows) app.On("JobORM").Return(jobsORM) batcher := jobBatcher{app} @@ -274,7 +274,7 @@ func TestLoader_JobsByExternalJobIDs(t *testing.T) { ejID := uuid.New() job := job.Job{ID: int32(2), ExternalJobID: ejID} - jobsORM.On("FindJobByExternalJobID", ejID).Return(job, nil) + jobsORM.On("FindJobByExternalJobID", mock.Anything, ejID).Return(job, nil) app.On("JobORM").Return(jobsORM) batcher := jobBatcher{app} @@ -335,7 +335,7 @@ func TestLoader_SpecErrorsByJobID(t *testing.T) { specErr2 := job.SpecError{ID: int64(3), JobID: int32(2)} specErr3 := job.SpecError{ID: int64(4), JobID: int32(3)} - jobsORM.On("FindSpecErrorsByJobIDs", []int32{3, 1, 2}, mock.Anything).Return([]job.SpecError{ + jobsORM.On("FindSpecErrorsByJobIDs", mock.Anything, []int32{3, 1, 2}, mock.Anything).Return([]job.SpecError{ specErr1, specErr2, specErr3, }, nil) app.On("JobORM").Return(jobsORM) @@ -358,7 +358,7 @@ func TestLoader_SpecErrorsByJobID(t *testing.T) { app := coremocks.NewApplication(t) ctx := InjectDataloader(testutils.Context(t), app) - jobsORM.On("FindSpecErrorsByJobIDs", []int32{3, 1, 2}, mock.Anything).Return([]job.SpecError{}, sql.ErrNoRows) + jobsORM.On("FindSpecErrorsByJobIDs", mock.Anything, []int32{3, 1, 2}, mock.Anything).Return([]job.SpecError{}, sql.ErrNoRows) app.On("JobORM").Return(jobsORM) batcher := jobSpecErrorsBatcher{app} diff --git a/core/web/pipeline_job_spec_errors_controller_test.go b/core/web/pipeline_job_spec_errors_controller_test.go index 8ec77a84f05..9729bfac51d 100644 --- a/core/web/pipeline_job_spec_errors_controller_test.go +++ b/core/web/pipeline_job_spec_errors_controller_test.go @@ -17,7 +17,8 @@ func TestPipelineJobSpecErrorsController_Delete_2(t *testing.T) { description := "job spec error description" - require.NoError(t, app.JobORM().RecordError(jID, description)) + ctx := testutils.Context(t) + require.NoError(t, app.JobORM().RecordError(ctx, jID, description)) // FindJob -> find error j, err := app.JobORM().FindJob(testutils.Context(t), jID) diff --git a/core/web/pipeline_runs_controller.go b/core/web/pipeline_runs_controller.go index 1bd52b021c3..099f824e0c3 100644 --- a/core/web/pipeline_runs_controller.go +++ b/core/web/pipeline_runs_controller.go @@ -40,8 +40,9 @@ func (prc *PipelineRunsController) Index(c *gin.Context, size, page, offset int) var count int var err error + ctx := c.Request.Context() if id == "" { - pipelineRuns, count, err = prc.App.JobORM().PipelineRuns(nil, offset, size) + pipelineRuns, count, err = prc.App.JobORM().PipelineRuns(ctx, nil, offset, size) } else { jobSpec := job.Job{} err = jobSpec.SetID(c.Param("ID")) @@ -50,7 +51,7 @@ func (prc *PipelineRunsController) Index(c *gin.Context, size, page, offset int) return } - pipelineRuns, count, err = prc.App.JobORM().PipelineRuns(&jobSpec.ID, offset, size) + pipelineRuns, count, err = prc.App.JobORM().PipelineRuns(ctx, &jobSpec.ID, offset, size) } if err != nil { @@ -113,13 +114,13 @@ func (prc *PipelineRunsController) Create(c *gin.Context) { // Is it a UUID? Then process it as a webhook job jobUUID, err := uuid.Parse(idStr) if err == nil { - canRun, err2 := authorizer.CanRun(c.Request.Context(), prc.App.GetConfig().JobPipeline(), jobUUID) + canRun, err2 := authorizer.CanRun(ctx, prc.App.GetConfig().JobPipeline(), jobUUID) if err2 != nil { jsonAPIError(c, http.StatusInternalServerError, err2) return } if canRun { - jobRunID, err3 := prc.App.RunWebhookJobV2(c.Request.Context(), jobUUID, string(bodyBytes), jsonserializable.JSONSerializable{}) + jobRunID, err3 := prc.App.RunWebhookJobV2(ctx, jobUUID, string(bodyBytes), jsonserializable.JSONSerializable{}) if errors.Is(err3, webhook.ErrJobNotExists) { jsonAPIError(c, http.StatusNotFound, err3) return @@ -141,7 +142,7 @@ func (prc *PipelineRunsController) Create(c *gin.Context) { jobID64, err := strconv.ParseInt(idStr, 10, 32) if err == nil { jobID = int32(jobID64) - jobRunID, err := prc.App.RunJobV2(c.Request.Context(), jobID, nil) + jobRunID, err := prc.App.RunJobV2(ctx, jobID, nil) if err != nil { jsonAPIError(c, http.StatusInternalServerError, err) return diff --git a/core/web/pipeline_runs_controller_test.go b/core/web/pipeline_runs_controller_test.go index f6b4291a34f..e123df2bdb3 100644 --- a/core/web/pipeline_runs_controller_test.go +++ b/core/web/pipeline_runs_controller_test.go @@ -33,6 +33,7 @@ import ( func TestPipelineRunsController_CreateWithBody_HappyPath(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) ethClient := cltest.NewEthMocksWithStartupAssertions(t) cfg := configtest.NewGeneralConfig(t, func(c *chainlink.Config, s *chainlink.Secrets) { c.JobPipeline.HTTPRequest.DefaultTimeout = commonconfig.MustNewDuration(2 * time.Second) @@ -50,13 +51,13 @@ func TestPipelineRunsController_CreateWithBody_HappyPath(t *testing.T) { require.Equal(t, `{"result":"12345"}`, string(bs)) }) - _, bridge := cltest.MustCreateBridge(t, app.GetSqlxDB(), cltest.BridgeOpts{URL: mockServer.URL}) + _, bridge := cltest.MustCreateBridge(t, app.GetDB(), cltest.BridgeOpts{URL: mockServer.URL}) // Add the job uuid := uuid.New() { tomlStr := fmt.Sprintf(testspecs.WebhookSpecWithBodyTemplate, uuid, bridge.Name.String()) - jb, err := webhook.ValidatedWebhookSpec(tomlStr, app.GetExternalInitiatorManager()) + jb, err := webhook.ValidatedWebhookSpec(ctx, tomlStr, app.GetExternalInitiatorManager()) require.NoError(t, err) err = app.AddJobV2(testutils.Context(t), &jb) @@ -88,6 +89,7 @@ func TestPipelineRunsController_CreateWithBody_HappyPath(t *testing.T) { func TestPipelineRunsController_CreateNoBody_HappyPath(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) ethClient := cltest.NewEthMocksWithStartupAssertions(t) cfg := configtest.NewGeneralConfig(t, func(c *chainlink.Config, s *chainlink.Secrets) { c.JobPipeline.HTTPRequest.DefaultTimeout = commonconfig.MustNewDuration(2 * time.Second) @@ -100,7 +102,7 @@ func TestPipelineRunsController_CreateNoBody_HappyPath(t *testing.T) { // Setup the bridges mockServer := cltest.NewHTTPMockServer(t, 200, "POST", `{"data":{"result":"123.45"}}`) - _, bridge := cltest.MustCreateBridge(t, app.GetSqlxDB(), cltest.BridgeOpts{URL: mockServer.URL}) + _, bridge := cltest.MustCreateBridge(t, app.GetDB(), cltest.BridgeOpts{URL: mockServer.URL}) mockServer = cltest.NewHTTPMockServerWithRequest(t, 200, `{}`, func(r *http.Request) { defer r.Body.Close() @@ -109,13 +111,13 @@ func TestPipelineRunsController_CreateNoBody_HappyPath(t *testing.T) { require.Equal(t, `{"result":"12345"}`, string(bs)) }) - _, submitBridge := cltest.MustCreateBridge(t, app.GetSqlxDB(), cltest.BridgeOpts{URL: mockServer.URL}) + _, submitBridge := cltest.MustCreateBridge(t, app.GetDB(), cltest.BridgeOpts{URL: mockServer.URL}) // Add the job uuid := uuid.New() { tomlStr := testspecs.GetWebhookSpecNoBody(uuid, bridge.Name.String(), submitBridge.Name.String()) - jb, err := webhook.ValidatedWebhookSpec(tomlStr, app.GetExternalInitiatorManager()) + jb, err := webhook.ValidatedWebhookSpec(ctx, tomlStr, app.GetExternalInitiatorManager()) require.NoError(t, err) err = app.AddJobV2(testutils.Context(t), &jb) diff --git a/core/web/resolver/bridge_test.go b/core/web/resolver/bridge_test.go index 706d3fb6d5a..2244ddf3dac 100644 --- a/core/web/resolver/bridge_test.go +++ b/core/web/resolver/bridge_test.go @@ -418,7 +418,7 @@ func Test_DeleteBridgeMutation(t *testing.T) { f.Mocks.bridgeORM.On("FindBridge", mock.Anything, name).Return(bridge, nil) f.Mocks.bridgeORM.On("DeleteBridgeType", mock.Anything, &bridge).Return(nil) - f.Mocks.jobORM.On("FindJobIDsWithBridge", name.String()).Return([]int32{}, nil) + f.Mocks.jobORM.On("FindJobIDsWithBridge", mock.Anything, name.String()).Return([]int32{}, nil) f.App.On("JobORM").Return(f.Mocks.jobORM) f.App.On("BridgeORM").Return(f.Mocks.bridgeORM) }, @@ -481,7 +481,7 @@ func Test_DeleteBridgeMutation(t *testing.T) { }, before: func(f *gqlTestFramework) { f.Mocks.bridgeORM.On("FindBridge", mock.Anything, name).Return(bridges.BridgeType{}, nil) - f.Mocks.jobORM.On("FindJobIDsWithBridge", name.String()).Return([]int32{1}, nil) + f.Mocks.jobORM.On("FindJobIDsWithBridge", mock.Anything, name.String()).Return([]int32{1}, nil) f.App.On("BridgeORM").Return(f.Mocks.bridgeORM) f.App.On("JobORM").Return(f.Mocks.jobORM) }, diff --git a/core/web/resolver/feeds_manager_chain_config_test.go b/core/web/resolver/feeds_manager_chain_config_test.go index ae869b50874..31208aa0581 100644 --- a/core/web/resolver/feeds_manager_chain_config_test.go +++ b/core/web/resolver/feeds_manager_chain_config_test.go @@ -101,7 +101,7 @@ func Test_CreateFeedsManagerChainConfig(t *testing.T) { }, }, }).Return(cfgID, nil) - f.Mocks.feedsSvc.On("GetChainConfig", cfgID).Return(&feeds.ChainConfig{ + f.Mocks.feedsSvc.On("GetChainConfig", mock.Anything, cfgID).Return(&feeds.ChainConfig{ ID: cfgID, ChainType: feeds.ChainTypeEVM, ChainID: chainID, @@ -164,7 +164,7 @@ func Test_CreateFeedsManagerChainConfig(t *testing.T) { before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("CreateChainConfig", mock.Anything, mock.IsType(feeds.ChainConfig{})).Return(cfgID, nil) - f.Mocks.feedsSvc.On("GetChainConfig", cfgID).Return(nil, sql.ErrNoRows) + f.Mocks.feedsSvc.On("GetChainConfig", mock.Anything, cfgID).Return(nil, sql.ErrNoRows) }, query: mutation, variables: variables, @@ -211,7 +211,7 @@ func Test_DeleteFeedsManagerChainConfig(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) - f.Mocks.feedsSvc.On("GetChainConfig", cfgID).Return(&feeds.ChainConfig{ + f.Mocks.feedsSvc.On("GetChainConfig", mock.Anything, cfgID).Return(&feeds.ChainConfig{ ID: cfgID, }, nil) f.Mocks.feedsSvc.On("DeleteChainConfig", mock.Anything, cfgID).Return(cfgID, nil) @@ -232,7 +232,7 @@ func Test_DeleteFeedsManagerChainConfig(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) - f.Mocks.feedsSvc.On("GetChainConfig", cfgID).Return(nil, sql.ErrNoRows) + f.Mocks.feedsSvc.On("GetChainConfig", mock.Anything, cfgID).Return(nil, sql.ErrNoRows) }, query: mutation, variables: variables, @@ -249,7 +249,7 @@ func Test_DeleteFeedsManagerChainConfig(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) - f.Mocks.feedsSvc.On("GetChainConfig", cfgID).Return(&feeds.ChainConfig{ + f.Mocks.feedsSvc.On("GetChainConfig", mock.Anything, cfgID).Return(&feeds.ChainConfig{ ID: cfgID, }, nil) f.Mocks.feedsSvc.On("DeleteChainConfig", mock.Anything, cfgID).Return(int64(0), sql.ErrNoRows) @@ -352,7 +352,7 @@ func Test_UpdateFeedsManagerChainConfig(t *testing.T) { }, }, }).Return(cfgID, nil) - f.Mocks.feedsSvc.On("GetChainConfig", cfgID).Return(&feeds.ChainConfig{ + f.Mocks.feedsSvc.On("GetChainConfig", mock.Anything, cfgID).Return(&feeds.ChainConfig{ ID: cfgID, AccountAddress: accountAddr, AdminAddress: adminAddr, @@ -413,7 +413,7 @@ func Test_UpdateFeedsManagerChainConfig(t *testing.T) { before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateChainConfig", mock.Anything, mock.IsType(feeds.ChainConfig{})).Return(cfgID, nil) - f.Mocks.feedsSvc.On("GetChainConfig", cfgID).Return(nil, sql.ErrNoRows) + f.Mocks.feedsSvc.On("GetChainConfig", mock.Anything, cfgID).Return(nil, sql.ErrNoRows) }, query: mutation, variables: variables, diff --git a/core/web/resolver/feeds_manager_test.go b/core/web/resolver/feeds_manager_test.go index 84558090da5..a3ea80a6443 100644 --- a/core/web/resolver/feeds_manager_test.go +++ b/core/web/resolver/feeds_manager_test.go @@ -42,14 +42,14 @@ func Test_FeedsManagers(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) - f.Mocks.feedsSvc.On("ListJobProposalsByManagersIDs", []int64{1}).Return([]feeds.JobProposal{ + f.Mocks.feedsSvc.On("ListJobProposalsByManagersIDs", mock.Anything, []int64{1}).Return([]feeds.JobProposal{ { ID: int64(100), FeedsManagerID: int64(1), Status: feeds.JobProposalStatusApproved, }, }, nil) - f.Mocks.feedsSvc.On("ListManagers").Return([]feeds.FeedsManager{ + f.Mocks.feedsSvc.On("ListManagers", mock.Anything).Return([]feeds.FeedsManager{ { ID: 1, Name: "manager1", @@ -115,7 +115,7 @@ func Test_FeedsManager(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) - f.Mocks.feedsSvc.On("GetManager", mgrID).Return(&feeds.FeedsManager{ + f.Mocks.feedsSvc.On("GetManager", mock.Anything, mgrID).Return(&feeds.FeedsManager{ ID: mgrID, Name: "manager1", URI: "localhost:2000", @@ -142,7 +142,7 @@ func Test_FeedsManager(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) - f.Mocks.feedsSvc.On("GetManager", mgrID).Return(nil, sql.ErrNoRows) + f.Mocks.feedsSvc.On("GetManager", mock.Anything, mgrID).Return(nil, sql.ErrNoRows) }, query: query, result: ` @@ -219,7 +219,7 @@ func Test_CreateFeedsManager(t *testing.T) { URI: uri, PublicKey: *pubKey, }).Return(mgrID, nil) - f.Mocks.feedsSvc.On("GetManager", mgrID).Return(&feeds.FeedsManager{ + f.Mocks.feedsSvc.On("GetManager", mock.Anything, mgrID).Return(&feeds.FeedsManager{ ID: mgrID, Name: name, URI: uri, @@ -269,7 +269,7 @@ func Test_CreateFeedsManager(t *testing.T) { before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("RegisterManager", mock.Anything, mock.IsType(feeds.RegisterManagerParams{})).Return(mgrID, nil) - f.Mocks.feedsSvc.On("GetManager", mgrID).Return(nil, sql.ErrNoRows) + f.Mocks.feedsSvc.On("GetManager", mock.Anything, mgrID).Return(nil, sql.ErrNoRows) }, query: mutation, variables: variables, @@ -366,7 +366,7 @@ func Test_UpdateFeedsManager(t *testing.T) { URI: uri, PublicKey: *pubKey, }).Return(nil) - f.Mocks.feedsSvc.On("GetManager", mgrID).Return(&feeds.FeedsManager{ + f.Mocks.feedsSvc.On("GetManager", mock.Anything, mgrID).Return(&feeds.FeedsManager{ ID: mgrID, Name: name, URI: uri, @@ -397,7 +397,7 @@ func Test_UpdateFeedsManager(t *testing.T) { before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateManager", mock.Anything, mock.IsType(feeds.FeedsManager{})).Return(nil) - f.Mocks.feedsSvc.On("GetManager", mgrID).Return(nil, sql.ErrNoRows) + f.Mocks.feedsSvc.On("GetManager", mock.Anything, mgrID).Return(nil, sql.ErrNoRows) }, query: mutation, variables: variables, diff --git a/core/web/resolver/job.go b/core/web/resolver/job.go index cb79b0c6e63..e9855fdb8fe 100644 --- a/core/web/resolver/job.go +++ b/core/web/resolver/job.go @@ -113,7 +113,7 @@ func (r *JobResolver) Runs(ctx context.Context, args struct { limit = 100 } - ids, err := r.app.JobORM().FindPipelineRunIDsByJobID(r.j.ID, offset, limit) + ids, err := r.app.JobORM().FindPipelineRunIDsByJobID(ctx, r.j.ID, offset, limit) if err != nil { return nil, err } @@ -123,7 +123,7 @@ func (r *JobResolver) Runs(ctx context.Context, args struct { return nil, err } - count, err := r.app.JobORM().CountPipelineRunsByJobID(r.j.ID) + count, err := r.app.JobORM().CountPipelineRunsByJobID(ctx, r.j.ID) if err != nil { return nil, err } diff --git a/core/web/resolver/job_error_test.go b/core/web/resolver/job_error_test.go index 30d06289ed6..69899a3ec47 100644 --- a/core/web/resolver/job_error_test.go +++ b/core/web/resolver/job_error_test.go @@ -29,10 +29,10 @@ func TestResolver_JobErrors(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ ID: int32(1), }, nil) - f.Mocks.jobORM.On("FindSpecErrorsByJobIDs", []int32{1}, mock.Anything).Return([]job.SpecError{ + f.Mocks.jobORM.On("FindSpecErrorsByJobIDs", mock.Anything, []int32{1}, mock.Anything).Return([]job.SpecError{ { ID: errorID, Description: "no contract code at given address", @@ -124,7 +124,7 @@ func TestResolver_DismissJobError(t *testing.T) { name: "success", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.jobORM.On("FindSpecError", id).Return(job.SpecError{ + f.Mocks.jobORM.On("FindSpecError", mock.Anything, id).Return(job.SpecError{ ID: id, Occurrences: 5, Description: "test-description", @@ -141,7 +141,7 @@ func TestResolver_DismissJobError(t *testing.T) { name: "not found on FindSpecError()", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.jobORM.On("FindSpecError", id).Return(job.SpecError{}, sql.ErrNoRows) + f.Mocks.jobORM.On("FindSpecError", mock.Anything, id).Return(job.SpecError{}, sql.ErrNoRows) f.App.On("JobORM").Return(f.Mocks.jobORM) }, query: mutation, @@ -159,7 +159,7 @@ func TestResolver_DismissJobError(t *testing.T) { name: "not found on DismissError()", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.jobORM.On("FindSpecError", id).Return(job.SpecError{}, nil) + f.Mocks.jobORM.On("FindSpecError", mock.Anything, id).Return(job.SpecError{}, nil) f.Mocks.jobORM.On("DismissError", mock.Anything, id).Return(sql.ErrNoRows) f.App.On("JobORM").Return(f.Mocks.jobORM) }, @@ -178,7 +178,7 @@ func TestResolver_DismissJobError(t *testing.T) { name: "generic error on FindSpecError()", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.jobORM.On("FindSpecError", id).Return(job.SpecError{}, gError) + f.Mocks.jobORM.On("FindSpecError", mock.Anything, id).Return(job.SpecError{}, gError) f.App.On("JobORM").Return(f.Mocks.jobORM) }, query: mutation, @@ -197,7 +197,7 @@ func TestResolver_DismissJobError(t *testing.T) { name: "generic error on DismissError()", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.jobORM.On("FindSpecError", id).Return(job.SpecError{}, nil) + f.Mocks.jobORM.On("FindSpecError", mock.Anything, id).Return(job.SpecError{}, nil) f.Mocks.jobORM.On("DismissError", mock.Anything, id).Return(gError) f.App.On("JobORM").Return(f.Mocks.jobORM) }, diff --git a/core/web/resolver/job_proposal_spec_test.go b/core/web/resolver/job_proposal_spec_test.go index 2681129b2ae..c65702c5622 100644 --- a/core/web/resolver/job_proposal_spec_test.go +++ b/core/web/resolver/job_proposal_spec_test.go @@ -53,7 +53,7 @@ func TestResolver_ApproveJobProposalSpec(t *testing.T) { before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("ApproveSpec", mock.Anything, specID, false).Return(nil) - f.Mocks.feedsSvc.On("GetSpec", specID).Return(&feeds.JobProposalSpec{ + f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(&feeds.JobProposalSpec{ ID: specID, }, nil) }, @@ -84,7 +84,7 @@ func TestResolver_ApproveJobProposalSpec(t *testing.T) { before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("ApproveSpec", mock.Anything, specID, false).Return(nil) - f.Mocks.feedsSvc.On("GetSpec", specID).Return(nil, sql.ErrNoRows) + f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(nil, sql.ErrNoRows) }, query: mutation, variables: variables, @@ -157,7 +157,7 @@ func TestResolver_CancelJobProposalSpec(t *testing.T) { before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("CancelSpec", mock.Anything, specID).Return(nil) - f.Mocks.feedsSvc.On("GetSpec", specID).Return(&feeds.JobProposalSpec{ + f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(&feeds.JobProposalSpec{ ID: specID, }, nil) @@ -189,7 +189,7 @@ func TestResolver_CancelJobProposalSpec(t *testing.T) { before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("CancelSpec", mock.Anything, specID).Return(nil) - f.Mocks.feedsSvc.On("GetSpec", specID).Return(nil, sql.ErrNoRows) + f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(nil, sql.ErrNoRows) }, query: mutation, variables: variables, @@ -245,7 +245,7 @@ func TestResolver_RejectJobProposalSpec(t *testing.T) { before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("RejectSpec", mock.Anything, specID).Return(nil) - f.Mocks.feedsSvc.On("GetSpec", specID).Return(&feeds.JobProposalSpec{ + f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(&feeds.JobProposalSpec{ ID: specID, }, nil) }, @@ -276,7 +276,7 @@ func TestResolver_RejectJobProposalSpec(t *testing.T) { before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("RejectSpec", mock.Anything, specID).Return(nil) - f.Mocks.feedsSvc.On("GetSpec", specID).Return(nil, sql.ErrNoRows) + f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(nil, sql.ErrNoRows) }, query: mutation, variables: variables, @@ -335,7 +335,7 @@ func TestResolver_UpdateJobProposalSpecDefinition(t *testing.T) { before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateSpecDefinition", mock.Anything, specID, "").Return(nil) - f.Mocks.feedsSvc.On("GetSpec", specID).Return(&feeds.JobProposalSpec{ + f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(&feeds.JobProposalSpec{ ID: specID, }, nil) }, @@ -367,7 +367,7 @@ func TestResolver_UpdateJobProposalSpecDefinition(t *testing.T) { before: func(f *gqlTestFramework) { f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) f.Mocks.feedsSvc.On("UpdateSpecDefinition", mock.Anything, specID, "").Return(nil) - f.Mocks.feedsSvc.On("GetSpec", specID).Return(nil, sql.ErrNoRows) + f.Mocks.feedsSvc.On("GetSpec", mock.Anything, specID).Return(nil, sql.ErrNoRows) }, query: mutation, variables: variables, @@ -446,7 +446,7 @@ func TestResolver_GetJobProposal_Spec(t *testing.T) { name: "success", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.feedsSvc.On("GetJobProposal", jpID).Return(&feeds.JobProposal{ + f.Mocks.feedsSvc.On("GetJobProposal", mock.Anything, jpID).Return(&feeds.JobProposal{ ID: jpID, Status: feeds.JobProposalStatusApproved, FeedsManagerID: 1, @@ -454,7 +454,7 @@ func TestResolver_GetJobProposal_Spec(t *testing.T) { PendingUpdate: false, }, nil) f.Mocks.feedsSvc. - On("ListSpecsByJobProposalIDs", []int64{jpID}). + On("ListSpecsByJobProposalIDs", mock.Anything, []int64{jpID}). Return(specs, nil) f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) }, diff --git a/core/web/resolver/job_proposal_test.go b/core/web/resolver/job_proposal_test.go index 466ddd5d8ab..5544b39c936 100644 --- a/core/web/resolver/job_proposal_test.go +++ b/core/web/resolver/job_proposal_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/google/uuid" + "github.com/stretchr/testify/mock" "gopkg.in/guregu/null.v4" "github.com/smartcontractkit/chainlink/v2/core/services/feeds" @@ -64,13 +65,13 @@ func TestResolver_GetJobProposal(t *testing.T) { name: "success", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.feedsSvc.On("ListManagersByIDs", []int64{1}).Return([]feeds.FeedsManager{ + f.Mocks.feedsSvc.On("ListManagersByIDs", mock.Anything, []int64{1}).Return([]feeds.FeedsManager{ { ID: 1, Name: "manager", }, }, nil) - f.Mocks.feedsSvc.On("GetJobProposal", jpID).Return(&feeds.JobProposal{ + f.Mocks.feedsSvc.On("GetJobProposal", mock.Anything, jpID).Return(&feeds.JobProposal{ ID: jpID, Name: null.StringFrom(name), Status: feeds.JobProposalStatusApproved, @@ -89,7 +90,7 @@ func TestResolver_GetJobProposal(t *testing.T) { name: "not found error", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.feedsSvc.On("GetJobProposal", jpID).Return(nil, sql.ErrNoRows) + f.Mocks.feedsSvc.On("GetJobProposal", mock.Anything, jpID).Return(nil, sql.ErrNoRows) f.App.On("GetFeedsService").Return(f.Mocks.feedsSvc) }, query: query, diff --git a/core/web/resolver/job_run_test.go b/core/web/resolver/job_run_test.go index a35a2f66ac5..51631864e8c 100644 --- a/core/web/resolver/job_run_test.go +++ b/core/web/resolver/job_run_test.go @@ -40,7 +40,7 @@ func TestQuery_PaginatedJobRuns(t *testing.T) { name: "success", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.jobORM.On("PipelineRuns", (*int32)(nil), PageDefaultOffset, PageDefaultLimit).Return([]pipeline.Run{ + f.Mocks.jobORM.On("PipelineRuns", mock.Anything, (*int32)(nil), PageDefaultOffset, PageDefaultLimit).Return([]pipeline.Run{ { ID: int64(200), }, @@ -64,7 +64,7 @@ func TestQuery_PaginatedJobRuns(t *testing.T) { name: "generic error on PipelineRuns()", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.jobORM.On("PipelineRuns", (*int32)(nil), PageDefaultOffset, PageDefaultLimit).Return(nil, 0, gError) + f.Mocks.jobORM.On("PipelineRuns", mock.Anything, (*int32)(nil), PageDefaultOffset, PageDefaultLimit).Return(nil, 0, gError) f.App.On("JobORM").Return(f.Mocks.jobORM) }, query: query, @@ -131,7 +131,7 @@ func TestResolver_JobRun(t *testing.T) { name: "success", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.jobORM.On("FindPipelineRunByID", int64(2)).Return(pipeline.Run{ + f.Mocks.jobORM.On("FindPipelineRunByID", mock.Anything, int64(2)).Return(pipeline.Run{ ID: 2, PipelineSpecID: 5, CreatedAt: f.Timestamp(), @@ -142,7 +142,7 @@ func TestResolver_JobRun(t *testing.T) { Outputs: outputs, State: pipeline.RunStatusErrored, }, nil) - f.Mocks.jobORM.On("FindJobsByPipelineSpecIDs", []int32{5}).Return([]job.Job{ + f.Mocks.jobORM.On("FindJobsByPipelineSpecIDs", mock.Anything, []int32{5}).Return([]job.Job{ { ID: 1, PipelineSpecID: 2, @@ -180,7 +180,7 @@ func TestResolver_JobRun(t *testing.T) { name: "not found error", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.jobORM.On("FindPipelineRunByID", int64(2)).Return(pipeline.Run{}, sql.ErrNoRows) + f.Mocks.jobORM.On("FindPipelineRunByID", mock.Anything, int64(2)).Return(pipeline.Run{}, sql.ErrNoRows) f.App.On("JobORM").Return(f.Mocks.jobORM) }, query: query, @@ -197,7 +197,7 @@ func TestResolver_JobRun(t *testing.T) { name: "generic error on FindPipelineRunByID()", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.jobORM.On("FindPipelineRunByID", int64(2)).Return(pipeline.Run{}, gError) + f.Mocks.jobORM.On("FindPipelineRunByID", mock.Anything, int64(2)).Return(pipeline.Run{}, gError) f.App.On("JobORM").Return(f.Mocks.jobORM) }, query: query, diff --git a/core/web/resolver/job_test.go b/core/web/resolver/job_test.go index e91e37b6903..0615e47a621 100644 --- a/core/web/resolver/job_test.go +++ b/core/web/resolver/job_test.go @@ -72,7 +72,7 @@ func TestResolver_Jobs(t *testing.T) { plnSpecID := int32(12) f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobs", 0, 50).Return([]job.Job{ + f.Mocks.jobORM.On("FindJobs", mock.Anything, 0, 50).Return([]job.Job{ { ID: 1, Name: null.StringFrom("job1"), @@ -89,13 +89,13 @@ func TestResolver_Jobs(t *testing.T) { }, }, 1, nil) f.Mocks.jobORM. - On("FindPipelineRunIDsByJobID", int32(1), 0, 50). + On("FindPipelineRunIDsByJobID", mock.Anything, int32(1), 0, 50). Return([]int64{200}, nil) f.Mocks.jobORM. - On("FindPipelineRunsByIDs", []int64{200}). + On("FindPipelineRunsByIDs", mock.Anything, []int64{200}). Return([]pipeline.Run{{ID: 200}}, nil) f.Mocks.jobORM. - On("CountPipelineRunsByJobID", int32(1)). + On("CountPipelineRunsByJobID", mock.Anything, int32(1)). Return(int32(1), nil) }, query: query, @@ -208,7 +208,7 @@ func TestResolver_Job(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ ID: 1, Name: null.StringFrom("job1"), SchemaVersion: 1, @@ -223,13 +223,13 @@ func TestResolver_Job(t *testing.T) { }, }, nil) f.Mocks.jobORM. - On("FindPipelineRunIDsByJobID", int32(1), 0, 50). + On("FindPipelineRunIDsByJobID", mock.Anything, int32(1), 0, 50). Return([]int64{200}, nil) f.Mocks.jobORM. - On("FindPipelineRunsByIDs", []int64{200}). + On("FindPipelineRunsByIDs", mock.Anything, []int64{200}). Return([]pipeline.Run{{ID: 200}}, nil) f.Mocks.jobORM. - On("CountPipelineRunsByJobID", int32(1)). + On("CountPipelineRunsByJobID", mock.Anything, int32(1)). Return(int32(1), nil) }, query: query, @@ -240,7 +240,7 @@ func TestResolver_Job(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{}, sql.ErrNoRows) + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{}, sql.ErrNoRows) }, query: query, result: ` @@ -257,7 +257,7 @@ func TestResolver_Job(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ ID: 1, Name: null.StringFrom("job1"), SchemaVersion: 1, @@ -272,13 +272,13 @@ func TestResolver_Job(t *testing.T) { }, }, chains.ErrNoSuchChainID) f.Mocks.jobORM. - On("FindPipelineRunIDsByJobID", int32(1), 0, 50). + On("FindPipelineRunIDsByJobID", mock.Anything, int32(1), 0, 50). Return([]int64{200}, nil) f.Mocks.jobORM. - On("FindPipelineRunsByIDs", []int64{200}). + On("FindPipelineRunsByIDs", mock.Anything, []int64{200}). Return([]pipeline.Run{{ID: 200}}, nil) f.Mocks.jobORM. - On("CountPipelineRunsByJobID", int32(1)). + On("CountPipelineRunsByJobID", mock.Anything, int32(1)). Return(int32(1), nil) }, query: query, @@ -453,7 +453,7 @@ func TestResolver_DeleteJob(t *testing.T) { name: "success", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ ID: id, Name: null.StringFrom("test-job"), ExternalJobID: extJID, @@ -471,7 +471,7 @@ func TestResolver_DeleteJob(t *testing.T) { name: "not found on FindJob()", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{}, sql.ErrNoRows) + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{}, sql.ErrNoRows) f.App.On("JobORM").Return(f.Mocks.jobORM) }, query: mutation, @@ -489,7 +489,7 @@ func TestResolver_DeleteJob(t *testing.T) { name: "not found on DeleteJob()", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{}, nil) + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{}, nil) f.App.On("JobORM").Return(f.Mocks.jobORM) f.App.On("DeleteJob", mock.Anything, id).Return(sql.ErrNoRows) }, @@ -508,7 +508,7 @@ func TestResolver_DeleteJob(t *testing.T) { name: "generic error on FindJob()", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{}, gError) + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{}, gError) f.App.On("JobORM").Return(f.Mocks.jobORM) }, query: mutation, @@ -527,7 +527,7 @@ func TestResolver_DeleteJob(t *testing.T) { name: "generic error on DeleteJob()", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{}, nil) + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{}, nil) f.App.On("JobORM").Return(f.Mocks.jobORM) f.App.On("DeleteJob", mock.Anything, id).Return(gError) }, diff --git a/core/web/resolver/mutation.go b/core/web/resolver/mutation.go index 9663f9dfe82..641eabdfd8b 100644 --- a/core/web/resolver/mutation.go +++ b/core/web/resolver/mutation.go @@ -238,7 +238,7 @@ func (r *Resolver) CreateFeedsManagerChainConfig(ctx context.Context, args struc return nil, err } - ccfg, err := fsvc.GetChainConfig(id) + ccfg, err := fsvc.GetChainConfig(ctx, id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return NewCreateFeedsManagerChainConfigPayload(nil, err, nil), nil @@ -267,7 +267,7 @@ func (r *Resolver) DeleteFeedsManagerChainConfig(ctx context.Context, args struc fsvc := r.App.GetFeedsService() - ccfg, err := fsvc.GetChainConfig(id) + ccfg, err := fsvc.GetChainConfig(ctx, id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return NewDeleteFeedsManagerChainConfigPayload(nil, err), nil @@ -367,7 +367,7 @@ func (r *Resolver) UpdateFeedsManagerChainConfig(ctx context.Context, args struc return nil, err } - ccfg, err := fsvc.GetChainConfig(id) + ccfg, err := fsvc.GetChainConfig(ctx, id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return NewUpdateFeedsManagerChainConfigPayload(nil, err, nil), nil @@ -418,7 +418,7 @@ func (r *Resolver) CreateFeedsManager(ctx context.Context, args struct { return nil, err } - mgr, err := feedsService.GetManager(id) + mgr, err := feedsService.GetManager(ctx, id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return NewCreateFeedsManagerPayload(nil, err, nil), nil @@ -541,7 +541,7 @@ func (r *Resolver) UpdateFeedsManager(ctx context.Context, args struct { return nil, err } - mgr, err = feedsService.GetManager(id) + mgr, err = feedsService.GetManager(ctx, id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return NewUpdateFeedsManagerPayload(nil, err, nil), nil @@ -615,7 +615,7 @@ func (r *Resolver) DeleteBridge(ctx context.Context, args struct { return nil, err } - jobsUsingBridge, err := r.App.JobORM().FindJobIDsWithBridge(string(args.ID)) + jobsUsingBridge, err := r.App.JobORM().FindJobIDsWithBridge(ctx, string(args.ID)) if err != nil { return nil, err } @@ -751,7 +751,7 @@ func (r *Resolver) ApproveJobProposalSpec(ctx context.Context, args struct { return nil, err } - spec, err := feedsSvc.GetSpec(id) + spec, err := feedsSvc.GetSpec(ctx, id) if err != nil { if !errors.Is(err, sql.ErrNoRows) { return nil, err @@ -786,7 +786,7 @@ func (r *Resolver) CancelJobProposalSpec(ctx context.Context, args struct { return nil, err } - spec, err := feedsSvc.GetSpec(id) + spec, err := feedsSvc.GetSpec(ctx, id) if err != nil { if !errors.Is(err, sql.ErrNoRows) { return nil, err @@ -821,7 +821,7 @@ func (r *Resolver) RejectJobProposalSpec(ctx context.Context, args struct { return nil, err } - spec, err := feedsSvc.GetSpec(id) + spec, err := feedsSvc.GetSpec(ctx, id) if err != nil { if !errors.Is(err, sql.ErrNoRows) { return nil, err @@ -859,7 +859,7 @@ func (r *Resolver) UpdateJobProposalSpecDefinition(ctx context.Context, args str return nil, err } - spec, err := feedsSvc.GetSpec(id) + spec, err := feedsSvc.GetSpec(ctx, id) if err != nil { if !errors.Is(err, sql.ErrNoRows) { return nil, err @@ -1039,7 +1039,7 @@ func (r *Resolver) CreateJob(ctx context.Context, args struct { case job.VRF: jb, err = vrfcommon.ValidatedVRFSpec(args.Input.TOML) case job.Webhook: - jb, err = webhook.ValidatedWebhookSpec(args.Input.TOML, r.App.GetExternalInitiatorManager()) + jb, err = webhook.ValidatedWebhookSpec(ctx, args.Input.TOML, r.App.GetExternalInitiatorManager()) case job.BlockhashStore: jb, err = blockhashstore.ValidatedSpec(args.Input.TOML) case job.BlockHeaderFeeder: @@ -1085,7 +1085,7 @@ func (r *Resolver) DeleteJob(ctx context.Context, args struct { return nil, err } - j, err := r.App.JobORM().FindJobWithoutSpecErrors(id) + j, err := r.App.JobORM().FindJobWithoutSpecErrors(ctx, id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return NewDeleteJobPayload(r.App, nil, err), nil @@ -1119,7 +1119,7 @@ func (r *Resolver) DismissJobError(ctx context.Context, args struct { return nil, err } - specErr, err := r.App.JobORM().FindSpecError(id) + specErr, err := r.App.JobORM().FindSpecError(ctx, id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return NewDismissJobErrorPayload(nil, err), nil diff --git a/core/web/resolver/query.go b/core/web/resolver/query.go index e24974e765d..9de678adc51 100644 --- a/core/web/resolver/query.go +++ b/core/web/resolver/query.go @@ -132,7 +132,7 @@ func (r *Resolver) FeedsManager(ctx context.Context, args struct{ ID graphql.ID return nil, err } - mgr, err := r.App.GetFeedsService().GetManager(id) + mgr, err := r.App.GetFeedsService().GetManager(ctx, id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return NewFeedsManagerPayload(nil, err), nil @@ -149,7 +149,7 @@ func (r *Resolver) FeedsManagers(ctx context.Context) (*FeedsManagersPayloadReso return nil, err } - mgrs, err := r.App.GetFeedsService().ListManagers() + mgrs, err := r.App.GetFeedsService().ListManagers(ctx) if err != nil { return nil, err } @@ -168,7 +168,7 @@ func (r *Resolver) Job(ctx context.Context, args struct{ ID graphql.ID }) (*JobP return nil, err } - j, err := r.App.JobORM().FindJobWithoutSpecErrors(id) + j, err := r.App.JobORM().FindJobWithoutSpecErrors(ctx, id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return NewJobPayload(r.App, nil, err), nil @@ -197,7 +197,7 @@ func (r *Resolver) Jobs(ctx context.Context, args struct { offset := pageOffset(args.Offset) limit := pageLimit(args.Limit) - jobs, count, err := r.App.JobORM().FindJobs(offset, limit) + jobs, count, err := r.App.JobORM().FindJobs(ctx, offset, limit) if err != nil { return nil, err } @@ -328,7 +328,7 @@ func (r *Resolver) JobProposal(ctx context.Context, args struct { return nil, err } - jp, err := r.App.GetFeedsService().GetJobProposal(id) + jp, err := r.App.GetFeedsService().GetJobProposal(ctx, id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return NewJobProposalPayload(nil, err), nil @@ -377,7 +377,7 @@ func (r *Resolver) JobRuns(ctx context.Context, args struct { limit := pageLimit(args.Limit) offset := pageOffset(args.Offset) - runs, count, err := r.App.JobORM().PipelineRuns(nil, offset, limit) + runs, count, err := r.App.JobORM().PipelineRuns(ctx, nil, offset, limit) if err != nil { return nil, err } @@ -397,7 +397,7 @@ func (r *Resolver) JobRun(ctx context.Context, args struct { return nil, err } - jr, err := r.App.JobORM().FindPipelineRunByID(id) + jr, err := r.App.JobORM().FindPipelineRunByID(ctx, id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return NewJobRunPayload(nil, r.App, err), nil diff --git a/core/web/resolver/spec_test.go b/core/web/resolver/spec_test.go index 7021576fdcf..43682c14ead 100644 --- a/core/web/resolver/spec_test.go +++ b/core/web/resolver/spec_test.go @@ -6,6 +6,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/lib/pq" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "gopkg.in/guregu/null.v4" @@ -35,7 +36,7 @@ func TestResolver_CronSpec(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Cron, CronSpec: &job.CronSpec{ CronSchedule: "CRON_TZ=UTC 0 0 1 1 *", @@ -89,7 +90,7 @@ func TestResolver_DirectRequestSpec(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.DirectRequest, DirectRequestSpec: &job.DirectRequestSpec{ ContractAddress: contractAddress, @@ -154,7 +155,7 @@ func TestResolver_FluxMonitorSpec(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.FluxMonitor, FluxMonitorSpec: &job.FluxMonitorSpec{ ContractAddress: contractAddress, @@ -221,7 +222,7 @@ func TestResolver_FluxMonitorSpec(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.FluxMonitor, FluxMonitorSpec: &job.FluxMonitorSpec{ ContractAddress: contractAddress, @@ -304,7 +305,7 @@ func TestResolver_KeeperSpec(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Keeper, KeeperSpec: &job.KeeperSpec{ ContractAddress: contractAddress, @@ -368,7 +369,7 @@ func TestResolver_OCRSpec(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.OffchainReporting, OCROracleSpec: &job.OCROracleSpec{ BlockchainTimeout: models.Interval(1 * time.Minute), @@ -473,7 +474,7 @@ func TestResolver_OCR2Spec(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.OffchainReporting2, OCR2OracleSpec: &job.OCR2OracleSpec{ BlockchainTimeout: models.Interval(1 * time.Minute), @@ -575,7 +576,7 @@ func TestResolver_VRFSpec(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.VRF, VRFSpec: &job.VRFSpec{ BatchCoordinatorAddress: &batchCoordinatorAddress, @@ -671,7 +672,7 @@ func TestResolver_WebhookSpec(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Webhook, WebhookSpec: &job.WebhookSpec{ CreatedAt: f.Timestamp(), @@ -740,7 +741,7 @@ func TestResolver_BlockhashStoreSpec(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.BlockhashStore, BlockhashStoreSpec: &job.BlockhashStoreSpec{ CoordinatorV1Address: &coordinatorV1Address, @@ -844,7 +845,7 @@ func TestResolver_BlockHeaderFeederSpec(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.BlockHeaderFeeder, BlockHeaderFeederSpec: &job.BlockHeaderFeederSpec{ CoordinatorV1Address: &coordinatorV1Address, @@ -931,7 +932,7 @@ func TestResolver_BootstrapSpec(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Bootstrap, BootstrapSpec: &job.BootstrapSpec{ ID: id, @@ -1003,7 +1004,7 @@ func TestResolver_WorkflowSpec(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Workflow, WorkflowSpec: &job.WorkflowSpec{ ID: id, @@ -1061,7 +1062,7 @@ func TestResolver_GatewaySpec(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("JobORM").Return(f.Mocks.jobORM) - f.Mocks.jobORM.On("FindJobWithoutSpecErrors", id).Return(job.Job{ + f.Mocks.jobORM.On("FindJobWithoutSpecErrors", mock.Anything, id).Return(job.Job{ Type: job.Gateway, GatewaySpec: &job.GatewaySpec{ ID: id, diff --git a/core/web/sessions_controller_test.go b/core/web/sessions_controller_test.go index cd63628390e..049be4e2b69 100644 --- a/core/web/sessions_controller_test.go +++ b/core/web/sessions_controller_test.go @@ -8,10 +8,10 @@ import ( "testing" "time" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" clhttptest "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/httptest" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/sessions" "github.com/smartcontractkit/chainlink/v2/core/web" @@ -25,7 +25,7 @@ func TestSessionsController_Create(t *testing.T) { ctx := testutils.Context(t) app := cltest.NewApplicationEVMDisabled(t) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) user := cltest.MustRandomUser(t) require.NoError(t, app.AuthenticationProvider().CreateUser(ctx, &user)) @@ -44,6 +44,7 @@ func TestSessionsController_Create(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + ctx := testutils.Context(t) body := fmt.Sprintf(`{"email":"%s","password":"%s"}`, test.email, test.password) request, err := http.NewRequestWithContext(ctx, "POST", app.Server.URL+"/sessions", bytes.NewBufferString(body)) assert.NoError(t, err) @@ -78,9 +79,10 @@ func TestSessionsController_Create(t *testing.T) { } } -func mustInsertSession(t *testing.T, q pg.Q, session *sessions.Session) { +func mustInsertSession(t *testing.T, ds sqlutil.DataSource, session *sessions.Session) { + ctx := testutils.Context(t) sql := "INSERT INTO sessions (id, email, last_used, created_at) VALUES ($1, $2, $3, $4) RETURNING *" - _, err := q.Exec(sql, session.ID, session.Email, session.LastUsed, session.CreatedAt) + _, err := ds.ExecContext(ctx, sql, session.ID, session.Email, session.LastUsed, session.CreatedAt) require.NoError(t, err) } @@ -97,8 +99,7 @@ func TestSessionsController_Create_ReapSessions(t *testing.T) { staleSession := cltest.NewSession() staleSession.LastUsed = time.Now().Add(-cltest.MustParseDuration(t, "241h")) staleSession.Email = user.Email - q := pg.NewQ(app.GetSqlxDB(), app.GetLogger(), app.GetConfig().Database()) - mustInsertSession(t, q, &staleSession) + mustInsertSession(t, app.GetDB(), &staleSession) body := fmt.Sprintf(`{"email":"%s","password":"%s"}`, user.Email, cltest.Password) req, err := http.NewRequestWithContext(ctx, "POST", app.Server.URL+"/sessions", bytes.NewBufferString(body)) @@ -127,15 +128,14 @@ func TestSessionsController_Destroy(t *testing.T) { ctx := testutils.Context(t) app := cltest.NewApplicationEVMDisabled(t) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) user := cltest.MustRandomUser(t) require.NoError(t, app.AuthenticationProvider().CreateUser(ctx, &user)) correctSession := sessions.NewSession() correctSession.Email = user.Email - q := pg.NewQ(app.GetSqlxDB(), app.GetLogger(), app.GetConfig().Database()) - mustInsertSession(t, q, &correctSession) + mustInsertSession(t, app.GetDB(), &correctSession) client := clhttptest.NewTestLocalOnlyHTTPClient() tests := []struct { @@ -148,6 +148,7 @@ func TestSessionsController_Destroy(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + ctx := testutils.Context(t) cookie := cltest.MustGenerateSessionCookie(t, test.sessionID) request, err := http.NewRequestWithContext(ctx, "DELETE", app.Server.URL+"/sessions", nil) assert.NoError(t, err) @@ -173,8 +174,7 @@ func TestSessionsController_Destroy_ReapSessions(t *testing.T) { client := clhttptest.NewTestLocalOnlyHTTPClient() app := cltest.NewApplicationEVMDisabled(t) - q := pg.NewQ(app.GetSqlxDB(), app.GetLogger(), app.GetConfig().Database()) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) user := cltest.MustRandomUser(t) require.NoError(t, app.AuthenticationProvider().CreateUser(ctx, &user)) @@ -182,13 +182,13 @@ func TestSessionsController_Destroy_ReapSessions(t *testing.T) { correctSession := sessions.NewSession() correctSession.Email = user.Email - mustInsertSession(t, q, &correctSession) + mustInsertSession(t, app.GetDB(), &correctSession) cookie := cltest.MustGenerateSessionCookie(t, correctSession.ID) staleSession := cltest.NewSession() staleSession.Email = user.Email staleSession.LastUsed = time.Now().Add(-cltest.MustParseDuration(t, "241h")) - mustInsertSession(t, q, &staleSession) + mustInsertSession(t, app.GetDB(), &staleSession) request, err := http.NewRequestWithContext(ctx, "DELETE", app.Server.URL+"/sessions", nil) assert.NoError(t, err) diff --git a/go.mod b/go.mod index 752dd110c49..6294d81f5cb 100644 --- a/go.mod +++ b/go.mod @@ -72,8 +72,8 @@ require ( github.com/shopspring/decimal v1.3.1 github.com/smartcontractkit/chain-selectors v1.0.10 github.com/smartcontractkit/chainlink-automation v1.0.3 - github.com/smartcontractkit/chainlink-common v0.1.7-0.20240419205832-845fa69af8d9 - github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419131812-73d148593d92 + github.com/smartcontractkit/chainlink-common v0.1.7-0.20240424104752-ed1756cf454c + github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419213354-ea34a948e2ee github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540 github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240422172640-59d47c73ba58 diff --git a/go.sum b/go.sum index 5f042f2e4c4..c335fc39642 100644 --- a/go.sum +++ b/go.sum @@ -1180,10 +1180,10 @@ github.com/smartcontractkit/chain-selectors v1.0.10 h1:t9kJeE6B6G+hKD0GYR4kGJSCq github.com/smartcontractkit/chain-selectors v1.0.10/go.mod h1:d4Hi+E1zqjy9HqMkjBE5q1vcG9VGgxf5VxiRHfzi2kE= github.com/smartcontractkit/chainlink-automation v1.0.3 h1:h/ijT0NiyV06VxYVgcNfsE3+8OEzT3Q0Z9au0z1BPWs= github.com/smartcontractkit/chainlink-automation v1.0.3/go.mod h1:RjboV0Qd7YP+To+OrzHGXaxUxoSONveCoAK2TQ1INLU= -github.com/smartcontractkit/chainlink-common v0.1.7-0.20240419205832-845fa69af8d9 h1:elDIBChe7ByPNvCyrSjMLTPKrgY+sKgzzlWe2p3wokY= -github.com/smartcontractkit/chainlink-common v0.1.7-0.20240419205832-845fa69af8d9/go.mod h1:GTDBbovHUSAUk+fuGIySF2A/whhdtHGaWmU61BoERks= -github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419131812-73d148593d92 h1:MvaNzuaQh1vX4CAYLM8qFd99cf0ZF1JNwtDZtLU7WvU= -github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419131812-73d148593d92/go.mod h1:uATrrJ8IsuBkOBJ46USuf73gz9gZy5k5bzGE5/ji/rc= +github.com/smartcontractkit/chainlink-common v0.1.7-0.20240424104752-ed1756cf454c h1:nk3g1il/cG0raV2ymNlytAPvjfYNSvwHP7Gfy6ItmSI= +github.com/smartcontractkit/chainlink-common v0.1.7-0.20240424104752-ed1756cf454c/go.mod h1:GTDBbovHUSAUk+fuGIySF2A/whhdtHGaWmU61BoERks= +github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419213354-ea34a948e2ee h1:eFuBKyEbL2b+eyfgV/Eu9+8HuCEev+IcBi+K9l1dG7g= +github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419213354-ea34a948e2ee/go.mod h1:uATrrJ8IsuBkOBJ46USuf73gz9gZy5k5bzGE5/ji/rc= github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540 h1:xFSv8561jsLtF6gYZr/zW2z5qUUAkcFkApin2mnbYTo= github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540/go.mod h1:sjAmX8K2kbQhvDarZE1ZZgDgmHJ50s0BBc/66vKY2ek= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab h1:Ct1oUlyn03HDUVdFHJqtRGRUujMqdoMzvf/Cjhe30Ag= diff --git a/integration-tests/go.mod b/integration-tests/go.mod index 327e5e69d94..5974f453712 100644 --- a/integration-tests/go.mod +++ b/integration-tests/go.mod @@ -25,7 +25,7 @@ require ( github.com/shopspring/decimal v1.3.1 github.com/slack-go/slack v0.12.2 github.com/smartcontractkit/chainlink-automation v1.0.3 - github.com/smartcontractkit/chainlink-common v0.1.7-0.20240419205832-845fa69af8d9 + github.com/smartcontractkit/chainlink-common v0.1.7-0.20240424104752-ed1756cf454c github.com/smartcontractkit/chainlink-testing-framework v1.28.4 github.com/smartcontractkit/chainlink-vrf v0.0.0-20240222010609-cd67d123c772 github.com/smartcontractkit/chainlink/v2 v2.0.0-00010101000000-000000000000 @@ -377,7 +377,7 @@ require ( github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/smartcontractkit/chain-selectors v1.0.10 // indirect - github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419131812-73d148593d92 // indirect + github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419213354-ea34a948e2ee // indirect github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540 // indirect github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab // indirect github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240422172640-59d47c73ba58 // indirect diff --git a/integration-tests/go.sum b/integration-tests/go.sum index 577ae3f2a0d..c8c5ea54823 100644 --- a/integration-tests/go.sum +++ b/integration-tests/go.sum @@ -1517,10 +1517,10 @@ github.com/smartcontractkit/chain-selectors v1.0.10 h1:t9kJeE6B6G+hKD0GYR4kGJSCq github.com/smartcontractkit/chain-selectors v1.0.10/go.mod h1:d4Hi+E1zqjy9HqMkjBE5q1vcG9VGgxf5VxiRHfzi2kE= github.com/smartcontractkit/chainlink-automation v1.0.3 h1:h/ijT0NiyV06VxYVgcNfsE3+8OEzT3Q0Z9au0z1BPWs= github.com/smartcontractkit/chainlink-automation v1.0.3/go.mod h1:RjboV0Qd7YP+To+OrzHGXaxUxoSONveCoAK2TQ1INLU= -github.com/smartcontractkit/chainlink-common v0.1.7-0.20240419205832-845fa69af8d9 h1:elDIBChe7ByPNvCyrSjMLTPKrgY+sKgzzlWe2p3wokY= -github.com/smartcontractkit/chainlink-common v0.1.7-0.20240419205832-845fa69af8d9/go.mod h1:GTDBbovHUSAUk+fuGIySF2A/whhdtHGaWmU61BoERks= -github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419131812-73d148593d92 h1:MvaNzuaQh1vX4CAYLM8qFd99cf0ZF1JNwtDZtLU7WvU= -github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419131812-73d148593d92/go.mod h1:uATrrJ8IsuBkOBJ46USuf73gz9gZy5k5bzGE5/ji/rc= +github.com/smartcontractkit/chainlink-common v0.1.7-0.20240424104752-ed1756cf454c h1:nk3g1il/cG0raV2ymNlytAPvjfYNSvwHP7Gfy6ItmSI= +github.com/smartcontractkit/chainlink-common v0.1.7-0.20240424104752-ed1756cf454c/go.mod h1:GTDBbovHUSAUk+fuGIySF2A/whhdtHGaWmU61BoERks= +github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419213354-ea34a948e2ee h1:eFuBKyEbL2b+eyfgV/Eu9+8HuCEev+IcBi+K9l1dG7g= +github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419213354-ea34a948e2ee/go.mod h1:uATrrJ8IsuBkOBJ46USuf73gz9gZy5k5bzGE5/ji/rc= github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540 h1:xFSv8561jsLtF6gYZr/zW2z5qUUAkcFkApin2mnbYTo= github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540/go.mod h1:sjAmX8K2kbQhvDarZE1ZZgDgmHJ50s0BBc/66vKY2ek= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab h1:Ct1oUlyn03HDUVdFHJqtRGRUujMqdoMzvf/Cjhe30Ag= diff --git a/integration-tests/load/go.mod b/integration-tests/load/go.mod index eaf0332af63..37ae0a0d33a 100644 --- a/integration-tests/load/go.mod +++ b/integration-tests/load/go.mod @@ -16,7 +16,7 @@ require ( github.com/rs/zerolog v1.30.0 github.com/slack-go/slack v0.12.2 github.com/smartcontractkit/chainlink-automation v1.0.3 - github.com/smartcontractkit/chainlink-common v0.1.7-0.20240419205832-845fa69af8d9 + github.com/smartcontractkit/chainlink-common v0.1.7-0.20240424104752-ed1756cf454c github.com/smartcontractkit/chainlink-testing-framework v1.28.4 github.com/smartcontractkit/chainlink/integration-tests v0.0.0-20240214231432-4ad5eb95178c github.com/smartcontractkit/chainlink/v2 v2.9.0-beta0.0.20240216210048-da02459ddad8 @@ -366,7 +366,7 @@ require ( github.com/shopspring/decimal v1.3.1 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/smartcontractkit/chain-selectors v1.0.10 // indirect - github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419131812-73d148593d92 // indirect + github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419213354-ea34a948e2ee // indirect github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540 // indirect github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab // indirect github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240422172640-59d47c73ba58 // indirect diff --git a/integration-tests/load/go.sum b/integration-tests/load/go.sum index 631336409b3..30bef5b5b79 100644 --- a/integration-tests/load/go.sum +++ b/integration-tests/load/go.sum @@ -1500,10 +1500,10 @@ github.com/smartcontractkit/chain-selectors v1.0.10 h1:t9kJeE6B6G+hKD0GYR4kGJSCq github.com/smartcontractkit/chain-selectors v1.0.10/go.mod h1:d4Hi+E1zqjy9HqMkjBE5q1vcG9VGgxf5VxiRHfzi2kE= github.com/smartcontractkit/chainlink-automation v1.0.3 h1:h/ijT0NiyV06VxYVgcNfsE3+8OEzT3Q0Z9au0z1BPWs= github.com/smartcontractkit/chainlink-automation v1.0.3/go.mod h1:RjboV0Qd7YP+To+OrzHGXaxUxoSONveCoAK2TQ1INLU= -github.com/smartcontractkit/chainlink-common v0.1.7-0.20240419205832-845fa69af8d9 h1:elDIBChe7ByPNvCyrSjMLTPKrgY+sKgzzlWe2p3wokY= -github.com/smartcontractkit/chainlink-common v0.1.7-0.20240419205832-845fa69af8d9/go.mod h1:GTDBbovHUSAUk+fuGIySF2A/whhdtHGaWmU61BoERks= -github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419131812-73d148593d92 h1:MvaNzuaQh1vX4CAYLM8qFd99cf0ZF1JNwtDZtLU7WvU= -github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419131812-73d148593d92/go.mod h1:uATrrJ8IsuBkOBJ46USuf73gz9gZy5k5bzGE5/ji/rc= +github.com/smartcontractkit/chainlink-common v0.1.7-0.20240424104752-ed1756cf454c h1:nk3g1il/cG0raV2ymNlytAPvjfYNSvwHP7Gfy6ItmSI= +github.com/smartcontractkit/chainlink-common v0.1.7-0.20240424104752-ed1756cf454c/go.mod h1:GTDBbovHUSAUk+fuGIySF2A/whhdtHGaWmU61BoERks= +github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419213354-ea34a948e2ee h1:eFuBKyEbL2b+eyfgV/Eu9+8HuCEev+IcBi+K9l1dG7g= +github.com/smartcontractkit/chainlink-cosmos v0.4.1-0.20240419213354-ea34a948e2ee/go.mod h1:uATrrJ8IsuBkOBJ46USuf73gz9gZy5k5bzGE5/ji/rc= github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540 h1:xFSv8561jsLtF6gYZr/zW2z5qUUAkcFkApin2mnbYTo= github.com/smartcontractkit/chainlink-data-streams v0.0.0-20240220203239-09be0ea34540/go.mod h1:sjAmX8K2kbQhvDarZE1ZZgDgmHJ50s0BBc/66vKY2ek= github.com/smartcontractkit/chainlink-feeds v0.0.0-20240422130241-13c17a91b2ab h1:Ct1oUlyn03HDUVdFHJqtRGRUujMqdoMzvf/Cjhe30Ag= From 84c7d193dd712ecf82ab251e6cae56bd1c6e61df Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Wed, 24 Apr 2024 14:58:06 -0500 Subject: [PATCH 2/2] changeset --- .changeset/fuzzy-pans-destroy.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/fuzzy-pans-destroy.md diff --git a/.changeset/fuzzy-pans-destroy.md b/.changeset/fuzzy-pans-destroy.md new file mode 100644 index 00000000000..3cff19f8d8a --- /dev/null +++ b/.changeset/fuzzy-pans-destroy.md @@ -0,0 +1,5 @@ +--- +"chainlink": minor +--- + +Use sqlutil instead of pg.Opts/Q/Queryer #internal