diff --git a/p2p/exchange.go b/p2p/exchange.go index c498425e..564f50b2 100644 --- a/p2p/exchange.go +++ b/p2p/exchange.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "golang.org/x/sync/singleflight" "math/rand" "sort" "time" @@ -35,6 +36,10 @@ var ( // chosen. const minHeadResponses = 2 +// syncKeyHeadOrigin represents the origin value used specifically for Head requests, +// serving as a sync key to prevent redundant queries. +const syncKeyHeadOrigin = "0" + // maxUntrustedHeadRequests is the number of head requests to be made to // the network in order to determine the network head. var maxUntrustedHeadRequests = 4 @@ -52,6 +57,8 @@ type Exchange[H header.Header[H]] struct { peerTracker *peerTracker metrics *exchangeMetrics + singleFlight *singleflight.Group + Params ClientParameters } @@ -81,11 +88,12 @@ func NewExchange[H header.Header[H]]( } ex := &Exchange[H]{ - host: host, - protocolID: protocolID(params.networkID), - peerTracker: newPeerTracker(host, gater, params.pidstore, metrics), - Params: params, - metrics: metrics, + host: host, + protocolID: protocolID(params.networkID), + peerTracker: newPeerTracker(host, gater, params.pidstore, metrics), + Params: params, + metrics: metrics, + singleFlight: &singleflight.Group{}, } ex.trustedPeers = func() peer.IDSlice { @@ -124,6 +132,19 @@ func (ex *Exchange[H]) Head(ctx context.Context, opts ...header.HeadOption[H]) ( ctx, span := tracerClient.Start(ctx, "head") defer span.End() + head, err, _ := ex.singleFlight.Do(syncKeyHeadOrigin, func() (interface{}, error) { + return ex.head(ctx, span, opts...) + }) + ex.singleFlight.Forget(syncKeyHeadOrigin) + if err != nil { + span.SetStatus(codes.Error, err.Error()) + return head.(H), err + } + span.SetStatus(codes.Ok, "") + return head.(H), err +} + +func (ex *Exchange[H]) head(ctx context.Context, span trace.Span, opts ...header.HeadOption[H]) (H, error) { reqCtx := ctx startTime := time.Now() if deadline, ok := ctx.Deadline(); ok { @@ -244,7 +265,6 @@ func (ex *Exchange[H]) Head(ctx context.Context, opts ...header.HeadOption[H]) ( } ex.metrics.head(ctx, time.Since(startTime), len(headers), headType, headStatusOk) - span.SetStatus(codes.Ok, "") return head, nil } diff --git a/p2p/exchange_test.go b/p2p/exchange_test.go index 1a775bd3..66b99e05 100644 --- a/p2p/exchange_test.go +++ b/p2p/exchange_test.go @@ -3,6 +3,7 @@ package p2p import ( "context" "strconv" + sync2 "sync" "testing" "time" @@ -156,6 +157,75 @@ func TestExchange_RequestHead_UnresponsivePeer(t *testing.T) { assert.NotNil(t, head) } +func TestExchange_RequestHeadFlightProtection(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + hosts := createMocknet(t, 3) + exchg, trustedStore := createP2PExAndServer(t, hosts[0], hosts[1]) + + // create new server-side exchange that will act as the tracked peer + // it will have a higher chain head than the trusted peer so that the + // test can determine which peer was asked + trackedStore := headertest.NewStore[*headertest.DummyHeader](t, headertest.NewTestSuite(t), 50) + serverSideEx, err := NewExchangeServer[*headertest.DummyHeader](hosts[2], trackedStore, + WithNetworkID[ServerParameters](networkID), + ) + require.NoError(t, err) + err = serverSideEx.Start(ctx) + require.NoError(t, err) + t.Cleanup(func() { + err = serverSideEx.Stop(ctx) + require.NoError(t, err) + }) + // create the same requests + tests := make([]struct { + requestFromTrusted bool + lastHeader *headertest.DummyHeader + expectedHeight uint64 + expectedHash header.Hash + }, 10) + for i := 0; i < 10; i++ { + tests[i] = struct { + requestFromTrusted bool + lastHeader *headertest.DummyHeader + expectedHeight uint64 + expectedHash header.Hash + }{ + requestFromTrusted: true, + lastHeader: trustedStore.Headers[trustedStore.HeadHeight-1], + expectedHeight: trustedStore.HeadHeight, + expectedHash: trustedStore.Headers[trustedStore.HeadHeight].Hash(), + } + } + + var wg sync2.WaitGroup + // run over goroutine + for i, tt := range tests { + wg.Add(1) + go func(testStruct struct { + requestFromTrusted bool + lastHeader *headertest.DummyHeader + expectedHeight uint64 + expectedHash header.Hash + }, it int) { + defer wg.Done() + var opts []header.HeadOption[*headertest.DummyHeader] + if !testStruct.requestFromTrusted { + opts = append(opts, header.WithTrustedHead[*headertest.DummyHeader](testStruct.lastHeader)) + } + + h, errG := exchg.Head(ctx, opts...) + require.NoError(t, errG) + + assert.Equal(t, testStruct.expectedHeight, h.Height()) + assert.Equal(t, testStruct.expectedHash, h.Hash()) + + }(tt, i) + } + wg.Wait() +} + func TestExchange_RequestHeader(t *testing.T) { hosts := createMocknet(t, 2) exchg, store := createP2PExAndServer(t, hosts[0], hosts[1])