diff --git a/core/chains/evm/monitor/balance.go b/core/chains/evm/monitor/balance.go index 8dbd4ed8507..28682f2509a 100644 --- a/core/chains/evm/monitor/balance.go +++ b/core/chains/evm/monitor/balance.go @@ -47,8 +47,10 @@ type ( NullBalanceMonitor struct{} ) +var _ BalanceMonitor = (*balanceMonitor)(nil) + // NewBalanceMonitor returns a new balanceMonitor -func NewBalanceMonitor(ethClient evmclient.Client, ethKeyStore keystore.Eth, logger logger.Logger) BalanceMonitor { +func NewBalanceMonitor(ethClient evmclient.Client, ethKeyStore keystore.Eth, logger logger.Logger) *balanceMonitor { chainId := ethClient.ConfiguredChainID() bm := &balanceMonitor{ utils.StartStopOnce{}, diff --git a/core/chains/evm/monitor/balance_helpers_test.go b/core/chains/evm/monitor/balance_helpers_test.go new file mode 100644 index 00000000000..624aa69f061 --- /dev/null +++ b/core/chains/evm/monitor/balance_helpers_test.go @@ -0,0 +1,7 @@ +package monitor + +func (bm *balanceMonitor) WorkDone() <-chan struct{} { + return bm.sleeperTask.(interface { + WorkDone() <-chan struct{} + }).WorkDone() +} diff --git a/core/chains/evm/monitor/balance_test.go b/core/chains/evm/monitor/balance_test.go index c908c395671..d6417381815 100644 --- a/core/chains/evm/monitor/balance_test.go +++ b/core/chains/evm/monitor/balance_test.go @@ -169,12 +169,9 @@ func TestBalanceMonitor_OnNewLongestChain_UpdatesBalance(t *testing.T) { // Do the thing bm.OnNewLongestChain(testutils.Context(t), head) - gomega.NewWithT(t).Eventually(func() *big.Int { - return bm.GetEthBalance(k0Addr).ToInt() - }).Should(gomega.Equal(k0bal)) - gomega.NewWithT(t).Eventually(func() *big.Int { - return bm.GetEthBalance(k1Addr).ToInt() - }).Should(gomega.Equal(k1bal)) + <-bm.WorkDone() + assert.Equal(t, k0bal, bm.GetEthBalance(k0Addr).ToInt()) + assert.Equal(t, k1bal, bm.GetEthBalance(k1Addr).ToInt()) // Do it again k0bal2 := big.NewInt(142) @@ -187,12 +184,9 @@ func TestBalanceMonitor_OnNewLongestChain_UpdatesBalance(t *testing.T) { bm.OnNewLongestChain(testutils.Context(t), head) - gomega.NewWithT(t).Eventually(func() *big.Int { - return bm.GetEthBalance(k0Addr).ToInt() - }).Should(gomega.Equal(k0bal2)) - gomega.NewWithT(t).Eventually(func() *big.Int { - return bm.GetEthBalance(k1Addr).ToInt() - }).Should(gomega.Equal(k1bal2)) + <-bm.WorkDone() + assert.Equal(t, k0bal2, bm.GetEthBalance(k0Addr).ToInt()) + assert.Equal(t, k1bal2, bm.GetEthBalance(k1Addr).ToInt()) }) } diff --git a/core/sessions/reaper.go b/core/sessions/reaper.go index a80f0124bb6..c4f0ed6796c 100644 --- a/core/sessions/reaper.go +++ b/core/sessions/reaper.go @@ -13,10 +13,6 @@ type sessionReaper struct { db *sql.DB config SessionReaperConfig lggr logger.Logger - - // Receive from this for testing via sr.RunSignal() - // to be notified after each reaper run. - runSignal chan struct{} } type SessionReaperConfig interface { @@ -26,18 +22,11 @@ type SessionReaperConfig interface { // NewSessionReaper creates a reaper that cleans stale sessions from the store. func NewSessionReaper(db *sql.DB, config SessionReaperConfig, lggr logger.Logger) utils.SleeperTask { - return utils.NewSleeperTask(NewSessionReaperWorker(db, config, lggr)) -} - -func NewSessionReaperWorker(db *sql.DB, config SessionReaperConfig, lggr logger.Logger) *sessionReaper { - return &sessionReaper{ + return utils.NewSleeperTask(&sessionReaper{ db, config, lggr.Named("SessionReaper"), - - // For testing only. - make(chan struct{}, 10), - } + }) } func (sr *sessionReaper) Name() string { @@ -51,11 +40,6 @@ func (sr *sessionReaper) Work() { if err != nil { sr.lggr.Error("unable to reap stale sessions: ", err) } - - select { - case sr.runSignal <- struct{}{}: - default: - } } // DeleteStaleSessions deletes all sessions before the passed time. diff --git a/core/sessions/reaper_helper_test.go b/core/sessions/reaper_helper_test.go deleted file mode 100644 index cec9b72d7ee..00000000000 --- a/core/sessions/reaper_helper_test.go +++ /dev/null @@ -1,5 +0,0 @@ -package sessions - -func (sr *sessionReaper) RunSignal() <-chan struct{} { - return sr.runSignal -} diff --git a/core/sessions/reaper_test.go b/core/sessions/reaper_test.go index 1e325ea5063..a96c3822ef5 100644 --- a/core/sessions/reaper_test.go +++ b/core/sessions/reaper_test.go @@ -10,7 +10,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger/audit" "github.com/smartcontractkit/chainlink/v2/core/sessions" "github.com/smartcontractkit/chainlink/v2/core/store/models" - "github.com/smartcontractkit/chainlink/v2/core/utils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -34,8 +33,7 @@ func TestSessionReaper_ReapSessions(t *testing.T) { lggr := logger.TestLogger(t) orm := sessions.NewORM(db, config.SessionTimeout().Duration(), lggr, pgtest.NewQConfig(true), audit.NoopLogger) - rw := sessions.NewSessionReaperWorker(db.DB, config, lggr) - r := utils.NewSleeperTask(rw) + r := sessions.NewSessionReaper(db.DB, config, lggr) t.Cleanup(func() { assert.NoError(t, r.Stop()) @@ -70,7 +68,9 @@ func TestSessionReaper_ReapSessions(t *testing.T) { }) r.WakeUp() - <-rw.RunSignal() + <-r.(interface { + WorkDone() <-chan struct{} + }).WorkDone() sessions, err := orm.Sessions(0, 10) assert.NoError(t, err) diff --git a/core/utils/sleeper_task.go b/core/utils/sleeper_task.go index d020799a9c6..0b45507a82f 100644 --- a/core/utils/sleeper_task.go +++ b/core/utils/sleeper_task.go @@ -19,10 +19,11 @@ type Worker interface { } type sleeperTask struct { - worker Worker - chQueue chan struct{} - chStop chan struct{} - chDone chan struct{} + worker Worker + chQueue chan struct{} + chStop chan struct{} + chDone chan struct{} + chWorkDone chan struct{} StartStopOnce } @@ -36,10 +37,11 @@ type sleeperTask struct { // WakeUp does not block. func NewSleeperTask(worker Worker) SleeperTask { s := &sleeperTask{ - worker: worker, - chQueue: make(chan struct{}, 1), - chStop: make(chan struct{}), - chDone: make(chan struct{}), + worker: worker, + chQueue: make(chan struct{}, 1), + chStop: make(chan struct{}), + chDone: make(chan struct{}), + chWorkDone: make(chan struct{}, 10), } _ = s.StartOnce("SleeperTask-"+worker.Name(), func() error { @@ -83,6 +85,19 @@ func (s *sleeperTask) WakeUp() { } } +func (s *sleeperTask) workDone() { + select { + case s.chWorkDone <- struct{}{}: + default: + } +} + +// WorkDone isn't part of the SleeperTask interface, but can be +// useful in tests to assert that the work has been done. +func (s *sleeperTask) WorkDone() <-chan struct{} { + return s.chWorkDone +} + func (s *sleeperTask) workerLoop() { defer close(s.chDone) @@ -90,6 +105,7 @@ func (s *sleeperTask) workerLoop() { select { case <-s.chQueue: s.worker.Work() + s.workDone() case <-s.chStop: return }