diff --git a/store/heightsub.go b/store/heightsub.go index 2335001d..80139c74 100644 --- a/store/heightsub.go +++ b/store/heightsub.go @@ -34,8 +34,17 @@ func (hs *heightSub[H]) Height() uint64 { } // SetHeight sets the new head height for heightSub. +// Only the higher height can be set, otherwise no-op. func (hs *heightSub[H]) SetHeight(height uint64) { - hs.height.Store(height) + for { + curr := hs.height.Load() + if curr > height { + return + } + if hs.height.CompareAndSwap(curr, height) { + return + } + } } // Sub subscribes for a header of a given height. @@ -89,12 +98,7 @@ func (hs *heightSub[H]) Pub(headers ...H) { return } - height := hs.Height() from, to := headers[0].Height(), headers[ln-1].Height() - if height+1 != from && height != 0 { // height != 0 is needed to enable init from any height and not only 1 - log.Fatalf("PLEASE FILE A BUG REPORT: headers given to the heightSub are in the wrong order: expected %d, got %d", height+1, from) - return - } hs.SetHeight(to) hs.heightReqsLk.Lock() @@ -114,17 +118,17 @@ func (hs *heightSub[H]) Pub(headers ...H) { return } - // instead of looping over each header in 'headers', we can loop over each request - // which will drastically decrease idle iterations, as there will be less requests than headers - for height, reqs := range hs.heightReqs { - // then we look if any of the requests match the given range of headers - if height >= from && height <= to { - // and if so, calculate its position and fulfill requests - h := headers[height-from] - for req := range reqs { - req <- h // reqs must always be buffered, so this won't block - } - delete(hs.heightReqs, height) + for _, h := range headers { + height := h.Height() + + reqs, ok := hs.heightReqs[height] + if !ok { + continue + } + + for req := range reqs { + req <- h // reqs must always be buffered, so this won't block } + delete(hs.heightReqs, height) } } diff --git a/store/heightsub_test.go b/store/heightsub_test.go index 3a48d950..64ef1804 100644 --- a/store/heightsub_test.go +++ b/store/heightsub_test.go @@ -47,6 +47,37 @@ func TestHeightSub(t *testing.T) { } } +func TestHeightSubNonAdjacement(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + hs := newHeightSub[*headertest.DummyHeader]() + + { + h := headertest.RandDummyHeader(t) + h.HeightI = 100 + hs.SetHeight(99) + hs.Pub(h) + } + + { + go func() { + // fixes flakiness on CI + time.Sleep(time.Millisecond) + + h1 := headertest.RandDummyHeader(t) + h1.HeightI = 200 + h2 := headertest.RandDummyHeader(t) + h2.HeightI = 300 + hs.Pub(h1, h2) + }() + + h, err := hs.Sub(ctx, 200) + assert.NoError(t, err) + assert.NotNil(t, h) + } +} + func TestHeightSubCancellation(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() diff --git a/store/store.go b/store/store.go index 657a22c7..ef908a9a 100644 --- a/store/store.go +++ b/store/store.go @@ -1,9 +1,11 @@ package store import ( + "cmp" "context" "errors" "fmt" + "slices" "sync/atomic" "time" @@ -319,6 +321,10 @@ func (s *Store[H]) Append(ctx context.Context, headers ...H) error { head = *headPtr } + slices.SortFunc(headers, func(a, b H) int { + return cmp.Compare(a.Height(), b.Height()) + }) + // collect valid headers verified := make([]H, 0, lh) for i, h := range headers {