Skip to content

Commit

Permalink
Merge pull request #167 from vasi/parallel
Browse files Browse the repository at this point in the history
Add parallel compression
  • Loading branch information
SaveTheRbtz authored Oct 1, 2024
2 parents 92d1f81 + 4e1ed4e commit 0e32fcc
Show file tree
Hide file tree
Showing 7 changed files with 345 additions and 21 deletions.
23 changes: 15 additions & 8 deletions cmd/zstdseek/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"bytes"
"context"
"crypto/sha512"
"errors"
"flag"
Expand All @@ -26,6 +27,8 @@ type readCloser struct {
}

func main() {
ctx := context.Background()

var (
inputFlag, chunkingFlag, outputFlag string
qualityFlag int
Expand Down Expand Up @@ -156,21 +159,25 @@ func main() {
logger.Fatal("failed to create chunker", zap.Error(err))
}

for {
frameSource := func() ([]byte, error) {
chunk, err := chunker.Next()
if err != nil {
if errors.Is(err, io.EOF) {
break
return nil, nil
}
logger.Fatal("failed to read", zap.Error(err))
}
n, err := w.Write(chunk.Data)
if err != nil {
logger.Fatal("failed to write data", zap.Error(err))
return nil, err
}
// Chunker invalidates the data after calling Next, so we need to clone it
return bytes.Clone(chunk.Data), nil
}

_ = bar.Add(n)
err = w.WriteMany(ctx, frameSource, seekable.WithWriteCallback(func(size uint32) {
_ = bar.Add(int(size))
}))
if err != nil {
logger.Fatal("failed to write data", zap.Error(err))
}

_ = bar.Finish()
input.Close()
w.Close()
Expand Down
24 changes: 16 additions & 8 deletions pkg/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,40 @@ func NewEncoder(encoder ZSTDEncoder, opts ...wOption) (Encoder, error) {
return sw.(*writerImpl), err
}

func (s *writerImpl) Encode(src []byte) ([]byte, error) {
func (s *writerImpl) encodeOne(src []byte) ([]byte, seekTableEntry, error) {
if int64(len(src)) > maxChunkSize {
return nil, fmt.Errorf("chunk size too big for seekable format: %d > %d",
len(src), maxChunkSize)
return nil, seekTableEntry{},
fmt.Errorf("chunk size too big for seekable format: %d > %d",
len(src), maxChunkSize)
}

if len(src) == 0 {
return nil, nil
return nil, seekTableEntry{}, nil
}

dst := s.enc.EncodeAll(src, nil)

if int64(len(dst)) > maxChunkSize {
return nil, fmt.Errorf("result size too big for seekable format: %d > %d",
len(src), maxChunkSize)
return nil, seekTableEntry{},
fmt.Errorf("result size too big for seekable format: %d > %d",
len(src), maxChunkSize)
}

entry := seekTableEntry{
return dst, seekTableEntry{
CompressedSize: uint32(len(dst)),
DecompressedSize: uint32(len(src)),
Checksum: uint32((xxhash.Sum64(src) << 32) >> 32),
}, nil
}

func (s *writerImpl) Encode(src []byte) ([]byte, error) {
dst, entry, err := s.encodeOne(src)
if err != nil {
return nil, err
}

s.logger.Debug("appending frame", zap.Object("frame", &entry))
s.frameEntries = append(s.frameEntries, entry)

return dst, nil
}

Expand Down
1 change: 1 addition & 0 deletions pkg/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
go.uber.org/atomic v1.11.0
go.uber.org/multierr v1.11.0
go.uber.org/zap v1.27.0
golang.org/x/sync v0.8.0
)

require (
Expand Down
2 changes: 2 additions & 0 deletions pkg/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down
122 changes: 121 additions & 1 deletion pkg/writer.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package seekable

import (
"context"
"fmt"
"io"
"runtime"
"sync"

"golang.org/x/sync/errgroup"

"go.uber.org/multierr"
"go.uber.org/zap"

Expand Down Expand Up @@ -53,14 +57,26 @@ type Writer interface {
Close() (err error)
}

// FrameSource returns one frame of data at a time.
// When there are no more frames, returns nil.
type FrameSource func() ([]byte, error)

// ConcurrentWriter allows writing many frames concurrently
type ConcurrentWriter interface {
Writer

// WriteMany writes many frames concurrently
WriteMany(ctx context.Context, frameSource FrameSource, options ...WriteManyOption) error
}

// ZSTDEncoder is the compressor. Tested with github.com/klauspost/compress/zstd.
type ZSTDEncoder interface {
EncodeAll(src, dst []byte) []byte
}

// NewWriter wraps the passed io.Writer and Encoder into and indexed ZSTD stream.
// Resulting stream then can be randomly accessed through the Reader and Decoder interfaces.
func NewWriter(w io.Writer, encoder ZSTDEncoder, opts ...wOption) (Writer, error) {
func NewWriter(w io.Writer, encoder ZSTDEncoder, opts ...wOption) (ConcurrentWriter, error) {
sw := writerImpl{
once: &sync.Once{},
enc: encoder,
Expand Down Expand Up @@ -107,6 +123,110 @@ func (s *writerImpl) Close() (err error) {
return
}

type encodeResult struct {
buf []byte
entry seekTableEntry
}

func (s *writerImpl) writeManyEncoder(ctx context.Context, ch chan<- encodeResult, frame []byte) func() error {
return func() error {
dst, entry, err := s.encodeOne(frame)
if err != nil {
return fmt.Errorf("failed to encode frame: %w", err)
}

select {
case <-ctx.Done():
// Fulfill our promise
case ch <- encodeResult{dst, entry}:
close(ch)
}

return nil
}
}

func (s *writerImpl) writeManyProducer(ctx context.Context, frameSource FrameSource, g *errgroup.Group, queue chan<- chan encodeResult) func() error {
return func() error {
for {
frame, err := frameSource()
if err != nil {
return fmt.Errorf("frame source failed: %w", err)
}
if frame == nil {
close(queue)
return nil
}

// Put a channel on the queue as a sort of promise.
// This is a nice trick to keep our results ordered, even when compression
// completes out-of-order.
ch := make(chan encodeResult)
select {
case <-ctx.Done():
return nil
case queue <- ch:
}

g.Go(s.writeManyEncoder(ctx, ch, frame))
}
}
}

func (s *writerImpl) writeManyConsumer(ctx context.Context, callback func(uint32), queue <-chan chan encodeResult) func() error {
return func() error {
for {
var ch <-chan encodeResult
select {
case <-ctx.Done():
return nil
case ch = <-queue:
}
if ch == nil {
return nil
}

// Wait for the block to be complete
var result encodeResult
select {
case <-ctx.Done():
return nil
case result = <-ch:
}

n, err := s.env.WriteFrame(result.buf)
if err != nil {
return fmt.Errorf("failed to write compressed data: %w", err)
}
if n != len(result.buf) {
return fmt.Errorf("partial write: %d out of %d", n, len(result.buf))
}
s.frameEntries = append(s.frameEntries, result.entry)

if callback != nil {
callback(result.entry.DecompressedSize)
}
}
}
}

func (s *writerImpl) WriteMany(ctx context.Context, frameSource FrameSource, options ...WriteManyOption) error {
opts := writeManyOptions{concurrency: runtime.GOMAXPROCS(0)}
for _, o := range options {
if err := o(&opts); err != nil {
return err // no wrap, these should be user-comprehensible
}
}

g, gCtx := errgroup.WithContext(ctx)
g.SetLimit(opts.concurrency + 2) // reader and writer
// Add extra room in the queue, so we can keep throughput high even if blocks finish out of order
queue := make(chan chan encodeResult, opts.concurrency*2)
g.Go(s.writeManyProducer(gCtx, frameSource, g, queue))
g.Go(s.writeManyConsumer(gCtx, opts.writeCallback, queue))
return g.Wait()
}

func (s *writerImpl) writeSeekTable() error {
seekTableBytes, err := s.EndStream()
if err != nil {
Expand Down
26 changes: 26 additions & 0 deletions pkg/writer_options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package seekable

import (
"fmt"

"go.uber.org/zap"

"github.com/SaveTheRbtz/zstd-seekable-format-go/pkg/env"
Expand All @@ -15,3 +17,27 @@ func WithWLogger(l *zap.Logger) wOption {
func WithWEnvironment(e env.WEnvironment) wOption {
return func(w *writerImpl) error { w.env = e; return nil }
}

type writeManyOptions struct {
concurrency int
writeCallback func(uint32)
}

type WriteManyOption func(options *writeManyOptions) error

func WithConcurrency(concurrency int) WriteManyOption {
return func(options *writeManyOptions) error {
if concurrency < 1 {
return fmt.Errorf("concurrency must be positive: %d", concurrency)
}
options.concurrency = concurrency
return nil
}
}

func WithWriteCallback(cb func(size uint32)) WriteManyOption {
return func(options *writeManyOptions) error {
options.writeCallback = cb
return nil
}
}
Loading

0 comments on commit 0e32fcc

Please sign in to comment.