Skip to content

Commit

Permalink
Darren/feat/add subscription cache (#866)
Browse files Browse the repository at this point in the history
* ehancement: create a cache for block based subscriptions

* minor: change function names for subscriptions

* test: add unit test for message cache

* chore: add license headers

* refactor: fix up error handling

* fix: remove bad test

* fix: PR comments

* fix: PR comments - remove block cache

* refactor(subscriptions): store structs in cache, not bytes

* fix(license): add license header

* chore(subscriptions): revert unit test changes

* enhancement: resolve pr comments to use simplelru

* enhancement: resolve pr comments - use id as key
  • Loading branch information
darrenvechain authored Nov 5, 2024
1 parent ad5bfbd commit 9d5b515
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/on-pull-request.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Pull Request CI
on:
pull_request:
branches:
- master
- '*'

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number }}
Expand Down
40 changes: 27 additions & 13 deletions api/subscriptions/beat2_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ import (
type beat2Reader struct {
repo *chain.Repository
blockReader chain.BlockReader
cache *messageCache[Beat2Message]
}

func newBeat2Reader(repo *chain.Repository, position thor.Bytes32) *beat2Reader {
func newBeat2Reader(repo *chain.Repository, position thor.Bytes32, cache *messageCache[Beat2Message]) *beat2Reader {
return &beat2Reader{
repo: repo,
blockReader: repo.NewBlockReader(position),
cache: cache,
}
}

Expand All @@ -33,21 +35,32 @@ func (br *beat2Reader) Read() ([]interface{}, bool, error) {
}
var msgs []interface{}

bloomGenerator := &bloom.Generator{}

bloomAdd := func(key []byte) {
key = bytes.TrimLeft(key, "\x00")
// exclude non-address key
if len(key) <= thor.AddressLength {
bloomGenerator.Add(key)
for _, block := range blocks {
msg, _, err := br.cache.GetOrAdd(block.Header().ID(), br.generateBeat2Message(block))
if err != nil {
return nil, false, err
}
msgs = append(msgs, msg)
}
return msgs, len(blocks) > 0, nil
}

func (br *beat2Reader) generateBeat2Message(block *chain.ExtendedBlock) func() (Beat2Message, error) {
return func() (Beat2Message, error) {
bloomGenerator := &bloom.Generator{}

bloomAdd := func(key []byte) {
key = bytes.TrimLeft(key, "\x00")
// exclude non-address key
if len(key) <= thor.AddressLength {
bloomGenerator.Add(key)
}
}

for _, block := range blocks {
header := block.Header()
receipts, err := br.repo.GetBlockReceipts(header.ID())
if err != nil {
return nil, false, err
return Beat2Message{}, err
}
txs := block.Transactions()
for i, receipt := range receipts {
Expand All @@ -74,7 +87,7 @@ func (br *beat2Reader) Read() ([]interface{}, bool, error) {
const bitsPerKey = 20
filter := bloomGenerator.Generate(bitsPerKey, bloom.K(bitsPerKey))

msgs = append(msgs, &Beat2Message{
beat2 := Beat2Message{
Number: header.Number(),
ID: header.ID(),
ParentID: header.ParentID(),
Expand All @@ -84,7 +97,8 @@ func (br *beat2Reader) Read() ([]interface{}, bool, error) {
Bloom: hexutil.Encode(filter.Bits),
K: filter.K,
Obsolete: block.Obsolete,
})
}

return beat2, nil
}
return msgs, len(blocks) > 0, nil
}
10 changes: 6 additions & 4 deletions api/subscriptions/beat2_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,22 @@ func TestBeat2Reader_Read(t *testing.T) {
newBlock := generatedBlocks[1]

// Act
beatReader := newBeat2Reader(repo, genesisBlk.Header().ID())
beatReader := newBeat2Reader(repo, genesisBlk.Header().ID(), newMessageCache[Beat2Message](10))
res, ok, err := beatReader.Read()

// Assert
assert.NoError(t, err)
assert.True(t, ok)
if beatMsg, ok := res[0].(*Beat2Message); !ok {
if beatMsg, ok := res[0].(Beat2Message); !ok {
t.Fatal("unexpected type")
} else {
assert.Equal(t, newBlock.Header().Number(), beatMsg.Number)
assert.Equal(t, newBlock.Header().ID(), beatMsg.ID)
assert.Equal(t, newBlock.Header().ParentID(), beatMsg.ParentID)
assert.Equal(t, newBlock.Header().Timestamp(), beatMsg.Timestamp)
assert.Equal(t, uint32(newBlock.Header().TxsFeatures()), beatMsg.TxsFeatures)
// GasLimit is not part of the deprecated BeatMessage
assert.Equal(t, newBlock.Header().GasLimit(), beatMsg.GasLimit)
}
}

Expand All @@ -42,7 +44,7 @@ func TestBeat2Reader_Read_NoNewBlocksToRead(t *testing.T) {
newBlock := generatedBlocks[1]

// Act
beatReader := newBeat2Reader(repo, newBlock.Header().ID())
beatReader := newBeat2Reader(repo, newBlock.Header().ID(), newMessageCache[Beat2Message](10))
res, ok, err := beatReader.Read()

// Assert
Expand All @@ -56,7 +58,7 @@ func TestBeat2Reader_Read_ErrorWhenReadingBlocks(t *testing.T) {
repo, _, _ := initChain(t)

// Act
beatReader := newBeat2Reader(repo, thor.MustParseBytes32("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"))
beatReader := newBeat2Reader(repo, thor.MustParseBytes32("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), newMessageCache[Beat2Message](10))
res, ok, err := beatReader.Read()

// Assert
Expand Down
46 changes: 30 additions & 16 deletions api/subscriptions/beat_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ import (
type beatReader struct {
repo *chain.Repository
blockReader chain.BlockReader
cache *messageCache[BeatMessage]
}

func newBeatReader(repo *chain.Repository, position thor.Bytes32) *beatReader {
func newBeatReader(repo *chain.Repository, position thor.Bytes32, cache *messageCache[BeatMessage]) *beatReader {
return &beatReader{
repo: repo,
blockReader: repo.NewBlockReader(position),
cache: cache,
}
}

Expand All @@ -33,40 +35,51 @@ func (br *beatReader) Read() ([]interface{}, bool, error) {
}
var msgs []interface{}
for _, block := range blocks {
msg, _, err := br.cache.GetOrAdd(block.Header().ID(), br.generateBeatMessage(block))
if err != nil {
return nil, false, err
}
msgs = append(msgs, msg)
}
return msgs, len(blocks) > 0, nil
}

func (br *beatReader) generateBeatMessage(block *chain.ExtendedBlock) func() (BeatMessage, error) {
return func() (BeatMessage, error) {
header := block.Header()
receipts, err := br.repo.GetBlockReceipts(header.ID())
if err != nil {
return nil, false, err
return BeatMessage{}, err
}
txs := block.Transactions()
bloomContent := &bloomContent{}
content := &bloomContent{}
for i, receipt := range receipts {
bloomContent.add(receipt.GasPayer.Bytes())
content.add(receipt.GasPayer.Bytes())
for _, output := range receipt.Outputs {
for _, event := range output.Events {
bloomContent.add(event.Address.Bytes())
content.add(event.Address.Bytes())
for _, topic := range event.Topics {
bloomContent.add(topic.Bytes())
content.add(topic.Bytes())
}
}
for _, transfer := range output.Transfers {
bloomContent.add(transfer.Sender.Bytes())
bloomContent.add(transfer.Recipient.Bytes())
content.add(transfer.Sender.Bytes())
content.add(transfer.Recipient.Bytes())
}
}
origin, _ := txs[i].Origin()
bloomContent.add(origin.Bytes())
content.add(origin.Bytes())
}
signer, _ := header.Signer()
bloomContent.add(signer.Bytes())
bloomContent.add(header.Beneficiary().Bytes())
content.add(signer.Bytes())
content.add(header.Beneficiary().Bytes())

k := bloom.LegacyEstimateBloomK(bloomContent.len())
k := bloom.LegacyEstimateBloomK(content.len())
bloom := bloom.NewLegacyBloom(k)
for _, item := range bloomContent.items {
for _, item := range content.items {
bloom.Add(item)
}
msgs = append(msgs, &BeatMessage{
beat := BeatMessage{
Number: header.Number(),
ID: header.ID(),
ParentID: header.ParentID(),
Expand All @@ -75,9 +88,10 @@ func (br *beatReader) Read() ([]interface{}, bool, error) {
Bloom: hexutil.Encode(bloom.Bits[:]),
K: uint32(k),
Obsolete: block.Obsolete,
})
}

return beat, nil
}
return msgs, len(blocks) > 0, nil
}

type bloomContent struct {
Expand Down
8 changes: 4 additions & 4 deletions api/subscriptions/beat_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ func TestBeatReader_Read(t *testing.T) {
newBlock := generatedBlocks[1]

// Act
beatReader := newBeatReader(repo, genesisBlk.Header().ID())
beatReader := newBeatReader(repo, genesisBlk.Header().ID(), newMessageCache[BeatMessage](10))
res, ok, err := beatReader.Read()

// Assert
assert.NoError(t, err)
assert.True(t, ok)
if beatMsg, ok := res[0].(*BeatMessage); !ok {
if beatMsg, ok := res[0].(BeatMessage); !ok {
t.Fatal("unexpected type")
} else {
assert.Equal(t, newBlock.Header().Number(), beatMsg.Number)
Expand All @@ -42,7 +42,7 @@ func TestBeatReader_Read_NoNewBlocksToRead(t *testing.T) {
newBlock := generatedBlocks[1]

// Act
beatReader := newBeatReader(repo, newBlock.Header().ID())
beatReader := newBeatReader(repo, newBlock.Header().ID(), newMessageCache[BeatMessage](10))
res, ok, err := beatReader.Read()

// Assert
Expand All @@ -56,7 +56,7 @@ func TestBeatReader_Read_ErrorWhenReadingBlocks(t *testing.T) {
repo, _, _ := initChain(t)

// Act
beatReader := newBeatReader(repo, thor.MustParseBytes32("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"))
beatReader := newBeatReader(repo, thor.MustParseBytes32("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), newMessageCache[BeatMessage](10))
res, ok, err := beatReader.Read()

// Assert
Expand Down
2 changes: 0 additions & 2 deletions api/subscriptions/block_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@ import (
)

type blockReader struct {
repo *chain.Repository
blockReader chain.BlockReader
}

func newBlockReader(repo *chain.Repository, position thor.Bytes32) *blockReader {
return &blockReader{
repo: repo,
blockReader: repo.NewBlockReader(position),
}
}
Expand Down
65 changes: 65 additions & 0 deletions api/subscriptions/message_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) 2024 The VeChainThor developers
//
// Distributed under the GNU Lesser General Public License v3.0 software license, see the accompanying
// file LICENSE or <https://www.gnu.org/licenses/lgpl-3.0.html>

package subscriptions

import (
"fmt"
"sync"

"github.com/hashicorp/golang-lru/simplelru"
"github.com/vechain/thor/v2/thor"
)

// messageCache is a generic cache that stores messages of any type.
type messageCache[T any] struct {
cache *simplelru.LRU
mu sync.RWMutex
}

// newMessageCache creates a new messageCache with the specified cache size.
func newMessageCache[T any](cacheSize uint32) *messageCache[T] {
if cacheSize > 1000 {
cacheSize = 1000
}
if cacheSize == 0 {
cacheSize = 1
}
cache, err := simplelru.NewLRU(int(cacheSize), nil)
if err != nil {
// lru.New only throws an error if the number is less than 1
panic(fmt.Errorf("failed to create message cache: %v", err))
}
return &messageCache[T]{
cache: cache,
}
}

// GetOrAdd returns the message of the block. If the message is not in the cache,
// it will generate the message and add it to the cache. The second return value
// indicates whether the message is newly generated.
func (mc *messageCache[T]) GetOrAdd(id thor.Bytes32, createMessage func() (T, error)) (T, bool, error) {
mc.mu.RLock()
msg, ok := mc.cache.Get(id)
mc.mu.RUnlock()
if ok {
return msg.(T), false, nil
}

mc.mu.Lock()
defer mc.mu.Unlock()
msg, ok = mc.cache.Get(id)
if ok {
return msg.(T), false, nil
}

newMsg, err := createMessage()
if err != nil {
var zero T
return zero, false, err
}
mc.cache.Add(id, newMsg)
return newMsg, true, nil
}
61 changes: 61 additions & 0 deletions api/subscriptions/message_cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) 2024 The VeChainThor developers
//
// Distributed under the GNU Lesser General Public License v3.0 software license, see the accompanying
// file LICENSE or <https://www.gnu.org/licenses/lgpl-3.0.html>

package subscriptions

import (
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/vechain/thor/v2/block"
)

type message struct {
id string
}

func handler(blk *block.Block) func() (message, error) {
return func() (message, error) {
msg := message{
id: blk.Header().ID().String(),
}
return msg, nil
}
}

func TestMessageCache_GetOrAdd(t *testing.T) {
_, generatedBlocks, _ := initChain(t)

blk0 := generatedBlocks[0]
blk1 := generatedBlocks[1]

cache := newMessageCache[message](10)

counter := atomic.Int32{}
wg := sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
start := time.Now().Add(20 * time.Millisecond)
go func() {
defer wg.Done()
time.Sleep(time.Until(start))
_, added, err := cache.GetOrAdd(blk0.Header().ID(), handler(blk0))
assert.NoError(t, err)
if added {
counter.Add(1)
}
}()
}
wg.Wait()
assert.Equal(t, counter.Load(), int32(1))

_, added, err := cache.GetOrAdd(blk1.Header().ID(), handler(blk1))
assert.NoError(t, err)
assert.True(t, added)
assert.Equal(t, cache.cache.Len(), 2)
}
Loading

0 comments on commit 9d5b515

Please sign in to comment.