diff --git a/core/chains/evm/logpoller/disabled.go b/core/chains/evm/logpoller/disabled.go index 80a5550b0a5..602bd5cec27 100644 --- a/core/chains/evm/logpoller/disabled.go +++ b/core/chains/evm/logpoller/disabled.go @@ -37,6 +37,8 @@ func (disabled) RegisterFilter(filter Filter, qopts ...pg.QOpt) error { return E func (disabled) UnregisterFilter(name string, qopts ...pg.QOpt) error { return ErrDisabled } +func (disabled) HasFilter(name string) bool { return false } + func (disabled) LatestBlock(qopts ...pg.QOpt) (int64, error) { return -1, ErrDisabled } func (disabled) GetBlocksRange(ctx context.Context, numbers []uint64, qopts ...pg.QOpt) ([]LogPollerBlock, error) { diff --git a/core/chains/evm/logpoller/log_poller.go b/core/chains/evm/logpoller/log_poller.go index 60767970289..69a3143859f 100644 --- a/core/chains/evm/logpoller/log_poller.go +++ b/core/chains/evm/logpoller/log_poller.go @@ -36,6 +36,7 @@ type LogPoller interface { ReplayAsync(fromBlock int64) RegisterFilter(filter Filter, qopts ...pg.QOpt) error UnregisterFilter(name string, qopts ...pg.QOpt) error + HasFilter(name string) bool LatestBlock(qopts ...pg.QOpt) (int64, error) GetBlocksRange(ctx context.Context, numbers []uint64, qopts ...pg.QOpt) ([]LogPollerBlock, error) @@ -258,6 +259,15 @@ func (lp *logPoller) UnregisterFilter(name string, qopts ...pg.QOpt) error { return nil } +// HasFilter returns true if the log poller has an active filter with the given name. +func (lp *logPoller) HasFilter(name string) bool { + lp.filterMu.RLock() + defer lp.filterMu.RUnlock() + + _, ok := lp.filters[name] + return ok +} + func (lp *logPoller) Filter(from, to *big.Int, bh *common.Hash) ethereum.FilterQuery { lp.filterMu.Lock() defer lp.filterMu.Unlock() diff --git a/core/chains/evm/logpoller/log_poller_test.go b/core/chains/evm/logpoller/log_poller_test.go index 17df4dc6291..c434d18c999 100644 --- a/core/chains/evm/logpoller/log_poller_test.go +++ b/core/chains/evm/logpoller/log_poller_test.go @@ -760,6 +760,13 @@ func TestLogPoller_LoadFilters(t *testing.T) { require.True(t, ok) assert.True(t, filter.Contains(&filter3)) assert.True(t, filter3.Contains(&filter)) + + t.Run("HasFilter", func(t *testing.T) { + assert.True(t, th.LogPoller.HasFilter("first Filter")) + assert.True(t, th.LogPoller.HasFilter("second Filter")) + assert.True(t, th.LogPoller.HasFilter("third Filter")) + assert.False(t, th.LogPoller.HasFilter("fourth Filter")) + }) } func TestLogPoller_GetBlocks_Range(t *testing.T) { diff --git a/core/chains/evm/logpoller/mocks/log_poller.go b/core/chains/evm/logpoller/mocks/log_poller.go index 9e96947abc7..3c30454a4f7 100644 --- a/core/chains/evm/logpoller/mocks/log_poller.go +++ b/core/chains/evm/logpoller/mocks/log_poller.go @@ -68,6 +68,20 @@ func (_m *LogPoller) GetBlocksRange(ctx context.Context, numbers []uint64, qopts return r0, r1 } +// HasFilter provides a mock function with given fields: name +func (_m *LogPoller) HasFilter(name string) bool { + ret := _m.Called(name) + + var r0 bool + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(name) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + // HealthReport provides a mock function with given fields: func (_m *LogPoller) HealthReport() map[string]error { ret := _m.Called()