diff --git a/channeldb/graph.go b/channeldb/graph.go index 2333bc94c6..2eaf08edf6 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -10,6 +10,7 @@ import ( "io" "math" "net" + "sort" "sync" "time" @@ -1704,12 +1705,25 @@ func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { return newChanIDs, nil } +// BlockChannelRange represents a range of channels for a given block height. +type BlockChannelRange struct { + // Height is the height of the block all of the channels below were + // included in. + Height uint32 + + // Channels is the list of channels identified by their short ID + // representation known to us that were included in the block height + // above. + Channels []lnwire.ShortChannelID +} + // FilterChannelRange returns the channel ID's of all known channels which were -// mined in a block height within the passed range. This method can be used to -// quickly share with a peer the set of channels we know of within a particular -// range to catch them up after a period of time offline. -func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint64, error) { - var chanIDs []uint64 +// mined in a block height within the passed range. The channel IDs are grouped +// by their common block height. This method can be used to quickly share with a +// peer the set of channels we know of within a particular range to catch them +// up after a period of time offline. +func (c *ChannelGraph) FilterChannelRange(startHeight, + endHeight uint32) ([]BlockChannelRange, error) { startChanID := &lnwire.ShortChannelID{ BlockHeight: startHeight, @@ -1728,6 +1742,7 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint byteOrder.PutUint64(chanIDStart[:], startChanID.ToUint64()) byteOrder.PutUint64(chanIDEnd[:], endChanID.ToUint64()) + var channelsPerBlock map[uint32][]lnwire.ShortChannelID err := kvdb.View(c.db, func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) if edges == nil { @@ -1742,33 +1757,51 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint // We'll now iterate through the database, and find each // channel ID that resides within the specified range. - var cid uint64 for k, _ := cursor.Seek(chanIDStart[:]); k != nil && bytes.Compare(k, chanIDEnd[:]) <= 0; k, _ = cursor.Next() { // This channel ID rests within the target range, so - // we'll convert it into an integer and add it to our - // returned set. - cid = byteOrder.Uint64(k) - chanIDs = append(chanIDs, cid) + // we'll add it to our returned set. + rawCid := byteOrder.Uint64(k) + cid := lnwire.NewShortChanIDFromInt(rawCid) + channelsPerBlock[cid.BlockHeight] = append( + channelsPerBlock[cid.BlockHeight], cid, + ) } return nil }, func() { - chanIDs = nil + channelsPerBlock = make(map[uint32][]lnwire.ShortChannelID) }) switch { // If we don't know of any channels yet, then there's nothing to // filter, so we'll return an empty slice. - case err == ErrGraphNoEdgesFound: - return chanIDs, nil + case err == ErrGraphNoEdgesFound || len(channelsPerBlock) == 0: + return nil, nil case err != nil: return nil, err } - return chanIDs, nil + // Return the channel ranges in ascending block height order. + blocks := make([]uint32, 0, len(channelsPerBlock)) + for block := range channelsPerBlock { + blocks = append(blocks, block) + } + sort.Slice(blocks, func(i, j int) bool { + return blocks[i] < blocks[j] + }) + + channelRanges := make([]BlockChannelRange, 0, len(channelsPerBlock)) + for _, block := range blocks { + channelRanges = append(channelRanges, BlockChannelRange{ + Height: block, + Channels: channelsPerBlock[block], + }) + } + + return channelRanges, nil } // FetchChanInfos returns the set of channel edges that correspond to the passed diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 2abdcc8e40..331d176998 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -1848,24 +1848,32 @@ func TestFilterChannelRange(t *testing.T) { t.Fatalf("expected zero chans, instead got %v", len(resp)) } - // To start, we'll create a set of channels, each mined in a block 10 + // To start, we'll create a set of channels, two mined in a block 10 // blocks after the prior one. startHeight := uint32(100) endHeight := startHeight const numChans = 10 - chanIDs := make([]uint64, 0, numChans) - for i := 0; i < numChans; i++ { + channelRanges := make([]BlockChannelRange, 0, numChans/2) + for i := 0; i < numChans/2; i++ { chanHeight := endHeight - channel, chanID := createEdge( - uint32(chanHeight), uint32(i+1), 0, 0, node1, node2, + channel1, chanID1 := createEdge( + chanHeight, uint32(i+1), 0, 0, node1, node2, ) - - if err := graph.AddChannelEdge(&channel); err != nil { + if err := graph.AddChannelEdge(&channel1); err != nil { t.Fatalf("unable to create channel edge: %v", err) } - chanIDs = append(chanIDs, chanID.ToUint64()) + channel2, chanID2 := createEdge( + chanHeight, uint32(i+2), 0, 0, node1, node2, + ) + if err := graph.AddChannelEdge(&channel2); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + channelRanges = append(channelRanges, BlockChannelRange{ + Height: chanHeight, + Channels: []lnwire.ShortChannelID{chanID1, chanID2}, + }) endHeight += 10 } @@ -1876,7 +1884,7 @@ func TestFilterChannelRange(t *testing.T) { startHeight uint32 endHeight uint32 - resp []uint64 + resp []BlockChannelRange }{ // If we query for the entire range, then we should get the same // set of short channel IDs back. @@ -1884,7 +1892,7 @@ func TestFilterChannelRange(t *testing.T) { startHeight: startHeight, endHeight: endHeight, - resp: chanIDs, + resp: channelRanges, }, // If we query for a range of channels right before our range, we @@ -1900,7 +1908,7 @@ func TestFilterChannelRange(t *testing.T) { startHeight: endHeight - 10, endHeight: endHeight - 10, - resp: chanIDs[9:], + resp: channelRanges[4:], }, // If we query for just the first height, we should only get a @@ -1909,7 +1917,14 @@ func TestFilterChannelRange(t *testing.T) { startHeight: startHeight, endHeight: startHeight, - resp: chanIDs[:1], + resp: channelRanges[:1], + }, + + { + startHeight: startHeight + 10, + endHeight: endHeight - 10, + + resp: channelRanges[1:5], }, } for i, queryCase := range queryCases { diff --git a/discovery/chan_series.go b/discovery/chan_series.go index ffb59b4ef5..42ebe88892 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -39,10 +39,11 @@ type ChannelGraphTimeSeries interface { superSet []lnwire.ShortChannelID) ([]lnwire.ShortChannelID, error) // FilterChannelRange returns the set of channels that we created - // between the start height and the end height. We'll use this to to a - // remote peer's QueryChannelRange message. + // between the start height and the end height. The channel IDs are + // grouped by their common block height. We'll use this to to a remote + // peer's QueryChannelRange message. FilterChannelRange(chain chainhash.Hash, - startHeight, endHeight uint32) ([]lnwire.ShortChannelID, error) + startHeight, endHeight uint32) ([]channeldb.BlockChannelRange, error) // FetchChanAnns returns a full set of channel announcements as well as // their updates that match the set of specified short channel ID's. @@ -203,26 +204,15 @@ func (c *ChanSeries) FilterKnownChanIDs(chain chainhash.Hash, } // FilterChannelRange returns the set of channels that we created between the -// start height and the end height. We'll use this respond to a remote peer's -// QueryChannelRange message. +// start height and the end height. The channel IDs are grouped by their common +// block height. We'll use this respond to a remote peer's QueryChannelRange +// message. // // NOTE: This is part of the ChannelGraphTimeSeries interface. func (c *ChanSeries) FilterChannelRange(chain chainhash.Hash, - startHeight, endHeight uint32) ([]lnwire.ShortChannelID, error) { + startHeight, endHeight uint32) ([]channeldb.BlockChannelRange, error) { - chansInRange, err := c.graph.FilterChannelRange(startHeight, endHeight) - if err != nil { - return nil, err - } - - chanResp := make([]lnwire.ShortChannelID, 0, len(chansInRange)) - for _, chanID := range chansInRange { - chanResp = append( - chanResp, lnwire.NewShortChanIDFromInt(chanID), - ) - } - - return chanResp, nil + return c.graph.FilterChannelRange(startHeight, endHeight) } // FetchChanAnns returns a full set of channel announcements as well as their diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 845bce469d..b29c23e813 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -265,16 +265,6 @@ type AuthenticatedGossiper struct { // every new block height. blockEpochs *chainntnfs.BlockEpochEvent - // prematureAnnouncements maps a block height to a set of network - // messages which are "premature" from our PoV. A message is premature - // if it claims to be anchored in a block which is beyond the current - // main chain tip as we know it. Premature network messages will be - // processed once the chain tip as we know it extends to/past the - // premature height. - // - // TODO(roasbeef): limit premature networkMsgs to N - prematureAnnouncements map[uint32][]*networkMsg - // prematureChannelUpdates is a map of ChannelUpdates we have received // that wasn't associated with any channel we know about. We store // them temporarily, such that we can reprocess them when a @@ -338,21 +328,22 @@ func New(cfg Config, selfKey *btcec.PublicKey) *AuthenticatedGossiper { networkMsgs: make(chan *networkMsg), quit: make(chan struct{}), chanPolicyUpdates: make(chan *chanPolicyUpdateRequest), - prematureAnnouncements: make(map[uint32][]*networkMsg), prematureChannelUpdates: make(map[uint64][]*networkMsg), channelMtx: multimutex.NewMutex(), recentRejects: make(map[uint64]struct{}), heightForLastChanUpdate: make(map[uint64][2]uint32), - syncMgr: newSyncManager(&SyncManagerCfg{ - ChainHash: cfg.ChainHash, - ChanSeries: cfg.ChanSeries, - RotateTicker: cfg.RotateTicker, - HistoricalSyncTicker: cfg.HistoricalSyncTicker, - NumActiveSyncers: cfg.NumActiveSyncers, - IgnoreHistoricalFilters: cfg.IgnoreHistoricalFilters, - }), } + gossiper.syncMgr = newSyncManager(&SyncManagerCfg{ + ChainHash: cfg.ChainHash, + ChanSeries: cfg.ChanSeries, + RotateTicker: cfg.RotateTicker, + HistoricalSyncTicker: cfg.HistoricalSyncTicker, + NumActiveSyncers: cfg.NumActiveSyncers, + IgnoreHistoricalFilters: cfg.IgnoreHistoricalFilters, + BestHeight: gossiper.latestHeight, + }) + gossiper.reliableSender = newReliableSender(&reliableSenderCfg{ NotifyWhenOnline: cfg.NotifyWhenOnline, NotifyWhenOffline: cfg.NotifyWhenOffline, @@ -1045,33 +1036,11 @@ func (d *AuthenticatedGossiper) networkHandler() { d.Lock() blockHeight := uint32(newBlock.Height) d.bestHeight = blockHeight + d.Unlock() log.Debugf("New block: height=%d, hash=%s", blockHeight, newBlock.Hash) - // Next we check if we have any premature announcements - // for this height, if so, then we process them once - // more as normal announcements. - premature := d.prematureAnnouncements[blockHeight] - if len(premature) == 0 { - d.Unlock() - continue - } - delete(d.prematureAnnouncements, blockHeight) - d.Unlock() - - log.Infof("Re-processing %v premature announcements "+ - "for height %v", len(premature), blockHeight) - - for _, ann := range premature { - emittedAnnouncements := d.processNetworkAnnouncement(ann) - if emittedAnnouncements != nil { - announcements.AddMsgs( - emittedAnnouncements..., - ) - } - } - // The trickle timer has ticked, which indicates we should // flush to the network the pending batch of new announcements // we've received since the last trickle tick. @@ -1501,7 +1470,6 @@ func (d *AuthenticatedGossiper) addNode(msg *lnwire.NodeAnnouncement) error { func (d *AuthenticatedGossiper) processNetworkAnnouncement( nMsg *networkMsg) []networkMsg { - // isPremature *MUST* be called with the gossiper's lock held. isPremature := func(chanID lnwire.ShortChannelID, delta uint32) bool { // TODO(roasbeef) make height delta 6 // * or configurable @@ -1593,18 +1561,12 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // to be fully verified once we advance forward in the chain. d.Lock() if nMsg.isRemote && isPremature(msg.ShortChannelID, 0) { - blockHeight := msg.ShortChannelID.BlockHeight log.Infof("Announcement for chan_id=(%v), is "+ "premature: advertises height %v, only "+ "height %v is known", msg.ShortChannelID.ToUint64(), msg.ShortChannelID.BlockHeight, d.bestHeight) - - d.prematureAnnouncements[blockHeight] = append( - d.prematureAnnouncements[blockHeight], - nMsg, - ) d.Unlock() return nil } @@ -1824,11 +1786,6 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( "height %v, only height %v is known", shortChanID, blockHeight, d.bestHeight) - - d.prematureAnnouncements[blockHeight] = append( - d.prematureAnnouncements[blockHeight], - nMsg, - ) d.Unlock() return nil } @@ -2124,10 +2081,6 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // to other clients if this constraint was changed. d.Lock() if isPremature(msg.ShortChannelID, d.cfg.ProofMatureDelta) { - d.prematureAnnouncements[needBlockHeight] = append( - d.prematureAnnouncements[needBlockHeight], - nMsg, - ) log.Infof("Premature proof announcement, "+ "current block height lower than needed: %v <"+ " %v, add announcement to reprocessing batch", @@ -2644,3 +2597,10 @@ func IsKeepAliveUpdate(update *lnwire.ChannelUpdate, } return true } + +// latestHeight returns the gossiper's latest height known of the chain. +func (d *AuthenticatedGossiper) latestHeight() uint32 { + d.Lock() + defer d.Unlock() + return d.bestHeight +} diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 728cf1ec89..c3ed17b1d6 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -918,8 +918,7 @@ func TestPrematureAnnouncement(t *testing.T) { // Pretending that we receive the valid channel update announcement from // remote side, but block height of this announcement is greater than - // highest know to us, for that reason it should be added to the - // repeat/premature batch. + // highest known to us, so it should be rejected. ua, err := createUpdateAnnouncement(1, 0, nodeKeyPriv1, timestamp) if err != nil { t.Fatalf("can't create update announcement: %v", err) @@ -934,31 +933,6 @@ func TestPrematureAnnouncement(t *testing.T) { if len(ctx.router.edges) != 0 { t.Fatal("edge update was added to router") } - - // Generate new block and waiting the previously added announcements - // to be proceeded. - newBlock := &wire.MsgBlock{} - ctx.notifier.notifyBlock(newBlock.Header.BlockHash(), 1) - - select { - case <-ctx.broadcastedMessage: - case <-time.After(2 * trickleDelay): - t.Fatal("announcement wasn't broadcasted") - } - - if len(ctx.router.infos) != 1 { - t.Fatalf("edge wasn't added to router: %v", err) - } - - select { - case <-ctx.broadcastedMessage: - case <-time.After(2 * trickleDelay): - t.Fatal("announcement wasn't broadcasted") - } - - if len(ctx.router.edges) != 1 { - t.Fatalf("edge update wasn't added to router: %v", err) - } } // TestSignatureAnnouncementLocalFirst ensures that the AuthenticatedGossiper diff --git a/discovery/sync_manager.go b/discovery/sync_manager.go index daebcc35a8..a0e73b0663 100644 --- a/discovery/sync_manager.go +++ b/discovery/sync_manager.go @@ -89,6 +89,9 @@ type SyncManagerCfg struct { // This prevents ranges with old start times from causing us to dump the // graph on connect. IgnoreHistoricalFilters bool + + // BestHeight returns the latest height known of the chain. + BestHeight func() uint32 } // SyncManager is a subsystem of the gossiper that manages the gossip syncers @@ -419,7 +422,11 @@ func (m *SyncManager) createGossipSyncer(peer lnpeer.Peer) *GossipSyncer { sendToPeerSync: func(msgs ...lnwire.Message) error { return peer.SendMessageLazy(true, msgs...) }, - ignoreHistoricalFilters: m.cfg.IgnoreHistoricalFilters, + ignoreHistoricalFilters: m.cfg.IgnoreHistoricalFilters, + maxUndelayedQueryReplies: DefaultMaxUndelayedQueryReplies, + delayedQueryReplyInterval: DefaultDelayedQueryReplyInterval, + bestHeight: m.cfg.BestHeight, + maxQueryChanRangeReplies: maxQueryChanRangeReplies, }) // Gossip syncers are initialized by default in a PassiveSync type diff --git a/discovery/sync_manager_test.go b/discovery/sync_manager_test.go index c7a228f8cf..ac721868d8 100644 --- a/discovery/sync_manager_test.go +++ b/discovery/sync_manager_test.go @@ -2,7 +2,6 @@ package discovery import ( "fmt" - "math" "reflect" "sync/atomic" "testing" @@ -12,6 +11,7 @@ import ( "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/ticker" + "github.com/stretchr/testify/require" ) // randPeer creates a random peer. @@ -34,6 +34,9 @@ func newTestSyncManager(numActiveSyncers int) *SyncManager { RotateTicker: ticker.NewForce(DefaultSyncerRotationInterval), HistoricalSyncTicker: ticker.NewForce(DefaultHistoricalSyncInterval), NumActiveSyncers: numActiveSyncers, + BestHeight: func() uint32 { + return latestKnownHeight + }, }) } @@ -202,7 +205,7 @@ func TestSyncManagerInitialHistoricalSync(t *testing.T) { syncMgr.InitSyncState(peer) assertMsgSent(t, peer, &lnwire.QueryChannelRange{ FirstBlockHeight: 0, - NumBlocks: math.MaxUint32, + NumBlocks: latestKnownHeight, }) // The graph should not be considered as synced since the initial @@ -290,7 +293,7 @@ func TestSyncManagerForceHistoricalSync(t *testing.T) { syncMgr.InitSyncState(peer) assertMsgSent(t, peer, &lnwire.QueryChannelRange{ FirstBlockHeight: 0, - NumBlocks: math.MaxUint32, + NumBlocks: latestKnownHeight, }) // If an additional peer connects, then a historical sync should not be @@ -305,7 +308,7 @@ func TestSyncManagerForceHistoricalSync(t *testing.T) { syncMgr.cfg.HistoricalSyncTicker.(*ticker.Force).Force <- time.Time{} assertMsgSent(t, extraPeer, &lnwire.QueryChannelRange{ FirstBlockHeight: 0, - NumBlocks: math.MaxUint32, + NumBlocks: latestKnownHeight, }) } @@ -326,7 +329,7 @@ func TestSyncManagerGraphSyncedAfterHistoricalSyncReplacement(t *testing.T) { syncMgr.InitSyncState(peer) assertMsgSent(t, peer, &lnwire.QueryChannelRange{ FirstBlockHeight: 0, - NumBlocks: math.MaxUint32, + NumBlocks: latestKnownHeight, }) // The graph should not be considered as synced since the initial @@ -531,14 +534,18 @@ func assertTransitionToChansSynced(t *testing.T, s *GossipSyncer, peer *mockPeer query := &lnwire.QueryChannelRange{ FirstBlockHeight: 0, - NumBlocks: math.MaxUint32, + NumBlocks: latestKnownHeight, } assertMsgSent(t, peer, query) - s.ProcessQueryMsg(&lnwire.ReplyChannelRange{ + require.Eventually(t, func() bool { + return s.syncState() == waitingQueryRangeReply + }, time.Second, 500*time.Millisecond) + + require.NoError(t, s.ProcessQueryMsg(&lnwire.ReplyChannelRange{ QueryChannelRange: *query, Complete: 1, - }, nil) + }, nil)) chanSeries := s.cfg.channelSeries.(*mockChannelGraphTimeSeries) diff --git a/discovery/syncer.go b/discovery/syncer.go index 10b6d4205d..04a722f221 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" "math" + "math/rand" + "sort" "sync" "sync/atomic" "time" @@ -128,6 +130,14 @@ const ( // maxUndelayedQueryReplies queries. DefaultDelayedQueryReplyInterval = 5 * time.Second + // maxQueryChanRangeReplies specifies the default limit of replies to + // process for a single QueryChannelRange request. + maxQueryChanRangeReplies = 500 + + // maxQueryChanRangeRepliesZlibFactor specifies the factor applied to + // the maximum number of replies allowed for zlib encoded replies. + maxQueryChanRangeRepliesZlibFactor = 4 + // chanRangeQueryBuffer is the number of blocks back that we'll go when // asking the remote peer for their any channels they know of beyond // our highest known channel ID. @@ -237,6 +247,13 @@ type gossipSyncerCfg struct { // This prevents ranges with old start times from causing us to dump the // graph on connect. ignoreHistoricalFilters bool + + // bestHeight returns the latest height known of the chain. + bestHeight func() uint32 + + // maxQueryChanRangeReplies is the maximum number of replies we'll allow + // for a single QueryChannelRange request. + maxQueryChanRangeReplies uint32 } // GossipSyncer is a struct that handles synchronizing the channel graph state @@ -313,6 +330,11 @@ type GossipSyncer struct { // buffer all the chunked response to our query. bufferedChanRangeReplies []lnwire.ShortChannelID + // numChanRangeRepliesRcvd is used to track the number of replies + // received as part of a QueryChannelRange. This field is primarily used + // within the waitingQueryChanReply state. + numChanRangeRepliesRcvd uint32 + // newChansToQuery is used to pass the set of channels we should query // for from the waitingQueryChanReply state to the queryNewChannels // state. @@ -738,17 +760,27 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro g.bufferedChanRangeReplies = append( g.bufferedChanRangeReplies, msg.ShortChanIDs..., ) + switch g.cfg.encodingType { + case lnwire.EncodingSortedPlain: + g.numChanRangeRepliesRcvd++ + case lnwire.EncodingSortedZlib: + g.numChanRangeRepliesRcvd += maxQueryChanRangeRepliesZlibFactor + default: + return fmt.Errorf("unhandled encoding type %v", g.cfg.encodingType) + } log.Infof("GossipSyncer(%x): buffering chan range reply of size=%v", g.cfg.peerPub[:], len(msg.ShortChanIDs)) - // If this isn't the last response, then we can exit as we've already - // buffered the latest portion of the streaming reply. + // If this isn't the last response and we can continue to receive more, + // then we can exit as we've already buffered the latest portion of the + // streaming reply. + maxReplies := g.cfg.maxQueryChanRangeReplies switch { // If we're communicating with a legacy node, we'll need to look at the // complete field. case isLegacyReplyChannelRange(g.curQueryRangeMsg, msg): - if msg.Complete == 0 { + if msg.Complete == 0 && g.numChanRangeRepliesRcvd < maxReplies { return nil } @@ -760,7 +792,8 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro // TODO(wilmer): This might require some padding if the remote // node is not aware of the last height we sent them, i.e., is // behind a few blocks from us. - if replyLastHeight < queryLastHeight { + if replyLastHeight < queryLastHeight && + g.numChanRangeRepliesRcvd < maxReplies { return nil } } @@ -783,6 +816,7 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro g.curQueryRangeMsg = nil g.prevReplyChannelRange = nil g.bufferedChanRangeReplies = nil + g.numChanRangeRepliesRcvd = 0 // If there aren't any channels that we don't know of, then we can // switch straight to our terminal state. @@ -834,9 +868,17 @@ func (g *GossipSyncer) genChanRangeQuery( startHeight = uint32(newestChan.BlockHeight - chanRangeQueryBuffer) } + // Determine the number of blocks to request based on our best height. + // We'll take into account any potential underflows and explicitly set + // numBlocks to its minimum value of 1 if so. + bestHeight := g.cfg.bestHeight() + numBlocks := bestHeight - startHeight + if int64(numBlocks) < 1 { + numBlocks = 1 + } + log.Infof("GossipSyncer(%x): requesting new chans from height=%v "+ - "and %v blocks after", g.cfg.peerPub[:], startHeight, - math.MaxUint32-startHeight) + "and %v blocks after", g.cfg.peerPub[:], startHeight, numBlocks) // Finally, we'll craft the channel range query, using our starting // height, then asking for all known channels to the foreseeable end of @@ -844,7 +886,7 @@ func (g *GossipSyncer) genChanRangeQuery( query := &lnwire.QueryChannelRange{ ChainHash: g.cfg.chainHash, FirstBlockHeight: startHeight, - NumBlocks: math.MaxUint32 - startHeight, + NumBlocks: numBlocks, } g.curQueryRangeMsg = query @@ -919,7 +961,7 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro // channel ID's that match their query. startBlock := query.FirstBlockHeight endBlock := query.LastBlockHeight() - channelRange, err := g.cfg.channelSeries.FilterChannelRange( + channelRanges, err := g.cfg.channelSeries.FilterChannelRange( query.ChainHash, startBlock, endBlock, ) if err != nil { @@ -929,102 +971,98 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro // TODO(roasbeef): means can't send max uint above? // * or make internal 64 - // In the base case (no actual response) the first block and last block - // will match those of the query. In the loop below, we'll update these - // two variables incrementally with each chunk to properly compute the - // starting block for each response and the number of blocks in a - // response. - firstBlockHeight := startBlock - lastBlockHeight := endBlock - - numChannels := int32(len(channelRange)) - numChansSent := int32(0) - for { - // We'll send our this response in a streaming manner, - // chunk-by-chunk. We do this as there's a transport message - // size limit which we'll need to adhere to. - var channelChunk []lnwire.ShortChannelID - - // We know this is the final chunk, if the difference between - // the total number of channels, and the number of channels - // we've sent is less-than-or-equal to the chunk size. - isFinalChunk := (numChannels - numChansSent) <= g.cfg.chunkSize - - // If this is indeed the last chunk, then we'll send the - // remainder of the channels. - if isFinalChunk { - channelChunk = channelRange[numChansSent:] - - log.Infof("GossipSyncer(%x): sending final chan "+ - "range chunk, size=%v", g.cfg.peerPub[:], - len(channelChunk)) - } else { - // Otherwise, we'll only send off a fragment exactly - // sized to the proper chunk size. - channelChunk = channelRange[numChansSent : numChansSent+g.cfg.chunkSize] - - log.Infof("GossipSyncer(%x): sending range chunk of "+ - "size=%v", g.cfg.peerPub[:], len(channelChunk)) - } - - // If we have any channels at all to return, then we need to - // update our pointers to the first and last blocks for each - // response. - if len(channelChunk) > 0 { - // If this is the first response we'll send, we'll point - // the first block to the first block in the query. - // Otherwise, we'll continue from the block we left off - // at. - if numChansSent == 0 { - firstBlockHeight = startBlock - } else { - firstBlockHeight = lastBlockHeight - } - - // If this is the last response we'll send, we'll point - // the last block to the last block of the query. - // Otherwise, we'll set it to the height of the last - // channel in the chunk. - if isFinalChunk { - lastBlockHeight = endBlock - } else { - lastBlockHeight = channelChunk[len(channelChunk)-1].BlockHeight - } - } + // We'll send our response in a streaming manner, chunk-by-chunk. We do + // this as there's a transport message size limit which we'll need to + // adhere to. We also need to make sure all of our replies cover the + // expected range of the query. + sendReplyForChunk := func(channelChunk []lnwire.ShortChannelID, + firstHeight, lastHeight uint32, finalChunk bool) error { - // The number of blocks contained in this response (the total - // span) is the difference between the last channel ID and the - // first in the range. We add one as even if all channels + // The number of blocks contained in the current chunk (the + // total span) is the difference between the last channel ID and + // the first in the range. We add one as even if all channels // returned are in the same block, we need to count that. - numBlocksInResp := lastBlockHeight - firstBlockHeight + 1 + numBlocks := lastHeight - firstHeight + 1 + complete := uint8(0) + if finalChunk { + complete = 1 + } - // With our chunk assembled, we'll now send to the remote peer - // the current chunk. - replyChunk := lnwire.ReplyChannelRange{ + return g.cfg.sendToPeerSync(&lnwire.ReplyChannelRange{ QueryChannelRange: lnwire.QueryChannelRange{ ChainHash: query.ChainHash, - NumBlocks: numBlocksInResp, - FirstBlockHeight: firstBlockHeight, + NumBlocks: numBlocks, + FirstBlockHeight: firstHeight, }, - Complete: 0, + Complete: complete, EncodingType: g.cfg.encodingType, ShortChanIDs: channelChunk, + }) + } + + var ( + firstHeight = query.FirstBlockHeight + lastHeight uint32 + channelChunk []lnwire.ShortChannelID + ) + for _, channelRange := range channelRanges { + channels := channelRange.Channels + numChannels := int32(len(channels)) + numLeftToAdd := g.cfg.chunkSize - int32(len(channelChunk)) + + // Include the current block in the ongoing chunk if it can fit + // and move on to the next block. + if numChannels <= numLeftToAdd { + channelChunk = append(channelChunk, channels...) + continue } - if isFinalChunk { - replyChunk.Complete = 1 - } - if err := g.cfg.sendToPeerSync(&replyChunk); err != nil { + + // Otherwise, we need to send our existing channel chunk as is + // as its own reply and start a new one for the current block. + // We'll mark the end of our current chunk as the height before + // the current block to ensure the whole query range is replied + // to. + log.Infof("GossipSyncer(%x): sending range chunk of size=%v", + g.cfg.peerPub[:], len(channelChunk)) + lastHeight = channelRange.Height - 1 + err := sendReplyForChunk( + channelChunk, firstHeight, lastHeight, false, + ) + if err != nil { return err } - // If this was the final chunk, then we'll exit now as our - // response is now complete. - if isFinalChunk { - return nil + // With the reply constructed, we'll start tallying channels for + // our next one keeping in mind our chunk size. This may result + // in channels for this block being left out from the reply, but + // this isn't an issue since we'll randomly shuffle them and we + // assume a historical gossip sync is performed at a later time. + firstHeight = channelRange.Height + chunkSize := numChannels + exceedsChunkSize := numChannels > g.cfg.chunkSize + if exceedsChunkSize { + rand.Shuffle(len(channels), func(i, j int) { + channels[i], channels[j] = channels[j], channels[i] + }) + chunkSize = g.cfg.chunkSize + } + channelChunk = channels[:chunkSize] + + // Sort the chunk once again if we had to shuffle it. + if exceedsChunkSize { + sort.Slice(channelChunk, func(i, j int) bool { + return channelChunk[i].ToUint64() < + channelChunk[j].ToUint64() + }) } - - numChansSent += int32(len(channelChunk)) } + + // Send the remaining chunk as the final reply. + log.Infof("GossipSyncer(%x): sending final chan range chunk, size=%v", + g.cfg.peerPub[:], len(channelChunk)) + return sendReplyForChunk( + channelChunk, firstHeight, query.LastBlockHeight(), true, + ) } // replyShortChanIDs will be dispatched in response to a query by the remote @@ -1274,11 +1312,23 @@ func (g *GossipSyncer) FilterGossipMsgs(msgs ...msgWithSenders) { // ProcessQueryMsg is used by outside callers to pass new channel time series // queries to the internal processing goroutine. -func (g *GossipSyncer) ProcessQueryMsg(msg lnwire.Message, peerQuit <-chan struct{}) { +func (g *GossipSyncer) ProcessQueryMsg(msg lnwire.Message, peerQuit <-chan struct{}) error { var msgChan chan lnwire.Message switch msg.(type) { case *lnwire.QueryChannelRange, *lnwire.QueryShortChanIDs: msgChan = g.queryMsgs + + // Reply messages should only be expected in states where we're waiting + // for a reply. + case *lnwire.ReplyChannelRange, *lnwire.ReplyShortChanIDsEnd: + syncState := g.syncState() + if syncState != waitingQueryRangeReply && + syncState != waitingQueryChanReply { + return fmt.Errorf("received unexpected query reply "+ + "message %T", msg) + } + msgChan = g.gossipMsgs + default: msgChan = g.gossipMsgs } @@ -1288,6 +1338,8 @@ func (g *GossipSyncer) ProcessQueryMsg(msg lnwire.Message, peerQuit <-chan struc case <-peerQuit: case <-g.quit: } + + return nil } // setSyncState sets the gossip syncer's state to the given state. diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 8e99fa49ef..d9da938232 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -5,6 +5,7 @@ import ( "fmt" "math" "reflect" + "sort" "sync" "testing" "time" @@ -12,7 +13,9 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" ) const ( @@ -95,11 +98,36 @@ func (m *mockChannelGraphTimeSeries) FilterKnownChanIDs(chain chainhash.Hash, return <-m.filterResp, nil } func (m *mockChannelGraphTimeSeries) FilterChannelRange(chain chainhash.Hash, - startHeight, endHeight uint32) ([]lnwire.ShortChannelID, error) { + startHeight, endHeight uint32) ([]channeldb.BlockChannelRange, error) { m.filterRangeReqs <- filterRangeReq{startHeight, endHeight} + reply := <-m.filterRangeResp - return <-m.filterRangeResp, nil + channelsPerBlock := make(map[uint32][]lnwire.ShortChannelID) + for _, cid := range reply { + channelsPerBlock[cid.BlockHeight] = append( + channelsPerBlock[cid.BlockHeight], cid, + ) + } + + // Return the channel ranges in ascending block height order. + blocks := make([]uint32, 0, len(channelsPerBlock)) + for block := range channelsPerBlock { + blocks = append(blocks, block) + } + sort.Slice(blocks, func(i, j int) bool { + return blocks[i] < blocks[j] + }) + + channelRanges := make([]channeldb.BlockChannelRange, 0, len(channelsPerBlock)) + for _, block := range blocks { + channelRanges = append(channelRanges, channeldb.BlockChannelRange{ + Height: block, + Channels: channelsPerBlock[block], + }) + } + + return channelRanges, nil } func (m *mockChannelGraphTimeSeries) FetchChanAnns(chain chainhash.Hash, shortChanIDs []lnwire.ShortChannelID) ([]lnwire.Message, error) { @@ -158,6 +186,10 @@ func newTestSyncer(hID lnwire.ShortChannelID, return nil }, delayedQueryReplyInterval: 2 * time.Second, + bestHeight: func() uint32 { + return latestKnownHeight + }, + maxQueryChanRangeReplies: maxQueryChanRangeReplies, } syncer := newGossipSyncer(cfg) @@ -825,6 +857,7 @@ func TestGossipSyncerReplyChanRangeQuery(t *testing.T) { // reply. We should get three sets of messages as two of them should be // full, while the other is the final fragment. const numExpectedChunks = 3 + var prevResp *lnwire.ReplyChannelRange respMsgs := make([]lnwire.ShortChannelID, 0, 5) for i := 0; i < numExpectedChunks; i++ { select { @@ -852,14 +885,14 @@ func TestGossipSyncerReplyChanRangeQuery(t *testing.T) { // channels. case i == 0: expectedFirstBlockHeight = startingBlockHeight - expectedNumBlocks = chunkSize + 1 + expectedNumBlocks = 4 // The last reply should range starting from the next // block of our previous reply up until the ending // height of the query. It should also have the Complete // bit set. case i == numExpectedChunks-1: - expectedFirstBlockHeight = respMsgs[len(respMsgs)-1].BlockHeight + expectedFirstBlockHeight = prevResp.LastBlockHeight() + 1 expectedNumBlocks = endingBlockHeight - expectedFirstBlockHeight + 1 expectedComplete = 1 @@ -867,8 +900,8 @@ func TestGossipSyncerReplyChanRangeQuery(t *testing.T) { // the next block of our previous reply up until it // reaches its maximum capacity of channels. default: - expectedFirstBlockHeight = respMsgs[len(respMsgs)-1].BlockHeight - expectedNumBlocks = 5 + expectedFirstBlockHeight = prevResp.LastBlockHeight() + 1 + expectedNumBlocks = 4 } switch { @@ -886,9 +919,10 @@ func TestGossipSyncerReplyChanRangeQuery(t *testing.T) { case rangeResp.Complete != expectedComplete: t.Fatalf("Complete in resp #%d incorrect: "+ "expected %v, got %v", i+1, - expectedNumBlocks, rangeResp.Complete) + expectedComplete, rangeResp.Complete) } + prevResp = rangeResp respMsgs = append(respMsgs, rangeResp.ShortChanIDs...) } } @@ -1134,9 +1168,9 @@ func TestGossipSyncerGenChanRangeQuery(t *testing.T) { rangeQuery.FirstBlockHeight, startingHeight-chanRangeQueryBuffer) } - if rangeQuery.NumBlocks != math.MaxUint32-firstHeight { + if rangeQuery.NumBlocks != latestKnownHeight-firstHeight { t.Fatalf("wrong num blocks: expected %v, got %v", - math.MaxUint32-firstHeight, rangeQuery.NumBlocks) + latestKnownHeight-firstHeight, rangeQuery.NumBlocks) } // Generating a historical range query should result in a start height @@ -1149,9 +1183,9 @@ func TestGossipSyncerGenChanRangeQuery(t *testing.T) { t.Fatalf("incorrect chan range query: expected %v, %v", 0, rangeQuery.FirstBlockHeight) } - if rangeQuery.NumBlocks != math.MaxUint32 { + if rangeQuery.NumBlocks != latestKnownHeight { t.Fatalf("wrong num blocks: expected %v, got %v", - math.MaxUint32, rangeQuery.NumBlocks) + latestKnownHeight, rangeQuery.NumBlocks) } } @@ -1495,10 +1529,12 @@ func TestGossipSyncerDelayDOS(t *testing.T) { // inherently disjoint. var syncer2Chans []lnwire.ShortChannelID for i := 0; i < numTotalChans; i++ { - syncer2Chans = append(syncer2Chans, lnwire.ShortChannelID{ - BlockHeight: highestID.BlockHeight - 1, - TxIndex: uint32(i), - }) + syncer2Chans = append([]lnwire.ShortChannelID{ + { + BlockHeight: highestID.BlockHeight - uint32(i) - 1, + TxIndex: uint32(i), + }, + }, syncer2Chans...) } // We'll kick off the test by asserting syncer1 sends over the @@ -2234,7 +2270,7 @@ func TestGossipSyncerHistoricalSync(t *testing.T) { // sent to the remote peer with a FirstBlockHeight of 0. expectedMsg := &lnwire.QueryChannelRange{ FirstBlockHeight: 0, - NumBlocks: math.MaxUint32, + NumBlocks: latestKnownHeight, } select { @@ -2302,3 +2338,80 @@ func TestGossipSyncerSyncedSignal(t *testing.T) { t.Fatal("expected to receive chansSynced signal") } } + +// TestGossipSyncerMaxChannelRangeReplies ensures that a gossip syncer +// transitions its state after receiving the maximum possible number of replies +// for a single QueryChannelRange message, and that any further replies after +// said limit are not processed. +func TestGossipSyncerMaxChannelRangeReplies(t *testing.T) { + t.Parallel() + + msgChan, syncer, chanSeries := newTestSyncer( + lnwire.ShortChannelID{BlockHeight: latestKnownHeight}, + defaultEncoding, defaultChunkSize, + ) + + // We'll tune the maxQueryChanRangeReplies to a more sensible value for + // the sake of testing. + syncer.cfg.maxQueryChanRangeReplies = 100 + + syncer.Start() + defer syncer.Stop() + + // Upon initialization, the syncer should submit a QueryChannelRange + // request. + var query *lnwire.QueryChannelRange + select { + case msgs := <-msgChan: + require.Len(t, msgs, 1) + require.IsType(t, &lnwire.QueryChannelRange{}, msgs[0]) + query = msgs[0].(*lnwire.QueryChannelRange) + + case <-time.After(time.Second): + t.Fatal("expected query channel range request msg") + } + + // We'll send the maximum number of replies allowed to a + // QueryChannelRange request with each reply consuming only one block in + // order to transition the syncer's state. + for i := uint32(0); i < syncer.cfg.maxQueryChanRangeReplies; i++ { + reply := &lnwire.ReplyChannelRange{ + QueryChannelRange: *query, + ShortChanIDs: []lnwire.ShortChannelID{ + { + BlockHeight: query.FirstBlockHeight + i, + }, + }, + } + reply.FirstBlockHeight = query.FirstBlockHeight + i + reply.NumBlocks = 1 + require.NoError(t, syncer.ProcessQueryMsg(reply, nil)) + } + + // We should receive a filter request for the syncer's local channels + // after processing all of the replies. We'll send back a nil response + // indicating that no new channels need to be synced, so it should + // transition to its final chansSynced state. + select { + case <-chanSeries.filterReq: + case <-time.After(time.Second): + t.Fatal("expected local filter request of known channels") + } + select { + case chanSeries.filterResp <- nil: + case <-time.After(time.Second): + t.Fatal("timed out sending filter response") + } + assertSyncerStatus(t, syncer, chansSynced, ActiveSync) + + // Finally, attempting to process another reply for the same query + // should result in an error. + require.Error(t, syncer.ProcessQueryMsg(&lnwire.ReplyChannelRange{ + QueryChannelRange: *query, + ShortChanIDs: []lnwire.ShortChannelID{ + { + BlockHeight: query.LastBlockHeight() + 1, + }, + }, + }, nil)) +}