Skip to content

Commit

Permalink
misc(p2p/peerTracker): extend conditions for peers handling
Browse files Browse the repository at this point in the history
  • Loading branch information
vgonkivs committed Mar 15, 2024
1 parent 0c5d8c2 commit df866eb
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 19 deletions.
7 changes: 4 additions & 3 deletions p2p/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ func NewExchange[H header.Header[H]](
}
}

id := protocolID(params.networkID)
ex := &Exchange[H]{
host: host,
protocolID: protocolID(params.networkID),
peerTracker: newPeerTracker(host, gater, params.pidstore, metrics),
protocolID: id,
peerTracker: newPeerTracker(host, gater, id, params.pidstore, metrics),
Params: params,
metrics: metrics,
}
Expand Down Expand Up @@ -172,7 +173,7 @@ func (ex *Exchange[H]) Head(ctx context.Context, opts ...header.HeadOption[H]) (
trace.WithAttributes(attribute.String("peerID", from.String())),
)
defer newSpan.End()

headers, err := ex.request(reqCtx, from, headerReq)
if err != nil {
newSpan.SetStatus(codes.Error, err.Error())
Expand Down
85 changes: 73 additions & 12 deletions p2p/peer_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ package p2p

import (
"context"
"errors"
"slices"
"sync"
"time"

"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
libpeer "github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/p2p/net/conngater"
)

Expand Down Expand Up @@ -37,6 +40,8 @@ type peerTracker struct {
// online until pruneDeadline, it will be removed and its score will be lost
disconnectedPeers map[libpeer.ID]*peerStat

protocolID protocol.ID

// an optional interface used to periodically dump
// good peers during garbage collection
pidstore PeerIDStore
Expand All @@ -51,13 +56,15 @@ type peerTracker struct {
func newPeerTracker(
h host.Host,
connGater *conngater.BasicConnectionGater,
protocolID protocol.ID,
pidstore PeerIDStore,
metrics *exchangeMetrics,
) *peerTracker {
ctx, cancel := context.WithCancel(context.Background())
return &peerTracker{
host: h,
connGater: connGater,
protocolID: protocolID,
metrics: metrics,
trackedPeers: make(map[libpeer.ID]*peerStat),
disconnectedPeers: make(map[libpeer.ID]*peerStat),
Expand Down Expand Up @@ -105,7 +112,16 @@ func (p *peerTracker) bootstrap(ctx context.Context, trusted []libpeer.ID) error

// connectToPeer attempts to connect to the given peer.
func (p *peerTracker) connectToPeer(ctx context.Context, peer libpeer.ID) {
err := p.host.Connect(ctx, p.host.Peerstore().PeerInfo(peer))
// check that peer supports our protocol id.
protocol, err := p.host.Peerstore().SupportsProtocols(peer, p.protocolID)
if err != nil {
return
}
if !slices.Contains(protocol, p.protocolID) {
return
}

err = p.host.Connect(ctx, p.host.Peerstore().PeerInfo(peer))
if err != nil {
log.Debugw("failed to connect to peer", "id", peer.String(), "err", err)
return
Expand All @@ -123,27 +139,50 @@ func (p *peerTracker) track() {
p.connected(c.RemotePeer())
}

subs, err := p.host.EventBus().Subscribe(&event.EvtPeerConnectednessChanged{})
connSubs, err := p.host.EventBus().Subscribe(&event.EvtPeerConnectednessChanged{})
if err != nil {
log.Errorw("subscribing to EvtPeerConnectednessChanged", "err", err)
return
}

identifySub, err := p.host.EventBus().Subscribe(&event.EvtPeerIdentificationCompleted{})
if err != nil {
log.Errorw("subscribing to EvtPeerIdentificationCompleted", "err", err)
return
}

protocolSub, err := p.host.EventBus().Subscribe(&event.EvtPeerProtocolsUpdated{})
if err != nil {
log.Errorw("subscribing to EvtPeerProtocolsUpdated", "err", err)
return
}

for {
select {
case <-p.ctx.Done():
err = subs.Close()
err = connSubs.Close()
errors.Join(err, identifySub.Close(), protocolSub.Close())
if err != nil {
log.Errorw("closing subscription", "err", err)
log.Errorw("closing subscriptions", "err", err)
}
return
case subscription := <-subs.Out():
ev := subscription.(event.EvtPeerConnectednessChanged)
switch ev.Connectedness {
case network.Connected:
p.connected(ev.Peer)
case network.NotConnected:
case connSubscription := <-connSubs.Out():
ev := connSubscription.(event.EvtPeerConnectednessChanged)
if network.NotConnected == ev.Connectedness {
p.disconnected(ev.Peer)
}
case subscription := <-identifySub.Out():
ev := subscription.(event.EvtPeerIdentificationCompleted)
p.connected(ev.Peer)
case subscription := <-protocolSub.Out():
ev := subscription.(event.EvtPeerProtocolsUpdated)
if slices.Contains(ev.Removed, p.protocolID) {
p.disconnected(ev.Peer)
break
}

if slices.Contains(ev.Added, p.protocolID) {
p.connected(ev.Peer)
}
}
}
Expand All @@ -165,6 +204,9 @@ func (p *peerTracker) getPeers(max int) []libpeer.ID {
}

func (p *peerTracker) connected(pID libpeer.ID) {
if err := pID.Validate(); err != nil {
return
}
if p.host.ID() == pID {
return
}
Expand All @@ -176,11 +218,22 @@ func (p *peerTracker) connected(pID libpeer.ID) {
}
}

// check that peer supports our protocol id.
protocol, err := p.host.Peerstore().SupportsProtocols(pID, p.protocolID)
if err != nil {
return
}
if !slices.Contains(protocol, p.protocolID) {
return
}

p.peerLk.Lock()
defer p.peerLk.Unlock()

// additional check in p.trackedPeers should be done,
// because libp2p does not emit multiple Connected events per 1 peer
if _, ok := p.trackedPeers[pID]; ok {
return
}

stats, ok := p.disconnectedPeers[pID]
if !ok {
stats = &peerStat{peerID: pID, peerScore: defaultScore}
Expand All @@ -193,6 +246,10 @@ func (p *peerTracker) connected(pID libpeer.ID) {
}

func (p *peerTracker) disconnected(pID libpeer.ID) {
if err := pID.Validate(); err != nil {
return
}

p.peerLk.Lock()
defer p.peerLk.Unlock()
stats, ok := p.trackedPeers[pID]
Expand Down Expand Up @@ -295,6 +352,10 @@ func (p *peerTracker) stop(ctx context.Context) error {

// blockPeer blocks a peer on the networking level and removes it from the local cache.
func (p *peerTracker) blockPeer(pID libpeer.ID, reason error) {
if err := pID.Validate(); err != nil {
return
}

// add peer to the blacklist, so we can't connect to it in the future.
err := p.connGater.BlockPeer(pID)
if err != nil {
Expand Down
7 changes: 3 additions & 4 deletions p2p/peer_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestPeerTracker_GC(t *testing.T) {
require.NoError(t, err)

pidstore := newDummyPIDStore()
p := newPeerTracker(h[0], connGater, pidstore, nil)
p := newPeerTracker(h[0], connGater, protocolID("private"), pidstore, nil)

maxAwaitingTime = time.Millisecond

Expand Down Expand Up @@ -68,7 +68,7 @@ func TestPeerTracker_BlockPeer(t *testing.T) {
h := createMocknet(t, 2)
connGater, err := conngater.NewBasicConnectionGater(sync.MutexWrap(datastore.NewMapDatastore()))
require.NoError(t, err)
p := newPeerTracker(h[0], connGater, nil, nil)
p := newPeerTracker(h[0], connGater, protocolID("private"), nil, nil)
maxAwaitingTime = time.Millisecond
p.blockPeer(h[1].ID(), errors.New("test"))
require.Len(t, connGater.ListBlockedPeers(), 1)
Expand All @@ -82,7 +82,6 @@ func TestPeerTracker_Bootstrap(t *testing.T) {
connGater, err := conngater.NewBasicConnectionGater(sync.MutexWrap(datastore.NewMapDatastore()))
require.NoError(t, err)

// mn := createMocknet(t, 10)
mn, err := mocknet.FullMeshConnected(10)
require.NoError(t, err)

Expand All @@ -101,7 +100,7 @@ func TestPeerTracker_Bootstrap(t *testing.T) {
err = pidstore.Put(ctx, prevSeen[2:])
require.NoError(t, err)

tracker := newPeerTracker(mn.Hosts()[0], connGater, pidstore, nil)
tracker := newPeerTracker(mn.Hosts()[0], connGater, protocolID("private"), pidstore, nil)

go tracker.track()

Expand Down

0 comments on commit df866eb

Please sign in to comment.