Skip to content

Commit

Permalink
optimise: use byteBuffer pool in decompression
Browse files Browse the repository at this point in the history
  • Loading branch information
kalbhor committed Jul 15, 2024
1 parent 23ddc6b commit b4830b9
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 58 deletions.
74 changes: 48 additions & 26 deletions pkg/kgo/compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,7 @@ import (
"github.com/pierrec/lz4/v4"
)

// sliceWriter a reusable slice as an io.Writer
type sliceWriter struct{ inner []byte }

func (s *sliceWriter) Write(p []byte) (int, error) {
s.inner = append(s.inner, p...)
return len(p), nil
}

var sliceWriters = sync.Pool{New: func() any { r := make([]byte, 8<<10); return &sliceWriter{inner: r} }}
var byteBuffers = sync.Pool{New: func() any { return bytes.NewBuffer(make([]byte, 8<<10)) }}

type codecType int8

Expand Down Expand Up @@ -175,9 +167,7 @@ type zstdEncoder struct {
//
// The writer should be put back to its pool after the returned slice is done
// being used.
func (c *compressor) compress(dst *sliceWriter, src []byte, produceRequestVersion int16) ([]byte, codecType) {
dst.inner = dst.inner[:0]

func (c *compressor) compress(dst *bytes.Buffer, src []byte, produceRequestVersion int16) ([]byte, codecType) {
var use codecType
for _, option := range c.options {
if option == codecZstd && produceRequestVersion < 7 {
Expand All @@ -187,6 +177,7 @@ func (c *compressor) compress(dst *sliceWriter, src []byte, produceRequestVersio
break
}

var out []byte
switch use {
case codecNone:
return src, 0
Expand All @@ -200,10 +191,7 @@ func (c *compressor) compress(dst *sliceWriter, src []byte, produceRequestVersio
if err := gz.Close(); err != nil {
return nil, -1
}

case codecSnappy:
dst.inner = s2.EncodeSnappy(dst.inner[:cap(dst.inner)], src)

out = dst.Bytes()
case codecLZ4:
lz := c.lz4Pool.Get().(*lz4.Writer)
defer c.lz4Pool.Put(lz)
Expand All @@ -214,13 +202,34 @@ func (c *compressor) compress(dst *sliceWriter, src []byte, produceRequestVersio
if err := lz.Close(); err != nil {
return nil, -1
}
out = dst.Bytes()
case codecSnappy:
// Because the Snappy and Zstd codecs do not accept an io.Writer interface
// and directly take a []byte slice, here, the underlying []byte slice (`dst`)
// obtained from the bytes.Buffer{} from the pool is passed.
// As the `Write()` method on the buffer isn't used, its internal
// book-keeping goes out of sync, making the buffer unusable for further
// reading and writing via it's (eg: accessing via `Byte()`). For subsequent
// reads, the underlying slice has to be used directly.
//
// In this particular context, it is acceptable as there there are no subsequent
// operations performed on the buffer and it is immediately returned to the
// pool and `Reset()` the next time it is obtained and used where `compress()`
// is called.
if l := s2.MaxEncodedLen(len(src)); l > dst.Cap() {
dst.Grow(l)
}
out = s2.EncodeSnappy(dst.Bytes(), src)
case codecZstd:
zstdEnc := c.zstdPool.Get().(*zstdEncoder)
defer c.zstdPool.Put(zstdEnc)
dst.inner = zstdEnc.inner.EncodeAll(src, dst.inner)
if l := zstdEnc.inner.MaxEncodedSize(len(src)); l > dst.Cap() {
dst.Grow(l)
}
out = zstdEnc.inner.EncodeAll(src, dst.Bytes())
}

return dst.inner, use
return out, use
}

type decompressor struct {
Expand Down Expand Up @@ -259,38 +268,51 @@ type zstdDecoder struct {
}

func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) {
switch codecType(codec) {
case codecNone:
// Early return in case there is no compression
compCodec := codecType(codec)
if compCodec == codecNone {
return src, nil
}
out := byteBuffers.Get().(*bytes.Buffer)
out.Reset()
defer byteBuffers.Put(out)

switch compCodec {
case codecGzip:
ungz := d.ungzPool.Get().(*gzip.Reader)
defer d.ungzPool.Put(ungz)
if err := ungz.Reset(bytes.NewReader(src)); err != nil {
return nil, err
}
out := new(bytes.Buffer)
if _, err := io.Copy(out, ungz); err != nil {
return nil, err
}
return out.Bytes(), nil
return append([]byte(nil), out.Bytes()...), nil
case codecSnappy:
if len(src) > 16 && bytes.HasPrefix(src, xerialPfx) {
return xerialDecode(src)
}
return s2.Decode(nil, src)
decoded, err := s2.Decode(out.Bytes(), src)
if err != nil {
return nil, err
}
return append([]byte(nil), decoded...), nil
case codecLZ4:
unlz4 := d.unlz4Pool.Get().(*lz4.Reader)
defer d.unlz4Pool.Put(unlz4)
unlz4.Reset(bytes.NewReader(src))
out := new(bytes.Buffer)
if _, err := io.Copy(out, unlz4); err != nil {
return nil, err
}
return out.Bytes(), nil
return append([]byte(nil), out.Bytes()...), nil
case codecZstd:
unzstd := d.unzstdPool.Get().(*zstdDecoder)
defer d.unzstdPool.Put(unzstd)
return unzstd.inner.DecodeAll(src, nil)
decoded, err := unzstd.inner.DecodeAll(src, out.Bytes())
if err != nil {
return nil, err
}
return append([]byte(nil), decoded...), nil
default:
return nil, errors.New("unknown compression codec")
}
Expand Down
63 changes: 50 additions & 13 deletions pkg/kgo/compression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/base64"
"fmt"
"math/rand"
"reflect"
"sync"
"testing"
Expand Down Expand Up @@ -46,9 +47,23 @@ func TestNewCompressor(t *testing.T) {
}

func TestCompressDecompress(t *testing.T) {
randStr := func(length int) []byte {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[rand.Intn(len(charset))]
}
return b
}

t.Parallel()
d := newDecompressor()
in := []byte("foo")
inputs := [][]byte{
randStr(1 << 2),
randStr(1 << 5),
randStr(1 << 8),
}

var wg sync.WaitGroup
for _, produceVersion := range []int16{
0, 7,
Expand All @@ -74,18 +89,21 @@ func TestCompressDecompress(t *testing.T) {
for i := 0; i < 3; i++ {
wg.Add(1)
go func() {
w := byteBuffers.Get().(*bytes.Buffer)
defer wg.Done()
w := sliceWriters.Get().(*sliceWriter)
defer sliceWriters.Put(w)
got, used := c.compress(w, in, produceVersion)
defer byteBuffers.Put(w)
for _, in := range inputs {
w.Reset()

got, err := d.decompress(got, byte(used))
if err != nil {
t.Errorf("unexpected decompress err: %v", err)
return
}
if !bytes.Equal(got, in) {
t.Errorf("got decompress %s != exp compress in %s", got, in)
got, used := c.compress(w, in, produceVersion)
got, err := d.decompress(got, byte(used))
if err != nil {
t.Errorf("unexpected decompress err: %v", err)
return
}
if !bytes.Equal(got, in) {
t.Errorf("got decompress %s != exp compress in %s", got, in)
}
}
}()
}
Expand All @@ -102,16 +120,35 @@ func BenchmarkCompress(b *testing.B) {
b.Run(fmt.Sprint(codec), func(b *testing.B) {
var afterSize int
for i := 0; i < b.N; i++ {
w := sliceWriters.Get().(*sliceWriter)
w := byteBuffers.Get().(*bytes.Buffer)
w.Reset()
after, _ := c.compress(w, in, 99)
afterSize = len(after)
sliceWriters.Put(w)
byteBuffers.Put(w)
}
b.Logf("%d => %d", len(in), afterSize)
})
}
}

func BenchmarkDecompress(b *testing.B) {
in := bytes.Repeat([]byte("abcdefghijklmno pqrs tuvwxy z"), 100)
for _, codec := range []codecType{codecGzip, codecSnappy, codecLZ4, codecZstd} {
c, _ := newCompressor(CompressionCodec{codec: codec})
w := byteBuffers.Get().(*bytes.Buffer)
w.Reset()
c.compress(w, in, 99)

b.Run(fmt.Sprint(codec), func(b *testing.B) {
for i := 0; i < b.N; i++ {
d := newDecompressor()
d.decompress(w.Bytes(), byte(codec))
}
})
byteBuffers.Put(w)
}
}

func Test_xerialDecode(t *testing.T) {
tests := []struct {
name string
Expand Down
24 changes: 12 additions & 12 deletions pkg/kgo/logger.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package kgo

import (
"bytes"
"fmt"
"io"
"strings"
Expand Down Expand Up @@ -73,28 +74,27 @@ type basicLogger struct {

func (b *basicLogger) Level() LogLevel { return b.level }
func (b *basicLogger) Log(level LogLevel, msg string, keyvals ...any) {
buf := sliceWriters.Get().(*sliceWriter)
defer sliceWriters.Put(buf)
buf := byteBuffers.Get().(*bytes.Buffer)
defer byteBuffers.Put(buf)

buf.inner = buf.inner[:0]
buf.Reset()
if b.pfxFn != nil {
buf.inner = append(buf.inner, b.pfxFn()...)
buf.WriteString(b.pfxFn())
}
buf.inner = append(buf.inner, '[')
buf.inner = append(buf.inner, level.String()...)
buf.inner = append(buf.inner, "] "...)
buf.inner = append(buf.inner, msg...)
buf.WriteByte('[')
buf.WriteString(level.String())
buf.WriteString("] ")
buf.WriteString(msg)

if len(keyvals) > 0 {
buf.inner = append(buf.inner, "; "...)
buf.WriteString("; ")
format := strings.Repeat("%v: %v, ", len(keyvals)/2)
format = format[:len(format)-2] // trim trailing comma and space
fmt.Fprintf(buf, format, keyvals...)
}

buf.inner = append(buf.inner, '\n')

b.dst.Write(buf.inner)
buf.WriteByte('\n')
b.dst.Write(buf.Bytes())
}

// nopLogger, the default logger, drops everything.
Expand Down
12 changes: 9 additions & 3 deletions pkg/kgo/produce_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ func TestRecBatchAppendTo(t *testing.T) {
compressor, _ = newCompressor(CompressionCodec{codec: 2}) // snappy
{
kbatch.Attributes |= 0x0002 // snappy
kbatch.Records, _ = compressor.compress(sliceWriters.Get().(*sliceWriter), kbatch.Records, version)
w := byteBuffers.Get().(*bytes.Buffer)
w.Reset()
kbatch.Records, _ = compressor.compress(w, kbatch.Records, version)
}

fixFields()
Expand Down Expand Up @@ -254,7 +256,9 @@ func TestMessageSetAppendTo(t *testing.T) {
Offset: 1,
Attributes: 0x02,
}
kset0c.Value, _ = compressor.compress(sliceWriters.Get().(*sliceWriter), kset0raw, 1) // version 0, 1 use message set 0
w := byteBuffers.Get().(*bytes.Buffer)
w.Reset()
kset0c.Value, _ = compressor.compress(w, kset0raw, 1) // version 0, 1 use message set 0
kset0c.CRC = int32(crc32.ChecksumIEEE(kset0c.AppendTo(nil)[16:]))
kset0c.MessageSize = int32(len(kset0c.AppendTo(nil)[12:]))

Expand All @@ -265,7 +269,9 @@ func TestMessageSetAppendTo(t *testing.T) {
Attributes: 0x02,
Timestamp: kset11.Timestamp,
}
kset1c.Value, _ = compressor.compress(sliceWriters.Get().(*sliceWriter), kset1raw, 2) // version 2 use message set 1
wbuf := byteBuffers.Get().(*bytes.Buffer)
wbuf.Reset()
kset1c.Value, _ = compressor.compress(wbuf, kset1raw, 2) // version 2 use message set 1
kset1c.CRC = int32(crc32.ChecksumIEEE(kset1c.AppendTo(nil)[16:]))
kset1c.MessageSize = int32(len(kset1c.AppendTo(nil)[12:]))

Expand Down
10 changes: 6 additions & 4 deletions pkg/kgo/sink.go
Original file line number Diff line number Diff line change
Expand Up @@ -2094,8 +2094,9 @@ func (b seqRecBatch) appendTo(
m.CompressedBytes = m.UncompressedBytes

if compressor != nil {
w := sliceWriters.Get().(*sliceWriter)
defer sliceWriters.Put(w)
w := byteBuffers.Get().(*bytes.Buffer)
defer byteBuffers.Put(w)
w.Reset()

compressed, codec := compressor.compress(w, toCompress, version)
if compressed != nil && // nil would be from an error
Expand Down Expand Up @@ -2175,8 +2176,9 @@ func (b seqRecBatch) appendToAsMessageSet(dst []byte, version uint8, compressor
m.CompressedBytes = m.UncompressedBytes

if compressor != nil {
w := sliceWriters.Get().(*sliceWriter)
defer sliceWriters.Put(w)
w := byteBuffers.Get().(*bytes.Buffer)
defer byteBuffers.Put(w)
w.Reset()

compressed, codec := compressor.compress(w, toCompress, int16(version))
inner := &Record{Value: compressed}
Expand Down

0 comments on commit b4830b9

Please sign in to comment.