Skip to content

Commit

Permalink
Merge pull request #46 from nao1215/feat/s3hub-cp
Browse files Browse the repository at this point in the history
Fix bug in s3hub cp subcommand
  • Loading branch information
nao1215 authored Jan 27, 2024
2 parents d03dc7a + 28a9ab8 commit f4aa82b
Show file tree
Hide file tree
Showing 12 changed files with 499 additions and 137 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
SOFTWARE.
7 changes: 6 additions & 1 deletion app/domain/model/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,12 @@ func (b Bucket) Split() (Bucket, S3Key) {
if len(s) == 1 {
return b, ""
}
return Bucket(s[0]), S3Key(strings.Join(s[1:], "/"))

key := strings.Join(s[1:], "/")
if key == "" {
return Bucket(s[0]), S3Key("")
}
return Bucket(s[0]), S3Key(filepath.Clean(key))
}

// Validate returns true if the Bucket is valid.
Expand Down
14 changes: 12 additions & 2 deletions app/domain/model/s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,11 +659,21 @@ func TestBucket_Split(t *testing.T) {
want: Bucket("abc"),
want1: S3Key(filepath.Join("def", "ghi")),
},
{
name: "If Bucket is 'abc/def/ghi/', Split() returns 'abc' and 'def/ghi/'",
b: Bucket(filepath.Join("abc", "def", "ghi/")),
want: Bucket("abc"),
want1: S3Key(filepath.Join("def", "ghi")),
},
{
name: "If Bucket is 'abc/def/../ghi/jkl', Split() returns 'abc' and 'def/../ghi/jkl'",
b: Bucket(filepath.Join("abc", "def", "..", "ghi", "jkl")),
want: Bucket("abc"),
want1: S3Key(filepath.Join("ghi", "jkl")),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, got1 := tt.b.Split()
if got != tt.want {
t.Errorf("Bucket.Split() got = %v, want %v", got, tt.want)
Expand Down
1 change: 1 addition & 0 deletions app/external/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ func (c *S3ObjectUploader) UploadS3Object(ctx context.Context, input *service.S3
if err != nil {
return nil, err
}

return &service.S3ObjectUploaderOutput{
ContentType: input.S3Object.ContentType(),
ContentLength: input.S3Object.ContentLength(),
Expand Down
23 changes: 20 additions & 3 deletions app/external/s3_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,34 @@ func CreateS3Buckets(t *testing.T, client *s3.Client, buckets []model.Bucket) {
}
}

// DeleteAllS3BucketDelete deletes all S3 buckets.
// DeleteAllS3BucketDelete deletes all S3 buckets and objects.
func DeleteAllS3BucketDelete(t *testing.T, client *s3.Client) {
t.Helper()
ctx := context.Background()

buckets, err := client.ListBuckets(context.Background(), &s3.ListBucketsInput{})
buckets, err := client.ListBuckets(ctx, &s3.ListBucketsInput{})
if err != nil {
t.Fatal(err)
}

for _, bucket := range buckets.Buckets {
if _, err := client.DeleteBucket(context.Background(), &s3.DeleteBucketInput{Bucket: bucket.Name}); err != nil {
output, err := client.ListObjects(ctx, &s3.ListObjectsInput{
Bucket: bucket.Name,
})
if err != nil {
t.Fatal(err)
}

for _, object := range output.Contents {
if _, err := client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: bucket.Name,
Key: object.Key,
}); err != nil {
t.Fatal(err)
}
}

if _, err := client.DeleteBucket(ctx, &s3.DeleteBucketInput{Bucket: bucket.Name}); err != nil {
t.Fatal(err)
}
}
Expand Down
15 changes: 15 additions & 0 deletions app/interactor/mock/s3.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package mock

import (
"context"

"github.com/nao1215/rainbow/app/usecase"
)

// S3ObjectsLister is a mock of the S3ObjectLister interface.
type S3ObjectsLister func(ctx context.Context, input *usecase.S3ObjectsListerInput) (*usecase.S3ObjectsListerOutput, error)

// ListS3Objects calls the ListS3ObjectsFunc.
func (m S3ObjectsLister) ListS3Objects(ctx context.Context, input *usecase.S3ObjectsListerInput) (*usecase.S3ObjectsListerOutput, error) {
return m(ctx, input)
}
1 change: 1 addition & 0 deletions app/interactor/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ func (u *FileUploader) UploadFile(ctx context.Context, input *usecase.FileUpload
if err != nil {
return nil, err
}

return &usecase.FileUploaderOutput{
ContentType: output.ContentType,
ContentLength: output.ContentLength,
Expand Down
93 changes: 56 additions & 37 deletions cmd/subcmd/s3hub/cp.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/fatih/color"
"github.com/gogf/gf/os/gfile"
"github.com/nao1215/rainbow/app/domain/model"
"github.com/nao1215/rainbow/app/usecase"
"github.com/nao1215/rainbow/cmd/subcmd"
Expand Down Expand Up @@ -65,15 +66,15 @@ type copyPathPair struct {
From string
// To is a path of destination.
To string
// copyType is a type of copy.
// Type indicates the direction of the copy operation: from local to S3, from S3 to local, or within S3.
Type copyType
}

// newCopyPathPair returns a new copyPathPair.
func newCopyPathPair(from, to string) *copyPathPair {
pair := &copyPathPair{
From: filepath.Clean(from),
To: filepath.Clean(to),
From: from,
To: to,
}
pair.Type = pair.copyType()
return pair
Expand Down Expand Up @@ -121,34 +122,48 @@ func (c *cpCmd) Do() error {
case copyTypeS3ToS3:
return c.s3ToS3()
case copyTypeUnknown:
fallthrough
default:
return fmt.Errorf("unsupported copy type. from=%s, to=%s",
color.YellowString(c.pair.From), color.YellowString(c.pair.To))
}
return nil
}

// copyTargetsInLocal returns a slice of target files in local.
func (c *cpCmd) copyTargetsInLocal() ([]string, error) {
if gfile.IsFile(c.pair.From) {
return []string{c.pair.From}, nil
}
targets, err := file.WalkDir(c.pair.From)
if err != nil {
return nil, err
}
return targets, nil
}

// localToS3 copies from local to S3.
func (c *cpCmd) localToS3() error {
targets, err := file.WalkDir(c.pair.From)
targets, err := c.copyTargetsInLocal()
if err != nil {
return err
}
toBucket, toKey := model.NewBucketWithoutProtocol(c.pair.To).Split()

toBucket, toKey := model.NewBucketWithoutProtocol(c.pair.To).Split()
fileNum := len(targets)

for i, v := range targets {
data, err := os.ReadFile(filepath.Clean(v))
if err != nil {
return err
return fmt.Errorf("can not read file %s: %w", color.YellowString(v), err)
}

if _, err := c.s3hub.FileUploader.UploadFile(c.ctx, &usecase.FileUploaderInput{
Bucket: toBucket,
Region: c.s3hub.region,
Key: toKey,
Key: model.S3Key(filepath.Join(toKey.String(), filepath.Base(v))),
Data: data,
}); err != nil {
return err
return fmt.Errorf("can not upload file %s: %w", color.YellowString(v), err)
}
c.printf("[%d/%d] copy %s to %s\n",
i+1,
Expand All @@ -163,44 +178,29 @@ func (c *cpCmd) localToS3() error {
// s3ToLocal copies from S3 to local.
func (c *cpCmd) s3ToLocal() error {
fromBucket, fromKey := model.NewBucketWithoutProtocol(c.pair.From).Split()
_, toKey := model.NewBucketWithoutProtocol(c.pair.To).Split()

listOutput, err := c.s3hub.ListS3Objects(c.ctx, &usecase.S3ObjectsListerInput{
Bucket: fromBucket,
})
targets, err := c.filterS3Objects(fromBucket, fromKey)
if err != nil {
return err
}

targets := make([]model.S3Key, 0, len(listOutput.Objects))
for _, v := range listOutput.Objects {
if strings.Contains(v.S3Key.String(), fromKey.String()) {
targets = append(targets, v.S3Key)
}
}

if len(targets) == 0 {
return fmt.Errorf("no objects found. bucket=%s, key=%s",
color.YellowString(fromBucket.String()), color.YellowString(fromKey.String()))
}

fileNum := len(targets)
for i, v := range targets {
downloadOutput, err := c.s3hub.S3ObjectDownloader.DownloadS3Object(c.ctx, &usecase.S3ObjectDownloaderInput{
Bucket: fromBucket,
Key: v,
})
if err != nil {
return err
return fmt.Errorf("can not download s3 object=%s: %w",
color.YellowString(fromBucket.Join(v).WithProtocol().String()), err)
}

relativePath, err := filepath.Rel(fromKey.String(), v.String())
if err != nil {
return err
destinationPath := filepath.Clean(filepath.Join(c.pair.To, fromKey.String()))
if err := os.MkdirAll(filepath.Dir(destinationPath), 0750); err != nil {
return fmt.Errorf("can not create directory %s: %w", color.YellowString(filepath.Dir(destinationPath)), err)
}
destinationPath := filepath.Join(toKey.String(), relativePath)

if err := downloadOutput.S3Object.ToFile(destinationPath, 0644); err != nil {
return err
return fmt.Errorf("can not write file to %s: %w", color.YellowString(destinationPath), err)
}

c.printf("[%d/%d] copy %s to %s\n",
Expand All @@ -213,6 +213,29 @@ func (c *cpCmd) s3ToLocal() error {
return nil
}

// filterS3Objects returns a slice of S3Key that matches the fromKey.
func (c *cpCmd) filterS3Objects(fromBucket model.Bucket, fromKey model.S3Key) ([]model.S3Key, error) {
listOutput, err := c.s3hub.ListS3Objects(c.ctx, &usecase.S3ObjectsListerInput{
Bucket: fromBucket,
})
if err != nil {
return nil, fmt.Errorf("%w: bucket=%s", err, color.YellowString(fromBucket.String()))
}

targets := make([]model.S3Key, 0, len(listOutput.Objects))
for _, v := range listOutput.Objects {
if strings.Contains(filepath.Join(fromBucket.String(), v.S3Key.String()), fromKey.String()) {
targets = append(targets, v.S3Key)
}
}

if len(targets) == 0 {
return nil, fmt.Errorf("no objects found. bucket=%s, key=%s",
color.YellowString(fromBucket.String()), color.YellowString(fromKey.String()))
}
return targets, nil
}

// s3ToS3 copies from S3 to S3.
func (c *cpCmd) s3ToS3() error {
fromBucket, fromKey := model.NewBucketWithoutProtocol(c.pair.From).Split()
Expand All @@ -238,11 +261,7 @@ func (c *cpCmd) s3ToS3() error {

fileNum := len(targets)
for i, v := range targets {
relativePath, err := filepath.Rel(fromKey.String(), v.String())
if err != nil {
return err
}
destinationKey := model.S3Key(filepath.Join(toKey.String(), relativePath))
destinationKey := model.S3Key(filepath.Clean(filepath.Join(toKey.String(), v.String())))

if _, err := c.s3hub.S3ObjectCopier.CopyS3Object(c.ctx, &usecase.S3ObjectCopierInput{
SourceBucket: fromBucket,
Expand Down
Loading

0 comments on commit f4aa82b

Please sign in to comment.