diff --git a/utils.go b/utils.go index eb5b689..d806839 100644 --- a/utils.go +++ b/utils.go @@ -95,11 +95,15 @@ func ReadAllContext(ctx context.Context, r io.Reader) ([]byte, error) { // with a cancelable context. func WriteAllContext(ctx context.Context, w io.Writer, data []byte) error { const chunkSize = 4 * 1024 * 1024 // 4MB + return writeAllContext(ctx, w, data, chunkSize) +} + +func writeAllContext(ctx context.Context, w io.Writer, data []byte, chunkSize int) error { for len(data) > 0 { - if ctx.Err() != nil { - return ctx.Err() + if err := ctx.Err(); err != nil { + return err } - n, err := w.Write(data[:max(chunkSize, len(data))]) + n, err := w.Write(data[:min(chunkSize, len(data))]) if err != nil { return err } diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 0000000..34b3273 --- /dev/null +++ b/utils_test.go @@ -0,0 +1,70 @@ +package fs + +import ( + "bytes" + "context" + "errors" + "testing" + "time" +) + +type failContext struct { + errAfter int + err error + + counter int +} + +func (*failContext) Deadline() (deadline time.Time, ok bool) { return time.Time{}, false } +func (*failContext) Done() <-chan struct{} { return nil } +func (*failContext) Value(key any) any { return nil } + +func (f *failContext) Err() error { + f.counter++ + if f.counter > f.errAfter { + return f.err + } + return nil +} + +func Test_writeAllContext(t *testing.T) { + type args struct { + ctx context.Context + data []byte + chunkSize int + } + tests := []struct { + name string + args args + wantW string + wantErr bool + }{ + {name: " chunkSize2", args: args{ctx: context.Background(), data: []byte(""), chunkSize: 2}, wantW: "", wantErr: false}, + {name: "1 chunkSize2", args: args{ctx: context.Background(), data: []byte("1"), chunkSize: 2}, wantW: "1", wantErr: false}, + {name: "12 chunkSize2", args: args{ctx: context.Background(), data: []byte("12"), chunkSize: 2}, wantW: "12", wantErr: false}, + {name: "123 chunkSize2", args: args{ctx: context.Background(), data: []byte("123"), chunkSize: 2}, wantW: "123", wantErr: false}, + {name: "1234 chunkSize2", args: args{ctx: context.Background(), data: []byte("1234"), chunkSize: 2}, wantW: "1234", wantErr: false}, + {name: "12345 chunkSize2", args: args{ctx: context.Background(), data: []byte("12345"), chunkSize: 2}, wantW: "12345", wantErr: false}, + + {name: " chunkSize3", args: args{ctx: context.Background(), data: []byte(""), chunkSize: 3}, wantW: "", wantErr: false}, + {name: "1 chunkSize3", args: args{ctx: context.Background(), data: []byte("1"), chunkSize: 3}, wantW: "1", wantErr: false}, + {name: "12 chunkSize3", args: args{ctx: context.Background(), data: []byte("12"), chunkSize: 3}, wantW: "12", wantErr: false}, + {name: "123 chunkSize3", args: args{ctx: context.Background(), data: []byte("123"), chunkSize: 3}, wantW: "123", wantErr: false}, + {name: "1234 chunkSize3", args: args{ctx: context.Background(), data: []byte("1234"), chunkSize: 3}, wantW: "1234", wantErr: false}, + {name: "12345 chunkSize3", args: args{ctx: context.Background(), data: []byte("12345"), chunkSize: 3}, wantW: "12345", wantErr: false}, + + {name: "12345 chunkSize2 error", args: args{ctx: &failContext{errAfter: 1, err: errors.New("contextError")}, data: []byte("12345"), chunkSize: 2}, wantW: "12", wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &bytes.Buffer{} + if err := writeAllContext(tt.args.ctx, w, tt.args.data, tt.args.chunkSize); (err != nil) != tt.wantErr { + t.Errorf("writeAllContext() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotW := w.String(); gotW != tt.wantW { + t.Errorf("writeAllContext() = %v, want %v", gotW, tt.wantW) + } + }) + } +}