Skip to content

Commit

Permalink
Pool decode buffers when encoding messages using compression (#854)
Browse files Browse the repository at this point in the history
Pool the decode buffers and set strict max cap allocation for zstd
decompressor.

See Benchamark results below.

Before:

```
BenchmarkCborEncoding
BenchmarkCborEncoding-12    	  467352	      2376 ns/op	   14680 B/op	      10 allocs/op
BenchmarkCborDecoding
BenchmarkCborDecoding-12    	  104410	     11347 ns/op	   14944 B/op	      27 allocs/op
BenchmarkZstdEncoding
BenchmarkZstdEncoding-12    	  286735	      3897 ns/op	   46748 B/op	      12 allocs/op
BenchmarkZstdDecoding
BenchmarkZstdDecoding-12    	  110794	     10783 ns/op	   28512 B/op	      28 allocs/op
```

After:

```
BenchmarkCborEncoding
BenchmarkCborEncoding-12    	  436754	      2383 ns/op	   14680 B/op	      10 allocs/op
BenchmarkCborDecoding
BenchmarkCborDecoding-12    	  106809	     11280 ns/op	   14944 B/op	      27 allocs/op
BenchmarkZstdEncoding
BenchmarkZstdEncoding-12    	  294043	      3918 ns/op	   46746 B/op	      12 allocs/op
BenchmarkZstdDecoding
BenchmarkZstdDecoding-12    	  114854	     10747 ns/op	   18314 B/op	      27 allocs/op
```

Fixes: #849
  • Loading branch information
masih authored Jan 28, 2025
1 parent fe47737 commit 09c0853
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
22 changes: 20 additions & 2 deletions encoding_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func BenchmarkCborDecoding(b *testing.B) {
for pb.Next() {
var got PartialGMessage
require.NoError(b, encoder.Decode(data, &got))
require.Equal(b, msg, &got)
requireEqualPartialMessages(b, msg, &got)
}
})
}
Expand Down Expand Up @@ -79,11 +79,29 @@ func BenchmarkZstdDecoding(b *testing.B) {
for pb.Next() {
var got PartialGMessage
require.NoError(b, encoder.Decode(data, &got))
require.Equal(b, msg, &got)
requireEqualPartialMessages(b, msg, &got)
}
})
}

func requireEqualPartialMessages(b *testing.B, expected, actual *PartialGMessage) {
// Because empty ECChain gets marshaled as null, we need to use ECChain.Eq for
// checking equality. Hence, the custom equality check.
require.Equal(b, expected.Sender, actual.Sender)
require.Equal(b, expected.Signature, actual.Signature)
require.Equal(b, expected.VoteValueKey, actual.VoteValueKey)
require.Equal(b, expected.Ticket, actual.Ticket)
require.True(b, expected.Vote.Eq(&actual.Vote))
if expected.Justification == nil {
require.Nil(b, actual.Justification)
} else {
require.NotNil(b, actual.Justification)
require.Equal(b, expected.Justification.Signature, actual.Justification.Signature)
require.Equal(b, expected.Justification.Signers, actual.Justification.Signers)
require.True(b, expected.Justification.Vote.Eq(&actual.Justification.Vote))
}
}

func generateRandomPartialGMessage(b *testing.B, rng *rand.Rand) *PartialGMessage {
var pgmsg PartialGMessage
pgmsg.GMessage = generateRandomGMessage(b, rng)
Expand Down
23 changes: 18 additions & 5 deletions internal/encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package encoding
import (
"bytes"
"fmt"
"sync"

"github.com/klauspost/compress/zstd"
cbg "github.com/whyrusleeping/cbor-gen"
Expand All @@ -13,6 +14,13 @@ import (
// size in GossipSub.
const maxDecompressedSize = 1 << 20

var bufferPool = sync.Pool{
New: func() any {
buf := make([]byte, maxDecompressedSize)
return &buf
},
}

type CBORMarshalUnmarshaler interface {
cbg.CBORMarshaler
cbg.CBORUnmarshaler
Expand All @@ -30,11 +38,11 @@ func NewCBOR[T CBORMarshalUnmarshaler]() *CBOR[T] {
}

func (c *CBOR[T]) Encode(m T) ([]byte, error) {
var buf bytes.Buffer
if err := m.MarshalCBOR(&buf); err != nil {
var out bytes.Buffer
if err := m.MarshalCBOR(&out); err != nil {
return nil, err
}
return buf.Bytes(), nil
return out.Bytes(), nil
}

func (c *CBOR[T]) Decode(v []byte, t T) error {
Expand All @@ -53,7 +61,9 @@ func NewZSTD[T CBORMarshalUnmarshaler]() (*ZSTD[T], error) {
if err != nil {
return nil, err
}
reader, err := zstd.NewReader(nil, zstd.WithDecoderMaxMemory(maxDecompressedSize))
reader, err := zstd.NewReader(nil,
zstd.WithDecoderMaxMemory(maxDecompressedSize),
zstd.WithDecodeAllCapLimit(true))
if err != nil {
return nil, err
}
Expand All @@ -78,7 +88,10 @@ func (c *ZSTD[T]) Encode(m T) ([]byte, error) {
}

func (c *ZSTD[T]) Decode(v []byte, t T) error {
cborEncoded, err := c.decompressor.DecodeAll(v, make([]byte, 0, len(v)))
buf := bufferPool.Get().(*[]byte)
defer bufferPool.Put(buf)

cborEncoded, err := c.decompressor.DecodeAll(v, (*buf)[:0])
if err != nil {
return err
}
Expand Down

0 comments on commit 09c0853

Please sign in to comment.