diff --git a/.github/workflows/on-pull-request.yaml b/.github/workflows/on-pull-request.yaml index 57cee70ee..6c6300109 100644 --- a/.github/workflows/on-pull-request.yaml +++ b/.github/workflows/on-pull-request.yaml @@ -3,7 +3,7 @@ name: Pull Request CI on: pull_request: branches: - - master + - '*' concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number }} diff --git a/api/subscriptions/beat2_reader.go b/api/subscriptions/beat2_reader.go index 9d1940538..267ba9d00 100644 --- a/api/subscriptions/beat2_reader.go +++ b/api/subscriptions/beat2_reader.go @@ -17,12 +17,14 @@ import ( type beat2Reader struct { repo *chain.Repository blockReader chain.BlockReader + cache *messageCache[Beat2Message] } -func newBeat2Reader(repo *chain.Repository, position thor.Bytes32) *beat2Reader { +func newBeat2Reader(repo *chain.Repository, position thor.Bytes32, cache *messageCache[Beat2Message]) *beat2Reader { return &beat2Reader{ repo: repo, blockReader: repo.NewBlockReader(position), + cache: cache, } } @@ -33,21 +35,32 @@ func (br *beat2Reader) Read() ([]interface{}, bool, error) { } var msgs []interface{} - bloomGenerator := &bloom.Generator{} - - bloomAdd := func(key []byte) { - key = bytes.TrimLeft(key, "\x00") - // exclude non-address key - if len(key) <= thor.AddressLength { - bloomGenerator.Add(key) + for _, block := range blocks { + msg, _, err := br.cache.GetOrAdd(block.Header().ID(), br.generateBeat2Message(block)) + if err != nil { + return nil, false, err } + msgs = append(msgs, msg) } + return msgs, len(blocks) > 0, nil +} + +func (br *beat2Reader) generateBeat2Message(block *chain.ExtendedBlock) func() (Beat2Message, error) { + return func() (Beat2Message, error) { + bloomGenerator := &bloom.Generator{} + + bloomAdd := func(key []byte) { + key = bytes.TrimLeft(key, "\x00") + // exclude non-address key + if len(key) <= thor.AddressLength { + bloomGenerator.Add(key) + } + } - for _, block := range blocks { header := block.Header() receipts, err := br.repo.GetBlockReceipts(header.ID()) if err != nil { - return nil, false, err + return Beat2Message{}, err } txs := block.Transactions() for i, receipt := range receipts { @@ -74,7 +87,7 @@ func (br *beat2Reader) Read() ([]interface{}, bool, error) { const bitsPerKey = 20 filter := bloomGenerator.Generate(bitsPerKey, bloom.K(bitsPerKey)) - msgs = append(msgs, &Beat2Message{ + beat2 := Beat2Message{ Number: header.Number(), ID: header.ID(), ParentID: header.ParentID(), @@ -84,7 +97,8 @@ func (br *beat2Reader) Read() ([]interface{}, bool, error) { Bloom: hexutil.Encode(filter.Bits), K: filter.K, Obsolete: block.Obsolete, - }) + } + + return beat2, nil } - return msgs, len(blocks) > 0, nil } diff --git a/api/subscriptions/beat2_reader_test.go b/api/subscriptions/beat2_reader_test.go index dbeecbd7f..4a292f485 100644 --- a/api/subscriptions/beat2_reader_test.go +++ b/api/subscriptions/beat2_reader_test.go @@ -19,13 +19,13 @@ func TestBeat2Reader_Read(t *testing.T) { newBlock := generatedBlocks[1] // Act - beatReader := newBeat2Reader(repo, genesisBlk.Header().ID()) + beatReader := newBeat2Reader(repo, genesisBlk.Header().ID(), newMessageCache[Beat2Message](10)) res, ok, err := beatReader.Read() // Assert assert.NoError(t, err) assert.True(t, ok) - if beatMsg, ok := res[0].(*Beat2Message); !ok { + if beatMsg, ok := res[0].(Beat2Message); !ok { t.Fatal("unexpected type") } else { assert.Equal(t, newBlock.Header().Number(), beatMsg.Number) @@ -33,6 +33,8 @@ func TestBeat2Reader_Read(t *testing.T) { assert.Equal(t, newBlock.Header().ParentID(), beatMsg.ParentID) assert.Equal(t, newBlock.Header().Timestamp(), beatMsg.Timestamp) assert.Equal(t, uint32(newBlock.Header().TxsFeatures()), beatMsg.TxsFeatures) + // GasLimit is not part of the deprecated BeatMessage + assert.Equal(t, newBlock.Header().GasLimit(), beatMsg.GasLimit) } } @@ -42,7 +44,7 @@ func TestBeat2Reader_Read_NoNewBlocksToRead(t *testing.T) { newBlock := generatedBlocks[1] // Act - beatReader := newBeat2Reader(repo, newBlock.Header().ID()) + beatReader := newBeat2Reader(repo, newBlock.Header().ID(), newMessageCache[Beat2Message](10)) res, ok, err := beatReader.Read() // Assert @@ -56,7 +58,7 @@ func TestBeat2Reader_Read_ErrorWhenReadingBlocks(t *testing.T) { repo, _, _ := initChain(t) // Act - beatReader := newBeat2Reader(repo, thor.MustParseBytes32("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")) + beatReader := newBeat2Reader(repo, thor.MustParseBytes32("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), newMessageCache[Beat2Message](10)) res, ok, err := beatReader.Read() // Assert diff --git a/api/subscriptions/beat_reader.go b/api/subscriptions/beat_reader.go index ed4e77147..c315f1dc2 100644 --- a/api/subscriptions/beat_reader.go +++ b/api/subscriptions/beat_reader.go @@ -17,12 +17,14 @@ import ( type beatReader struct { repo *chain.Repository blockReader chain.BlockReader + cache *messageCache[BeatMessage] } -func newBeatReader(repo *chain.Repository, position thor.Bytes32) *beatReader { +func newBeatReader(repo *chain.Repository, position thor.Bytes32, cache *messageCache[BeatMessage]) *beatReader { return &beatReader{ repo: repo, blockReader: repo.NewBlockReader(position), + cache: cache, } } @@ -33,40 +35,51 @@ func (br *beatReader) Read() ([]interface{}, bool, error) { } var msgs []interface{} for _, block := range blocks { + msg, _, err := br.cache.GetOrAdd(block.Header().ID(), br.generateBeatMessage(block)) + if err != nil { + return nil, false, err + } + msgs = append(msgs, msg) + } + return msgs, len(blocks) > 0, nil +} + +func (br *beatReader) generateBeatMessage(block *chain.ExtendedBlock) func() (BeatMessage, error) { + return func() (BeatMessage, error) { header := block.Header() receipts, err := br.repo.GetBlockReceipts(header.ID()) if err != nil { - return nil, false, err + return BeatMessage{}, err } txs := block.Transactions() - bloomContent := &bloomContent{} + content := &bloomContent{} for i, receipt := range receipts { - bloomContent.add(receipt.GasPayer.Bytes()) + content.add(receipt.GasPayer.Bytes()) for _, output := range receipt.Outputs { for _, event := range output.Events { - bloomContent.add(event.Address.Bytes()) + content.add(event.Address.Bytes()) for _, topic := range event.Topics { - bloomContent.add(topic.Bytes()) + content.add(topic.Bytes()) } } for _, transfer := range output.Transfers { - bloomContent.add(transfer.Sender.Bytes()) - bloomContent.add(transfer.Recipient.Bytes()) + content.add(transfer.Sender.Bytes()) + content.add(transfer.Recipient.Bytes()) } } origin, _ := txs[i].Origin() - bloomContent.add(origin.Bytes()) + content.add(origin.Bytes()) } signer, _ := header.Signer() - bloomContent.add(signer.Bytes()) - bloomContent.add(header.Beneficiary().Bytes()) + content.add(signer.Bytes()) + content.add(header.Beneficiary().Bytes()) - k := bloom.LegacyEstimateBloomK(bloomContent.len()) + k := bloom.LegacyEstimateBloomK(content.len()) bloom := bloom.NewLegacyBloom(k) - for _, item := range bloomContent.items { + for _, item := range content.items { bloom.Add(item) } - msgs = append(msgs, &BeatMessage{ + beat := BeatMessage{ Number: header.Number(), ID: header.ID(), ParentID: header.ParentID(), @@ -75,9 +88,10 @@ func (br *beatReader) Read() ([]interface{}, bool, error) { Bloom: hexutil.Encode(bloom.Bits[:]), K: uint32(k), Obsolete: block.Obsolete, - }) + } + + return beat, nil } - return msgs, len(blocks) > 0, nil } type bloomContent struct { diff --git a/api/subscriptions/beat_reader_test.go b/api/subscriptions/beat_reader_test.go index 6e6974af9..07d020a7b 100644 --- a/api/subscriptions/beat_reader_test.go +++ b/api/subscriptions/beat_reader_test.go @@ -19,13 +19,13 @@ func TestBeatReader_Read(t *testing.T) { newBlock := generatedBlocks[1] // Act - beatReader := newBeatReader(repo, genesisBlk.Header().ID()) + beatReader := newBeatReader(repo, genesisBlk.Header().ID(), newMessageCache[BeatMessage](10)) res, ok, err := beatReader.Read() // Assert assert.NoError(t, err) assert.True(t, ok) - if beatMsg, ok := res[0].(*BeatMessage); !ok { + if beatMsg, ok := res[0].(BeatMessage); !ok { t.Fatal("unexpected type") } else { assert.Equal(t, newBlock.Header().Number(), beatMsg.Number) @@ -42,7 +42,7 @@ func TestBeatReader_Read_NoNewBlocksToRead(t *testing.T) { newBlock := generatedBlocks[1] // Act - beatReader := newBeatReader(repo, newBlock.Header().ID()) + beatReader := newBeatReader(repo, newBlock.Header().ID(), newMessageCache[BeatMessage](10)) res, ok, err := beatReader.Read() // Assert @@ -56,7 +56,7 @@ func TestBeatReader_Read_ErrorWhenReadingBlocks(t *testing.T) { repo, _, _ := initChain(t) // Act - beatReader := newBeatReader(repo, thor.MustParseBytes32("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")) + beatReader := newBeatReader(repo, thor.MustParseBytes32("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), newMessageCache[BeatMessage](10)) res, ok, err := beatReader.Read() // Assert diff --git a/api/subscriptions/block_reader.go b/api/subscriptions/block_reader.go index 8c1d3ba50..817fb1c1a 100644 --- a/api/subscriptions/block_reader.go +++ b/api/subscriptions/block_reader.go @@ -11,13 +11,11 @@ import ( ) type blockReader struct { - repo *chain.Repository blockReader chain.BlockReader } func newBlockReader(repo *chain.Repository, position thor.Bytes32) *blockReader { return &blockReader{ - repo: repo, blockReader: repo.NewBlockReader(position), } } diff --git a/api/subscriptions/message_cache.go b/api/subscriptions/message_cache.go new file mode 100644 index 000000000..0ef85bf94 --- /dev/null +++ b/api/subscriptions/message_cache.go @@ -0,0 +1,65 @@ +// Copyright (c) 2024 The VeChainThor developers +// +// Distributed under the GNU Lesser General Public License v3.0 software license, see the accompanying +// file LICENSE or + +package subscriptions + +import ( + "fmt" + "sync" + + "github.com/hashicorp/golang-lru/simplelru" + "github.com/vechain/thor/v2/thor" +) + +// messageCache is a generic cache that stores messages of any type. +type messageCache[T any] struct { + cache *simplelru.LRU + mu sync.RWMutex +} + +// newMessageCache creates a new messageCache with the specified cache size. +func newMessageCache[T any](cacheSize uint32) *messageCache[T] { + if cacheSize > 1000 { + cacheSize = 1000 + } + if cacheSize == 0 { + cacheSize = 1 + } + cache, err := simplelru.NewLRU(int(cacheSize), nil) + if err != nil { + // lru.New only throws an error if the number is less than 1 + panic(fmt.Errorf("failed to create message cache: %v", err)) + } + return &messageCache[T]{ + cache: cache, + } +} + +// GetOrAdd returns the message of the block. If the message is not in the cache, +// it will generate the message and add it to the cache. The second return value +// indicates whether the message is newly generated. +func (mc *messageCache[T]) GetOrAdd(id thor.Bytes32, createMessage func() (T, error)) (T, bool, error) { + mc.mu.RLock() + msg, ok := mc.cache.Get(id) + mc.mu.RUnlock() + if ok { + return msg.(T), false, nil + } + + mc.mu.Lock() + defer mc.mu.Unlock() + msg, ok = mc.cache.Get(id) + if ok { + return msg.(T), false, nil + } + + newMsg, err := createMessage() + if err != nil { + var zero T + return zero, false, err + } + mc.cache.Add(id, newMsg) + return newMsg, true, nil +} diff --git a/api/subscriptions/message_cache_test.go b/api/subscriptions/message_cache_test.go new file mode 100644 index 000000000..6bf65c600 --- /dev/null +++ b/api/subscriptions/message_cache_test.go @@ -0,0 +1,61 @@ +// Copyright (c) 2024 The VeChainThor developers +// +// Distributed under the GNU Lesser General Public License v3.0 software license, see the accompanying +// file LICENSE or + +package subscriptions + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/vechain/thor/v2/block" +) + +type message struct { + id string +} + +func handler(blk *block.Block) func() (message, error) { + return func() (message, error) { + msg := message{ + id: blk.Header().ID().String(), + } + return msg, nil + } +} + +func TestMessageCache_GetOrAdd(t *testing.T) { + _, generatedBlocks, _ := initChain(t) + + blk0 := generatedBlocks[0] + blk1 := generatedBlocks[1] + + cache := newMessageCache[message](10) + + counter := atomic.Int32{} + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + start := time.Now().Add(20 * time.Millisecond) + go func() { + defer wg.Done() + time.Sleep(time.Until(start)) + _, added, err := cache.GetOrAdd(blk0.Header().ID(), handler(blk0)) + assert.NoError(t, err) + if added { + counter.Add(1) + } + }() + } + wg.Wait() + assert.Equal(t, counter.Load(), int32(1)) + + _, added, err := cache.GetOrAdd(blk1.Header().ID(), handler(blk1)) + assert.NoError(t, err) + assert.True(t, added) + assert.Equal(t, cache.cache.Len(), 2) +} diff --git a/api/subscriptions/subscriptions.go b/api/subscriptions/subscriptions.go index 73f895bf3..7582da5bb 100644 --- a/api/subscriptions/subscriptions.go +++ b/api/subscriptions/subscriptions.go @@ -31,6 +31,8 @@ type Subscriptions struct { pendingTx *pendingTx done chan struct{} wg sync.WaitGroup + beat2Cache *messageCache[Beat2Message] + beatCache *messageCache[BeatMessage] } type msgReader interface { @@ -67,8 +69,10 @@ func New(repo *chain.Repository, allowedOrigins []string, backtraceLimit uint32, return false }, }, - pendingTx: newPendingTx(txpool), - done: make(chan struct{}), + pendingTx: newPendingTx(txpool), + done: make(chan struct{}), + beat2Cache: newMessageCache[Beat2Message](backtraceLimit), + beatCache: newMessageCache[BeatMessage](backtraceLimit), } sub.wg.Add(1) @@ -158,7 +162,7 @@ func (s *Subscriptions) handleBeatReader(w http.ResponseWriter, req *http.Reques if err != nil { return nil, err } - return newBeatReader(s.repo, position), nil + return newBeatReader(s.repo, position, s.beatCache), nil } func (s *Subscriptions) handleBeat2Reader(w http.ResponseWriter, req *http.Request) (*beat2Reader, error) { @@ -166,7 +170,7 @@ func (s *Subscriptions) handleBeat2Reader(w http.ResponseWriter, req *http.Reque if err != nil { return nil, err } - return newBeat2Reader(s.repo, position), nil + return newBeat2Reader(s.repo, position, s.beat2Cache), nil } func (s *Subscriptions) handleSubject(w http.ResponseWriter, req *http.Request) error {