Skip to content

Commit

Permalink
fix WriteAllContext
Browse files Browse the repository at this point in the history
  • Loading branch information
ungerik committed Nov 16, 2023
1 parent 0419dcd commit 43bcc72
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 3 deletions.
10 changes: 7 additions & 3 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,15 @@ func ReadAllContext(ctx context.Context, r io.Reader) ([]byte, error) {
// with a cancelable context.
func WriteAllContext(ctx context.Context, w io.Writer, data []byte) error {
const chunkSize = 4 * 1024 * 1024 // 4MB
return writeAllContext(ctx, w, data, chunkSize)
}

func writeAllContext(ctx context.Context, w io.Writer, data []byte, chunkSize int) error {
for len(data) > 0 {
if ctx.Err() != nil {
return ctx.Err()
if err := ctx.Err(); err != nil {
return err
}
n, err := w.Write(data[:max(chunkSize, len(data))])
n, err := w.Write(data[:min(chunkSize, len(data))])
if err != nil {
return err
}
Expand Down
70 changes: 70 additions & 0 deletions utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package fs

import (
"bytes"
"context"
"errors"
"testing"
"time"
)

type failContext struct {
errAfter int
err error

counter int
}

func (*failContext) Deadline() (deadline time.Time, ok bool) { return time.Time{}, false }
func (*failContext) Done() <-chan struct{} { return nil }
func (*failContext) Value(key any) any { return nil }

func (f *failContext) Err() error {
f.counter++
if f.counter > f.errAfter {
return f.err
}
return nil
}

func Test_writeAllContext(t *testing.T) {
type args struct {
ctx context.Context
data []byte
chunkSize int
}
tests := []struct {
name string
args args
wantW string
wantErr bool
}{
{name: " chunkSize2", args: args{ctx: context.Background(), data: []byte(""), chunkSize: 2}, wantW: "", wantErr: false},
{name: "1 chunkSize2", args: args{ctx: context.Background(), data: []byte("1"), chunkSize: 2}, wantW: "1", wantErr: false},
{name: "12 chunkSize2", args: args{ctx: context.Background(), data: []byte("12"), chunkSize: 2}, wantW: "12", wantErr: false},
{name: "123 chunkSize2", args: args{ctx: context.Background(), data: []byte("123"), chunkSize: 2}, wantW: "123", wantErr: false},
{name: "1234 chunkSize2", args: args{ctx: context.Background(), data: []byte("1234"), chunkSize: 2}, wantW: "1234", wantErr: false},
{name: "12345 chunkSize2", args: args{ctx: context.Background(), data: []byte("12345"), chunkSize: 2}, wantW: "12345", wantErr: false},

{name: " chunkSize3", args: args{ctx: context.Background(), data: []byte(""), chunkSize: 3}, wantW: "", wantErr: false},
{name: "1 chunkSize3", args: args{ctx: context.Background(), data: []byte("1"), chunkSize: 3}, wantW: "1", wantErr: false},
{name: "12 chunkSize3", args: args{ctx: context.Background(), data: []byte("12"), chunkSize: 3}, wantW: "12", wantErr: false},
{name: "123 chunkSize3", args: args{ctx: context.Background(), data: []byte("123"), chunkSize: 3}, wantW: "123", wantErr: false},
{name: "1234 chunkSize3", args: args{ctx: context.Background(), data: []byte("1234"), chunkSize: 3}, wantW: "1234", wantErr: false},
{name: "12345 chunkSize3", args: args{ctx: context.Background(), data: []byte("12345"), chunkSize: 3}, wantW: "12345", wantErr: false},

{name: "12345 chunkSize2 error", args: args{ctx: &failContext{errAfter: 1, err: errors.New("contextError")}, data: []byte("12345"), chunkSize: 2}, wantW: "12", wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &bytes.Buffer{}
if err := writeAllContext(tt.args.ctx, w, tt.args.data, tt.args.chunkSize); (err != nil) != tt.wantErr {
t.Errorf("writeAllContext() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotW := w.String(); gotW != tt.wantW {
t.Errorf("writeAllContext() = %v, want %v", gotW, tt.wantW)
}
})
}
}

0 comments on commit 43bcc72

Please sign in to comment.