Skip to content

Commit

Permalink
fix(p2p/session): return err if peer tracker is empty
Browse files Browse the repository at this point in the history
  • Loading branch information
vgonkivs committed Jul 10, 2024
1 parent b80ef73 commit c88c554
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 59 deletions.
12 changes: 9 additions & 3 deletions p2p/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,11 @@ func (ex *Exchange[H]) Head(ctx context.Context, opts ...header.HeadOption[H]) (
// their Head and verify against the given trusted header.
useTrackedPeers := !reqParams.TrustedHead.IsZero()
if useTrackedPeers {
trackedPeers := ex.peerTracker.getPeers(maxUntrustedHeadRequests)
trackedPeers := ex.peerTracker.peers(maxUntrustedHeadRequests)
if len(trackedPeers) > 0 {
peers = trackedPeers
peers = transform(trackedPeers, func(p *peerStat) peer.ID {
return p.peerID
})
log.Debugw("requesting head from tracked peers", "amount", len(peers))
}
}
Expand Down Expand Up @@ -292,9 +294,13 @@ func (ex *Exchange[H]) GetRangeByHeight(
attribute.Int64("to", int64(to)),
))
defer span.End()
session := newSession[H](
session, err := newSession[H](
ex.ctx, ex.host, ex.peerTracker, ex.protocolID, ex.Params.RequestTimeout, ex.metrics, withValidation(from),
)
// TODO(@vgonkivs): decide what to do with this error. Maybe we should fall into "discovery mode" and try to collect peers???
if err != nil {
return nil, err
}
defer session.close()
// we request the next header height that we don't have: `fromHead`+1
amount := to - (from.Height() + 1)
Expand Down
10 changes: 10 additions & 0 deletions p2p/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,13 @@ func convertStatusCodeToError(code p2p_pb.StatusCode) error {
return fmt.Errorf("unknown status code %d", code)
}
}

// transform applies a provided function to each element of the input slice,
// producing a new slice with the results of the function.
func transform[T, U any](ts []T, f func(T) U) []U {
us := make([]U, len(ts))
for i := range ts {
us[i] = f(ts[i])
}
return us
}
80 changes: 36 additions & 44 deletions p2p/peer_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,22 @@ import (
)

