From 1d53e1a1ced36fc14e9a373a6fa560b9e5f41bf9 Mon Sep 17 00:00:00 2001 From: Oleg Kovalov Date: Tue, 30 Jul 2024 15:14:04 +0200 Subject: [PATCH] updated writeHead after a real write --- store/store.go | 53 ++++++++++++++++++++------------------------- store/store_test.go | 12 ++++++---- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/store/store.go b/store/store.go index 73e0afcf..83d98494 100644 --- a/store/store.go +++ b/store/store.go @@ -201,6 +201,9 @@ func (s *Store[H]) Head(ctx context.Context, _ ...header.HeadOption[H]) (H, erro var zero H return zero, err } + + s.writeHead.CompareAndSwap(nil, &head) + return head, nil } @@ -326,19 +329,10 @@ func (s *Store[H]) Append(ctx context.Context, headers ...H) error { return nil } - var err error // take current write head to verify headers against - var head H - if headPtr := s.writeHead.Load(); headPtr == nil { - head, err = s.Head(ctx) - if err != nil { - return err - } - // store header from the disk. - gotHead := head - s.writeHead.Store(&gotHead) - } else { - head = *headPtr + head, err := s.Head(ctx) + if err != nil { + return err } slices.SortFunc(headers, func(a, b H) int { @@ -369,13 +363,6 @@ func (s *Store[H]) Append(ctx context.Context, headers ...H) error { } verified = append(verified, h) head = h - s.addKnownHeader(head) - } - - onWrite := func() { - newHead := s.tryAdvanceHead() - log.Infow("new head", "height", newHead.Height(), "hash", newHead.Hash()) - s.metrics.newHead(newHead.Height()) } // queue headers to be written on disk @@ -383,15 +370,14 @@ func (s *Store[H]) Append(ctx context.Context, headers ...H) error { case s.writes <- verified: // we return an error here after writing, // as there might be an invalid header in between of a given range - onWrite() return err default: s.metrics.writesQueueBlocked(ctx) } + // if the writes queue is full, we block until it is not select { case s.writes <- verified: - onWrite() return err case <-s.writesDn: return errStoppedStore @@ -430,6 +416,9 @@ func (s *Store[H]) flushLoop() { s.metrics.flush(ctx, time.Since(startTime), s.pending.Len(), true) continue } + + s.tryAdvanceHead(toFlush...) + s.metrics.flush(ctx, time.Since(startTime), s.pending.Len(), false) // reset pending s.pending.Reset() @@ -518,18 +507,21 @@ func (s *Store[H]) get(ctx context.Context, hash header.Hash) ([]byte, error) { return data, nil } -func (s *Store[H]) addKnownHeader(h H) { - s.knownHeadersLk.Lock() - s.knownHeaders[h.Height()] = h - s.knownHeadersLk.Unlock() -} - // try advance heighest header if we saw a higher continuous before. -func (s *Store[H]) tryAdvanceHead() H { +func (s *Store[H]) tryAdvanceHead(headers ...H) { + headPtr := s.writeHead.Load() + if headPtr == nil || len(headers) == 0 { + return + } + s.knownHeadersLk.Lock() defer s.knownHeadersLk.Unlock() - head := *s.writeHead.Load() + for _, h := range headers { + s.knownHeaders[h.Height()] = h + } + + head := *headPtr height := head.Height() currHead := head @@ -549,8 +541,9 @@ func (s *Store[H]) tryAdvanceHead() H { // we don't need CAS here because that's the only place // where writeHead is updated, knownHeadersLk ensures 1 goroutine. s.writeHead.Store(&head) + log.Infow("new head", "height", head.Height(), "hash", head.Hash()) + s.metrics.newHead(head.Height()) } - return head } // indexTo saves mapping between header Height and Hash to the given batch. diff --git a/store/store_test.go b/store/store_test.go index fcaefa6e..a0f2512f 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -1,6 +1,7 @@ package store import ( + "bytes" "context" "math/rand" stdsync "sync" @@ -22,7 +23,7 @@ func TestStore(t *testing.T) { suite := headertest.NewTestSuite(t) ds := sync.MutexWrap(datastore.NewMapDatastore()) - store := NewTestStore(t, ctx, ds, suite.Head()) + store := NewTestStore(t, ctx, ds, suite.Head(), WithWriteBatchSize(5)) head, err := store.Head(ctx) require.NoError(t, err) @@ -38,9 +39,12 @@ func TestStore(t *testing.T) { assert.Equal(t, h.Hash(), out[i].Hash()) } - head, err = store.Head(ctx) - require.NoError(t, err) - assert.Equal(t, out[len(out)-1].Hash(), head.Hash()) + // we need to wait for a flush + assert.Eventually(t, func() bool { + head, err = store.Head(ctx) + require.NoError(t, err) + return bytes.Equal(out[len(out)-1].Hash(), head.Hash()) + }, time.Second, 100*time.Millisecond) ok, err := store.Has(ctx, in[5].Hash()) require.NoError(t, err)