diff --git a/p2p/exchange.go b/p2p/exchange.go index 9f0cf9de..834a3b5b 100644 --- a/p2p/exchange.go +++ b/p2p/exchange.go @@ -150,9 +150,11 @@ func (ex *Exchange[H]) Head(ctx context.Context, opts ...header.HeadOption[H]) ( // their Head and verify against the given trusted header. useTrackedPeers := !reqParams.TrustedHead.IsZero() if useTrackedPeers { - trackedPeers := ex.peerTracker.getPeers(maxUntrustedHeadRequests) + trackedPeers := ex.peerTracker.peers(maxUntrustedHeadRequests) if len(trackedPeers) > 0 { - peers = trackedPeers + peers = transform(trackedPeers, func(p *peerStat) peer.ID { + return p.peerID + }) log.Debugw("requesting head from tracked peers", "amount", len(peers)) } } @@ -292,9 +294,13 @@ func (ex *Exchange[H]) GetRangeByHeight( attribute.Int64("to", int64(to)), )) defer span.End() - session := newSession[H]( + session, err := newSession[H]( ex.ctx, ex.host, ex.peerTracker, ex.protocolID, ex.Params.RequestTimeout, ex.metrics, withValidation(from), ) + // TODO(@vgonkivs): decide what to do with this error. Maybe we should fall into "discovery mode" and try to collect peers??? + if err != nil { + return nil, err + } defer session.close() // we request the next header height that we don't have: `fromHead`+1 amount := to - (from.Height() + 1) diff --git a/p2p/helpers.go b/p2p/helpers.go index dc8cc1c5..637fa961 100644 --- a/p2p/helpers.go +++ b/p2p/helpers.go @@ -124,3 +124,13 @@ func convertStatusCodeToError(code p2p_pb.StatusCode) error { return fmt.Errorf("unknown status code %d", code) } } + +// transform applies a provided function to each element of the input slice, +// producing a new slice with the results of the function. +func transform[T, U any](ts []T, f func(T) U) []U { + us := make([]U, len(ts)) + for i := range ts { + us[i] = f(ts[i]) + } + return us +} diff --git a/p2p/peer_tracker.go b/p2p/peer_tracker.go index f7b9ad27..294ade03 100644 --- a/p2p/peer_tracker.go +++ b/p2p/peer_tracker.go @@ -16,13 +16,13 @@ import ( ) type peerTracker struct { - host host.Host - connGater *conngater.BasicConnectionGater - metrics *exchangeMetrics protocolID protocol.ID - peerLk sync.RWMutex + + host host.Host + connGater *conngater.BasicConnectionGater + + peerLk sync.RWMutex // trackedPeers contains active peers that we can request to. - // we cache the peer once they disconnect, // so we can guarantee that peerQueue will only contain active peers trackedPeers map[libpeer.ID]struct{} @@ -30,6 +30,8 @@ type peerTracker struct { // good peers during garbage collection pidstore PeerIDStore + metrics *exchangeMetrics + ctx context.Context cancel context.CancelFunc // done is used to gracefully stop the peerTracker. @@ -103,19 +105,20 @@ func (p *peerTracker) track() { p.done <- struct{}{} }() - connSubs, err := p.host.EventBus().Subscribe(&event.EvtPeerConnectednessChanged{}) + evtBus := p.host.EventBus() + connSubs, err := evtBus.Subscribe(&event.EvtPeerConnectednessChanged{}) if err != nil { log.Errorw("subscribing to EvtPeerConnectednessChanged", "err", err) return } - identifySub, err := p.host.EventBus().Subscribe(&event.EvtPeerIdentificationCompleted{}) + identifySub, err := evtBus.Subscribe(&event.EvtPeerIdentificationCompleted{}) if err != nil { log.Errorw("subscribing to EvtPeerIdentificationCompleted", "err", err) return } - protocolSub, err := p.host.EventBus().Subscribe(&event.EvtPeerProtocolsUpdated{}) + protocolSub, err := evtBus.Subscribe(&event.EvtPeerProtocolsUpdated{}) if err != nil { log.Errorw("subscribing to EvtPeerProtocolsUpdated", "err", err) return @@ -124,9 +127,7 @@ func (p *peerTracker) track() { for { select { case <-p.ctx.Done(): - err = connSubs.Close() - errors.Join(err, identifySub.Close(), protocolSub.Close()) - if err != nil { + if err := closeSubscriptions(connSubs, identifySub, protocolSub); err != nil { log.Errorw("closing subscriptions", "err", err) } return @@ -135,35 +136,23 @@ func (p *peerTracker) track() { if network.NotConnected == ev.Connectedness { p.disconnected(ev.Peer) } - case subscription := <-identifySub.Out(): - ev := subscription.(event.EvtPeerIdentificationCompleted) - p.connected(ev.Peer) - case subscription := <-protocolSub.Out(): - ev := subscription.(event.EvtPeerProtocolsUpdated) + case identSubscription := <-identifySub.Out(): + ev := identSubscription.(event.EvtPeerIdentificationCompleted) + if slices.Contains(ev.Protocols, p.protocolID) { + p.connected(ev.Peer) + } + case protocolSubscription := <-protocolSub.Out(): + ev := protocolSubscription.(event.EvtPeerProtocolsUpdated) if slices.Contains(ev.Removed, p.protocolID) { p.disconnected(ev.Peer) - break } - p.connected(ev.Peer) + if slices.Contains(ev.Added, p.protocolID) { + p.connected(ev.Peer) + } } } } -// getPeers returns the tracker's currently tracked peers up to the `max`. -func (p *peerTracker) getPeers(max int) []libpeer.ID { - p.peerLk.RLock() - defer p.peerLk.RUnlock() - - peers := make([]libpeer.ID, 0, max) - for peer := range p.trackedPeers { - peers = append(peers, peer) - if len(peers) == max { - break - } - } - return peers -} - func (p *peerTracker) connected(pID libpeer.ID) { if err := pID.Validate(); err != nil { return @@ -173,15 +162,6 @@ func (p *peerTracker) connected(pID libpeer.ID) { return } - // check that peer supports our protocol id. - protocol, err := p.host.Peerstore().SupportsProtocols(pID, p.protocolID) - if err != nil { - return - } - if !slices.Contains(protocol, p.protocolID) { - return - } - for _, c := range p.host.Network().ConnsToPeer(pID) { // check if connection is short-termed and skip this peer if c.Stat().Limited { @@ -219,17 +199,21 @@ func (p *peerTracker) disconnected(pID libpeer.ID) { p.metrics.peersDisconnected(1) } -func (p *peerTracker) peers() []*peerStat { +// peers returns the tracker's currently tracked peers up to the `max`. +func (p *peerTracker) peers(max int) []*peerStat { p.peerLk.RLock() defer p.peerLk.RUnlock() - peers := make([]*peerStat, 0) + peers := make([]*peerStat, 0, max) for peerID := range p.trackedPeers { score := 0 if info := p.host.ConnManager().GetTagInfo(peerID); info != nil { score = info.Tags[string(p.protocolID)] } peers = append(peers, &peerStat{peerID: peerID, peerScore: score}) + if len(peers) == max { + break + } } return peers } @@ -300,3 +284,11 @@ func (p *peerTracker) updateScore(stats *peerStat, size uint64, duration time.Du score := stats.updateStats(size, duration) p.host.ConnManager().TagPeer(stats.peerID, string(p.protocolID), score) } + +func closeSubscriptions(subs ...event.Subscription) error { + var err error + for _, sub := range subs { + err = errors.Join(err, sub.Close()) + } + return err +} diff --git a/p2p/peer_tracker_test.go b/p2p/peer_tracker_test.go index 41796c8a..7151e086 100644 --- a/p2p/peer_tracker_test.go +++ b/p2p/peer_tracker_test.go @@ -62,7 +62,7 @@ func TestPeerTracker_Bootstrap(t *testing.T) { require.NoError(t, err) assert.Eventually(t, func() bool { - return len(tracker.getPeers(7)) > 0 + return len(tracker.peers(7)) > 0 }, time.Millisecond*500, time.Millisecond*100) } diff --git a/p2p/session.go b/p2p/session.go index 2e8594d2..dc297d9d 100644 --- a/p2p/session.go +++ b/p2p/session.go @@ -60,23 +60,28 @@ func newSession[H header.Header[H]]( requestTimeout time.Duration, metrics *exchangeMetrics, options ...option[H], -) *session[H] { +) (*session[H], error) { ctx, cancel := context.WithCancel(ctx) ses := &session[H]{ ctx: ctx, cancel: cancel, protocolID: protocolID, host: h, - queue: newPeerQueue(ctx, peerTracker.peers()), peerTracker: peerTracker, requestTimeout: requestTimeout, metrics: metrics, } + peers := peerTracker.peers(len(peerTracker.trackedPeers)) + if len(peers) == 0 { + return nil, errors.New("empty peer tracker") + } + ses.queue = newPeerQueue(ctx, peers) + for _, opt := range options { opt(ses) } - return ses + return ses, nil } // getRangeByHeight requests headers from different peers. diff --git a/p2p/session_test.go b/p2p/session_test.go index 7c8599f5..8aa1776a 100644 --- a/p2p/session_test.go +++ b/p2p/session_test.go @@ -6,6 +6,8 @@ import ( "time" "github.com/libp2p/go-libp2p/core/peer" + blankhost "github.com/libp2p/go-libp2p/p2p/host/blank" + swarm "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -25,16 +27,21 @@ func Test_PrepareRequests(t *testing.T) { func Test_Validate(t *testing.T) { suite := headertest.NewTestSuite(t) head := suite.Head() - ses := newSession( + peerId := peer.ID("test") + pT := &peerTracker{trackedPeers: make(map[peer.ID]struct{})} + pT.trackedPeers[peerId] = struct{}{} + pT.host = blankhost.NewBlankHost(swarm.GenSwarm(t)) + ses, err := newSession( context.Background(), nil, - &peerTracker{trackedPeers: make(map[peer.ID]struct{})}, + pT, "", time.Second, nil, withValidation(head), ) + require.NoError(t, err) headers := suite.GenDummyHeaders(5) - err := ses.verify(headers) + err = ses.verify(headers) assert.NoError(t, err) } @@ -42,17 +49,22 @@ func Test_Validate(t *testing.T) { func Test_ValidateFails(t *testing.T) { suite := headertest.NewTestSuite(t) head := suite.Head() - ses := newSession( + + peerId := peer.ID("test") + pT := &peerTracker{trackedPeers: make(map[peer.ID]struct{})} + pT.trackedPeers[peerId] = struct{}{} + pT.host = blankhost.NewBlankHost(swarm.GenSwarm(t)) + ses, err := newSession( context.Background(), - nil, - &peerTracker{trackedPeers: make(map[peer.ID]struct{})}, + blankhost.NewBlankHost(swarm.GenSwarm(t)), + pT, "", time.Second, nil, withValidation(head), ) - + require.NoError(t, err) headers := suite.GenDummyHeaders(5) // break adjacency headers[2] = headers[4] - err := ses.verify(headers) + err = ses.verify(headers) assert.Error(t, err) }