diff --git a/p2p/exchange.go b/p2p/exchange.go index c498425e..184c2732 100644 --- a/p2p/exchange.go +++ b/p2p/exchange.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "golang.org/x/sync/singleflight" "math/rand" "sort" "time" @@ -35,6 +36,10 @@ var ( // chosen. const minHeadResponses = 2 +// syncKeyHeadOrigin represents the origin value used specifically for Head requests, +// serving as a sync key to prevent redundant queries. +const syncKeyHeadOrigin = "0" + // maxUntrustedHeadRequests is the number of head requests to be made to // the network in order to determine the network head. var maxUntrustedHeadRequests = 4 @@ -52,6 +57,8 @@ type Exchange[H header.Header[H]] struct { peerTracker *peerTracker metrics *exchangeMetrics + singleFlight *singleflight.Group + Params ClientParameters } @@ -81,11 +88,12 @@ func NewExchange[H header.Header[H]]( } ex := &Exchange[H]{ - host: host, - protocolID: protocolID(params.networkID), - peerTracker: newPeerTracker(host, gater, params.pidstore, metrics), - Params: params, - metrics: metrics, + host: host, + protocolID: protocolID(params.networkID), + peerTracker: newPeerTracker(host, gater, params.pidstore, metrics), + Params: params, + metrics: metrics, + singleFlight: &singleflight.Group{}, } ex.trustedPeers = func() peer.IDSlice { @@ -124,6 +132,19 @@ func (ex *Exchange[H]) Head(ctx context.Context, opts ...header.HeadOption[H]) ( ctx, span := tracerClient.Start(ctx, "head") defer span.End() + head, err, _ := ex.singleFlight.Do(syncKeyHeadOrigin, func() (interface{}, error) { + return ex.head(ctx, span, opts...) + }) + ex.singleFlight.Forget(syncKeyHeadOrigin) + if err != nil { + span.SetStatus(codes.Error, err.Error()) + return head.(H), err + } + span.SetStatus(codes.Ok, "") + return head.(H), nil +} + +func (ex *Exchange[H]) head(ctx context.Context, span trace.Span, opts ...header.HeadOption[H]) (H, error) { reqCtx := ctx startTime := time.Now() if deadline, ok := ctx.Deadline(); ok { @@ -244,7 +265,6 @@ func (ex *Exchange[H]) Head(ctx context.Context, opts ...header.HeadOption[H]) ( } ex.metrics.head(ctx, time.Since(startTime), len(headers), headType, headStatusOk) - span.SetStatus(codes.Ok, "") return head, nil } diff --git a/p2p/exchange_test.go b/p2p/exchange_test.go index 1a775bd3..42e1ae94 100644 --- a/p2p/exchange_test.go +++ b/p2p/exchange_test.go @@ -3,6 +3,7 @@ package p2p import ( "context" "strconv" + sync2 "sync" "testing" "time" @@ -156,6 +157,64 @@ func TestExchange_RequestHead_UnresponsivePeer(t *testing.T) { assert.NotNil(t, head) } +func TestExchange_RequestHeadFlightProtection(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + hosts := createMocknet(t, 3) + exchg, trustedStore := createP2PExAndServer(t, hosts[0], hosts[1]) + + // create the same requests + tests := []struct { + requestFromTrusted bool + lastHeader *headertest.DummyHeader + expectedHeight uint64 + expectedHash header.Hash + }{ + { + requestFromTrusted: true, + lastHeader: trustedStore.Headers[trustedStore.HeadHeight-1], + expectedHeight: trustedStore.HeadHeight, + expectedHash: trustedStore.Headers[trustedStore.HeadHeight].Hash(), + }, + { + // request from untrusted peer should be the same as trusted bc of single-preflight + requestFromTrusted: false, + lastHeader: trustedStore.Headers[trustedStore.HeadHeight-1], + expectedHeight: trustedStore.HeadHeight, + expectedHash: trustedStore.Headers[trustedStore.HeadHeight].Hash(), + }, + } + + var wg sync2.WaitGroup + // run over goroutine + for i, tt := range tests { + wg.Add(1) + go func(testStruct struct { + requestFromTrusted bool + lastHeader *headertest.DummyHeader + expectedHeight uint64 + expectedHash header.Hash + }, it int) { + defer wg.Done() + var opts []header.HeadOption[*headertest.DummyHeader] + if !testStruct.requestFromTrusted { + opts = append(opts, header.WithTrustedHead[*headertest.DummyHeader](testStruct.lastHeader)) + } + + h, errG := exchg.Head(ctx, opts...) + require.NoError(t, errG) + + assert.Equal(t, testStruct.expectedHeight, h.Height()) + assert.Equal(t, testStruct.expectedHash, h.Hash()) + + }(tt, i) + // ensure first Head will be locked by request from trusted peer + time.Sleep(time.Microsecond) + } + wg.Wait() +} + func TestExchange_RequestHeader(t *testing.T) { hosts := createMocknet(t, 2) exchg, store := createP2PExAndServer(t, hosts[0], hosts[1]) diff --git a/sync/sync.go b/sync/sync.go index 05e1a781..1d86d8e0 100644 --- a/sync/sync.go +++ b/sync/sync.go @@ -33,7 +33,7 @@ var log = logging.Logger("header/sync") type Syncer[H header.Header[H]] struct { sub header.Subscriber[H] // to subscribe for new Network Heads store syncStore[H] // to store all the headers to - getter syncGetter[H] // to fetch headers from + getter header.Getter[H] // to fetch headers from metrics *metrics // stateLk protects state which represents the current or latest sync @@ -80,7 +80,7 @@ func NewSyncer[H header.Header[H]]( return &Syncer[H]{ sub: sub, store: syncStore[H]{Store: store}, - getter: syncGetter[H]{Getter: getter}, + getter: getter, metrics: metrics, triggerSync: make(chan struct{}, 1), // should be buffered Params: ¶ms, diff --git a/sync/sync_getter.go b/sync/sync_getter.go deleted file mode 100644 index 267240c5..00000000 --- a/sync/sync_getter.go +++ /dev/null @@ -1,52 +0,0 @@ -package sync - -import ( - "context" - "sync" - "sync/atomic" - - "github.com/celestiaorg/go-header" -) - -// syncGetter is a Getter wrapper that ensure only one Head call happens at the time -type syncGetter[H header.Header[H]] struct { - getterLk sync.RWMutex - isGetterLk atomic.Bool - header.Getter[H] -} - -// Lock locks the getter for single user. -// Reports 'true' if the lock was held by the current routine. -// Does not require unlocking on 'false'. -func (sg *syncGetter[H]) Lock() bool { - // the lock construction here ensures only one routine is freed at a time - // while others wait via Rlock - acquiredLock := sg.getterLk.TryLock() - if !acquiredLock { - sg.getterLk.RLock() - defer sg.getterLk.RUnlock() - return false - } - sg.isGetterLk.Store(acquiredLock) - return acquiredLock -} - -// Unlock unlocks the getter. -func (sg *syncGetter[H]) Unlock() { - sg.checkLock("Unlock without preceding Lock on syncGetter") - sg.getterLk.Unlock() - sg.isGetterLk.Store(false) -} - -// Head must be called with held Lock. -func (sg *syncGetter[H]) Head(ctx context.Context, opts ...header.HeadOption[H]) (H, error) { - sg.checkLock("Head without preceding Lock on syncGetter") - return sg.Getter.Head(ctx, opts...) -} - -// checkLock ensures api safety -func (sg *syncGetter[H]) checkLock(msg string) { - if !sg.isGetterLk.Load() { - panic(msg) - } -} diff --git a/sync/sync_getter_test.go b/sync/sync_getter_test.go deleted file mode 100644 index 47d228b1..00000000 --- a/sync/sync_getter_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package sync - -import ( - "context" - "errors" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/assert" - - "github.com/celestiaorg/go-header" - "github.com/celestiaorg/go-header/headertest" -) - -func TestSyncGetterHead(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - - fex := &fakeGetter[*headertest.DummyHeader]{} - sex := &syncGetter[*headertest.DummyHeader]{Getter: fex} - - var wg sync.WaitGroup - for i := 0; i < 100; i++ { - wg.Add(1) - go func() { - defer wg.Done() - if !sex.Lock() { - return - } - defer sex.Unlock() - h, err := sex.Head(ctx) - if h != nil || err != errFakeHead { - t.Fail() - } - }() - } - wg.Wait() - - assert.EqualValues(t, 1, fex.hits.Load()) -} - -var errFakeHead = errors.New("head") - -type fakeGetter[H header.Header[H]] struct { - hits atomic.Uint32 -} - -func (f *fakeGetter[H]) Head(ctx context.Context, _ ...header.HeadOption[H]) (h H, err error) { - f.hits.Add(1) - select { - case <-time.After(time.Millisecond * 100): - err = errFakeHead - case <-ctx.Done(): - err = ctx.Err() - } - - return -} - -func (f *fakeGetter[H]) Get(ctx context.Context, hash header.Hash) (H, error) { - panic("implement me") -} - -func (f *fakeGetter[H]) GetByHeight(ctx context.Context, u uint64) (H, error) { - panic("implement me") -} - -func (f *fakeGetter[H]) GetRangeByHeight(ctx context.Context, from H, to uint64) ([]H, error) { - panic("implement me") -} diff --git a/sync/sync_head.go b/sync/sync_head.go index c74347b7..3c121607 100644 --- a/sync/sync_head.go +++ b/sync/sync_head.go @@ -29,14 +29,6 @@ func (s *Syncer[H]) Head(ctx context.Context, _ ...header.HeadOption[H]) (H, err return sbjHead, nil } - // single-flight protection ensure only one Head is requested at the time - if !s.getter.Lock() { - // means that other routine held the lock and set the subjective head for us, - // so just recursively get it - return s.Head(ctx) - } - defer s.getter.Unlock() - s.metrics.outdatedHead(s.ctx) reqCtx, cancel := context.WithTimeout(ctx, headRequestTimeout) @@ -80,14 +72,6 @@ func (s *Syncer[H]) subjectiveHead(ctx context.Context) (H, error) { } // otherwise, request head from a trusted peer log.Infow("stored head header expired", "height", storeHead.Height()) - // single-flight protection - // ensure only one Head is requested at the time - if !s.getter.Lock() { - // means that other routine held the lock and set the subjective head for us, - // so just recursively get it - return s.subjectiveHead(ctx) - } - defer s.getter.Unlock() trustHead, err := s.getter.Head(ctx) if err != nil {