Skip to content

Commit

Permalink
feat: support context cancellation in file store (#803)
Browse files Browse the repository at this point in the history
Closes #619

---------

Signed-off-by: Lucas Rodriguez <[email protected]>
  • Loading branch information
lucasrod16 authored Aug 20, 2024
1 parent d9e1d43 commit 606f636
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 5 deletions.
8 changes: 4 additions & 4 deletions content/file/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ func (s *Store) Predecessors(ctx context.Context, node ocispec.Descriptor) ([]oc
}

// Add adds a file into the file store.
func (s *Store) Add(_ context.Context, name, mediaType, path string) (ocispec.Descriptor, error) {
func (s *Store) Add(ctx context.Context, name, mediaType, path string) (ocispec.Descriptor, error) {
if s.isClosedSet() {
return ocispec.Descriptor{}, ErrStoreClosed
}
Expand Down Expand Up @@ -426,7 +426,7 @@ func (s *Store) Add(_ context.Context, name, mediaType, path string) (ocispec.De
// generate descriptor
var desc ocispec.Descriptor
if fi.IsDir() {
desc, err = s.descriptorFromDir(name, mediaType, path)
desc, err = s.descriptorFromDir(ctx, name, mediaType, path)
} else {
desc, err = s.descriptorFromFile(fi, mediaType, path)
}
Expand Down Expand Up @@ -505,7 +505,7 @@ func (s *Store) pushDir(name, target string, expected ocispec.Descriptor, conten
}

// descriptorFromDir generates descriptor from the given directory.
func (s *Store) descriptorFromDir(name, mediaType, dir string) (desc ocispec.Descriptor, err error) {
func (s *Store) descriptorFromDir(ctx context.Context, name, mediaType, dir string) (desc ocispec.Descriptor, err error) {
// make a temp file to store the gzip
gz, err := s.tempFile()
if err != nil {
Expand All @@ -532,7 +532,7 @@ func (s *Store) descriptorFromDir(name, mediaType, dir string) (desc ocispec.Des
tw := io.MultiWriter(gzw, tarDigester.Hash())
buf := bufPool.Get().(*[]byte)
defer bufPool.Put(buf)
if err := tarDirectory(dir, name, tw, s.TarReproducible, *buf); err != nil {
if err := tarDirectory(ctx, dir, name, tw, s.TarReproducible, *buf); err != nil {
return ocispec.Descriptor{}, fmt.Errorf("failed to tar %s: %w", dir, err)
}

Expand Down
9 changes: 8 additions & 1 deletion content/file/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package file
import (
"archive/tar"
"compress/gzip"
"context"
"errors"
"fmt"
"io"
Expand All @@ -31,7 +32,7 @@ import (

// tarDirectory walks the directory specified by path, and tar those files with a new
// path prefix.
func tarDirectory(root, prefix string, w io.Writer, removeTimes bool, buf []byte) (err error) {
func tarDirectory(ctx context.Context, root, prefix string, w io.Writer, removeTimes bool, buf []byte) (err error) {
tw := tar.NewWriter(w)
defer func() {
closeErr := tw.Close()
Expand All @@ -45,6 +46,12 @@ func tarDirectory(root, prefix string, w io.Writer, removeTimes bool, buf []byte
return err
}

select {
case <-ctx.Done():
return ctx.Err()
default:
}

// Rename path
name, err := filepath.Rel(root, path)
if err != nil {
Expand Down
78 changes: 78 additions & 0 deletions content/file/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,89 @@ limitations under the License.
package file

import (
"compress/gzip"
"context"
"errors"
"os"
"path/filepath"
"testing"
)

func Test_tarDirectory(t *testing.T) {
setup := func(t *testing.T) (tmpdir string, gz *os.File, gw *gzip.Writer) {
tmpdir = t.TempDir()

paths := []string{
filepath.Join(tmpdir, "file1.txt"),
filepath.Join(tmpdir, "file2.txt"),
}

for _, p := range paths {
err := os.WriteFile(p, []byte("test content"), 0644)
if err != nil {
t.Fatal(err)
}
}

gz, err := os.CreateTemp(tmpdir, "tarDirectory-*")
if err != nil {
t.Fatal(err)
}

return tmpdir, gz, gzip.NewWriter(gz)
}

t.Run("success", func(t *testing.T) {
tmpdir, gz, gw := setup(t)
defer func() {
if err := gw.Close(); err != nil {
t.Fatal(err)
}
if err := gz.Close(); err != nil {
t.Fatal(err)
}
}()

err := tarDirectory(context.Background(), tmpdir, "prefix", gw, false, nil)
if err != nil {
t.Fatal(err)
}

_, err = gz.Stat()
if err != nil {
t.Fatal(err)
}
})

t.Run("context cancellation", func(t *testing.T) {
tmpdir, gz, gw := setup(t)
defer func() {
if err := gw.Close(); err != nil {
t.Fatal(err)
}
if err := gz.Close(); err != nil {
t.Fatal(err)
}
}()

ctx, cancel := context.WithCancel(context.Background())
cancel()

err := tarDirectory(ctx, tmpdir, "prefix", gw, false, nil)
if err == nil {
t.Fatal("expected context cancellation error, got nil")
}
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled error, got %v", err)
}

_, err = gz.Stat()
if err != nil {
t.Fatal(err)
}
})
}

func Test_ensureBasePath(t *testing.T) {
root := t.TempDir()
if err := os.MkdirAll(filepath.Join(root, "hello world", "foo", "bar"), 0700); err != nil {
Expand Down

0 comments on commit 606f636

Please sign in to comment.