-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2f60dbe
commit 3da9770
Showing
2 changed files
with
162 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
package client | ||
|
||
import ( | ||
"context" | ||
"sync" | ||
"time" | ||
|
||
"github.com/smartcontractkit/chainlink-common/pkg/logger" | ||
"github.com/smartcontractkit/chainlink-common/pkg/services" | ||
) | ||
|
||
// PollingTransformer is a component that polls a function at a given interval | ||
// and delivers the result to subscribers | ||
type PollingTransformer[HEAD Head] struct { | ||
interval time.Duration | ||
pollFunc func() (HEAD, error) | ||
|
||
logger logger.Logger | ||
|
||
subscribers []chan<- HEAD | ||
|
||
isPolling bool | ||
stopCh services.StopChan | ||
wg sync.WaitGroup | ||
} | ||
|
||
func NewPollingTransformer[HEAD Head](pollInterval time.Duration, pollFunc func() (HEAD, error), logger logger.Logger) *PollingTransformer[HEAD] { | ||
return &PollingTransformer[HEAD]{ | ||
interval: pollInterval, | ||
pollFunc: pollFunc, | ||
logger: logger, | ||
isPolling: false, | ||
} | ||
} | ||
|
||
// Subscribe adds a Subscriber to the polling transformer | ||
func (pt *PollingTransformer[HEAD]) Subscribe(sub chan<- HEAD) { | ||
pt.subscribers = append(pt.subscribers, sub) | ||
} | ||
|
||
// Unsubscribe removes a Subscriber from the polling transformer | ||
func (pt *PollingTransformer[HEAD]) Unsubscribe(sub chan<- HEAD) { | ||
for i, s := range pt.subscribers { | ||
if s == sub { | ||
close(s) | ||
pt.subscribers = append(pt.subscribers[:i], pt.subscribers[i+1:]...) | ||
return | ||
} | ||
} | ||
} | ||
|
||
// StartPolling starts the polling loop and delivers the polled value to subscribers | ||
func (pt *PollingTransformer[HEAD]) StartPolling() { | ||
pt.stopCh = make(chan struct{}) | ||
pt.wg.Add(1) | ||
go pt.pollingLoop(pt.stopCh.NewCtx()) | ||
pt.isPolling = true | ||
} | ||
|
||
// pollingLoop polls the pollFunc at the given interval and delivers the result to subscribers | ||
func (pt *PollingTransformer[HEAD]) pollingLoop(ctx context.Context, cancel context.CancelFunc) { | ||
defer pt.wg.Done() | ||
defer cancel() | ||
|
||
pollT := time.NewTicker(pt.interval) | ||
defer pollT.Stop() | ||
|
||
for { | ||
select { | ||
case <-ctx.Done(): | ||
for _, subscriber := range pt.subscribers { | ||
close(subscriber) | ||
} | ||
return | ||
case <-pollT.C: | ||
head, err := pt.pollFunc() | ||
if err != nil { | ||
// TODO: handle error | ||
} | ||
pt.logger.Debugw("PollingTransformer: polled value", "head", head) | ||
for _, subscriber := range pt.subscribers { | ||
select { | ||
case subscriber <- head: | ||
// Successfully sent head | ||
default: | ||
// Subscriber's channel is closed | ||
pt.Unsubscribe(subscriber) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
// StopPolling stops the polling loop | ||
func (pt *PollingTransformer[HEAD]) StopPolling() { | ||
close(pt.stopCh) | ||
pt.wg.Wait() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
package client | ||
|
||
import ( | ||
"math/big" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/require" | ||
|
||
"github.com/smartcontractkit/chainlink-common/pkg/logger" | ||
) | ||
|
||
type TestHead struct { | ||
blockNumber int64 | ||
} | ||
|
||
var _ Head = &TestHead{} | ||
|
||
func (th *TestHead) BlockNumber() int64 { | ||
return th.blockNumber | ||
} | ||
|
||
func (th *TestHead) BlockDifficulty() *big.Int { | ||
return nil | ||
} | ||
|
||
func (th *TestHead) IsValid() bool { | ||
return true | ||
} | ||
|
||
func Test_Polling_Transformer(t *testing.T) { | ||
t.Parallel() | ||
|
||
// Mock polling function that returns a new value every time it's called | ||
var lastBlockNumber int64 | ||
pollFunc := func() (Head, error) { | ||
lastBlockNumber++ | ||
return &TestHead{lastBlockNumber}, nil | ||
} | ||
|
||
pt := NewPollingTransformer(time.Millisecond, pollFunc, logger.Test(t)) | ||
pt.StartPolling() | ||
defer pt.StopPolling() | ||
|
||
// Create a subscriber channel | ||
subscriber := make(chan Head) | ||
pt.Subscribe(subscriber) | ||
defer pt.Unsubscribe(subscriber) | ||
|
||
// Create a goroutine to receive updates from the subscriber | ||
pollCount := 0 | ||
pollMax := 50 | ||
go func() { | ||
for i := 0; i < pollMax; i++ { | ||
value := <-subscriber | ||
pollCount++ | ||
require.Equal(t, int64(pollCount), value.BlockNumber()) | ||
} | ||
}() | ||
|
||
// Wait for a short duration to allow for some polling iterations | ||
time.Sleep(100 * time.Millisecond) | ||
require.Equal(t, pollMax, pollCount) | ||
} |