diff --git a/chainio/dispatcher.go b/chainio/dispatcher.go index 3fdd3fa63e..8012f8905b 100644 --- a/chainio/dispatcher.go +++ b/chainio/dispatcher.go @@ -9,6 +9,7 @@ import ( "github.com/btcsuite/btclog/v2" "github.com/lightningnetwork/lnd/chainntnfs" + "golang.org/x/sync/errgroup" ) // DefaultProcessBlockTimeout is the timeout value used when waiting for one @@ -171,41 +172,34 @@ func (b *BlockbeatDispatcher) dispatchBlocks( // notifyQueues notifies each queue concurrently about the latest block epoch. func (b *BlockbeatDispatcher) notifyQueues() error { - // errChans is a map of channels that will be used to receive errors - // returned from notifying the consumers. - errChans := make(map[uint32]chan error, len(b.consumerQueues)) + eg := &errgroup.Group{} // Notify each queue in goroutines. for qid, consumers := range b.consumerQueues { b.log().Debugf("Notifying queue=%d with %d consumers", qid, len(consumers)) - // Create a signal chan. - errChan := make(chan error, 1) - errChans[qid] = errChan - // Notify each queue concurrently. - go func(qid uint32, c []Consumer, beat Blockbeat) { + eg.Go(func() error { // Notify each consumer in this queue sequentially. - errChan <- DispatchSequential(beat, c) - }(qid, consumers, b.beat) - } + err := DispatchSequential(b.beat, consumers) - // Wait for all consumers in each queue to finish. - for qid, errChan := range errChans { - select { - case err := <-errChan: - if err != nil { - return fmt.Errorf("queue=%d got err: %w", qid, - err) + // Exit early if there's no error. + if err == nil { + return nil } - b.log().Debugf("Notified queue=%d", qid) + return fmt.Errorf("queue=%d got err: %w", qid, err) + }) + } - case <-b.quit: - } + // Wait for all consumers in each queue to finish. + if err := eg.Wait(); err != nil { + return err } + b.log().Debugf("Notified all queues") + return nil } @@ -227,34 +221,30 @@ func DispatchSequential(b Blockbeat, consumers []Consumer) error { // DispatchConcurrent notifies each consumer concurrently about the blockbeat. func DispatchConcurrent(b Blockbeat, consumers []Consumer) error { - // errChans is a map of channels that will be used to receive errors - // returned from notifying the consumers. - errChans := make(map[string]chan error, len(consumers)) + eg := &errgroup.Group{} // Notify each queue in goroutines. for _, c := range consumers { - // Create a signal chan. - errChan := make(chan error, 1) - errChans[c.Name()] = errChan - // Notify each consumer concurrently. - go func(c Consumer, beat Blockbeat) { + eg.Go(func() error { // Send the copy of the beat to the consumer. - errChan <- notifyAndWait( - b, c, DefaultProcessBlockTimeout, - ) - }(c, b) - } + err := notifyAndWait(b, c, DefaultProcessBlockTimeout) + + // Exit early if there's no error. + if err == nil { + return nil + } - // Wait for all consumers in each queue to finish. - for name, errChan := range errChans { - err := <-errChan - if err != nil { b.logger().Errorf("Consumer=%v failed to process "+ - "block: %v", name, err) + "block: %v", c.Name(), err) return err - } + }) + } + + // Wait for all consumers in each queue to finish. + if err := eg.Wait(); err != nil { + return err } return nil