diff --git a/common/client/polling_transformer.go b/common/client/polling_transformer.go new file mode 100644 index 00000000000..ed0261a852f --- /dev/null +++ b/common/client/polling_transformer.go @@ -0,0 +1,98 @@ +package client + +import ( + "context" + "sync" + "time" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" +) + +// PollingTransformer is a component that polls a function at a given interval +// and delivers the result to subscribers +type PollingTransformer[HEAD Head] struct { + interval time.Duration + pollFunc func() (HEAD, error) + + logger logger.Logger + + subscribers []chan<- HEAD + + isPolling bool + stopCh services.StopChan + wg sync.WaitGroup +} + +func NewPollingTransformer[HEAD Head](pollInterval time.Duration, pollFunc func() (HEAD, error), logger logger.Logger) *PollingTransformer[HEAD] { + return &PollingTransformer[HEAD]{ + interval: pollInterval, + pollFunc: pollFunc, + logger: logger, + isPolling: false, + } +} + +// Subscribe adds a Subscriber to the polling transformer +func (pt *PollingTransformer[HEAD]) Subscribe(sub chan<- HEAD) { + pt.subscribers = append(pt.subscribers, sub) +} + +// Unsubscribe removes a Subscriber from the polling transformer +func (pt *PollingTransformer[HEAD]) Unsubscribe(sub chan<- HEAD) { + for i, s := range pt.subscribers { + if s == sub { + close(s) + pt.subscribers = append(pt.subscribers[:i], pt.subscribers[i+1:]...) + return + } + } +} + +// StartPolling starts the polling loop and delivers the polled value to subscribers +func (pt *PollingTransformer[HEAD]) StartPolling() { + pt.stopCh = make(chan struct{}) + pt.wg.Add(1) + go pt.pollingLoop(pt.stopCh.NewCtx()) + pt.isPolling = true +} + +// pollingLoop polls the pollFunc at the given interval and delivers the result to subscribers +func (pt *PollingTransformer[HEAD]) pollingLoop(ctx context.Context, cancel context.CancelFunc) { + defer pt.wg.Done() + defer cancel() + + pollT := time.NewTicker(pt.interval) + defer pollT.Stop() + + for { + select { + case <-ctx.Done(): + for _, subscriber := range pt.subscribers { + close(subscriber) + } + return + case <-pollT.C: + head, err := pt.pollFunc() + if err != nil { + // TODO: handle error + } + pt.logger.Debugw("PollingTransformer: polled value", "head", head) + for _, subscriber := range pt.subscribers { + select { + case subscriber <- head: + // Successfully sent head + default: + // Subscriber's channel is closed + pt.Unsubscribe(subscriber) + } + } + } + } +} + +// StopPolling stops the polling loop +func (pt *PollingTransformer[HEAD]) StopPolling() { + close(pt.stopCh) + pt.wg.Wait() +} diff --git a/common/client/polling_transformer_test.go b/common/client/polling_transformer_test.go new file mode 100644 index 00000000000..f285f6f6abe --- /dev/null +++ b/common/client/polling_transformer_test.go @@ -0,0 +1,64 @@ +package client + +import ( + "math/big" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +type TestHead struct { + blockNumber int64 +} + +var _ Head = &TestHead{} + +func (th *TestHead) BlockNumber() int64 { + return th.blockNumber +} + +func (th *TestHead) BlockDifficulty() *big.Int { + return nil +} + +func (th *TestHead) IsValid() bool { + return true +} + +func Test_Polling_Transformer(t *testing.T) { + t.Parallel() + + // Mock polling function that returns a new value every time it's called + var lastBlockNumber int64 + pollFunc := func() (Head, error) { + lastBlockNumber++ + return &TestHead{lastBlockNumber}, nil + } + + pt := NewPollingTransformer(time.Millisecond, pollFunc, logger.Test(t)) + pt.StartPolling() + defer pt.StopPolling() + + // Create a subscriber channel + subscriber := make(chan Head) + pt.Subscribe(subscriber) + defer pt.Unsubscribe(subscriber) + + // Create a goroutine to receive updates from the subscriber + pollCount := 0 + pollMax := 50 + go func() { + for i := 0; i < pollMax; i++ { + value := <-subscriber + pollCount++ + require.Equal(t, int64(pollCount), value.BlockNumber()) + } + }() + + // Wait for a short duration to allow for some polling iterations + time.Sleep(100 * time.Millisecond) + require.Equal(t, pollMax, pollCount) +}