Skip to content

Commit

Permalink
use services.Config.NewService/Engine (part 3) (#14170)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmank88 authored Aug 26, 2024
1 parent 95ae744 commit 5070cf8
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 184 deletions.
58 changes: 27 additions & 31 deletions common/client/poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package client

import (
"context"
"sync"
"time"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
Expand All @@ -15,83 +14,80 @@ import (
// and delivers the result to a channel. It is used by multinode to poll
// for new heads and implements the Subscription interface.
type Poller[T any] struct {
services.StateMachine
services.Service
eng *services.Engine

pollingInterval time.Duration
pollingFunc func(ctx context.Context) (T, error)
pollingTimeout time.Duration
logger logger.Logger
channel chan<- T
errCh chan error

stopCh services.StopChan
wg sync.WaitGroup
}

// NewPoller creates a new Poller instance and returns a channel to receive the polled data
func NewPoller[
T any,
](pollingInterval time.Duration, pollingFunc func(ctx context.Context) (T, error), pollingTimeout time.Duration, logger logger.Logger) (Poller[T], <-chan T) {
](pollingInterval time.Duration, pollingFunc func(ctx context.Context) (T, error), pollingTimeout time.Duration, lggr logger.Logger) (Poller[T], <-chan T) {
channel := make(chan T)
return Poller[T]{
p := Poller[T]{
pollingInterval: pollingInterval,
pollingFunc: pollingFunc,
pollingTimeout: pollingTimeout,
channel: channel,
logger: logger,
errCh: make(chan error),
stopCh: make(chan struct{}),
}, channel
}
p.Service, p.eng = services.Config{
Name: "Poller",
Start: p.start,
Close: p.close,
}.NewServiceEngine(lggr)
return p, channel
}

var _ types.Subscription = &Poller[any]{}

func (p *Poller[T]) Start() error {
return p.StartOnce("Poller", func() error {
p.wg.Add(1)
go p.pollingLoop()
return nil
})
func (p *Poller[T]) start(ctx context.Context) error {
p.eng.Go(p.pollingLoop)
return nil
}

// Unsubscribe cancels the sending of events to the data channel
func (p *Poller[T]) Unsubscribe() {
_ = p.StopOnce("Poller", func() error {
close(p.stopCh)
p.wg.Wait()
close(p.errCh)
close(p.channel)
return nil
})
_ = p.Close()
}

func (p *Poller[T]) close() error {
close(p.errCh)
close(p.channel)
return nil
}

func (p *Poller[T]) Err() <-chan error {
return p.errCh
}

func (p *Poller[T]) pollingLoop() {
defer p.wg.Done()

func (p *Poller[T]) pollingLoop(ctx context.Context) {
ticker := time.NewTicker(p.pollingInterval)
defer ticker.Stop()

for {
select {
case <-p.stopCh:
case <-ctx.Done():
return
case <-ticker.C:
// Set polling timeout
pollingCtx, cancelPolling := p.stopCh.CtxCancel(context.WithTimeout(context.Background(), p.pollingTimeout))
pollingCtx, cancelPolling := context.WithTimeout(ctx, p.pollingTimeout)
// Execute polling function
result, err := p.pollingFunc(pollingCtx)
cancelPolling()
if err != nil {
p.logger.Warnf("polling error: %v", err)
p.eng.Warnf("polling error: %v", err)
continue
}
// Send result to channel or block if channel is full
select {
case p.channel <- result:
case <-p.stopCh:
case <-ctx.Done():
return
}
}
Expand Down
23 changes: 15 additions & 8 deletions common/client/poller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,22 @@ func Test_Poller(t *testing.T) {
lggr := logger.Test(t)

t.Run("Test multiple start", func(t *testing.T) {
ctx := tests.Context(t)
pollFunc := func(ctx context.Context) (Head, error) {
return nil, nil
}

poller, _ := NewPoller[Head](time.Millisecond, pollFunc, time.Second, lggr)
err := poller.Start()
err := poller.Start(ctx)
require.NoError(t, err)

err = poller.Start()
err = poller.Start(ctx)
require.Error(t, err)
poller.Unsubscribe()
})

t.Run("Test polling for heads", func(t *testing.T) {
ctx := tests.Context(t)
// Mock polling function that returns a new value every time it's called
var pollNumber int
pollLock := sync.Mutex{}
Expand All @@ -50,7 +52,7 @@ func Test_Poller(t *testing.T) {

// Create poller and start to receive data
poller, channel := NewPoller[Head](time.Millisecond, pollFunc, time.Second, lggr)
require.NoError(t, poller.Start())
require.NoError(t, poller.Start(ctx))
defer poller.Unsubscribe()

// Receive updates from the poller
Expand All @@ -63,6 +65,7 @@ func Test_Poller(t *testing.T) {
})

t.Run("Test polling errors", func(t *testing.T) {
ctx := tests.Context(t)
// Mock polling function that returns an error
var pollNumber int
pollLock := sync.Mutex{}
Expand All @@ -77,7 +80,7 @@ func Test_Poller(t *testing.T) {

// Create poller and subscribe to receive data
poller, _ := NewPoller[Head](time.Millisecond, pollFunc, time.Second, olggr)
require.NoError(t, poller.Start())
require.NoError(t, poller.Start(ctx))
defer poller.Unsubscribe()

// Ensure that all errors were logged as expected
Expand All @@ -94,6 +97,7 @@ func Test_Poller(t *testing.T) {
})

t.Run("Test polling timeout", func(t *testing.T) {
ctx := tests.Context(t)
pollFunc := func(ctx context.Context) (Head, error) {
if <-ctx.Done(); true {
return nil, ctx.Err()
Expand All @@ -108,7 +112,7 @@ func Test_Poller(t *testing.T) {

// Create poller and subscribe to receive data
poller, _ := NewPoller[Head](time.Millisecond, pollFunc, pollingTimeout, olggr)
require.NoError(t, poller.Start())
require.NoError(t, poller.Start(ctx))
defer poller.Unsubscribe()

// Ensure that timeout errors were logged as expected
Expand All @@ -119,6 +123,7 @@ func Test_Poller(t *testing.T) {
})

t.Run("Test unsubscribe during polling", func(t *testing.T) {
ctx := tests.Context(t)
wait := make(chan struct{})
closeOnce := sync.OnceFunc(func() { close(wait) })
pollFunc := func(ctx context.Context) (Head, error) {
Expand All @@ -137,7 +142,7 @@ func Test_Poller(t *testing.T) {

// Create poller and subscribe to receive data
poller, _ := NewPoller[Head](time.Millisecond, pollFunc, pollingTimeout, olggr)
require.NoError(t, poller.Start())
require.NoError(t, poller.Start(ctx))

// Unsubscribe while blocked in polling function
<-wait
Expand Down Expand Up @@ -167,8 +172,9 @@ func Test_Poller_Unsubscribe(t *testing.T) {
}

t.Run("Test multiple unsubscribe", func(t *testing.T) {
ctx := tests.Context(t)
poller, channel := NewPoller[Head](time.Millisecond, pollFunc, time.Second, lggr)
err := poller.Start()
err := poller.Start(ctx)
require.NoError(t, err)

<-channel
Expand All @@ -177,8 +183,9 @@ func Test_Poller_Unsubscribe(t *testing.T) {
})

t.Run("Read channel after unsubscribe", func(t *testing.T) {
ctx := tests.Context(t)
poller, channel := NewPoller[Head](time.Millisecond, pollFunc, time.Second, lggr)
err := poller.Start()
err := poller.Start(ctx)
require.NoError(t, err)

poller.Unsubscribe()
Expand Down
74 changes: 27 additions & 47 deletions core/capabilities/remote/target/endtoend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,67 +276,47 @@ func testRemoteTarget(ctx context.Context, t *testing.T, underlying commoncap.Ta
}

type testAsyncMessageBroker struct {
services.StateMachine
t *testing.T
services.Service
eng *services.Engine
t *testing.T

nodes map[p2ptypes.PeerID]remotetypes.Receiver

sendCh chan *remotetypes.MessageBody

stopCh services.StopChan
wg sync.WaitGroup
}

func (a *testAsyncMessageBroker) HealthReport() map[string]error {
return nil
}

func (a *testAsyncMessageBroker) Name() string {
return "testAsyncMessageBroker"
}

func newTestAsyncMessageBroker(t *testing.T, sendChBufferSize int) *testAsyncMessageBroker {
return &testAsyncMessageBroker{
b := &testAsyncMessageBroker{
t: t,
nodes: make(map[p2ptypes.PeerID]remotetypes.Receiver),
stopCh: make(services.StopChan),
sendCh: make(chan *remotetypes.MessageBody, sendChBufferSize),
}
}

func (a *testAsyncMessageBroker) Start(ctx context.Context) error {
return a.StartOnce("testAsyncMessageBroker", func() error {
a.wg.Add(1)
go func() {
defer a.wg.Done()

for {
select {
case <-a.stopCh:
return
case msg := <-a.sendCh:
receiverId := toPeerID(msg.Receiver)

receiver, ok := a.nodes[receiverId]
if !ok {
panic("server not found for peer id")
}

receiver.Receive(tests.Context(a.t), msg)
b.Service, b.eng = services.Config{
Name: "testAsyncMessageBroker",
Start: b.start,
}.NewServiceEngine(logger.TestLogger(t))
return b
}

func (a *testAsyncMessageBroker) start(ctx context.Context) error {
a.eng.Go(func(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case msg := <-a.sendCh:
receiverId := toPeerID(msg.Receiver)

receiver, ok := a.nodes[receiverId]
if !ok {
panic("server not found for peer id")
}
}
}()
return nil
})
}

func (a *testAsyncMessageBroker) Close() error {
return a.StopOnce("testAsyncMessageBroker", func() error {
close(a.stopCh)

a.wg.Wait()
return nil
receiver.Receive(tests.Context(a.t), msg)
}
}
})
return nil
}

func (a *testAsyncMessageBroker) NewDispatcherForNode(nodePeerID p2ptypes.PeerID) remotetypes.Dispatcher {
Expand Down
4 changes: 2 additions & 2 deletions core/chains/evm/client/rpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,14 +546,14 @@ func (r *rpcClient) SubscribeToHeads(ctx context.Context) (ch <-chan *evmtypes.H
return channel, forwarder, err
}

func (r *rpcClient) SubscribeToFinalizedHeads(_ context.Context) (<-chan *evmtypes.Head, commontypes.Subscription, error) {
func (r *rpcClient) SubscribeToFinalizedHeads(ctx context.Context) (<-chan *evmtypes.Head, commontypes.Subscription, error) {
interval := r.finalizedBlockPollInterval
if interval == 0 {
return nil, nil, errors.New("FinalizedBlockPollInterval is 0")
}
timeout := interval
poller, channel := commonclient.NewPoller[*evmtypes.Head](interval, r.LatestFinalizedBlock, timeout, r.rpcLog)
if err := poller.Start(); err != nil {
if err := poller.Start(ctx); err != nil {
return nil, nil, err
}
return channel, &poller, nil
Expand Down
Loading

0 comments on commit 5070cf8

Please sign in to comment.