diff --git a/store/store.go b/store/store.go index 526b195f..9afbfe78 100644 --- a/store/store.go +++ b/store/store.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "slices" - "sync" "sync/atomic" "time" @@ -55,11 +54,6 @@ type Store[H header.Header[H]] struct { // writeHead maintains the current write head writeHead atomic.Pointer[H] - knownHeadersLk sync.Mutex - // knownHeaders tracks all processed headers - // to advance writeHead only over continuous headers. - knownHeaders map[uint64]H - // pending keeps headers pending to be written in one batch pending *batch[H] @@ -117,8 +111,6 @@ func newStore[H header.Header[H]](ds datastore.Batching, opts ...Option) (*Store writesDn: make(chan struct{}), pending: newBatch[H](params.WriteBatchSize), Params: params, - - knownHeaders: make(map[uint64]H), }, nil } @@ -423,7 +415,7 @@ func (s *Store[H]) flushLoop() { time.Sleep(sleep) } - s.tryAdvanceHead(toFlush...) + s.tryAdvanceHead(ctx, toFlush...) s.metrics.flush(ctx, time.Since(startTime), s.pending.Len(), false) // reset pending @@ -513,43 +505,35 @@ func (s *Store[H]) get(ctx context.Context, hash header.Hash) ([]byte, error) { return data, nil } -// try advance heighest header if we saw a higher continuous before. -func (s *Store[H]) tryAdvanceHead(headers ...H) { - headPtr := s.writeHead.Load() - if headPtr == nil || len(headers) == 0 { +// try advance heighest writeHead based on passed or already written headers. +func (s *Store[H]) tryAdvanceHead(ctx context.Context, headers ...H) { + writeHead := s.writeHead.Load() + if writeHead == nil || len(headers) == 0 { return } - s.knownHeadersLk.Lock() - defer s.knownHeadersLk.Unlock() + currHeight := (*writeHead).Height() - for _, h := range headers { - s.knownHeaders[h.Height()] = h + // advance based on passed headers. + for i := 0; i < len(headers); i++ { + if headers[i].Height() != currHeight+1 { + break + } + s.writeHead.Store(&headers[i]) + currHeight++ } - currHead := *headPtr - height := currHead.Height() - newHead := currHead + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() - // try to move to the next height. - for len(s.knownHeaders) > 0 { - h, ok := s.knownHeaders[height+1] - if !ok { + // advance based on already written headers. + for { + newHead, err := s.GetByHeight(ctx, currHeight+1) + if err != nil { break } - newHead = h - delete(s.knownHeaders, height+1) - height++ - } - - // we found higher continuous header - update. - if currHead.Height() < newHead.Height() { - // we don't need CAS here because that's the only place - // where writeHead is updated, knownHeadersLk ensures 1 goroutine. - // NOTE: Store[H].Head also updates writeHead but only once when it's nil. s.writeHead.Store(&newHead) - log.Infow("new head", "height", newHead.Height(), "hash", newHead.Hash()) - s.metrics.newHead(newHead.Height()) + currHeight++ } } diff --git a/store/store_test.go b/store/store_test.go index a0f2512f..ef9d0172 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -248,7 +248,7 @@ func TestStore_Append_stableHeadWhenGaps(t *testing.T) { err := store.Append(ctx, missedChunk...) require.NoError(t, err) // wait for batch to be written. - time.Sleep(100 * time.Millisecond) + time.Sleep(time.Second) // after appending missing headers we're on the latest header. head, err := store.Head(ctx)