diff --git a/p2p/exchange.go b/p2p/exchange.go index 9f0cf9de..ad923061 100644 --- a/p2p/exchange.go +++ b/p2p/exchange.go @@ -292,9 +292,13 @@ func (ex *Exchange[H]) GetRangeByHeight( attribute.Int64("to", int64(to)), )) defer span.End() - session := newSession[H]( + session, err := newSession[H]( ex.ctx, ex.host, ex.peerTracker, ex.protocolID, ex.Params.RequestTimeout, ex.metrics, withValidation(from), ) + // TODO(@vgonkivs): decide what to do with this error. Maybe we should fall into "discovery mode" and try to collect peers??? + if err != nil { + return nil, err + } defer session.close() // we request the next header height that we don't have: `fromHead`+1 amount := to - (from.Height() + 1) diff --git a/p2p/peer_tracker.go b/p2p/peer_tracker.go index f7b9ad27..7c83aeed 100644 --- a/p2p/peer_tracker.go +++ b/p2p/peer_tracker.go @@ -16,11 +16,12 @@ import ( ) type peerTracker struct { - host host.Host - connGater *conngater.BasicConnectionGater - metrics *exchangeMetrics protocolID protocol.ID - peerLk sync.RWMutex + + host host.Host + connGater *conngater.BasicConnectionGater + + peerLk sync.RWMutex // trackedPeers contains active peers that we can request to. // we cache the peer once they disconnect, // so we can guarantee that peerQueue will only contain active peers @@ -30,6 +31,8 @@ type peerTracker struct { // good peers during garbage collection pidstore PeerIDStore + metrics *exchangeMetrics + ctx context.Context cancel context.CancelFunc // done is used to gracefully stop the peerTracker. @@ -135,11 +138,13 @@ func (p *peerTracker) track() { if network.NotConnected == ev.Connectedness { p.disconnected(ev.Peer) } - case subscription := <-identifySub.Out(): - ev := subscription.(event.EvtPeerIdentificationCompleted) - p.connected(ev.Peer) - case subscription := <-protocolSub.Out(): - ev := subscription.(event.EvtPeerProtocolsUpdated) + case identSubscription := <-identifySub.Out(): + ev := identSubscription.(event.EvtPeerIdentificationCompleted) + if slices.Contains(ev.Protocols, p.protocolID) { + p.connected(ev.Peer) + } + case protocolSubscription := <-protocolSub.Out(): + ev := protocolSubscription.(event.EvtPeerProtocolsUpdated) if slices.Contains(ev.Removed, p.protocolID) { p.disconnected(ev.Peer) break @@ -173,15 +178,6 @@ func (p *peerTracker) connected(pID libpeer.ID) { return } - // check that peer supports our protocol id. - protocol, err := p.host.Peerstore().SupportsProtocols(pID, p.protocolID) - if err != nil { - return - } - if !slices.Contains(protocol, p.protocolID) { - return - } - for _, c := range p.host.Network().ConnsToPeer(pID) { // check if connection is short-termed and skip this peer if c.Stat().Limited { diff --git a/p2p/session.go b/p2p/session.go index 2e8594d2..32facb0a 100644 --- a/p2p/session.go +++ b/p2p/session.go @@ -60,23 +60,28 @@ func newSession[H header.Header[H]]( requestTimeout time.Duration, metrics *exchangeMetrics, options ...option[H], -) *session[H] { +) (*session[H], error) { ctx, cancel := context.WithCancel(ctx) ses := &session[H]{ ctx: ctx, cancel: cancel, protocolID: protocolID, host: h, - queue: newPeerQueue(ctx, peerTracker.peers()), peerTracker: peerTracker, requestTimeout: requestTimeout, metrics: metrics, } + peers := peerTracker.peers() + if len(peers) == 0 { + return nil, errors.New("empty peer tracker") + } + ses.queue = newPeerQueue(ctx, peers) + for _, opt := range options { opt(ses) } - return ses + return ses, nil } // getRangeByHeight requests headers from different peers. diff --git a/p2p/session_test.go b/p2p/session_test.go index 7c8599f5..8aa1776a 100644 --- a/p2p/session_test.go +++ b/p2p/session_test.go @@ -6,6 +6,8 @@ import ( "time" "github.com/libp2p/go-libp2p/core/peer" + blankhost "github.com/libp2p/go-libp2p/p2p/host/blank" + swarm "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -25,16 +27,21 @@ func Test_PrepareRequests(t *testing.T) { func Test_Validate(t *testing.T) { suite := headertest.NewTestSuite(t) head := suite.Head() - ses := newSession( + peerId := peer.ID("test") + pT := &peerTracker{trackedPeers: make(map[peer.ID]struct{})} + pT.trackedPeers[peerId] = struct{}{} + pT.host = blankhost.NewBlankHost(swarm.GenSwarm(t)) + ses, err := newSession( context.Background(), nil, - &peerTracker{trackedPeers: make(map[peer.ID]struct{})}, + pT, "", time.Second, nil, withValidation(head), ) + require.NoError(t, err) headers := suite.GenDummyHeaders(5) - err := ses.verify(headers) + err = ses.verify(headers) assert.NoError(t, err) } @@ -42,17 +49,22 @@ func Test_Validate(t *testing.T) { func Test_ValidateFails(t *testing.T) { suite := headertest.NewTestSuite(t) head := suite.Head() - ses := newSession( + + peerId := peer.ID("test") + pT := &peerTracker{trackedPeers: make(map[peer.ID]struct{})} + pT.trackedPeers[peerId] = struct{}{} + pT.host = blankhost.NewBlankHost(swarm.GenSwarm(t)) + ses, err := newSession( context.Background(), - nil, - &peerTracker{trackedPeers: make(map[peer.ID]struct{})}, + blankhost.NewBlankHost(swarm.GenSwarm(t)), + pT, "", time.Second, nil, withValidation(head), ) - + require.NoError(t, err) headers := suite.GenDummyHeaders(5) // break adjacency headers[2] = headers[4] - err := ses.verify(headers) + err = ses.verify(headers) assert.Error(t, err) }