diff --git a/pkg/kgo/compression.go b/pkg/kgo/compression.go index c5f27234..814dc808 100644 --- a/pkg/kgo/compression.go +++ b/pkg/kgo/compression.go @@ -14,15 +14,7 @@ import ( "github.com/pierrec/lz4/v4" ) -// sliceWriter a reusable slice as an io.Writer -type sliceWriter struct{ inner []byte } - -func (s *sliceWriter) Write(p []byte) (int, error) { - s.inner = append(s.inner, p...) - return len(p), nil -} - -var sliceWriters = sync.Pool{New: func() any { r := make([]byte, 8<<10); return &sliceWriter{inner: r} }} +var byteBuffers = sync.Pool{New: func() any { return bytes.NewBuffer(make([]byte, 8<<10)) }} type codecType int8 @@ -175,9 +167,7 @@ type zstdEncoder struct { // // The writer should be put back to its pool after the returned slice is done // being used. -func (c *compressor) compress(dst *sliceWriter, src []byte, produceRequestVersion int16) ([]byte, codecType) { - dst.inner = dst.inner[:0] - +func (c *compressor) compress(dst *bytes.Buffer, src []byte, produceRequestVersion int16) ([]byte, codecType) { var use codecType for _, option := range c.options { if option == codecZstd && produceRequestVersion < 7 { @@ -187,6 +177,7 @@ func (c *compressor) compress(dst *sliceWriter, src []byte, produceRequestVersio break } + var out []byte switch use { case codecNone: return src, 0 @@ -200,10 +191,7 @@ func (c *compressor) compress(dst *sliceWriter, src []byte, produceRequestVersio if err := gz.Close(); err != nil { return nil, -1 } - - case codecSnappy: - dst.inner = s2.EncodeSnappy(dst.inner[:cap(dst.inner)], src) - + out = dst.Bytes() case codecLZ4: lz := c.lz4Pool.Get().(*lz4.Writer) defer c.lz4Pool.Put(lz) @@ -214,13 +202,34 @@ func (c *compressor) compress(dst *sliceWriter, src []byte, produceRequestVersio if err := lz.Close(); err != nil { return nil, -1 } + out = dst.Bytes() + case codecSnappy: + // Because the Snappy and Zstd codecs do not accept an io.Writer interface + // and directly take a []byte slice, here, the underlying []byte slice (`dst`) + // obtained from the bytes.Buffer{} from the pool is passed. + // As the `Write()` method on the buffer isn't used, its internal + // book-keeping goes out of sync, making the buffer unusable for further + // reading and writing via it's (eg: accessing via `Byte()`). For subsequent + // reads, the underlying slice has to be used directly. + // + // In this particular context, it is acceptable as there there are no subsequent + // operations performed on the buffer and it is immediately returned to the + // pool and `Reset()` the next time it is obtained and used where `compress()` + // is called. + if l := s2.MaxEncodedLen(len(src)); l > dst.Cap() { + dst.Grow(l) + } + out = s2.EncodeSnappy(dst.Bytes(), src) case codecZstd: zstdEnc := c.zstdPool.Get().(*zstdEncoder) defer c.zstdPool.Put(zstdEnc) - dst.inner = zstdEnc.inner.EncodeAll(src, dst.inner) + if l := zstdEnc.inner.MaxEncodedSize(len(src)); l > dst.Cap() { + dst.Grow(l) + } + out = zstdEnc.inner.EncodeAll(src, dst.Bytes()) } - return dst.inner, use + return out, use } type decompressor struct { @@ -259,38 +268,51 @@ type zstdDecoder struct { } func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) { - switch codecType(codec) { - case codecNone: + // Early return in case there is no compression + compCodec := codecType(codec) + if compCodec == codecNone { return src, nil + } + out := byteBuffers.Get().(*bytes.Buffer) + out.Reset() + defer byteBuffers.Put(out) + + switch compCodec { case codecGzip: ungz := d.ungzPool.Get().(*gzip.Reader) defer d.ungzPool.Put(ungz) if err := ungz.Reset(bytes.NewReader(src)); err != nil { return nil, err } - out := new(bytes.Buffer) if _, err := io.Copy(out, ungz); err != nil { return nil, err } - return out.Bytes(), nil + return append([]byte(nil), out.Bytes()...), nil case codecSnappy: if len(src) > 16 && bytes.HasPrefix(src, xerialPfx) { return xerialDecode(src) } - return s2.Decode(nil, src) + decoded, err := s2.Decode(out.Bytes(), src) + if err != nil { + return nil, err + } + return append([]byte(nil), decoded...), nil case codecLZ4: unlz4 := d.unlz4Pool.Get().(*lz4.Reader) defer d.unlz4Pool.Put(unlz4) unlz4.Reset(bytes.NewReader(src)) - out := new(bytes.Buffer) if _, err := io.Copy(out, unlz4); err != nil { return nil, err } - return out.Bytes(), nil + return append([]byte(nil), out.Bytes()...), nil case codecZstd: unzstd := d.unzstdPool.Get().(*zstdDecoder) defer d.unzstdPool.Put(unzstd) - return unzstd.inner.DecodeAll(src, nil) + decoded, err := unzstd.inner.DecodeAll(src, out.Bytes()) + if err != nil { + return nil, err + } + return append([]byte(nil), decoded...), nil default: return nil, errors.New("unknown compression codec") } diff --git a/pkg/kgo/compression_test.go b/pkg/kgo/compression_test.go index 29b247e4..8ac4710f 100644 --- a/pkg/kgo/compression_test.go +++ b/pkg/kgo/compression_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/base64" "fmt" + "math/rand" "reflect" "sync" "testing" @@ -46,9 +47,23 @@ func TestNewCompressor(t *testing.T) { } func TestCompressDecompress(t *testing.T) { + randStr := func(length int) []byte { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + b := make([]byte, length) + for i := range b { + b[i] = charset[rand.Intn(len(charset))] + } + return b + } + t.Parallel() d := newDecompressor() - in := []byte("foo") + inputs := [][]byte{ + randStr(1 << 2), + randStr(1 << 5), + randStr(1 << 8), + } + var wg sync.WaitGroup for _, produceVersion := range []int16{ 0, 7, @@ -74,18 +89,21 @@ func TestCompressDecompress(t *testing.T) { for i := 0; i < 3; i++ { wg.Add(1) go func() { + w := byteBuffers.Get().(*bytes.Buffer) defer wg.Done() - w := sliceWriters.Get().(*sliceWriter) - defer sliceWriters.Put(w) - got, used := c.compress(w, in, produceVersion) + defer byteBuffers.Put(w) + for _, in := range inputs { + w.Reset() - got, err := d.decompress(got, byte(used)) - if err != nil { - t.Errorf("unexpected decompress err: %v", err) - return - } - if !bytes.Equal(got, in) { - t.Errorf("got decompress %s != exp compress in %s", got, in) + got, used := c.compress(w, in, produceVersion) + got, err := d.decompress(got, byte(used)) + if err != nil { + t.Errorf("unexpected decompress err: %v", err) + return + } + if !bytes.Equal(got, in) { + t.Errorf("got decompress %s != exp compress in %s", got, in) + } } }() } @@ -102,16 +120,35 @@ func BenchmarkCompress(b *testing.B) { b.Run(fmt.Sprint(codec), func(b *testing.B) { var afterSize int for i := 0; i < b.N; i++ { - w := sliceWriters.Get().(*sliceWriter) + w := byteBuffers.Get().(*bytes.Buffer) + w.Reset() after, _ := c.compress(w, in, 99) afterSize = len(after) - sliceWriters.Put(w) + byteBuffers.Put(w) } b.Logf("%d => %d", len(in), afterSize) }) } } +func BenchmarkDecompress(b *testing.B) { + in := bytes.Repeat([]byte("abcdefghijklmno pqrs tuvwxy z"), 100) + for _, codec := range []codecType{codecGzip, codecSnappy, codecLZ4, codecZstd} { + c, _ := newCompressor(CompressionCodec{codec: codec}) + w := byteBuffers.Get().(*bytes.Buffer) + w.Reset() + c.compress(w, in, 99) + + b.Run(fmt.Sprint(codec), func(b *testing.B) { + for i := 0; i < b.N; i++ { + d := newDecompressor() + d.decompress(w.Bytes(), byte(codec)) + } + }) + byteBuffers.Put(w) + } +} + func Test_xerialDecode(t *testing.T) { tests := []struct { name string diff --git a/pkg/kgo/logger.go b/pkg/kgo/logger.go index 997e23b5..bfc5dc0d 100644 --- a/pkg/kgo/logger.go +++ b/pkg/kgo/logger.go @@ -1,6 +1,7 @@ package kgo import ( + "bytes" "fmt" "io" "strings" @@ -73,28 +74,27 @@ type basicLogger struct { func (b *basicLogger) Level() LogLevel { return b.level } func (b *basicLogger) Log(level LogLevel, msg string, keyvals ...any) { - buf := sliceWriters.Get().(*sliceWriter) - defer sliceWriters.Put(buf) + buf := byteBuffers.Get().(*bytes.Buffer) + defer byteBuffers.Put(buf) - buf.inner = buf.inner[:0] + buf.Reset() if b.pfxFn != nil { - buf.inner = append(buf.inner, b.pfxFn()...) + buf.WriteString(b.pfxFn()) } - buf.inner = append(buf.inner, '[') - buf.inner = append(buf.inner, level.String()...) - buf.inner = append(buf.inner, "] "...) - buf.inner = append(buf.inner, msg...) + buf.WriteByte('[') + buf.WriteString(level.String()) + buf.WriteString("] ") + buf.WriteString(msg) if len(keyvals) > 0 { - buf.inner = append(buf.inner, "; "...) + buf.WriteString("; ") format := strings.Repeat("%v: %v, ", len(keyvals)/2) format = format[:len(format)-2] // trim trailing comma and space fmt.Fprintf(buf, format, keyvals...) } - buf.inner = append(buf.inner, '\n') - - b.dst.Write(buf.inner) + buf.WriteByte('\n') + b.dst.Write(buf.Bytes()) } // nopLogger, the default logger, drops everything. diff --git a/pkg/kgo/produce_request_test.go b/pkg/kgo/produce_request_test.go index 8a652a2f..d6ddb08a 100644 --- a/pkg/kgo/produce_request_test.go +++ b/pkg/kgo/produce_request_test.go @@ -163,7 +163,9 @@ func TestRecBatchAppendTo(t *testing.T) { compressor, _ = newCompressor(CompressionCodec{codec: 2}) // snappy { kbatch.Attributes |= 0x0002 // snappy - kbatch.Records, _ = compressor.compress(sliceWriters.Get().(*sliceWriter), kbatch.Records, version) + w := byteBuffers.Get().(*bytes.Buffer) + w.Reset() + kbatch.Records, _ = compressor.compress(w, kbatch.Records, version) } fixFields() @@ -254,7 +256,9 @@ func TestMessageSetAppendTo(t *testing.T) { Offset: 1, Attributes: 0x02, } - kset0c.Value, _ = compressor.compress(sliceWriters.Get().(*sliceWriter), kset0raw, 1) // version 0, 1 use message set 0 + w := byteBuffers.Get().(*bytes.Buffer) + w.Reset() + kset0c.Value, _ = compressor.compress(w, kset0raw, 1) // version 0, 1 use message set 0 kset0c.CRC = int32(crc32.ChecksumIEEE(kset0c.AppendTo(nil)[16:])) kset0c.MessageSize = int32(len(kset0c.AppendTo(nil)[12:])) @@ -265,7 +269,9 @@ func TestMessageSetAppendTo(t *testing.T) { Attributes: 0x02, Timestamp: kset11.Timestamp, } - kset1c.Value, _ = compressor.compress(sliceWriters.Get().(*sliceWriter), kset1raw, 2) // version 2 use message set 1 + wbuf := byteBuffers.Get().(*bytes.Buffer) + wbuf.Reset() + kset1c.Value, _ = compressor.compress(wbuf, kset1raw, 2) // version 2 use message set 1 kset1c.CRC = int32(crc32.ChecksumIEEE(kset1c.AppendTo(nil)[16:])) kset1c.MessageSize = int32(len(kset1c.AppendTo(nil)[12:])) diff --git a/pkg/kgo/sink.go b/pkg/kgo/sink.go index d7db8cff..8e33e56c 100644 --- a/pkg/kgo/sink.go +++ b/pkg/kgo/sink.go @@ -2114,8 +2114,9 @@ func (b seqRecBatch) appendTo( m.CompressedBytes = m.UncompressedBytes if compressor != nil { - w := sliceWriters.Get().(*sliceWriter) - defer sliceWriters.Put(w) + w := byteBuffers.Get().(*bytes.Buffer) + defer byteBuffers.Put(w) + w.Reset() compressed, codec := compressor.compress(w, toCompress, version) if compressed != nil && // nil would be from an error @@ -2195,8 +2196,9 @@ func (b seqRecBatch) appendToAsMessageSet(dst []byte, version uint8, compressor m.CompressedBytes = m.UncompressedBytes if compressor != nil { - w := sliceWriters.Get().(*sliceWriter) - defer sliceWriters.Put(w) + w := byteBuffers.Get().(*bytes.Buffer) + defer byteBuffers.Put(w) + w.Reset() compressed, codec := compressor.compress(w, toCompress, int16(version)) inner := &Record{Value: compressed}