From 09c085358f48268a9f7f89a4decea4a701b0cf0b Mon Sep 17 00:00:00 2001 From: "Masih H. Derkani" Date: Tue, 28 Jan 2025 10:19:17 +0000 Subject: [PATCH] Pool decode buffers when encoding messages using compression (#854) Pool the decode buffers and set strict max cap allocation for zstd decompressor. See Benchamark results below. Before: ``` BenchmarkCborEncoding BenchmarkCborEncoding-12 467352 2376 ns/op 14680 B/op 10 allocs/op BenchmarkCborDecoding BenchmarkCborDecoding-12 104410 11347 ns/op 14944 B/op 27 allocs/op BenchmarkZstdEncoding BenchmarkZstdEncoding-12 286735 3897 ns/op 46748 B/op 12 allocs/op BenchmarkZstdDecoding BenchmarkZstdDecoding-12 110794 10783 ns/op 28512 B/op 28 allocs/op ``` After: ``` BenchmarkCborEncoding BenchmarkCborEncoding-12 436754 2383 ns/op 14680 B/op 10 allocs/op BenchmarkCborDecoding BenchmarkCborDecoding-12 106809 11280 ns/op 14944 B/op 27 allocs/op BenchmarkZstdEncoding BenchmarkZstdEncoding-12 294043 3918 ns/op 46746 B/op 12 allocs/op BenchmarkZstdDecoding BenchmarkZstdDecoding-12 114854 10747 ns/op 18314 B/op 27 allocs/op ``` Fixes: #849 --- encoding_bench_test.go | 22 ++++++++++++++++++++-- internal/encoding/encoding.go | 23 ++++++++++++++++++----- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/encoding_bench_test.go b/encoding_bench_test.go index 4ec1f8a0..a63be17b 100644 --- a/encoding_bench_test.go +++ b/encoding_bench_test.go @@ -43,7 +43,7 @@ func BenchmarkCborDecoding(b *testing.B) { for pb.Next() { var got PartialGMessage require.NoError(b, encoder.Decode(data, &got)) - require.Equal(b, msg, &got) + requireEqualPartialMessages(b, msg, &got) } }) } @@ -79,11 +79,29 @@ func BenchmarkZstdDecoding(b *testing.B) { for pb.Next() { var got PartialGMessage require.NoError(b, encoder.Decode(data, &got)) - require.Equal(b, msg, &got) + requireEqualPartialMessages(b, msg, &got) } }) } +func requireEqualPartialMessages(b *testing.B, expected, actual *PartialGMessage) { + // Because empty ECChain gets marshaled as null, we need to use ECChain.Eq for + // checking equality. Hence, the custom equality check. + require.Equal(b, expected.Sender, actual.Sender) + require.Equal(b, expected.Signature, actual.Signature) + require.Equal(b, expected.VoteValueKey, actual.VoteValueKey) + require.Equal(b, expected.Ticket, actual.Ticket) + require.True(b, expected.Vote.Eq(&actual.Vote)) + if expected.Justification == nil { + require.Nil(b, actual.Justification) + } else { + require.NotNil(b, actual.Justification) + require.Equal(b, expected.Justification.Signature, actual.Justification.Signature) + require.Equal(b, expected.Justification.Signers, actual.Justification.Signers) + require.True(b, expected.Justification.Vote.Eq(&actual.Justification.Vote)) + } +} + func generateRandomPartialGMessage(b *testing.B, rng *rand.Rand) *PartialGMessage { var pgmsg PartialGMessage pgmsg.GMessage = generateRandomGMessage(b, rng) diff --git a/internal/encoding/encoding.go b/internal/encoding/encoding.go index 1be3cf08..65e03bb2 100644 --- a/internal/encoding/encoding.go +++ b/internal/encoding/encoding.go @@ -3,6 +3,7 @@ package encoding import ( "bytes" "fmt" + "sync" "github.com/klauspost/compress/zstd" cbg "github.com/whyrusleeping/cbor-gen" @@ -13,6 +14,13 @@ import ( // size in GossipSub. const maxDecompressedSize = 1 << 20 +var bufferPool = sync.Pool{ + New: func() any { + buf := make([]byte, maxDecompressedSize) + return &buf + }, +} + type CBORMarshalUnmarshaler interface { cbg.CBORMarshaler cbg.CBORUnmarshaler @@ -30,11 +38,11 @@ func NewCBOR[T CBORMarshalUnmarshaler]() *CBOR[T] { } func (c *CBOR[T]) Encode(m T) ([]byte, error) { - var buf bytes.Buffer - if err := m.MarshalCBOR(&buf); err != nil { + var out bytes.Buffer + if err := m.MarshalCBOR(&out); err != nil { return nil, err } - return buf.Bytes(), nil + return out.Bytes(), nil } func (c *CBOR[T]) Decode(v []byte, t T) error { @@ -53,7 +61,9 @@ func NewZSTD[T CBORMarshalUnmarshaler]() (*ZSTD[T], error) { if err != nil { return nil, err } - reader, err := zstd.NewReader(nil, zstd.WithDecoderMaxMemory(maxDecompressedSize)) + reader, err := zstd.NewReader(nil, + zstd.WithDecoderMaxMemory(maxDecompressedSize), + zstd.WithDecodeAllCapLimit(true)) if err != nil { return nil, err } @@ -78,7 +88,10 @@ func (c *ZSTD[T]) Encode(m T) ([]byte, error) { } func (c *ZSTD[T]) Decode(v []byte, t T) error { - cborEncoded, err := c.decompressor.DecodeAll(v, make([]byte, 0, len(v))) + buf := bufferPool.Get().(*[]byte) + defer bufferPool.Put(buf) + + cborEncoded, err := c.decompressor.DecodeAll(v, (*buf)[:0]) if err != nil { return err }