Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in s3hub cp subcommand #46

Merged
merged 9 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
SOFTWARE.
7 changes: 6 additions & 1 deletion app/domain/model/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,12 @@ func (b Bucket) Split() (Bucket, S3Key) {
if len(s) == 1 {
return b, ""
}
return Bucket(s[0]), S3Key(strings.Join(s[1:], "/"))

key := strings.Join(s[1:], "/")
if key == "" {
return Bucket(s[0]), S3Key("")
}
return Bucket(s[0]), S3Key(filepath.Clean(key))
}

// Validate returns true if the Bucket is valid.
Expand Down
14 changes: 12 additions & 2 deletions app/domain/model/s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,11 +659,21 @@ func TestBucket_Split(t *testing.T) {
want: Bucket("abc"),
want1: S3Key(filepath.Join("def", "ghi")),
},
{
name: "If Bucket is 'abc/def/ghi/', Split() returns 'abc' and 'def/ghi/'",
b: Bucket(filepath.Join("abc", "def", "ghi/")),
want: Bucket("abc"),
want1: S3Key(filepath.Join("def", "ghi")),
},
{
name: "If Bucket is 'abc/def/../ghi/jkl', Split() returns 'abc' and 'def/../ghi/jkl'",
b: Bucket(filepath.Join("abc", "def", "..", "ghi", "jkl")),
want: Bucket("abc"),
want1: S3Key(filepath.Join("ghi", "jkl")),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, got1 := tt.b.Split()
if got != tt.want {
t.Errorf("Bucket.Split() got = %v, want %v", got, tt.want)
Expand Down
1 change: 1 addition & 0 deletions app/external/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ func (c *S3ObjectUploader) UploadS3Object(ctx context.Context, input *service.S3
if err != nil {
return nil, err
}

return &service.S3ObjectUploaderOutput{
ContentType: input.S3Object.ContentType(),
ContentLength: input.S3Object.ContentLength(),
Expand Down
23 changes: 20 additions & 3 deletions app/external/s3_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,34 @@ func CreateS3Buckets(t *testing.T, client *s3.Client, buckets []model.Bucket) {
}
}

// DeleteAllS3BucketDelete deletes all S3 buckets.
// DeleteAllS3BucketDelete deletes all S3 buckets and objects.
func DeleteAllS3BucketDelete(t *testing.T, client *s3.Client) {
t.Helper()
ctx := context.Background()

buckets, err := client.ListBuckets(context.Background(), &s3.ListBucketsInput{})
buckets, err := client.ListBuckets(ctx, &s3.ListBucketsInput{})
if err != nil {
t.Fatal(err)
}

for _, bucket := range buckets.Buckets {
if _, err := client.DeleteBucket(context.Background(), &s3.DeleteBucketInput{Bucket: bucket.Name}); err != nil {
output, err := client.ListObjects(ctx, &s3.ListObjectsInput{
Bucket: bucket.Name,
})
if err != nil {
t.Fatal(err)
}

for _, object := range output.Contents {
if _, err := client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: bucket.Name,
Key: object.Key,
}); err != nil {
t.Fatal(err)
}
}

if _, err := client.DeleteBucket(ctx, &s3.DeleteBucketInput{Bucket: bucket.Name}); err != nil {
Comment on lines +39 to +66
Copy link
Contributor

@coderabbitai coderabbitai bot Jan 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation of DeleteAllS3BucketDelete now includes deleting all objects within an S3 bucket before attempting to delete the bucket itself. This change addresses potential issues with deleting non-empty buckets. However, it's important to handle pagination in the ListObjects call to ensure all objects are deleted, especially in buckets with a large number of objects.

+ // Ensure to handle pagination for buckets with many objects
+ var continuationToken *string = nil
+ for {
+     output, err := client.ListObjects(ctx, &s3.ListObjectsInput{
+         Bucket: bucket.Name,
+         ContinuationToken: continuationToken,
+     })
+     if err != nil {
+         t.Fatal(err)
+     }
+     for _, object := range output.Contents {
+         _, err := client.DeleteObject(ctx, &s3.DeleteObjectInput{
+             Bucket: bucket.Name,
+             Key:    object.Key,
+         })
+         if err != nil {
+             t.Fatal(err)
+         }
+     }
+     if !output.IsTruncated {
+         break
+     }
+     continuationToken = output.NextContinuationToken
+ }

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
// DeleteAllS3BucketDelete deletes all S3 buckets and objects.
func DeleteAllS3BucketDelete(t *testing.T, client *s3.Client) {
t.Helper()
ctx := context.Background()
buckets, err := client.ListBuckets(context.Background(), &s3.ListBucketsInput{})
buckets, err := client.ListBuckets(ctx, &s3.ListBucketsInput{})
if err != nil {
t.Fatal(err)
}
for _, bucket := range buckets.Buckets {
if _, err := client.DeleteBucket(context.Background(), &s3.DeleteBucketInput{Bucket: bucket.Name}); err != nil {
output, err := client.ListObjects(ctx, &s3.ListObjectsInput{
Bucket: bucket.Name,
})
if err != nil {
t.Fatal(err)
}
for _, object := range output.Contents {
if _, err := client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: bucket.Name,
Key: object.Key,
}); err != nil {
t.Fatal(err)
}
}
if _, err := client.DeleteBucket(ctx, &s3.DeleteBucketInput{Bucket: bucket.Name}); err != nil {
// DeleteAllS3BucketDelete deletes all S3 buckets and objects.
func DeleteAllS3BucketDelete(t *testing.T, client *s3.Client) {
t.Helper()
ctx := context.Background()
buckets, err := client.ListBuckets(ctx, &s3.ListBucketsInput{})
if err != nil {
t.Fatal(err)
}
for _, bucket := range buckets.Buckets {
// Ensure to handle pagination for buckets with many objects
var continuationToken *string = nil
for {
output, err := client.ListObjects(ctx, &s3.ListObjectsInput{
Bucket: bucket.Name,
ContinuationToken: continuationToken,
})
if err != nil {
t.Fatal(err)
}
for _, object := range output.Contents {
_, err := client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: bucket.Name,
Key: object.Key,
})
if err != nil {
t.Fatal(err)
}
}
if !output.IsTruncated {
break
}
continuationToken = output.NextContinuationToken
}
if _, err := client.DeleteBucket(ctx, &s3.DeleteBucketInput{Bucket: bucket.Name}); err != nil {

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeleteAllS3BucketDelete() method only used for unit test. So, there is no many s3 objects. We don't use pagination when execute unit test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nao1215, understood! Given that DeleteAllS3BucketDelete is used solely for unit testing with a limited number of objects, your approach makes sense. Skipping pagination under these conditions seems reasonable. Thank you for clarifying the context and usage of this method. If there's anything else you need, feel free to ask!

t.Fatal(err)
}
}
Expand Down
15 changes: 15 additions & 0 deletions app/interactor/mock/s3.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package mock

import (
"context"

"github.com/nao1215/rainbow/app/usecase"
)

// S3ObjectsLister is a mock of the S3ObjectLister interface.
type S3ObjectsLister func(ctx context.Context, input *usecase.S3ObjectsListerInput) (*usecase.S3ObjectsListerOutput, error)

// ListS3Objects calls the ListS3ObjectsFunc.
func (m S3ObjectsLister) ListS3Objects(ctx context.Context, input *usecase.S3ObjectsListerInput) (*usecase.S3ObjectsListerOutput, error) {
return m(ctx, input)
}
1 change: 1 addition & 0 deletions app/interactor/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ func (u *FileUploader) UploadFile(ctx context.Context, input *usecase.FileUpload
if err != nil {
return nil, err
}

return &usecase.FileUploaderOutput{
ContentType: output.ContentType,
ContentLength: output.ContentLength,
Expand Down
93 changes: 56 additions & 37 deletions cmd/subcmd/s3hub/cp.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/fatih/color"
"github.com/gogf/gf/os/gfile"
"github.com/nao1215/rainbow/app/domain/model"
"github.com/nao1215/rainbow/app/usecase"
"github.com/nao1215/rainbow/cmd/subcmd"
Expand Down Expand Up @@ -65,15 +66,15 @@ type copyPathPair struct {
From string
// To is a path of destination.
To string
// copyType is a type of copy.
// Type indicates the direction of the copy operation: from local to S3, from S3 to local, or within S3.
Type copyType
}

// newCopyPathPair returns a new copyPathPair.
func newCopyPathPair(from, to string) *copyPathPair {
pair := &copyPathPair{
From: filepath.Clean(from),
To: filepath.Clean(to),
From: from,
To: to,
}
pair.Type = pair.copyType()
return pair
Expand Down Expand Up @@ -121,34 +122,48 @@ func (c *cpCmd) Do() error {
case copyTypeS3ToS3:
return c.s3ToS3()
case copyTypeUnknown:
fallthrough
default:
return fmt.Errorf("unsupported copy type. from=%s, to=%s",
color.YellowString(c.pair.From), color.YellowString(c.pair.To))
}
return nil
}

// copyTargetsInLocal returns a slice of target files in local.
func (c *cpCmd) copyTargetsInLocal() ([]string, error) {
if gfile.IsFile(c.pair.From) {
return []string{c.pair.From}, nil
}
targets, err := file.WalkDir(c.pair.From)
if err != nil {
return nil, err
}
return targets, nil
}

// localToS3 copies from local to S3.
func (c *cpCmd) localToS3() error {
targets, err := file.WalkDir(c.pair.From)
targets, err := c.copyTargetsInLocal()
if err != nil {
return err
}
toBucket, toKey := model.NewBucketWithoutProtocol(c.pair.To).Split()

toBucket, toKey := model.NewBucketWithoutProtocol(c.pair.To).Split()
fileNum := len(targets)

nao1215 marked this conversation as resolved.
Show resolved Hide resolved
for i, v := range targets {
data, err := os.ReadFile(filepath.Clean(v))
if err != nil {
return err
return fmt.Errorf("can not read file %s: %w", color.YellowString(v), err)
}

if _, err := c.s3hub.FileUploader.UploadFile(c.ctx, &usecase.FileUploaderInput{
Bucket: toBucket,
Region: c.s3hub.region,
Key: toKey,
Key: model.S3Key(filepath.Join(toKey.String(), filepath.Base(v))),
Data: data,
}); err != nil {
return err
return fmt.Errorf("can not upload file %s: %w", color.YellowString(v), err)
}
c.printf("[%d/%d] copy %s to %s\n",
i+1,
Expand All @@ -163,44 +178,29 @@ func (c *cpCmd) localToS3() error {
// s3ToLocal copies from S3 to local.
func (c *cpCmd) s3ToLocal() error {
fromBucket, fromKey := model.NewBucketWithoutProtocol(c.pair.From).Split()
_, toKey := model.NewBucketWithoutProtocol(c.pair.To).Split()

listOutput, err := c.s3hub.ListS3Objects(c.ctx, &usecase.S3ObjectsListerInput{
Bucket: fromBucket,
})
targets, err := c.filterS3Objects(fromBucket, fromKey)
if err != nil {
return err
}

targets := make([]model.S3Key, 0, len(listOutput.Objects))
for _, v := range listOutput.Objects {
if strings.Contains(v.S3Key.String(), fromKey.String()) {
targets = append(targets, v.S3Key)
}
}

if len(targets) == 0 {
return fmt.Errorf("no objects found. bucket=%s, key=%s",
color.YellowString(fromBucket.String()), color.YellowString(fromKey.String()))
}

fileNum := len(targets)
for i, v := range targets {
downloadOutput, err := c.s3hub.S3ObjectDownloader.DownloadS3Object(c.ctx, &usecase.S3ObjectDownloaderInput{
Bucket: fromBucket,
Key: v,
})
if err != nil {
return err
return fmt.Errorf("can not download s3 object=%s: %w",
color.YellowString(fromBucket.Join(v).WithProtocol().String()), err)
}

relativePath, err := filepath.Rel(fromKey.String(), v.String())
if err != nil {
return err
destinationPath := filepath.Clean(filepath.Join(c.pair.To, fromKey.String()))
if err := os.MkdirAll(filepath.Dir(destinationPath), 0750); err != nil {
return fmt.Errorf("can not create directory %s: %w", color.YellowString(filepath.Dir(destinationPath)), err)
}
destinationPath := filepath.Join(toKey.String(), relativePath)

if err := downloadOutput.S3Object.ToFile(destinationPath, 0644); err != nil {
return err
return fmt.Errorf("can not write file to %s: %w", color.YellowString(destinationPath), err)
}

c.printf("[%d/%d] copy %s to %s\n",
Expand All @@ -213,6 +213,29 @@ func (c *cpCmd) s3ToLocal() error {
return nil
}

// filterS3Objects returns a slice of S3Key that matches the fromKey.
func (c *cpCmd) filterS3Objects(fromBucket model.Bucket, fromKey model.S3Key) ([]model.S3Key, error) {
listOutput, err := c.s3hub.ListS3Objects(c.ctx, &usecase.S3ObjectsListerInput{
Bucket: fromBucket,
})
if err != nil {
return nil, fmt.Errorf("%w: bucket=%s", err, color.YellowString(fromBucket.String()))
}

targets := make([]model.S3Key, 0, len(listOutput.Objects))
for _, v := range listOutput.Objects {
if strings.Contains(filepath.Join(fromBucket.String(), v.S3Key.String()), fromKey.String()) {
targets = append(targets, v.S3Key)
}
}

if len(targets) == 0 {
return nil, fmt.Errorf("no objects found. bucket=%s, key=%s",
color.YellowString(fromBucket.String()), color.YellowString(fromKey.String()))
}
return targets, nil
}

// s3ToS3 copies from S3 to S3.
func (c *cpCmd) s3ToS3() error {
fromBucket, fromKey := model.NewBucketWithoutProtocol(c.pair.From).Split()
Expand All @@ -238,11 +261,7 @@ func (c *cpCmd) s3ToS3() error {

fileNum := len(targets)
for i, v := range targets {
relativePath, err := filepath.Rel(fromKey.String(), v.String())
if err != nil {
return err
}
destinationKey := model.S3Key(filepath.Join(toKey.String(), relativePath))
destinationKey := model.S3Key(filepath.Clean(filepath.Join(toKey.String(), v.String())))

if _, err := c.s3hub.S3ObjectCopier.CopyS3Object(c.ctx, &usecase.S3ObjectCopierInput{
SourceBucket: fromBucket,
Expand Down
Loading
Loading