diff --git a/p2p/exchange.go b/p2p/exchange.go index ad923061..e17795d6 100644 --- a/p2p/exchange.go +++ b/p2p/exchange.go @@ -99,7 +99,10 @@ func (ex *Exchange[H]) Start(ctx context.Context) error { ex.ctx, ex.cancel = context.WithCancel(context.Background()) log.Infow("client: starting client", "protocol ID", ex.protocolID) - go ex.peerTracker.track() + err := ex.peerTracker.track() + if err != nil { + return err + } // bootstrap the peerTracker with trusted peers as well as previously seen // peers if provided. @@ -150,9 +153,11 @@ func (ex *Exchange[H]) Head(ctx context.Context, opts ...header.HeadOption[H]) ( // their Head and verify against the given trusted header. useTrackedPeers := !reqParams.TrustedHead.IsZero() if useTrackedPeers { - trackedPeers := ex.peerTracker.getPeers(maxUntrustedHeadRequests) + trackedPeers := ex.peerTracker.peers(maxUntrustedHeadRequests) if len(trackedPeers) > 0 { - peers = trackedPeers + peers = transform(trackedPeers, func(p *peerStat) peer.ID { + return p.peerID + }) log.Debugw("requesting head from tracked peers", "amount", len(peers)) } } diff --git a/p2p/helpers.go b/p2p/helpers.go index dc8cc1c5..637fa961 100644 --- a/p2p/helpers.go +++ b/p2p/helpers.go @@ -124,3 +124,13 @@ func convertStatusCodeToError(code p2p_pb.StatusCode) error { return fmt.Errorf("unknown status code %d", code) } } + +// transform applies a provided function to each element of the input slice, +// producing a new slice with the results of the function. +func transform[T, U any](ts []T, f func(T) U) []U { + us := make([]U, len(ts)) + for i := range ts { + us[i] = f(ts[i]) + } + return us +} diff --git a/p2p/peer_tracker.go b/p2p/peer_tracker.go index 7c83aeed..f795eb80 100644 --- a/p2p/peer_tracker.go +++ b/p2p/peer_tracker.go @@ -23,7 +23,6 @@ type peerTracker struct { peerLk sync.RWMutex // trackedPeers contains active peers that we can request to. - // we cache the peer once they disconnect, // so we can guarantee that peerQueue will only contain active peers trackedPeers map[libpeer.ID]struct{} @@ -101,72 +100,59 @@ func (p *peerTracker) connectToPeer(ctx context.Context, peer libpeer.ID) { } } -func (p *peerTracker) track() { - defer func() { - p.done <- struct{}{} - }() +// track creates subscriptions for different types of libp2p.Events to efficiently handle peers. +func (p *peerTracker) track() error { + evtBus := p.host.EventBus() - connSubs, err := p.host.EventBus().Subscribe(&event.EvtPeerConnectednessChanged{}) + connSubs, err := evtBus.Subscribe(&event.EvtPeerConnectednessChanged{}) if err != nil { log.Errorw("subscribing to EvtPeerConnectednessChanged", "err", err) - return + return err } - identifySub, err := p.host.EventBus().Subscribe(&event.EvtPeerIdentificationCompleted{}) + identifySub, err := evtBus.Subscribe(&event.EvtPeerIdentificationCompleted{}) if err != nil { log.Errorw("subscribing to EvtPeerIdentificationCompleted", "err", err) - return + return err } - protocolSub, err := p.host.EventBus().Subscribe(&event.EvtPeerProtocolsUpdated{}) + protocolSub, err := evtBus.Subscribe(&event.EvtPeerProtocolsUpdated{}) if err != nil { log.Errorw("subscribing to EvtPeerProtocolsUpdated", "err", err) - return + return err } - for { - select { - case <-p.ctx.Done(): - err = connSubs.Close() - errors.Join(err, identifySub.Close(), protocolSub.Close()) - if err != nil { - log.Errorw("closing subscriptions", "err", err) + go func() { + for { + select { + case <-p.ctx.Done(): + if err := closeSubscriptions(connSubs, identifySub, protocolSub); err != nil { + log.Errorw("closing subscriptions", "err", err) + } + p.done <- struct{}{} + return + case connSubscription := <-connSubs.Out(): + ev := connSubscription.(event.EvtPeerConnectednessChanged) + if network.NotConnected == ev.Connectedness { + p.disconnected(ev.Peer) + } + case identSubscription := <-identifySub.Out(): + ev := identSubscription.(event.EvtPeerIdentificationCompleted) + if slices.Contains(ev.Protocols, p.protocolID) { + p.connected(ev.Peer) + } + case protocolSubscription := <-protocolSub.Out(): + ev := protocolSubscription.(event.EvtPeerProtocolsUpdated) + if slices.Contains(ev.Removed, p.protocolID) { + p.disconnected(ev.Peer) + } + if slices.Contains(ev.Added, p.protocolID) { + p.connected(ev.Peer) + } } - return - case connSubscription := <-connSubs.Out(): - ev := connSubscription.(event.EvtPeerConnectednessChanged) - if network.NotConnected == ev.Connectedness { - p.disconnected(ev.Peer) - } - case identSubscription := <-identifySub.Out(): - ev := identSubscription.(event.EvtPeerIdentificationCompleted) - if slices.Contains(ev.Protocols, p.protocolID) { - p.connected(ev.Peer) - } - case protocolSubscription := <-protocolSub.Out(): - ev := protocolSubscription.(event.EvtPeerProtocolsUpdated) - if slices.Contains(ev.Removed, p.protocolID) { - p.disconnected(ev.Peer) - break - } - p.connected(ev.Peer) - } - } -} - -// getPeers returns the tracker's currently tracked peers up to the `max`. -func (p *peerTracker) getPeers(max int) []libpeer.ID { - p.peerLk.RLock() - defer p.peerLk.RUnlock() - - peers := make([]libpeer.ID, 0, max) - for peer := range p.trackedPeers { - peers = append(peers, peer) - if len(peers) == max { - break } - } - return peers + }() + return nil } func (p *peerTracker) connected(pID libpeer.ID) { @@ -215,17 +201,21 @@ func (p *peerTracker) disconnected(pID libpeer.ID) { p.metrics.peersDisconnected(1) } -func (p *peerTracker) peers() []*peerStat { +// peers returns the tracker's currently tracked peers up to the `max`. +func (p *peerTracker) peers(max int) []*peerStat { p.peerLk.RLock() defer p.peerLk.RUnlock() - peers := make([]*peerStat, 0) + peers := make([]*peerStat, 0, max) for peerID := range p.trackedPeers { score := 0 if info := p.host.ConnManager().GetTagInfo(peerID); info != nil { score = info.Tags[string(p.protocolID)] } peers = append(peers, &peerStat{peerID: peerID, peerScore: score}) + if len(peers) == max { + break + } } return peers } @@ -296,3 +286,11 @@ func (p *peerTracker) updateScore(stats *peerStat, size uint64, duration time.Du score := stats.updateStats(size, duration) p.host.ConnManager().TagPeer(stats.peerID, string(p.protocolID), score) } + +func closeSubscriptions(subs ...event.Subscription) error { + var err error + for _, sub := range subs { + err = errors.Join(err, sub.Close()) + } + return err +} diff --git a/p2p/peer_tracker_test.go b/p2p/peer_tracker_test.go index 41796c8a..dda0fd7e 100644 --- a/p2p/peer_tracker_test.go +++ b/p2p/peer_tracker_test.go @@ -56,13 +56,14 @@ func TestPeerTracker_Bootstrap(t *testing.T) { require.NoError(t, err) tracker := newPeerTracker(hosts[0], connGater, "private", pidstore, nil) - go tracker.track() + err = tracker.track() + require.NoError(t, err) err = tracker.bootstrap(prevSeen[:2]) require.NoError(t, err) assert.Eventually(t, func() bool { - return len(tracker.getPeers(7)) > 0 + return len(tracker.peers(7)) > 0 }, time.Millisecond*500, time.Millisecond*100) } diff --git a/p2p/session.go b/p2p/session.go index 32facb0a..dc297d9d 100644 --- a/p2p/session.go +++ b/p2p/session.go @@ -72,7 +72,7 @@ func newSession[H header.Header[H]]( metrics: metrics, } - peers := peerTracker.peers() + peers := peerTracker.peers(len(peerTracker.trackedPeers)) if len(peers) == 0 { return nil, errors.New("empty peer tracker") }