diff --git a/store/heightsub.go b/store/heightsub.go index a69f28f6..2335001d 100644 --- a/store/heightsub.go +++ b/store/heightsub.go @@ -18,13 +18,13 @@ type heightSub[H header.Header[H]] struct { // that has been fully verified and inserted into the subjective chain height atomic.Uint64 heightReqsLk sync.Mutex - heightReqs map[uint64][]chan H + heightReqs map[uint64]map[chan H]struct{} } // newHeightSub instantiates new heightSub. func newHeightSub[H header.Header[H]]() *heightSub[H] { return &heightSub[H]{ - heightReqs: make(map[uint64][]chan H), + heightReqs: make(map[uint64]map[chan H]struct{}), } } @@ -56,16 +56,24 @@ func (hs *heightSub[H]) Sub(ctx context.Context, height uint64) (H, error) { return zero, errElapsedHeight } resp := make(chan H, 1) - hs.heightReqs[height] = append(hs.heightReqs[height], resp) + reqs, ok := hs.heightReqs[height] + if !ok { + reqs = make(map[chan H]struct{}) + hs.heightReqs[height] = reqs + } + reqs[resp] = struct{}{} hs.heightReqsLk.Unlock() select { case resp := <-resp: return resp, nil case <-ctx.Done(): - // no need to keep the request, if the op is canceled + // no need to keep the request, if the op has canceled hs.heightReqsLk.Lock() - delete(hs.heightReqs, height) + delete(reqs, resp) + if len(reqs) == 0 { + delete(hs.heightReqs, height) + } hs.heightReqsLk.Unlock() return zero, ctx.Err() } @@ -98,7 +106,7 @@ func (hs *heightSub[H]) Pub(headers ...H) { if ln == 1 { reqs, ok := hs.heightReqs[from] if ok { - for _, req := range reqs { + for req := range reqs { req <- headers[0] // reqs must always be buffered, so this won't block } delete(hs.heightReqs, from) @@ -113,7 +121,7 @@ func (hs *heightSub[H]) Pub(headers ...H) { if height >= from && height <= to { // and if so, calculate its position and fulfill requests h := headers[height-from] - for _, req := range reqs { + for req := range reqs { req <- h // reqs must always be buffered, so this won't block } delete(hs.heightReqs, height) diff --git a/store/heightsub_test.go b/store/heightsub_test.go index 34981037..3a48d950 100644 --- a/store/heightsub_test.go +++ b/store/heightsub_test.go @@ -46,3 +46,40 @@ func TestHeightSub(t *testing.T) { assert.NotNil(t, h) } } + +func TestHeightSubCancellation(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + h := headertest.RandDummyHeader(t) + hs := newHeightSub[*headertest.DummyHeader]() + + sub := make(chan *headertest.DummyHeader) + go func() { + // subscribe first time + h, _ := hs.Sub(ctx, h.HeightI) + sub <- h + }() + + // give a bit time for subscription to settle + time.Sleep(time.Millisecond * 10) + + // subscribe again but with failed canceled context + canceledCtx, cancel := context.WithCancel(ctx) + cancel() + _, err := hs.Sub(canceledCtx, h.HeightI) + assert.Error(t, err) + + // publish header + hs.Pub(h) + + // ensure we still get our header + select { + case subH := <-sub: + assert.Equal(t, h.HeightI, subH.HeightI) + case <-ctx.Done(): + t.Error(ctx.Err()) + } + // ensure we don't have any active subscriptions + assert.Len(t, hs.heightReqs, 0) +}