diff --git a/p2p/peer_tracker.go b/p2p/peer_tracker.go index 33fffc91..f4834aeb 100644 --- a/p2p/peer_tracker.go +++ b/p2p/peer_tracker.go @@ -2,6 +2,7 @@ package p2p import ( "context" + "sort" "sync" "time" @@ -15,8 +16,10 @@ import ( const ( // defaultScore specifies the score for newly connected peers. defaultScore float32 = 1 - // maxTrackerSize specifies the max amount of peers that can be added to the peerTracker. + // maxPeerTrackerSize specifies the max amount of peers that can be added to the peerTracker. maxPeerTrackerSize = 100 + // minPeerTrackerSizeBeforeGC specifies the minimum amount of tracked peers before the peerTracker starts removing peers with lower peer scores. + minPeerTrackerSizeBeforeGC = 10 ) var ( @@ -240,31 +243,53 @@ func (p *peerTracker) gc() { p.done <- struct{}{} return case <-ticker.C: - p.peerLk.Lock() - - now := time.Now() - var deletedDisconnectedNum int - for id, peer := range p.disconnectedPeers { - if peer.pruneDeadline.Before(now) { - delete(p.disconnectedPeers, id) - deletedDisconnectedNum++ - } - } + p.cleanUpDisconnectedPeers() + p.cleanUpTrackedPeers() + p.dumpPeers(p.ctx) + } + } +} - var deletedTrackedNum int - for id, peer := range p.trackedPeers { - if peer.peerScore <= defaultScore { - delete(p.trackedPeers, id) - deletedTrackedNum++ - } - } - p.peerLk.Unlock() +func (p *peerTracker) cleanUpDisconnectedPeers() { + p.peerLk.Lock() + defer p.peerLk.Unlock() - p.metrics.peersDisconnected(-deletedDisconnectedNum) - p.metrics.peersTracked(-deletedTrackedNum) - p.dumpPeers(p.ctx) + now := time.Now() + var deletedDisconnectedNum int + for id, peer := range p.disconnectedPeers { + if peer.pruneDeadline.Before(now) { + delete(p.disconnectedPeers, id) + deletedDisconnectedNum++ + } + } + p.metrics.peersDisconnected(-deletedDisconnectedNum) +} + +func (p *peerTracker) cleanUpTrackedPeers() { + p.peerLk.Lock() + defer p.peerLk.Unlock() + + if len(p.trackedPeers) <= minPeerTrackerSizeBeforeGC { + return + } + + var deletedTrackedNum int + orderedPeers := make([]*peerStat, 0, len(p.trackedPeers)) + for _, peer := range p.trackedPeers { + orderedPeers = append(orderedPeers, peer) + } + sort.Slice(orderedPeers, func(i, j int) bool { + return orderedPeers[i].peerScore < orderedPeers[j].peerScore + }) + + for _, peer := range orderedPeers[:len(orderedPeers)-minPeerTrackerSizeBeforeGC] { + if peer.peerScore > defaultScore { + break } + delete(p.trackedPeers, peer.peerID) + deletedTrackedNum++ } + p.metrics.peersTracked(-deletedTrackedNum) } // dumpPeers stores peers to the peerTracker's PeerIDStore if diff --git a/p2p/peer_tracker_test.go b/p2p/peer_tracker_test.go index 782d3710..3543d985 100644 --- a/p2p/peer_tracker_test.go +++ b/p2p/peer_tracker_test.go @@ -33,7 +33,13 @@ func TestPeerTracker_GC(t *testing.T) { maxAwaitingTime = time.Millisecond - peerlist := generateRandomPeerlist(t, 4) + peerlist := generateRandomPeerlist(t, minPeerTrackerSizeBeforeGC) + for i := 0; i < minPeerTrackerSizeBeforeGC; i++ { + p.trackedPeers[peerlist[i]] = &peerStat{peerID: peerlist[i], peerScore: 0.5} + } + + // add peers to trackedPeers to make total number of peers > maxPeerTrackerSize + peerlist = generateRandomPeerlist(t, 4) pid1 := peerlist[0] pid2 := peerlist[1] pid3 := peerlist[2] @@ -54,13 +60,14 @@ func TestPeerTracker_GC(t *testing.T) { err = p.stop(ctx) require.NoError(t, err) - require.Nil(t, p.trackedPeers[pid1]) + // ensure amount of peers in trackedPeers is equal to minPeerTrackerSizeBeforeGC + require.Len(t, p.trackedPeers, minPeerTrackerSizeBeforeGC) require.Nil(t, p.disconnectedPeers[pid3]) // ensure good peers were dumped to store peers, err := pidstore.Load(ctx) require.NoError(t, err) - assert.Equal(t, 1, len(peers)) + require.Equal(t, minPeerTrackerSizeBeforeGC, len(peers)) } func TestPeerTracker_BlockPeer(t *testing.T) {