From 1efb52510704e43c3e349df20178480553c4670f Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Fri, 5 Apr 2024 14:13:35 -0500 Subject: [PATCH] core/services/relay/evm: switch RequestRound DB & Tracker to use sqlutil.DataSource (#12706) --- core/services/chainlink/relayer_factory.go | 1 + core/services/job/spawner_test.go | 1 + core/services/relay/evm/evm.go | 13 +++-- core/services/relay/evm/evm_test.go | 8 ++- core/services/relay/evm/median.go | 13 +++-- .../relay/evm/mocks/request_round_db.go | 53 +++++++++++++------ core/services/relay/evm/request_round_db.go | 29 ++++++---- .../relay/evm/request_round_db_test.go | 18 +++---- .../relay/evm/request_round_tracker.go | 20 ++++--- .../relay/evm/request_round_tracker_test.go | 18 +++++-- 10 files changed, 113 insertions(+), 61 deletions(-) diff --git a/core/services/chainlink/relayer_factory.go b/core/services/chainlink/relayer_factory.go index f5cb1badb95..2dd5e1eb68a 100644 --- a/core/services/chainlink/relayer_factory.go +++ b/core/services/chainlink/relayer_factory.go @@ -68,6 +68,7 @@ func (r *RelayerFactory) NewEVM(ctx context.Context, config EVMFactoryConfig) (m relayerOpts := evmrelay.RelayerOpts{ DB: ccOpts.SqlxDB, + DS: ccOpts.DB, QConfig: ccOpts.AppConfig.Database(), CSAETHKeystore: config.CSAETHKeystore, MercuryPool: r.MercuryPool, diff --git a/core/services/job/spawner_test.go b/core/services/job/spawner_test.go index b6de9d790fa..ac0783e9868 100644 --- a/core/services/job/spawner_test.go +++ b/core/services/job/spawner_test.go @@ -287,6 +287,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { evmRelayer, err := evmrelayer.NewRelayer(lggr, chain, evmrelayer.RelayerOpts{ DB: db, + DS: db, QConfig: testopts.GeneralConfig.Database(), CSAETHKeystore: keyStore, }) diff --git a/core/services/relay/evm/evm.go b/core/services/relay/evm/evm.go index ddddb82aaed..95cf9efc944 100644 --- a/core/services/relay/evm/evm.go +++ b/core/services/relay/evm/evm.go @@ -21,6 +21,7 @@ import ( ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" commontypes "github.com/smartcontractkit/chainlink-common/pkg/types" txmgrcommon "github.com/smartcontractkit/chainlink/v2/common/txmgr" @@ -70,7 +71,8 @@ func init() { var _ commontypes.Relayer = &Relayer{} //nolint:staticcheck type Relayer struct { - db *sqlx.DB + db *sqlx.DB // legacy: prefer to use ds instead + ds sqlutil.DataSource chain legacyevm.Chain lggr logger.Logger ks CSAETHKeystore @@ -93,7 +95,8 @@ type CSAETHKeystore interface { } type RelayerOpts struct { - *sqlx.DB + *sqlx.DB // legacy: prefer to use ds instead + DS sqlutil.DataSource pg.QConfig CSAETHKeystore MercuryPool wsrpc.Pool @@ -104,6 +107,9 @@ func (c RelayerOpts) Validate() error { if c.DB == nil { err = errors.Join(err, errors.New("nil DB")) } + if c.DS == nil { + err = errors.Join(err, errors.New("nil DataSource")) + } if c.QConfig == nil { err = errors.Join(err, errors.New("nil QConfig")) } @@ -129,6 +135,7 @@ func NewRelayer(lggr logger.Logger, chain legacyevm.Chain, opts RelayerOpts) (*R cdcFactory := llo.NewChannelDefinitionCacheFactory(lggr, lloORM, chain.LogPoller()) return &Relayer{ db: opts.DB, + ds: opts.DS, chain: chain, lggr: lggr, ks: opts.CSAETHKeystore, @@ -588,7 +595,7 @@ func (r *Relayer) NewMedianProvider(rargs commontypes.RelayArgs, pargs commontyp return nil, err } - medianContract, err := newMedianContract(configWatcher.ContractConfigTracker(), configWatcher.contractAddress, configWatcher.chain, rargs.JobID, r.db, lggr) + medianContract, err := newMedianContract(configWatcher.ContractConfigTracker(), configWatcher.contractAddress, configWatcher.chain, rargs.JobID, r.ds, lggr) if err != nil { return nil, err } diff --git a/core/services/relay/evm/evm_test.go b/core/services/relay/evm/evm_test.go index 41e51a7ab8f..d53fe910bc3 100644 --- a/core/services/relay/evm/evm_test.go +++ b/core/services/relay/evm/evm_test.go @@ -7,6 +7,7 @@ import ( "github.com/jmoiron/sqlx" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm" @@ -16,6 +17,7 @@ func TestRelayerOpts_Validate(t *testing.T) { cfg := configtest.NewTestGeneralConfig(t) type fields struct { DB *sqlx.DB + DS sqlutil.DataSource QConfig pg.QConfig CSAETHKeystore evm.CSAETHKeystore } @@ -28,20 +30,23 @@ func TestRelayerOpts_Validate(t *testing.T) { name: "all invalid", fields: fields{ DB: nil, + DS: nil, QConfig: nil, CSAETHKeystore: nil, }, wantErrContains: `nil DB +nil DataSource nil QConfig nil Keystore`, }, { - name: "missing db, keystore", + name: "missing db, ds, keystore", fields: fields{ DB: nil, QConfig: cfg.Database(), }, wantErrContains: `nil DB +nil DataSource nil Keystore`, }, } @@ -49,6 +54,7 @@ nil Keystore`, t.Run(tt.name, func(t *testing.T) { c := evm.RelayerOpts{ DB: tt.fields.DB, + DS: tt.fields.DS, QConfig: tt.fields.QConfig, CSAETHKeystore: tt.fields.CSAETHKeystore, } diff --git a/core/services/relay/evm/median.go b/core/services/relay/evm/median.go index e3200d8e867..2407cff7140 100644 --- a/core/services/relay/evm/median.go +++ b/core/services/relay/evm/median.go @@ -7,7 +7,6 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "github.com/smartcontractkit/libocr/gethwrappers2/ocr2aggregator" "github.com/smartcontractkit/libocr/offchainreporting2/reportingplugin/median" @@ -15,6 +14,7 @@ import ( ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/chains/legacyevm" offchain_aggregator_wrapper "github.com/smartcontractkit/chainlink/v2/core/internal/gethwrappers2/generated/offchainaggregator" "github.com/smartcontractkit/chainlink/v2/core/logger" @@ -30,7 +30,7 @@ type medianContract struct { requestRoundTracker *RequestRoundTracker } -func newMedianContract(configTracker types.ContractConfigTracker, contractAddress common.Address, chain legacyevm.Chain, specID int32, db *sqlx.DB, lggr logger.Logger) (*medianContract, error) { +func newMedianContract(configTracker types.ContractConfigTracker, contractAddress common.Address, chain legacyevm.Chain, specID int32, ds sqlutil.DataSource, lggr logger.Logger) (*medianContract, error) { lggr = lggr.Named("MedianContract") contract, err := offchain_aggregator_wrapper.NewOffchainAggregator(contractAddress, chain.Client()) if err != nil { @@ -58,16 +58,15 @@ func newMedianContract(configTracker types.ContractConfigTracker, contractAddres chain.LogBroadcaster(), specID, lggr, - db, - NewRoundRequestedDB(db.DB, specID, lggr), + ds, + NewRoundRequestedDB(ds, specID, lggr), chain.Config().EVM(), - chain.Config().Database(), ), }, nil } -func (oc *medianContract) Start(context.Context) error { +func (oc *medianContract) Start(ctx context.Context) error { return oc.StartOnce("MedianContract", func() error { - return oc.requestRoundTracker.Start() + return oc.requestRoundTracker.Start(ctx) }) } diff --git a/core/services/relay/evm/mocks/request_round_db.go b/core/services/relay/evm/mocks/request_round_db.go index eb27e8bd526..725fc6e6b37 100644 --- a/core/services/relay/evm/mocks/request_round_db.go +++ b/core/services/relay/evm/mocks/request_round_db.go @@ -3,9 +3,12 @@ package mocks import ( - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" - ocr2aggregator "github.com/smartcontractkit/libocr/gethwrappers2/ocr2aggregator" + context "context" + + evm "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm" mock "github.com/stretchr/testify/mock" + + ocr2aggregator "github.com/smartcontractkit/libocr/gethwrappers2/ocr2aggregator" ) // RequestRoundDB is an autogenerated mock type for the RequestRoundDB type @@ -13,9 +16,9 @@ type RequestRoundDB struct { mock.Mock } -// LoadLatestRoundRequested provides a mock function with given fields: -func (_m *RequestRoundDB) LoadLatestRoundRequested() (ocr2aggregator.OCR2AggregatorRoundRequested, error) { - ret := _m.Called() +// LoadLatestRoundRequested provides a mock function with given fields: _a0 +func (_m *RequestRoundDB) LoadLatestRoundRequested(_a0 context.Context) (ocr2aggregator.OCR2AggregatorRoundRequested, error) { + ret := _m.Called(_a0) if len(ret) == 0 { panic("no return value specified for LoadLatestRoundRequested") @@ -23,17 +26,17 @@ func (_m *RequestRoundDB) LoadLatestRoundRequested() (ocr2aggregator.OCR2Aggrega var r0 ocr2aggregator.OCR2AggregatorRoundRequested var r1 error - if rf, ok := ret.Get(0).(func() (ocr2aggregator.OCR2AggregatorRoundRequested, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (ocr2aggregator.OCR2AggregatorRoundRequested, error)); ok { + return rf(_a0) } - if rf, ok := ret.Get(0).(func() ocr2aggregator.OCR2AggregatorRoundRequested); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) ocr2aggregator.OCR2AggregatorRoundRequested); ok { + r0 = rf(_a0) } else { r0 = ret.Get(0).(ocr2aggregator.OCR2AggregatorRoundRequested) } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) } else { r1 = ret.Error(1) } @@ -41,17 +44,35 @@ func (_m *RequestRoundDB) LoadLatestRoundRequested() (ocr2aggregator.OCR2Aggrega return r0, r1 } -// SaveLatestRoundRequested provides a mock function with given fields: tx, rr -func (_m *RequestRoundDB) SaveLatestRoundRequested(tx pg.Queryer, rr ocr2aggregator.OCR2AggregatorRoundRequested) error { - ret := _m.Called(tx, rr) +// SaveLatestRoundRequested provides a mock function with given fields: ctx, rr +func (_m *RequestRoundDB) SaveLatestRoundRequested(ctx context.Context, rr ocr2aggregator.OCR2AggregatorRoundRequested) error { + ret := _m.Called(ctx, rr) if len(ret) == 0 { panic("no return value specified for SaveLatestRoundRequested") } var r0 error - if rf, ok := ret.Get(0).(func(pg.Queryer, ocr2aggregator.OCR2AggregatorRoundRequested) error); ok { - r0 = rf(tx, rr) + if rf, ok := ret.Get(0).(func(context.Context, ocr2aggregator.OCR2AggregatorRoundRequested) error); ok { + r0 = rf(ctx, rr) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Transact provides a mock function with given fields: _a0, _a1 +func (_m *RequestRoundDB) Transact(_a0 context.Context, _a1 func(evm.RequestRoundDB) error) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Transact") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, func(evm.RequestRoundDB) error) error); ok { + r0 = rf(_a0, _a1) } else { r0 = ret.Error(0) } diff --git a/core/services/relay/evm/request_round_db.go b/core/services/relay/evm/request_round_db.go index b3a5b01bc2c..2b6ae10782d 100644 --- a/core/services/relay/evm/request_round_db.go +++ b/core/services/relay/evm/request_round_db.go @@ -1,43 +1,50 @@ package evm import ( - "database/sql" + "context" "encoding/json" "github.com/pkg/errors" "github.com/smartcontractkit/libocr/gethwrappers2/ocr2aggregator" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) // RequestRoundDB stores requested rounds for querying by the median plugin. type RequestRoundDB interface { - SaveLatestRoundRequested(tx pg.Queryer, rr ocr2aggregator.OCR2AggregatorRoundRequested) error - LoadLatestRoundRequested() (rr ocr2aggregator.OCR2AggregatorRoundRequested, err error) + SaveLatestRoundRequested(ctx context.Context, rr ocr2aggregator.OCR2AggregatorRoundRequested) error + LoadLatestRoundRequested(context.Context) (rr ocr2aggregator.OCR2AggregatorRoundRequested, err error) + Transact(context.Context, func(db RequestRoundDB) error) error } var _ RequestRoundDB = &requestRoundDB{} //go:generate mockery --quiet --name RequestRoundDB --output ./mocks/ --case=underscore type requestRoundDB struct { - *sql.DB + ds sqlutil.DataSource oracleSpecID int32 lggr logger.Logger } // NewDB returns a new DB scoped to this oracleSpecID -func NewRoundRequestedDB(sqldb *sql.DB, oracleSpecID int32, lggr logger.Logger) *requestRoundDB { - return &requestRoundDB{sqldb, oracleSpecID, lggr} +func NewRoundRequestedDB(ds sqlutil.DataSource, oracleSpecID int32, lggr logger.Logger) *requestRoundDB { + return &requestRoundDB{ds, oracleSpecID, lggr} } -func (d *requestRoundDB) SaveLatestRoundRequested(tx pg.Queryer, rr ocr2aggregator.OCR2AggregatorRoundRequested) error { +func (d *requestRoundDB) Transact(ctx context.Context, fn func(db RequestRoundDB) error) error { + return sqlutil.Transact(ctx, func(ds sqlutil.DataSource) RequestRoundDB { + return NewRoundRequestedDB(ds, d.oracleSpecID, d.lggr) + }, d.ds, nil, fn) +} + +func (d *requestRoundDB) SaveLatestRoundRequested(ctx context.Context, rr ocr2aggregator.OCR2AggregatorRoundRequested) error { rawLog, err := json.Marshal(rr.Raw) if err != nil { return errors.Wrap(err, "could not marshal log as JSON") } - _, err = tx.Exec(` + _, err = d.ds.ExecContext(ctx, ` INSERT INTO ocr2_latest_round_requested (ocr2_oracle_spec_id, requester, config_digest, epoch, round, raw) VALUES ($1,$2,$3,$4,$5,$6) ON CONFLICT (ocr2_oracle_spec_id) DO UPDATE SET requester = EXCLUDED.requester, @@ -50,9 +57,9 @@ VALUES ($1,$2,$3,$4,$5,$6) ON CONFLICT (ocr2_oracle_spec_id) DO UPDATE SET return errors.Wrap(err, "could not save latest round requested") } -func (d *requestRoundDB) LoadLatestRoundRequested() (ocr2aggregator.OCR2AggregatorRoundRequested, error) { +func (d *requestRoundDB) LoadLatestRoundRequested(ctx context.Context) (ocr2aggregator.OCR2AggregatorRoundRequested, error) { rr := ocr2aggregator.OCR2AggregatorRoundRequested{} - rows, err := d.Query(` + rows, err := d.ds.QueryContext(ctx, ` SELECT requester, config_digest, epoch, round, raw FROM ocr2_latest_round_requested WHERE ocr2_oracle_spec_id = $1 diff --git a/core/services/relay/evm/request_round_db_test.go b/core/services/relay/evm/request_round_db_test.go index d10d6a41a61..10932c4e229 100644 --- a/core/services/relay/evm/request_round_db_test.go +++ b/core/services/relay/evm/request_round_db_test.go @@ -12,7 +12,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/testhelpers" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm" ) @@ -23,8 +22,8 @@ func Test_DB_LatestRoundRequested(t *testing.T) { require.NoError(t, err) lggr := logger.TestLogger(t) - db := evm.NewRoundRequestedDB(sqlDB.DB, 1, lggr) - db2 := evm.NewRoundRequestedDB(sqlDB.DB, 2, lggr) + db := evm.NewRoundRequestedDB(sqlDB, 1, lggr) + db2 := evm.NewRoundRequestedDB(sqlDB, 2, lggr) rawLog := cltest.LogFromFixture(t, "../../../testdata/jsonrpc/round_requested_log_1_1.json") @@ -38,8 +37,8 @@ func Test_DB_LatestRoundRequested(t *testing.T) { t.Run("saves latest round requested", func(t *testing.T) { ctx := testutils.Context(t) - err := pg.SqlxTransaction(ctx, sqlDB, logger.TestLogger(t), func(q pg.Queryer) error { - return db.SaveLatestRoundRequested(q, rr) + err := db.Transact(ctx, func(tx evm.RequestRoundDB) error { + return tx.SaveLatestRoundRequested(ctx, rr) }) require.NoError(t, err) @@ -54,19 +53,20 @@ func Test_DB_LatestRoundRequested(t *testing.T) { Raw: rawLog, } - err = pg.SqlxTransaction(ctx, sqlDB, logger.TestLogger(t), func(q pg.Queryer) error { - return db.SaveLatestRoundRequested(q, rr) + err = db.Transact(ctx, func(tx evm.RequestRoundDB) error { + return tx.SaveLatestRoundRequested(ctx, rr) }) require.NoError(t, err) }) t.Run("loads latest round requested", func(t *testing.T) { + ctx := testutils.Context(t) // There is no round for db2 - lrr, err := db2.LoadLatestRoundRequested() + lrr, err := db2.LoadLatestRoundRequested(ctx) require.NoError(t, err) require.Equal(t, 0, int(lrr.Epoch)) - lrr, err = db.LoadLatestRoundRequested() + lrr, err = db.LoadLatestRoundRequested(ctx) require.NoError(t, err) assert.Equal(t, rr, lrr) diff --git a/core/services/relay/evm/request_round_tracker.go b/core/services/relay/evm/request_round_tracker.go index 1e77ce28089..bb39271f278 100644 --- a/core/services/relay/evm/request_round_tracker.go +++ b/core/services/relay/evm/request_round_tracker.go @@ -9,19 +9,17 @@ import ( gethTypes "github.com/ethereum/go-ethereum/core/types" "github.com/pkg/errors" - "github.com/jmoiron/sqlx" - "github.com/smartcontractkit/libocr/gethwrappers2/ocr2aggregator" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/log" offchain_aggregator_wrapper "github.com/smartcontractkit/chainlink/v2/core/internal/gethwrappers2/generated/offchainaggregator" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocrcommon" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) // RequestRoundTracker subscribes to new request round logs. @@ -35,7 +33,7 @@ type RequestRoundTracker struct { jobID int32 lggr logger.SugaredLogger odb RequestRoundDB - q pg.Q + ds sqlutil.DataSource blockTranslator ocrcommon.BlockTranslator // Start/Stop lifecycle @@ -56,10 +54,9 @@ func NewRequestRoundTracker( logBroadcaster log.Broadcaster, jobID int32, lggr logger.Logger, - db *sqlx.DB, + ds sqlutil.DataSource, odb RequestRoundDB, chain ocrcommon.Config, - qConfig pg.QConfig, ) (o *RequestRoundTracker) { ctx, cancel := context.WithCancel(context.Background()) return &RequestRoundTracker{ @@ -70,7 +67,7 @@ func NewRequestRoundTracker( jobID: jobID, lggr: logger.Sugared(lggr), odb: odb, - q: pg.NewQ(db, lggr, qConfig), + ds: ds, blockTranslator: ocrcommon.NewBlockTranslator(chain, ethClient, lggr), ctx: ctx, ctxCancel: cancel, @@ -79,9 +76,9 @@ func NewRequestRoundTracker( // Start must be called before logs can be delivered // It ought to be called before starting OCR -func (t *RequestRoundTracker) Start() error { +func (t *RequestRoundTracker) Start(ctx context.Context) error { return t.StartOnce("RequestRoundTracker", func() (err error) { - t.latestRoundRequested, err = t.odb.LoadLatestRoundRequested() + t.latestRoundRequested, err = t.odb.LoadLatestRoundRequested(ctx) if err != nil { return errors.Wrap(err, "RequestRoundTracker#Start: failed to load latest round requested") } @@ -141,8 +138,9 @@ func (t *RequestRoundTracker) HandleLog(lb log.Broadcast) { return } if IsLaterThan(raw, t.latestRoundRequested.Raw) { - err = t.q.Transaction(func(q pg.Queryer) error { - if err = t.odb.SaveLatestRoundRequested(q, *rr); err != nil { + ctx := context.TODO() //TODO https://smartcontract-it.atlassian.net/browse/BCF-2887 + err = t.odb.Transact(ctx, func(tx RequestRoundDB) error { + if err = tx.SaveLatestRoundRequested(ctx, *rr); err != nil { return err } return t.logBroadcaster.MarkConsumed(t.ctx, lb) diff --git a/core/services/relay/evm/request_round_tracker_test.go b/core/services/relay/evm/request_round_tracker_test.go index cb2ee2a8d72..324b76dc6de 100644 --- a/core/services/relay/evm/request_round_tracker_test.go +++ b/core/services/relay/evm/request_round_tracker_test.go @@ -93,7 +93,6 @@ func newContractTrackerUni(t *testing.T, opts ...interface{}) (uni contractTrack db, uni.db, chain.EVM(), - chain.Database(), ) return uni @@ -174,6 +173,12 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.MatchedBy(func(rr ocr2aggregator.OCR2AggregatorRoundRequested) bool { return rr.Epoch == 1 && rr.Round == 1 })).Return(nil) + transact := uni.db.On("Transact", mock.Anything, mock.Anything) + transact.Run(func(args mock.Arguments) { + fn := args[1].(func(evm.RequestRoundDB) error) + err2 := fn(uni.db) + transact.ReturnArguments = []any{err2} + }) uni.requestRoundTracker.HandleLog(logBroadcast) @@ -245,6 +250,12 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything).Return(errors.New("something exploded")) + transact := uni.db.On("Transact", mock.Anything, mock.Anything) + transact.Run(func(args mock.Arguments) { + fn := args[1].(func(evm.RequestRoundDB) error) + err := fn(uni.db) + transact.ReturnArguments = []any{err} + }) uni.requestRoundTracker.HandleLog(logBroadcast) @@ -271,9 +282,10 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin uni.lb.On("Register", uni.requestRoundTracker, mock.Anything).Return(func() { eventuallyCloseLogBroadcaster.ItHappened() }) uni.lb.On("IsConnected").Return(true).Maybe() - uni.db.On("LoadLatestRoundRequested").Return(rr, nil) + uni.db.On("LoadLatestRoundRequested", mock.Anything).Return(rr, nil) - require.NoError(t, uni.requestRoundTracker.Start()) + ctx := testutils.Context(t) + require.NoError(t, uni.requestRoundTracker.Start(ctx)) configDigest, epoch, round, err := uni.requestRoundTracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err)