Skip to content

Commit

Permalink
Fix race condition in Poller tests (#13110)
Browse files Browse the repository at this point in the history
* Create polling channel

* Update poller_test.go
  • Loading branch information
DylanTinianov authored May 6, 2024
1 parent 0955d46 commit 466d161
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 34 deletions.
8 changes: 5 additions & 3 deletions common/client/poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ type Poller[T any] struct {
wg sync.WaitGroup
}

// NewPoller creates a new Poller instance
// NewPoller creates a new Poller instance and returns a channel to receive the polled data
func NewPoller[
T any,
](pollingInterval time.Duration, pollingFunc func(ctx context.Context) (T, error), pollingTimeout time.Duration, channel chan<- T, logger logger.Logger) Poller[T] {
](pollingInterval time.Duration, pollingFunc func(ctx context.Context) (T, error), pollingTimeout time.Duration, logger logger.Logger) (Poller[T], <-chan T) {
channel := make(chan T)
return Poller[T]{
pollingInterval: pollingInterval,
pollingFunc: pollingFunc,
Expand All @@ -39,7 +40,7 @@ func NewPoller[
logger: logger,
errCh: make(chan error),
stopCh: make(chan struct{}),
}
}, channel
}

var _ types.Subscription = &Poller[any]{}
Expand All @@ -58,6 +59,7 @@ func (p *Poller[T]) Unsubscribe() {
close(p.stopCh)
p.wg.Wait()
close(p.errCh)
close(p.channel)
return nil
})
}
Expand Down
40 changes: 9 additions & 31 deletions common/client/poller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ func Test_Poller(t *testing.T) {
return nil, nil
}

channel := make(chan Head, 1)
defer close(channel)

poller := NewPoller[Head](time.Millisecond, pollFunc, time.Second, channel, lggr)
poller, _ := NewPoller[Head](time.Millisecond, pollFunc, time.Second, lggr)
err := poller.Start()
require.NoError(t, err)

Expand All @@ -50,12 +47,8 @@ func Test_Poller(t *testing.T) {
return h.ToMockHead(t), nil
}

// data channel to receive updates from the poller
channel := make(chan Head, 1)
defer close(channel)

// Create poller and start to receive data
poller := NewPoller[Head](time.Millisecond, pollFunc, time.Second, channel, lggr)
poller, channel := NewPoller[Head](time.Millisecond, pollFunc, time.Second, lggr)
require.NoError(t, poller.Start())
defer poller.Unsubscribe()

Expand All @@ -79,14 +72,10 @@ func Test_Poller(t *testing.T) {
return nil, fmt.Errorf("polling error %d", pollNumber)
}

// data channel to receive updates from the poller
channel := make(chan Head, 1)
defer close(channel)

olggr, observedLogs := logger.TestObserved(t, zap.WarnLevel)

// Create poller and subscribe to receive data
poller := NewPoller[Head](time.Millisecond, pollFunc, time.Second, channel, olggr)
poller, _ := NewPoller[Head](time.Millisecond, pollFunc, time.Second, olggr)
require.NoError(t, poller.Start())
defer poller.Unsubscribe()

Expand Down Expand Up @@ -114,14 +103,10 @@ func Test_Poller(t *testing.T) {
// Set instant timeout
pollingTimeout := time.Duration(0)

// data channel to receive updates from the poller
channel := make(chan Head, 1)
defer close(channel)

olggr, observedLogs := logger.TestObserved(t, zap.WarnLevel)

// Create poller and subscribe to receive data
poller := NewPoller[Head](time.Millisecond, pollFunc, pollingTimeout, channel, olggr)
poller, _ := NewPoller[Head](time.Millisecond, pollFunc, pollingTimeout, olggr)
require.NoError(t, poller.Start())
defer poller.Unsubscribe()

Expand All @@ -146,14 +131,10 @@ func Test_Poller(t *testing.T) {
// Set long timeout
pollingTimeout := time.Minute

// data channel to receive updates from the poller
channel := make(chan Head, 1)
defer close(channel)

olggr, observedLogs := logger.TestObserved(t, zap.WarnLevel)

// Create poller and subscribe to receive data
poller := NewPoller[Head](time.Millisecond, pollFunc, pollingTimeout, channel, olggr)
poller, _ := NewPoller[Head](time.Millisecond, pollFunc, pollingTimeout, olggr)
require.NoError(t, poller.Start())

// Unsubscribe while blocked in polling function
Expand Down Expand Up @@ -184,8 +165,7 @@ func Test_Poller_Unsubscribe(t *testing.T) {
}

t.Run("Test multiple unsubscribe", func(t *testing.T) {
channel := make(chan Head, 1)
poller := NewPoller[Head](time.Millisecond, pollFunc, time.Second, channel, lggr)
poller, channel := NewPoller[Head](time.Millisecond, pollFunc, time.Second, lggr)
err := poller.Start()
require.NoError(t, err)

Expand All @@ -194,14 +174,12 @@ func Test_Poller_Unsubscribe(t *testing.T) {
poller.Unsubscribe()
})

t.Run("Test unsubscribe with closed channel", func(t *testing.T) {
channel := make(chan Head, 1)
poller := NewPoller[Head](time.Millisecond, pollFunc, time.Second, channel, lggr)
t.Run("Read channel after unsubscribe", func(t *testing.T) {
poller, channel := NewPoller[Head](time.Millisecond, pollFunc, time.Second, lggr)
err := poller.Start()
require.NoError(t, err)

<-channel
close(channel)
poller.Unsubscribe()
require.Equal(t, <-channel, nil)
})
}

0 comments on commit 466d161

Please sign in to comment.