type peerTracker struct {
host host.Host
connGater *conngater.BasicConnectionGater
metrics *exchangeMetrics
protocolID protocol.ID
peerLk sync.RWMutex

host host.Host
connGater *conngater.BasicConnectionGater

peerLk sync.RWMutex
// trackedPeers contains active peers that we can request to.
// we cache the peer once they disconnect,
// so we can guarantee that peerQueue will only contain active peers
trackedPeers map[libpeer.ID]struct{}

// an optional interface used to periodically dump
// good peers during garbage collection
pidstore PeerIDStore

metrics *exchangeMetrics

ctx context.Context
cancel context.CancelFunc
// done is used to gracefully stop the peerTracker.
Expand Down Expand Up @@ -103,19 +105,20 @@ func (p *peerTracker) track() {
p.done <- struct{}{}
}()

connSubs, err := p.host.EventBus().Subscribe(&event.EvtPeerConnectednessChanged{})
evtBus := p.host.EventBus()
connSubs, err := evtBus.Subscribe(&event.EvtPeerConnectednessChanged{})
if err != nil {
log.Errorw("subscribing to EvtPeerConnectednessChanged", "err", err)
return
}

identifySub, err := p.host.EventBus().Subscribe(&event.EvtPeerIdentificationCompleted{})
identifySub, err := evtBus.Subscribe(&event.EvtPeerIdentificationCompleted{})
if err != nil {
log.Errorw("subscribing to EvtPeerIdentificationCompleted", "err", err)
return
}

protocolSub, err := p.host.EventBus().Subscribe(&event.EvtPeerProtocolsUpdated{})
protocolSub, err := evtBus.Subscribe(&event.EvtPeerProtocolsUpdated{})
if err != nil {
log.Errorw("subscribing to EvtPeerProtocolsUpdated", "err", err)
return
Expand All @@ -124,9 +127,7 @@ func (p *peerTracker) track() {
for {
select {
case <-p.ctx.Done():
err = connSubs.Close()
errors.Join(err, identifySub.Close(), protocolSub.Close())
if err != nil {
if err := closeSubscriptions(connSubs, identifySub, protocolSub); err != nil {
log.Errorw("closing subscriptions", "err", err)
}
return
Expand All @@ -135,35 +136,23 @@ func (p *peerTracker) track() {
if network.NotConnected == ev.Connectedness {
p.disconnected(ev.Peer)
}
case subscription := <-identifySub.Out():
ev := subscription.(event.EvtPeerIdentificationCompleted)
p.connected(ev.Peer)
case subscription := <-protocolSub.Out():
ev := subscription.(event.EvtPeerProtocolsUpdated)
case identSubscription := <-identifySub.Out():
ev := identSubscription.(event.EvtPeerIdentificationCompleted)
if slices.Contains(ev.Protocols, p.protocolID) {
p.connected(ev.Peer)
}
case protocolSubscription := <-protocolSub.Out():
ev := protocolSubscription.(event.EvtPeerProtocolsUpdated)
if slices.Contains(ev.Removed, p.protocolID) {
p.disconnected(ev.Peer)
break
}
p.connected(ev.Peer)
if slices.Contains(ev.Added, p.protocolID) {
p.connected(ev.Peer)
}
}
}
}

// getPeers returns the tracker's currently tracked peers up to the `max`.
func (p *peerTracker) getPeers(max int) []libpeer.ID {
p.peerLk.RLock()
defer p.peerLk.RUnlock()

peers := make([]libpeer.ID, 0, max)
for peer := range p.trackedPeers {
peers = append(peers, peer)
if len(peers) == max {
break
}
}
return peers
}

func (p *peerTracker) connected(pID libpeer.ID) {
if err := pID.Validate(); err != nil {
return
Expand All @@ -173,15 +162,6 @@ func (p *peerTracker) connected(pID libpeer.ID) {
return
}

// check that peer supports our protocol id.
protocol, err := p.host.Peerstore().SupportsProtocols(pID, p.protocolID)
if err != nil {
return
}
if !slices.Contains(protocol, p.protocolID) {
return
}

for _, c := range p.host.Network().ConnsToPeer(pID) {
// check if connection is short-termed and skip this peer
if c.Stat().Limited {
Expand Down Expand Up @@ -219,17 +199,21 @@ func (p *peerTracker) disconnected(pID libpeer.ID) {
p.metrics.peersDisconnected(1)
}

func (p *peerTracker) peers() []*peerStat {
// peers returns the tracker's currently tracked peers up to the `max`.
func (p *peerTracker) peers(max int) []*peerStat {
p.peerLk.RLock()
defer p.peerLk.RUnlock()

peers := make([]*peerStat, 0)
peers := make([]*peerStat, 0, max)
for peerID := range p.trackedPeers {
score := 0
if info := p.host.ConnManager().GetTagInfo(peerID); info != nil {
score = info.Tags[string(p.protocolID)]
}
peers = append(peers, &peerStat{peerID: peerID, peerScore: score})
if len(peers) == max {
break
}
}
return peers
}
Expand Down Expand Up @@ -300,3 +284,11 @@ func (p *peerTracker) updateScore(stats *peerStat, size uint64, duration time.Du
score := stats.updateStats(size, duration)
p.host.ConnManager().TagPeer(stats.peerID, string(p.protocolID), score)
}

func closeSubscriptions(subs ...event.Subscription) error {
var err error
for _, sub := range subs {
err = errors.Join(err, sub.Close())
}
return err
}
2 changes: 1 addition & 1 deletion p2p/peer_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestPeerTracker_Bootstrap(t *testing.T) {
require.NoError(t, err)

assert.Eventually(t, func() bool {
return len(tracker.getPeers(7)) > 0
return len(tracker.peers(7)) > 0
}, time.Millisecond*500, time.Millisecond*100)
}

Expand Down
11 changes: 8 additions & 3 deletions p2p/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,28 @@ func newSession[H header.Header[H]](
requestTimeout time.Duration,
metrics *exchangeMetrics,
options ...option[H],
) *session[H] {
) (*session[H], error) {
ctx, cancel := context.WithCancel(ctx)
ses := &session[H]{
ctx: ctx,
cancel: cancel,
protocolID: protocolID,
host: h,
queue: newPeerQueue(ctx, peerTracker.peers()),
peerTracker: peerTracker,
requestTimeout: requestTimeout,
metrics: metrics,
}

peers := peerTracker.peers(len(peerTracker.trackedPeers))
if len(peers) == 0 {
return nil, errors.New("empty peer tracker")
}
ses.queue = newPeerQueue(ctx, peers)

for _, opt := range options {
opt(ses)
}
return ses
return ses, nil
}

// getRangeByHeight requests headers from different peers.
Expand Down
28 changes: 20 additions & 8 deletions p2p/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"time"

"github.com/libp2p/go-libp2p/core/peer"
blankhost "github.com/libp2p/go-libp2p/p2p/host/blank"
swarm "github.com/libp2p/go-libp2p/p2p/net/swarm/testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand All @@ -25,34 +27,44 @@ func Test_PrepareRequests(t *testing.T) {
func Test_Validate(t *testing.T) {
suite := headertest.NewTestSuite(t)
head := suite.Head()
ses := newSession(
peerId := peer.ID("test")
pT := &peerTracker{trackedPeers: make(map[peer.ID]struct{})}
pT.trackedPeers[peerId] = struct{}{}
pT.host = blankhost.NewBlankHost(swarm.GenSwarm(t))
ses, err := newSession(
context.Background(),
nil,
&peerTracker{trackedPeers: make(map[peer.ID]struct{})},
pT,
"", time.Second, nil,
withValidation(head),
)

require.NoError(t, err)
headers := suite.GenDummyHeaders(5)
err := ses.verify(headers)
err = ses.verify(headers)
assert.NoError(t, err)
}

// Test_ValidateFails ensures that non-adjacent range will return an error.
func Test_ValidateFails(t *testing.T) {
suite := headertest.NewTestSuite(t)
head := suite.Head()
ses := newSession(

peerId := peer.ID("test")
pT := &peerTracker{trackedPeers: make(map[peer.ID]struct{})}
pT.trackedPeers[peerId] = struct{}{}
pT.host = blankhost.NewBlankHost(swarm.GenSwarm(t))
ses, err := newSession(
context.Background(),
nil,
&peerTracker{trackedPeers: make(map[peer.ID]struct{})},
blankhost.NewBlankHost(swarm.GenSwarm(t)),
pT,
"", time.Second, nil,
withValidation(head),
)

require.NoError(t, err)
headers := suite.GenDummyHeaders(5)
// break adjacency
headers[2] = headers[4]
err := ses.verify(headers)
err = ses.verify(headers)
assert.Error(t, err)
}

0 comments on commit c88c554

Please sign in to comment.