Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parallel compression #167

Merged
merged 12 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions cmd/zstdseek/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"bytes"
"context"
"crypto/sha512"
"errors"
"flag"
Expand All @@ -26,6 +27,8 @@ type readCloser struct {
}

func main() {
ctx := context.Background()

var (
inputFlag, chunkingFlag, outputFlag string
qualityFlag int
Expand Down Expand Up @@ -156,21 +159,25 @@ func main() {
logger.Fatal("failed to create chunker", zap.Error(err))
}

for {
frameSource := func() ([]byte, error) {
chunk, err := chunker.Next()
if err != nil {
if errors.Is(err, io.EOF) {
break
return nil, nil
}
logger.Fatal("failed to read", zap.Error(err))
}
n, err := w.Write(chunk.Data)
if err != nil {
logger.Fatal("failed to write data", zap.Error(err))
return nil, err
}
// Chunker invalidates the data after calling Next, so we need to clone it
return bytes.Clone(chunk.Data), nil
}

_ = bar.Add(n)
err = w.WriteMany(ctx, frameSource, seekable.WithWriteCallback(func(size uint32) {
_ = bar.Add(int(size))
}))
if err != nil {
logger.Fatal("failed to write data", zap.Error(err))
}

_ = bar.Finish()
input.Close()
w.Close()
Expand Down
24 changes: 16 additions & 8 deletions pkg/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,40 @@
return sw.(*writerImpl), err
}

func (s *writerImpl) Encode(src []byte) ([]byte, error) {
func (s *writerImpl) encodeOne(src []byte) ([]byte, seekTableEntry, error) {
if int64(len(src)) > maxChunkSize {
return nil, fmt.Errorf("chunk size too big for seekable format: %d > %d",
len(src), maxChunkSize)
return nil, seekTableEntry{},
fmt.Errorf("chunk size too big for seekable format: %d > %d",
len(src), maxChunkSize)

Check warning on line 32 in pkg/encoder.go

View check run for this annotation

Codecov / codecov/patch

pkg/encoder.go#L30-L32

Added lines #L30 - L32 were not covered by tests
}

if len(src) == 0 {
return nil, nil
return nil, seekTableEntry{}, nil
}

dst := s.enc.EncodeAll(src, nil)

if int64(len(dst)) > maxChunkSize {
return nil, fmt.Errorf("result size too big for seekable format: %d > %d",
len(src), maxChunkSize)
return nil, seekTableEntry{},
fmt.Errorf("result size too big for seekable format: %d > %d",
len(src), maxChunkSize)

Check warning on line 44 in pkg/encoder.go

View check run for this annotation

Codecov / codecov/patch

pkg/encoder.go#L42-L44

Added lines #L42 - L44 were not covered by tests
}

entry := seekTableEntry{
return dst, seekTableEntry{
CompressedSize: uint32(len(dst)),
DecompressedSize: uint32(len(src)),
Checksum: uint32((xxhash.Sum64(src) << 32) >> 32),
}, nil
}

func (s *writerImpl) Encode(src []byte) ([]byte, error) {
dst, entry, err := s.encodeOne(src)
if err != nil {
return nil, err

Check warning on line 57 in pkg/encoder.go

View check run for this annotation

Codecov / codecov/patch

pkg/encoder.go#L57

Added line #L57 was not covered by tests
}

s.logger.Debug("appending frame", zap.Object("frame", &entry))
s.frameEntries = append(s.frameEntries, entry)

return dst, nil
}

Expand Down
1 change: 1 addition & 0 deletions pkg/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
go.uber.org/atomic v1.11.0
go.uber.org/multierr v1.11.0
go.uber.org/zap v1.27.0
golang.org/x/sync v0.8.0
)

require (
Expand Down
2 changes: 2 additions & 0 deletions pkg/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down
122 changes: 121 additions & 1 deletion pkg/writer.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package seekable

import (
"context"
"fmt"
"io"
"runtime"
"sync"

"golang.org/x/sync/errgroup"

"go.uber.org/multierr"
"go.uber.org/zap"

Expand Down Expand Up @@ -53,14 +57,26 @@
Close() (err error)
}

// FrameSource returns one frame of data at a time.
// When there are no more frames, returns nil.
type FrameSource func() ([]byte, error)

// ConcurrentWriter allows writing many frames concurrently
type ConcurrentWriter interface {
Writer

// WriteMany writes many frames concurrently
WriteMany(ctx context.Context, frameSource FrameSource, options ...WriteManyOption) error
}

// ZSTDEncoder is the compressor. Tested with github.com/klauspost/compress/zstd.
type ZSTDEncoder interface {
EncodeAll(src, dst []byte) []byte
}

// NewWriter wraps the passed io.Writer and Encoder into and indexed ZSTD stream.
// Resulting stream then can be randomly accessed through the Reader and Decoder interfaces.
func NewWriter(w io.Writer, encoder ZSTDEncoder, opts ...wOption) (Writer, error) {
func NewWriter(w io.Writer, encoder ZSTDEncoder, opts ...wOption) (ConcurrentWriter, error) {
sw := writerImpl{
once: &sync.Once{},
enc: encoder,
Expand Down Expand Up @@ -107,6 +123,110 @@
return
}

type encodeResult struct {
buf []byte
entry seekTableEntry
}

func (s *writerImpl) writeManyEncoder(ctx context.Context, ch chan<- encodeResult, frame []byte) func() error {
return func() error {
SaveTheRbtz marked this conversation as resolved.
Show resolved Hide resolved
dst, entry, err := s.encodeOne(frame)
if err != nil {
return fmt.Errorf("failed to encode frame: %w", err)

Check warning on line 135 in pkg/writer.go

View check run for this annotation

Codecov / codecov/patch

pkg/writer.go#L135

Added line #L135 was not covered by tests
}

select {
case <-ctx.Done():
// Fulfill our promise
case ch <- encodeResult{dst, entry}:
close(ch)
}

return nil
}
}

func (s *writerImpl) writeManyProducer(ctx context.Context, frameSource FrameSource, g *errgroup.Group, queue chan<- chan encodeResult) func() error {
return func() error {
for {
SaveTheRbtz marked this conversation as resolved.
Show resolved Hide resolved
frame, err := frameSource()
if err != nil {
return fmt.Errorf("frame source failed: %w", err)
}
if frame == nil {
close(queue)
return nil
}

// Put a channel on the queue as a sort of promise.
// This is a nice trick to keep our results ordered, even when compression
// completes out-of-order.
ch := make(chan encodeResult)
select {
case <-ctx.Done():
return nil
case queue <- ch:
}

g.Go(s.writeManyEncoder(ctx, ch, frame))
}
}
}

func (s *writerImpl) writeManyConsumer(ctx context.Context, callback func(uint32), queue <-chan chan encodeResult) func() error {
return func() error {
for {
var ch <-chan encodeResult
select {
case <-ctx.Done():
return nil
case ch = <-queue:
}
if ch == nil {
return nil
}

// Wait for the block to be complete
var result encodeResult
select {
case <-ctx.Done():
return nil

Check warning on line 193 in pkg/writer.go

View check run for this annotation

Codecov / codecov/patch

pkg/writer.go#L192-L193

Added lines #L192 - L193 were not covered by tests
case result = <-ch:
SaveTheRbtz marked this conversation as resolved.
Show resolved Hide resolved
}

n, err := s.env.WriteFrame(result.buf)
if err != nil {
return fmt.Errorf("failed to write compressed data: %w", err)
}
if n != len(result.buf) {
return fmt.Errorf("partial write: %d out of %d", n, len(result.buf))
}
s.frameEntries = append(s.frameEntries, result.entry)

if callback != nil {
callback(result.entry.DecompressedSize)
}
}
}
}

func (s *writerImpl) WriteMany(ctx context.Context, frameSource FrameSource, options ...WriteManyOption) error {
opts := writeManyOptions{concurrency: runtime.GOMAXPROCS(0)}
for _, o := range options {
if err := o(&opts); err != nil {
return err // no wrap, these should be user-comprehensible
}
}

g, gCtx := errgroup.WithContext(ctx)
g.SetLimit(opts.concurrency + 2) // reader and writer
// Add extra room in the queue, so we can keep throughput high even if blocks finish out of order
queue := make(chan chan encodeResult, opts.concurrency*2)
g.Go(s.writeManyProducer(gCtx, frameSource, g, queue))
g.Go(s.writeManyConsumer(gCtx, opts.writeCallback, queue))
return g.Wait()
}

func (s *writerImpl) writeSeekTable() error {
seekTableBytes, err := s.EndStream()
if err != nil {
Expand Down
26 changes: 26 additions & 0 deletions pkg/writer_options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package seekable

import (
"fmt"

"go.uber.org/zap"

"github.com/SaveTheRbtz/zstd-seekable-format-go/pkg/env"
Expand All @@ -15,3 +17,27 @@ func WithWLogger(l *zap.Logger) wOption {
func WithWEnvironment(e env.WEnvironment) wOption {
return func(w *writerImpl) error { w.env = e; return nil }
}

type writeManyOptions struct {
concurrency int
writeCallback func(uint32)
}

type WriteManyOption func(options *writeManyOptions) error

func WithConcurrency(concurrency int) WriteManyOption {
return func(options *writeManyOptions) error {
if concurrency < 1 {
return fmt.Errorf("concurrency must be positive: %d", concurrency)
}
options.concurrency = concurrency
return nil
}
}

func WithWriteCallback(cb func(size uint32)) WriteManyOption {
return func(options *writeManyOptions) error {
options.writeCallback = cb
return nil
}
}
Loading
Loading