Skip to content

Commit

Permalink
[copilot][flytedirectory] multipart blob download (#5715)
Browse files Browse the repository at this point in the history
* add download multipart blob

Signed-off-by: wayner0628 <[email protected]>

* recursively process subparts

Signed-off-by: wayner0628 <[email protected]>

* implement GetItems function

Signed-off-by: wayner0628 <[email protected]>

* add unit testing

Signed-off-by: wayner0628 <[email protected]>

* Parallelly handle blob items

Signed-off-by: wayner0628 <[email protected]>

* fix lint error

Signed-off-by: wayner0628 <[email protected]>

* implement GetItems function

Signed-off-by: wayner0628 <[email protected]>

* add mutex avoid racing

Signed-off-by: wayner0628 <[email protected]>

* avoid infinite call

Signed-off-by: wayner0628 <[email protected]>

* protect critical variables

Signed-off-by: wayner0628 <[email protected]>

* avoid infinite call

Signed-off-by: wayner0628 <[email protected]>

* lint

Signed-off-by: wayner0628 <[email protected]>

* add more unit tests

Signed-off-by: wayner0628 <[email protected]>

* add more unit tests

Signed-off-by: wayner0628 <[email protected]>

* fix mock

Signed-off-by: wayner0628 <[email protected]>

* Accept incoming changes

Signed-off-by: wayner0628 <[email protected]>

* multipart blob download based on new api

Signed-off-by: wayner0628 <[email protected]>

* cache store stop listing at end cursor

Signed-off-by: wayner0628 <[email protected]>

* lint

Signed-off-by: wayner0628 <[email protected]>

* remove old api mock

Signed-off-by: wayner0628 <[email protected]>

* remove old api mock

Signed-off-by: wayner0628 <[email protected]>

* remove old api mock

Signed-off-by: wayner0628 <[email protected]>

* update mem_store List to return global path

Signed-off-by: wayner0628 <[email protected]>

* change mkdir perm

Signed-off-by: wayner0628 <[email protected]>

* add comments and handle more errors

Signed-off-by: wayner0628 <[email protected]>

* lint

Co-authored-by: Han-Ru Chen (Future-Outlier) <[email protected]>
Signed-off-by: Wei-Yu Kao <[email protected]>

* address race condition and aggregate errors

Signed-off-by: wayner0628 <[email protected]>

* fix tests

Signed-off-by: Future-Outlier <[email protected]>

* err msg enhancement

Signed-off-by: Future-Outlier <[email protected]>

---------

Signed-off-by: wayner0628 <[email protected]>
Signed-off-by: Wei-Yu Kao <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Co-authored-by: Han-Ru Chen (Future-Outlier) <[email protected]>
  • Loading branch information
wayner0628 and Future-Outlier authored Nov 8, 2024
1 parent fef67b8 commit b5f23a6
Show file tree
Hide file tree
Showing 4 changed files with 339 additions and 38 deletions.
204 changes: 168 additions & 36 deletions flytecopilot/data/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import (
"io/ioutil"
"os"
"path"
"path/filepath"
"reflect"
"strconv"
"sync"

"github.com/ghodss/yaml"
"github.com/golang/protobuf/jsonpb"
Expand All @@ -31,57 +33,187 @@ type Downloader struct {
mode core.IOStrategy_DownloadMode
}

// TODO add support for multipart blobs
func (d Downloader) handleBlob(ctx context.Context, blob *core.Blob, toFilePath string) (interface{}, error) {
ref := storage.DataReference(blob.Uri)
scheme, _, _, err := ref.Split()
// TODO add timeout and rate limit
// TODO use chunk to download
func (d Downloader) handleBlob(ctx context.Context, blob *core.Blob, toPath string) (interface{}, error) {
/*
handleBlob handles the retrieval and local storage of blob data, including support for both single and multipart blob types.
For multipart blobs, it lists all parts recursively and spawns concurrent goroutines to download each part while managing file I/O in parallel.
- The function begins by validating the blob URI and categorizing the blob type (single or multipart).
- In the multipart case, it recursively lists all blob parts and launches goroutines to download and save each part.
Goroutine closure and I/O success tracking are managed to avoid resource leaks.
- For single-part blobs, it directly downloads and writes the data to the specified path.
Life Cycle:
1. Blob URI -> Blob Metadata Type check -> Recursive List parts if Multipart -> Launch goroutines to download parts
(input blob object) (determine multipart/single) (List API, handles recursive case) (each part handled in parallel)
2. Download part or full blob -> Save locally with error checks -> Handle reader/writer closures -> Return local path or error
(download each part) (error on write or directory) (close streams safely, track success) (completion or report missing closures)
*/

blobRef := storage.DataReference(blob.Uri)
scheme, _, _, err := blobRef.Split()
if err != nil {
return nil, errors.Wrapf(err, "Blob uri incorrectly formatted")
}
var reader io.ReadCloser
if scheme == "http" || scheme == "https" {
reader, err = DownloadFileFromHTTP(ctx, ref)
} else {
if blob.GetMetadata().GetType().Dimensionality == core.BlobType_MULTIPART {
logger.Warnf(ctx, "Currently only single part blobs are supported, we will force multipart to be 'path/00000'")
ref, err = d.store.ConstructReference(ctx, ref, "000000")
if err != nil {

if blob.GetMetadata().GetType().Dimensionality == core.BlobType_MULTIPART {
// Collect all parts of the multipart blob recursively (List API handles nested directories)
// Set maxItems to 100 as a parameter for the List API, enabling batch retrieval of items until all are downloaded
maxItems := 100
cursor := storage.NewCursorAtStart()
var items []storage.DataReference
var absPaths []string
for {
items, cursor, err = d.store.List(ctx, blobRef, maxItems, cursor)
if err != nil || len(items) == 0 {
logger.Errorf(ctx, "failed to collect items from multipart blob [%s]", blobRef)
return nil, err
}
for _, item := range items {
absPaths = append(absPaths, item.String())
}
if storage.IsCursorEnd(cursor) {
break
}
}

// Track the count of successful downloads and the total number of items
downloadSuccess := 0
itemCount := len(absPaths)
// Track successful closures of readers and writers in deferred functions
readerCloseSuccessCount := 0
writerCloseSuccessCount := 0
// We use Mutex to avoid race conditions when updating counters and creating directories
var mu sync.Mutex
var wg sync.WaitGroup
for _, absPath := range absPaths {
absPath := absPath

wg.Add(1)
go func() {
defer wg.Done()
defer func() {
if err := recover(); err != nil {
logger.Errorf(ctx, "recover receives error: [%s]", err)
}
}()

ref := storage.DataReference(absPath)
reader, err := DownloadFileFromStorage(ctx, ref, d.store)
if err != nil {
logger.Errorf(ctx, "Failed to download from ref [%s]", ref)
return
}
defer func() {
err := reader.Close()
if err != nil {
logger.Errorf(ctx, "failed to close Blob read stream @ref [%s].\n"+
"Error: %s", ref, err)
}
mu.Lock()
readerCloseSuccessCount++
mu.Unlock()
}()

_, _, k, err := ref.Split()
if err != nil {
logger.Errorf(ctx, "Failed to parse ref [%s]", ref)
return
}
newPath := filepath.Join(toPath, k)
dir := filepath.Dir(newPath)

mu.Lock()
// os.MkdirAll creates the specified directory structure if it doesn’t already exist
// 0777: the directory can be read and written by anyone
err = os.MkdirAll(dir, 0777)
mu.Unlock()
if err != nil {
logger.Errorf(ctx, "failed to make dir at path [%s]", dir)
return
}

writer, err := os.Create(newPath)
if err != nil {
logger.Errorf(ctx, "failed to open file at path [%s]", newPath)
return
}
defer func() {
err := writer.Close()
if err != nil {
logger.Errorf(ctx, "failed to close File write stream.\n"+
"Error: [%s]", err)
}
mu.Lock()
writerCloseSuccessCount++
mu.Unlock()
}()

_, err = io.Copy(writer, reader)
if err != nil {
logger.Errorf(ctx, "failed to write remote data to local filesystem")
return
}
mu.Lock()
downloadSuccess++
mu.Unlock()
}()
}
// Go routines are synchronized with a WaitGroup to prevent goroutine leaks.
wg.Wait()
if downloadSuccess != itemCount || readerCloseSuccessCount != itemCount || writerCloseSuccessCount != itemCount {
return nil, errors.Errorf(
"Failed to copy %d out of %d remote files from [%s] to local [%s].\n"+
"Failed to close %d readers\n"+
"Failed to close %d writers.",
itemCount-downloadSuccess, itemCount, blobRef, toPath, itemCount-readerCloseSuccessCount, itemCount-writerCloseSuccessCount,
)
}
logger.Infof(ctx, "successfully copied %d remote files from [%s] to local [%s]", downloadSuccess, blobRef, toPath)
return toPath, nil
} else if blob.GetMetadata().GetType().Dimensionality == core.BlobType_SINGLE {
// reader should be declared here (avoid being shared across all goroutines)
var reader io.ReadCloser
if scheme == "http" || scheme == "https" {
reader, err = DownloadFileFromHTTP(ctx, blobRef)
} else {
reader, err = DownloadFileFromStorage(ctx, blobRef, d.store)
}
reader, err = DownloadFileFromStorage(ctx, ref, d.store)
}
if err != nil {
logger.Errorf(ctx, "Failed to download from ref [%s]", ref)
return nil, err
}
defer func() {
err := reader.Close()
if err != nil {
logger.Errorf(ctx, "failed to close Blob read stream @ref [%s]. Error: %s", ref, err)
logger.Errorf(ctx, "Failed to download from ref [%s]", blobRef)
return nil, err
}
}()
defer func() {
err := reader.Close()
if err != nil {
logger.Errorf(ctx, "failed to close Blob read stream @ref [%s]. Error: %s", blobRef, err)
}
}()

writer, err := os.Create(toFilePath)
if err != nil {
return nil, errors.Wrapf(err, "failed to open file at path %s", toFilePath)
}
defer func() {
err := writer.Close()
writer, err := os.Create(toPath)
if err != nil {
logger.Errorf(ctx, "failed to close File write stream. Error: %s", err)
return nil, errors.Wrapf(err, "failed to open file at path %s", toPath)
}
}()
v, err := io.Copy(writer, reader)
if err != nil {
return nil, errors.Wrapf(err, "failed to write remote data to local filesystem")
defer func() {
err := writer.Close()
if err != nil {
logger.Errorf(ctx, "failed to close File write stream. Error: %s", err)
}
}()
v, err := io.Copy(writer, reader)
if err != nil {
return nil, errors.Wrapf(err, "failed to write remote data to local filesystem")
}
logger.Infof(ctx, "Successfully copied [%d] bytes remote data from [%s] to local [%s]", v, blobRef, toPath)
return toPath, nil
}
logger.Infof(ctx, "Successfully copied [%d] bytes remote data from [%s] to local [%s]", v, ref, toFilePath)
return toFilePath, nil

return nil, errors.Errorf("unexpected Blob type encountered")
}

func (d Downloader) handleSchema(ctx context.Context, schema *core.Schema, toFilePath string) (interface{}, error) {
// TODO Handle schema type
return d.handleBlob(ctx, &core.Blob{Uri: schema.Uri, Metadata: &core.BlobMetadata{Type: &core.BlobType{Dimensionality: core.BlobType_MULTIPART}}}, toFilePath)
}

Expand Down
151 changes: 151 additions & 0 deletions flytecopilot/data/download_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package data

import (
"bytes"
"context"
"os"
"path/filepath"
"testing"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flytestdlib/promutils"
"github.com/flyteorg/flyte/flytestdlib/storage"

"github.com/stretchr/testify/assert"
)

func TestHandleBlobMultipart(t *testing.T) {
t.Run("Successful Query", func(t *testing.T) {
s, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)
ref := storage.DataReference("s3://container/folder/file1")
s.WriteRaw(context.Background(), ref, 0, storage.Options{}, bytes.NewReader([]byte{}))
ref = storage.DataReference("s3://container/folder/file2")
s.WriteRaw(context.Background(), ref, 0, storage.Options{}, bytes.NewReader([]byte{}))

d := Downloader{store: s}

blob := &core.Blob{
Uri: "s3://container/folder",
Metadata: &core.BlobMetadata{
Type: &core.BlobType{
Dimensionality: core.BlobType_MULTIPART,
},
},
}

toPath := "./inputs"
defer func() {
err := os.RemoveAll(toPath)
if err != nil {
t.Errorf("Failed to delete directory: %v", err)
}
}()

result, err := d.handleBlob(context.Background(), blob, toPath)
assert.NoError(t, err)
assert.Equal(t, toPath, result)

// Check if files were created and data written
for _, file := range []string{"file1", "file2"} {
if _, err := os.Stat(filepath.Join(toPath, "folder", file)); os.IsNotExist(err) {
t.Errorf("expected file %s to exist", file)
}
}
})

t.Run("No Items", func(t *testing.T) {
s, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)

d := Downloader{store: s}

blob := &core.Blob{
Uri: "s3://container/folder",
Metadata: &core.BlobMetadata{
Type: &core.BlobType{
Dimensionality: core.BlobType_MULTIPART,
},
},
}

toPath := "./inputs"
defer func() {
err := os.RemoveAll(toPath)
if err != nil {
t.Errorf("Failed to delete directory: %v", err)
}
}()

result, err := d.handleBlob(context.Background(), blob, toPath)
assert.Error(t, err)
assert.Nil(t, result)
})
}

func TestHandleBlobSinglePart(t *testing.T) {
s, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)
ref := storage.DataReference("s3://container/file")
s.WriteRaw(context.Background(), ref, 0, storage.Options{}, bytes.NewReader([]byte{}))

d := Downloader{store: s}

blob := &core.Blob{
Uri: "s3://container/file",
Metadata: &core.BlobMetadata{
Type: &core.BlobType{
Dimensionality: core.BlobType_SINGLE,
},
},
}

toPath := "./input"
defer func() {
err := os.RemoveAll(toPath)
if err != nil {
t.Errorf("Failed to delete file: %v", err)
}
}()

result, err := d.handleBlob(context.Background(), blob, toPath)
assert.NoError(t, err)
assert.Equal(t, toPath, result)

// Check if files were created and data written
if _, err := os.Stat(toPath); os.IsNotExist(err) {
t.Errorf("expected file %s to exist", toPath)
}
}

func TestHandleBlobHTTP(t *testing.T) {
s, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)
d := Downloader{store: s}

blob := &core.Blob{
Uri: "https://raw.githubusercontent.com/flyteorg/flyte/master/README.md",
Metadata: &core.BlobMetadata{
Type: &core.BlobType{
Dimensionality: core.BlobType_SINGLE,
},
},
}

toPath := "./input"
defer func() {
err := os.RemoveAll(toPath)
if err != nil {
t.Errorf("Failed to delete file: %v", err)
}
}()

result, err := d.handleBlob(context.Background(), blob, toPath)
assert.NoError(t, err)
assert.Equal(t, toPath, result)

// Check if files were created and data written
if _, err := os.Stat(toPath); os.IsNotExist(err) {
t.Errorf("expected file %s to exist", toPath)
}
}
Loading

0 comments on commit b5f23a6

Please sign in to comment.