diff --git a/datanode/candlesv2/candle_updates.go b/datanode/candlesv2/candle_updates.go index 136450c921..5b84b15297 100644 --- a/datanode/candlesv2/candle_updates.go +++ b/datanode/candlesv2/candle_updates.go @@ -83,7 +83,7 @@ func (s *CandleUpdates) run(ctx context.Context) { case <-ctx.Done(): return case subscriptionMsg := <-s.subscriptionMsgChan: - s.handleSubscription(subscriptions, subscriptionMsg, lastCandle) + subscriptions = s.handleSubscription(subscriptions, subscriptionMsg, lastCandle) case now := <-ticker.C: if len(subscriptions) == 0 { lastCandle = nil @@ -105,31 +105,37 @@ func (s *CandleUpdates) run(ctx context.Context) { lastCandle = &candles[len(candles)-1] } - s.sendCandlesToSubscribers(candles, subscriptions) + subscriptions = s.sendCandlesToSubscribers(candles, subscriptions) } } } -func (s *CandleUpdates) handleSubscription(subscriptions map[string]chan entities.Candle, subscription subscriptionMsg, lastCandle *entities.Candle) { +func (s *CandleUpdates) handleSubscription(subscriptions map[string]chan entities.Candle, subscription subscriptionMsg, lastCandle *entities.Candle) map[string]chan entities.Candle { if subscription.subscribe { - s.addSubscription(subscriptions, subscription, lastCandle) - } else { - removeSubscription(subscriptions, subscription.id) + return s.addSubscription(subscriptions, subscription, lastCandle) } + return removeSubscription(subscriptions, subscription.id) } -func (s *CandleUpdates) addSubscription(subscriptions map[string]chan entities.Candle, subscription subscriptionMsg, lastCandle *entities.Candle) { - subscriptions[subscription.id] = subscription.out +func (s *CandleUpdates) addSubscription(subscriptions map[string]chan entities.Candle, subscription subscriptionMsg, lastCandle *entities.Candle) map[string]chan entities.Candle { if lastCandle != nil { - s.sendCandlesToSubscribers([]entities.Candle{*lastCandle}, map[string]chan entities.Candle{subscription.id: subscription.out}) + if rm := s.sendCandlesToSubscribers([]entities.Candle{*lastCandle}, map[string]chan entities.Candle{subscription.id: subscription.out}); len(rm) == 0 { + // try to send the last candle data to the new subscription, if it fails, don't update the map + return subscriptions + } } + subscriptions[subscription.id] = subscription.out + return subscriptions } -func removeSubscription(subscriptions map[string]chan entities.Candle, subscriptionID string) { - if _, ok := subscriptions[subscriptionID]; ok { - close(subscriptions[subscriptionID]) +func removeSubscription(subscriptions map[string]chan entities.Candle, subscriptionID string) map[string]chan entities.Candle { + if ch, ok := subscriptions[subscriptionID]; ok { + // first delete delete(subscriptions, subscriptionID) + // then close + close(ch) } + return subscriptions } func closeAllSubscriptions(subscribers map[string]chan entities.Candle) { @@ -215,15 +221,17 @@ func (s *CandleUpdates) getCandleUpdates(ctx context.Context, lastCandle *entiti return updates, nil } -func (s *CandleUpdates) sendCandlesToSubscribers(candles []entities.Candle, subscriptions map[string]chan entities.Candle) { +func (s *CandleUpdates) sendCandlesToSubscribers(candles []entities.Candle, subscriptions map[string]chan entities.Candle) map[string]chan entities.Candle { + ret := subscriptions for subscriptionID, outCh := range subscriptions { for _, candle := range candles { select { case outCh <- candle: default: - removeSubscription(subscriptions, subscriptionID) + ret = removeSubscription(subscriptions, subscriptionID) break } } } + return ret } diff --git a/datanode/candlesv2/candle_updates_test.go b/datanode/candlesv2/candle_updates_test.go index 840c385907..8bfb6c356e 100644 --- a/datanode/candlesv2/candle_updates_test.go +++ b/datanode/candlesv2/candle_updates_test.go @@ -18,6 +18,7 @@ package candlesv2_test import ( "context" "fmt" + "sync" "testing" "time" @@ -28,6 +29,7 @@ import ( "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type nonReturningCandleSource struct{} @@ -67,14 +69,74 @@ func (t *testCandleSource) GetCandleDataForTimeSpan(ctx context.Context, candleI } } +func TestSubscribeAndUnsubscribeCloseChannelPanic(t *testing.T) { + testCandleSource := &testCandleSource{candles: make(chan []entities.Candle, 3), errorCh: make(chan error)} + // ensure the sub channels are buffered + updates := candlesv2.NewCandleUpdates(context.Background(), logging.NewTestLogger(), "testCandles", + testCandleSource, newTestCandleConfig(5).CandleUpdates) + startTime := time.Now() + + updated := startTime + firstCandle := createCandle(startTime, updated, 1, 1, 1, 1, 10, 200) + lastCandle := firstCandle // just for the sake of types + fCh := make(chan struct{}) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + testCandleSource.candles <- []entities.Candle{firstCandle} + close(fCh) + // keep updating the most recent candle + for i := 0; i < 3; i++ { + updated = updated.Add(time.Second * time.Duration(i)) + lastCandle = createCandle(startTime, updated, 1, 1, 1, 1, 10, 200) + testCandleSource.candles <- []entities.Candle{lastCandle} + } + }() + <-fCh + // ensure the first candle is sent + sub1Id, out1, _ := updates.Subscribe() + sub2Id, out2, _ := updates.Subscribe() + + candle1 := <-out1 + assert.Equal(t, firstCandle, candle1) + + candle2 := <-out2 + assert.Equal(t, firstCandle, candle2) + + // unsubscribe the first subscriber + updates.Unsubscribe(sub1Id) + // now wait for the updates: + wg.Wait() + sub3Id, out3, _ := updates.Subscribe() + candle3 := <-out3 + require.Equal(t, lastCandle, candle3) + // this should unsubscribe sub2 already + testCandleSource.errorCh <- fmt.Errorf("transient error") + + // this sub should get instantly unsubscribed. + errSub, eOut, _ := updates.Subscribe() + require.NotNil(t, eOut) + // reading from the channel should indicate it was closed already due to the error + // once the channel is closed, the subscriber effectively has to have been removed. + _, closed := <-eOut + require.True(t, closed) + // we can still safely call unsubscribe, though: + updates.Unsubscribe(errSub) + updates.Unsubscribe(sub2Id) + updates.Unsubscribe(sub3Id) +} + func TestSubscribeAndUnsubscribeWhenCandleSourceErrorsAlways(t *testing.T) { errorsAlwaysCandleSource := &errorsAlwaysCandleSource{} updates := candlesv2.NewCandleUpdates(context.Background(), logging.NewTestLogger(), "testCandles", errorsAlwaysCandleSource, newTestCandleConfig(0).CandleUpdates) - sub1Id, _, _ := updates.Subscribe() - sub2Id, _, _ := updates.Subscribe() + sub1Id, _, err1 := updates.Subscribe() + sub2Id, _, err2 := updates.Subscribe() + require.NoError(t, err1) + require.NoError(t, err2) updates.Unsubscribe(sub1Id) updates.Unsubscribe(sub2Id)