Skip to content

Commit

Permalink
updated writeHead after a real write
Browse files Browse the repository at this point in the history
  • Loading branch information
cristaloleg committed Jul 30, 2024
1 parent 1a0ed39 commit 1d53e1a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 34 deletions.
53 changes: 23 additions & 30 deletions store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ func (s *Store[H]) Head(ctx context.Context, _ ...header.HeadOption[H]) (H, erro
var zero H
return zero, err
}

s.writeHead.CompareAndSwap(nil, &head)

return head, nil
}

Expand Down Expand Up @@ -326,19 +329,10 @@ func (s *Store[H]) Append(ctx context.Context, headers ...H) error {
return nil
}

var err error
// take current write head to verify headers against
var head H
if headPtr := s.writeHead.Load(); headPtr == nil {
head, err = s.Head(ctx)
if err != nil {
return err
}
// store header from the disk.
gotHead := head
s.writeHead.Store(&gotHead)
} else {
head = *headPtr
head, err := s.Head(ctx)
if err != nil {
return err
}

slices.SortFunc(headers, func(a, b H) int {
Expand Down Expand Up @@ -369,29 +363,21 @@ func (s *Store[H]) Append(ctx context.Context, headers ...H) error {
}
verified = append(verified, h)
head = h
s.addKnownHeader(head)
}

onWrite := func() {
newHead := s.tryAdvanceHead()
log.Infow("new head", "height", newHead.Height(), "hash", newHead.Hash())
s.metrics.newHead(newHead.Height())
}

// queue headers to be written on disk
select {
case s.writes <- verified:
// we return an error here after writing,
// as there might be an invalid header in between of a given range
onWrite()
return err
default:
s.metrics.writesQueueBlocked(ctx)
}

// if the writes queue is full, we block until it is not
select {
case s.writes <- verified:
onWrite()
return err
case <-s.writesDn:
return errStoppedStore
Expand Down Expand Up @@ -430,6 +416,9 @@ func (s *Store[H]) flushLoop() {
s.metrics.flush(ctx, time.Since(startTime), s.pending.Len(), true)
continue
}

s.tryAdvanceHead(toFlush...)

s.metrics.flush(ctx, time.Since(startTime), s.pending.Len(), false)
// reset pending
s.pending.Reset()
Expand Down Expand Up @@ -518,18 +507,21 @@ func (s *Store[H]) get(ctx context.Context, hash header.Hash) ([]byte, error) {
return data, nil
}

func (s *Store[H]) addKnownHeader(h H) {
s.knownHeadersLk.Lock()
s.knownHeaders[h.Height()] = h
s.knownHeadersLk.Unlock()
}

// try advance heighest header if we saw a higher continuous before.
func (s *Store[H]) tryAdvanceHead() H {
func (s *Store[H]) tryAdvanceHead(headers ...H) {
headPtr := s.writeHead.Load()
if headPtr == nil || len(headers) == 0 {
return
}

s.knownHeadersLk.Lock()
defer s.knownHeadersLk.Unlock()

head := *s.writeHead.Load()
for _, h := range headers {
s.knownHeaders[h.Height()] = h
}

head := *headPtr
height := head.Height()
currHead := head

Expand All @@ -549,8 +541,9 @@ func (s *Store[H]) tryAdvanceHead() H {
// we don't need CAS here because that's the only place
// where writeHead is updated, knownHeadersLk ensures 1 goroutine.
s.writeHead.Store(&head)
log.Infow("new head", "height", head.Height(), "hash", head.Hash())
s.metrics.newHead(head.Height())
}
return head
}

// indexTo saves mapping between header Height and Hash to the given batch.
Expand Down
12 changes: 8 additions & 4 deletions store/store_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package store

import (
"bytes"
"context"
"math/rand"
stdsync "sync"
Expand All @@ -22,7 +23,7 @@ func TestStore(t *testing.T) {
suite := headertest.NewTestSuite(t)

ds := sync.MutexWrap(datastore.NewMapDatastore())
store := NewTestStore(t, ctx, ds, suite.Head())
store := NewTestStore(t, ctx, ds, suite.Head(), WithWriteBatchSize(5))

head, err := store.Head(ctx)
require.NoError(t, err)
Expand All @@ -38,9 +39,12 @@ func TestStore(t *testing.T) {
assert.Equal(t, h.Hash(), out[i].Hash())
}

head, err = store.Head(ctx)
require.NoError(t, err)
assert.Equal(t, out[len(out)-1].Hash(), head.Hash())
// we need to wait for a flush
assert.Eventually(t, func() bool {
head, err = store.Head(ctx)
require.NoError(t, err)
return bytes.Equal(out[len(out)-1].Hash(), head.Hash())
}, time.Second, 100*time.Millisecond)

ok, err := store.Has(ctx, in[5].Hash())
require.NoError(t, err)
Expand Down

0 comments on commit 1d53e1a

Please sign in to comment.