diff --git a/cmd/zstdseek/main.go b/cmd/zstdseek/main.go index da7b2ab..2b09730 100644 --- a/cmd/zstdseek/main.go +++ b/cmd/zstdseek/main.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "context" "crypto/sha512" "errors" "flag" @@ -26,6 +27,8 @@ type readCloser struct { } func main() { + ctx := context.Background() + var ( inputFlag, chunkingFlag, outputFlag string qualityFlag int @@ -156,21 +159,25 @@ func main() { logger.Fatal("failed to create chunker", zap.Error(err)) } - for { + frameSource := func() ([]byte, error) { chunk, err := chunker.Next() if err != nil { if errors.Is(err, io.EOF) { - break + return nil, nil } - logger.Fatal("failed to read", zap.Error(err)) - } - n, err := w.Write(chunk.Data) - if err != nil { - logger.Fatal("failed to write data", zap.Error(err)) + return nil, err } + // Chunker invalidates the data after calling Next, so we need to clone it + return bytes.Clone(chunk.Data), nil + } - _ = bar.Add(n) + err = w.WriteMany(ctx, frameSource, seekable.WithWriteCallback(func(size uint32) { + _ = bar.Add(int(size)) + })) + if err != nil { + logger.Fatal("failed to write data", zap.Error(err)) } + _ = bar.Finish() input.Close() w.Close() diff --git a/pkg/encoder.go b/pkg/encoder.go index 7b343bd..b1a976d 100644 --- a/pkg/encoder.go +++ b/pkg/encoder.go @@ -25,32 +25,40 @@ func NewEncoder(encoder ZSTDEncoder, opts ...wOption) (Encoder, error) { return sw.(*writerImpl), err } -func (s *writerImpl) Encode(src []byte) ([]byte, error) { +func (s *writerImpl) encodeOne(src []byte) ([]byte, seekTableEntry, error) { if int64(len(src)) > maxChunkSize { - return nil, fmt.Errorf("chunk size too big for seekable format: %d > %d", - len(src), maxChunkSize) + return nil, seekTableEntry{}, + fmt.Errorf("chunk size too big for seekable format: %d > %d", + len(src), maxChunkSize) } if len(src) == 0 { - return nil, nil + return nil, seekTableEntry{}, nil } dst := s.enc.EncodeAll(src, nil) if int64(len(dst)) > maxChunkSize { - return nil, fmt.Errorf("result size too big for seekable format: %d > %d", - len(src), maxChunkSize) + return nil, seekTableEntry{}, + fmt.Errorf("result size too big for seekable format: %d > %d", + len(src), maxChunkSize) } - entry := seekTableEntry{ + return dst, seekTableEntry{ CompressedSize: uint32(len(dst)), DecompressedSize: uint32(len(src)), Checksum: uint32((xxhash.Sum64(src) << 32) >> 32), + }, nil +} + +func (s *writerImpl) Encode(src []byte) ([]byte, error) { + dst, entry, err := s.encodeOne(src) + if err != nil { + return nil, err } s.logger.Debug("appending frame", zap.Object("frame", &entry)) s.frameEntries = append(s.frameEntries, entry) - return dst, nil } diff --git a/pkg/go.mod b/pkg/go.mod index 654c635..37cb6a6 100644 --- a/pkg/go.mod +++ b/pkg/go.mod @@ -10,6 +10,7 @@ require ( go.uber.org/atomic v1.11.0 go.uber.org/multierr v1.11.0 go.uber.org/zap v1.27.0 + golang.org/x/sync v0.8.0 ) require ( diff --git a/pkg/go.sum b/pkg/go.sum index c81ee09..084b29b 100644 --- a/pkg/go.sum +++ b/pkg/go.sum @@ -18,6 +18,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/pkg/writer.go b/pkg/writer.go index d00cbc2..b3be5c0 100644 --- a/pkg/writer.go +++ b/pkg/writer.go @@ -1,10 +1,14 @@ package seekable import ( + "context" "fmt" "io" + "runtime" "sync" + "golang.org/x/sync/errgroup" + "go.uber.org/multierr" "go.uber.org/zap" @@ -53,6 +57,18 @@ type Writer interface { Close() (err error) } +// FrameSource returns one frame of data at a time. +// When there are no more frames, returns nil. +type FrameSource func() ([]byte, error) + +// ConcurrentWriter allows writing many frames concurrently +type ConcurrentWriter interface { + Writer + + // WriteMany writes many frames concurrently + WriteMany(ctx context.Context, frameSource FrameSource, options ...WriteManyOption) error +} + // ZSTDEncoder is the compressor. Tested with github.com/klauspost/compress/zstd. type ZSTDEncoder interface { EncodeAll(src, dst []byte) []byte @@ -60,7 +76,7 @@ type ZSTDEncoder interface { // NewWriter wraps the passed io.Writer and Encoder into and indexed ZSTD stream. // Resulting stream then can be randomly accessed through the Reader and Decoder interfaces. -func NewWriter(w io.Writer, encoder ZSTDEncoder, opts ...wOption) (Writer, error) { +func NewWriter(w io.Writer, encoder ZSTDEncoder, opts ...wOption) (ConcurrentWriter, error) { sw := writerImpl{ once: &sync.Once{}, enc: encoder, @@ -107,6 +123,110 @@ func (s *writerImpl) Close() (err error) { return } +type encodeResult struct { + buf []byte + entry seekTableEntry +} + +func (s *writerImpl) writeManyEncoder(ctx context.Context, ch chan<- encodeResult, frame []byte) func() error { + return func() error { + dst, entry, err := s.encodeOne(frame) + if err != nil { + return fmt.Errorf("failed to encode frame: %w", err) + } + + select { + case <-ctx.Done(): + // Fulfill our promise + case ch <- encodeResult{dst, entry}: + close(ch) + } + + return nil + } +} + +func (s *writerImpl) writeManyProducer(ctx context.Context, frameSource FrameSource, g *errgroup.Group, queue chan<- chan encodeResult) func() error { + return func() error { + for { + frame, err := frameSource() + if err != nil { + return fmt.Errorf("frame source failed: %w", err) + } + if frame == nil { + close(queue) + return nil + } + + // Put a channel on the queue as a sort of promise. + // This is a nice trick to keep our results ordered, even when compression + // completes out-of-order. + ch := make(chan encodeResult) + select { + case <-ctx.Done(): + return nil + case queue <- ch: + } + + g.Go(s.writeManyEncoder(ctx, ch, frame)) + } + } +} + +func (s *writerImpl) writeManyConsumer(ctx context.Context, callback func(uint32), queue <-chan chan encodeResult) func() error { + return func() error { + for { + var ch <-chan encodeResult + select { + case <-ctx.Done(): + return nil + case ch = <-queue: + } + if ch == nil { + return nil + } + + // Wait for the block to be complete + var result encodeResult + select { + case <-ctx.Done(): + return nil + case result = <-ch: + } + + n, err := s.env.WriteFrame(result.buf) + if err != nil { + return fmt.Errorf("failed to write compressed data: %w", err) + } + if n != len(result.buf) { + return fmt.Errorf("partial write: %d out of %d", n, len(result.buf)) + } + s.frameEntries = append(s.frameEntries, result.entry) + + if callback != nil { + callback(result.entry.DecompressedSize) + } + } + } +} + +func (s *writerImpl) WriteMany(ctx context.Context, frameSource FrameSource, options ...WriteManyOption) error { + opts := writeManyOptions{concurrency: runtime.GOMAXPROCS(0)} + for _, o := range options { + if err := o(&opts); err != nil { + return err // no wrap, these should be user-comprehensible + } + } + + g, gCtx := errgroup.WithContext(ctx) + g.SetLimit(opts.concurrency + 2) // reader and writer + // Add extra room in the queue, so we can keep throughput high even if blocks finish out of order + queue := make(chan chan encodeResult, opts.concurrency*2) + g.Go(s.writeManyProducer(gCtx, frameSource, g, queue)) + g.Go(s.writeManyConsumer(gCtx, opts.writeCallback, queue)) + return g.Wait() +} + func (s *writerImpl) writeSeekTable() error { seekTableBytes, err := s.EndStream() if err != nil { diff --git a/pkg/writer_options.go b/pkg/writer_options.go index 1d21bf1..7b56aaa 100644 --- a/pkg/writer_options.go +++ b/pkg/writer_options.go @@ -1,6 +1,8 @@ package seekable import ( + "fmt" + "go.uber.org/zap" "github.com/SaveTheRbtz/zstd-seekable-format-go/pkg/env" @@ -15,3 +17,27 @@ func WithWLogger(l *zap.Logger) wOption { func WithWEnvironment(e env.WEnvironment) wOption { return func(w *writerImpl) error { w.env = e; return nil } } + +type writeManyOptions struct { + concurrency int + writeCallback func(uint32) +} + +type WriteManyOption func(options *writeManyOptions) error + +func WithConcurrency(concurrency int) WriteManyOption { + return func(options *writeManyOptions) error { + if concurrency < 1 { + return fmt.Errorf("concurrency must be positive: %d", concurrency) + } + options.concurrency = concurrency + return nil + } +} + +func WithWriteCallback(cb func(size uint32)) WriteManyOption { + return func(options *writeManyOptions) error { + options.writeCallback = cb + return nil + } +} diff --git a/pkg/writer_test.go b/pkg/writer_test.go index 3337b52..ae8c75d 100644 --- a/pkg/writer_test.go +++ b/pkg/writer_test.go @@ -2,8 +2,10 @@ package seekable import ( "bytes" + "context" "crypto/rand" "encoding/binary" + "errors" "fmt" "io" "testing" @@ -70,6 +72,139 @@ func TestWriter(t *testing.T) { assert.Equal(t, concat, readBuf[:n]) } +func makeTestFrame(t *testing.T, idx int) []byte { + var b bytes.Buffer + for i := 0; i < 100; i++ { + s := fmt.Sprintf("test%d", idx+i) + _, err := b.WriteString(s) + require.NoError(t, err) + } + return b.Bytes() +} + +func makeTestFrameSource(frames [][]byte) FrameSource { + idx := 0 + return func() ([]byte, error) { + if idx >= len(frames) { + return nil, nil + } + ret := frames[idx] + idx++ + return ret, nil + } +} + +func TestConcurrentWriter(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedFastest)) + require.NoError(t, err) + + // Setup test data + const frameCount = 20 + var frames [][]byte + var concat []byte + for i := 0; i < frameCount; i++ { + frame := makeTestFrame(t, i) + frames = append(frames, frame) + concat = append(concat, frame...) + } + + // Write concurrently + var b bytes.Buffer + bw := io.Writer(&b) + concurrentWriter, err := NewWriter(bw, enc) + require.NoError(t, err) + + var totalWritten int + err = concurrentWriter.WriteMany(ctx, makeTestFrameSource(frames), WithConcurrency(5), + WithWriteCallback(func(size uint32) { + totalWritten += int(size) + })) + require.NoError(t, err) + require.Equal(t, len(concat), totalWritten) + + // Write one at a time + var nb bytes.Buffer + nbw := io.Writer(&nb) + oneWriter, err := NewWriter(nbw, enc) + require.NoError(t, err) + + for i := 0; i < frameCount; i++ { + require.NoError(t, err) + _, err = oneWriter.Write(frames[i]) + require.NoError(t, err) + } + + // Output should be the same + assert.Equal(t, b.Bytes(), nb.Bytes()) + + concurrentImpl := concurrentWriter.(*writerImpl) + oneImpl := oneWriter.(*writerImpl) + assert.Equal(t, concurrentImpl.frameEntries, oneImpl.frameEntries) + + // test decompression + dec, err := zstd.NewReader(nil) + require.NoError(t, err) + decoded, err := dec.DecodeAll(b.Bytes(), nil) + require.NoError(t, err) + assert.Equal(t, concat, decoded) +} + +type failingWriteEnvironment struct { + n int + err error +} + +func (e failingWriteEnvironment) WriteFrame(p []byte) (n int, err error) { + return e.n, e.err +} + +func (e failingWriteEnvironment) WriteSeekTable(p []byte) (n int, err error) { + return e.n, e.err +} + +func TestConcurrentWriterErrors(t *testing.T) { + t.Parallel() + + manyFrames := [][]byte{} + for i := 0; i < 100; i++ { + manyFrames = append(manyFrames, []byte(fmt.Sprintf("test%d", i))) + } + + ctx := context.Background() + enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedFastest)) + require.NoError(t, err) + w, err := NewWriter(nil, enc) + require.NoError(t, err) + + frameSource := makeTestFrameSource([][]byte{}) + err = w.WriteMany(ctx, frameSource, WithConcurrency(0)) + assert.ErrorContains(t, err, "concurrency must be positive") + + frameSource = func() ([]byte, error) { + return nil, errors.New("test error") + } + err = w.WriteMany(ctx, frameSource) + assert.ErrorContains(t, err, "frame source failed: test error") + + var b bytes.Buffer + w, err = NewWriter(&b, enc, + WithWEnvironment(failingWriteEnvironment{0, errors.New("test error")})) + require.NoError(t, err) + frameSource = makeTestFrameSource(manyFrames) // enough that we have to wait on ctx + err = w.WriteMany(ctx, frameSource, WithConcurrency(1)) + assert.ErrorContains(t, err, "failed to write compressed data") + + w, err = NewWriter(&b, enc, + WithWEnvironment(failingWriteEnvironment{1, nil})) + require.NoError(t, err) + err = w.WriteMany(ctx, frameSource, WithConcurrency(1)) + assert.ErrorContains(t, err, "partial write") +} + type fakeWriteEnvironment struct { bw io.Writer } @@ -117,7 +252,26 @@ func TestWriteEnvironment(t *testing.T) { assert.Equal(t, concat, readBuf[:n]) } +func makeRepeatingFrameSource(frame []byte, count int) FrameSource { + idx := 0 + return func() ([]byte, error) { + if idx >= count { + return nil, nil + } + idx++ + return frame, nil + } +} + +type nullWriter struct{} + +func (nullWriter) Write(p []byte) (n int, err error) { + return len(p), nil +} + func BenchmarkWrite(b *testing.B) { + ctx := context.Background() + enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedFastest)) require.NoError(b, err) @@ -126,20 +280,26 @@ func BenchmarkWrite(b *testing.B) { writeBuf := make([]byte, sz) _, err := rand.Read(writeBuf) require.NoError(b, err) - var buf bytes.Buffer - bw := io.Writer(&buf) - w, err := NewWriter(bw, enc) + + w, err := NewWriter(nullWriter{}, enc) require.NoError(b, err) b.Run(fmt.Sprintf("%d", sz), func(b *testing.B) { b.SetBytes(sz) b.ResetTimer() - // TODO: Limit memory consumption. for i := 0; i < b.N; i++ { _, _ = w.Write(writeBuf) } }) + b.Run(fmt.Sprintf("Parallel-%d", sz), func(b *testing.B) { + b.SetBytes(sz) + b.ResetTimer() + + err = w.WriteMany(ctx, makeRepeatingFrameSource(writeBuf, b.N)) + require.NoError(b, err) + }) + err = w.Close() require.NoError(b, err) }