diff --git a/Makefile b/Makefile index feafaba..6a48ece 100644 --- a/Makefile +++ b/Makefile @@ -23,7 +23,7 @@ clean: ## Clean project -rm -rf $(S3HUB) $(SPARE) cover.out cover.html test: ## Start unit test - env GOOS=$(GOOS) $(GO_TEST) -cover $(GO_PKGROOT) -coverprofile=cover.out + env GOOS=$(GOOS) $(GO_TEST) -coverpkg=./... -coverprofile=cover.out -cover ./... $(GO_TOOL) cover -html=cover.out -o cover.html coverage-tree: test ## Generate coverage tree diff --git a/app/domain/model/aws_test.go b/app/domain/model/aws_test.go index 12aec22..9a39d83 100644 --- a/app/domain/model/aws_test.go +++ b/app/domain/model/aws_test.go @@ -2,6 +2,8 @@ package model import ( "testing" + + "github.com/aws/aws-sdk-go-v2/aws" ) func TestNewAWSProfile(t *testing.T) { //nolint @@ -74,3 +76,47 @@ func TestAWSProfileString(t *testing.T) { }) } } + +func TestAWSConfig_Region(t *testing.T) { + t.Parallel() + + type fields struct { + Config *aws.Config + } + tests := []struct { + name string + fields fields + want Region + }{ + { + name: "If aws config region is ap-northeast-1, return RegionAPNortheast1", + fields: fields{ + Config: &aws.Config{ + Region: string(RegionAPNortheast1), + }, + }, + want: RegionAPNortheast1, + }, + { + name: "If aws config region isempty, return RegionUSEast1", + fields: fields{ + Config: &aws.Config{ + Region: "", + }, + }, + want: RegionUSEast1, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + c := &AWSConfig{ + Config: tt.fields.Config, + } + if got := c.Region(); got != tt.want { + t.Errorf("AWSConfig.Region() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/app/domain/model/s3.go b/app/domain/model/s3.go index 3e6e209..3f5b0cd 100644 --- a/app/domain/model/s3.go +++ b/app/domain/model/s3.go @@ -182,12 +182,21 @@ func NewBucketWithoutProtocol(s string) Bucket { // WithProtocol returns the Bucket with the protocol. func (b Bucket) WithProtocol() Bucket { + if strings.HasPrefix(b.String(), S3Protocol) { + return b + } return Bucket(S3Protocol + b.String()) } // Join returns the Bucket with the S3Key. // e.g. "bucket" + "key" -> "bucket/key" func (b Bucket) Join(key S3Key) Bucket { + if b.Empty() || key.Empty() { + return b + } + if strings.HasSuffix(key.String(), "/") { + key = S3Key(strings.TrimSuffix(key.String(), "/")) + } return Bucket(fmt.Sprintf("%s/%s", b.String(), key.String())) } @@ -325,26 +334,26 @@ type BucketSet struct { CreationDate time.Time } -// S3ObjectIdentifierSets is the set of the S3ObjectSet. -type S3ObjectIdentifierSets []S3ObjectIdentifier +// S3ObjectIdentifiers is the set of the S3ObjectSet. +type S3ObjectIdentifiers []S3ObjectIdentifier -// Len returns the length of the S3ObjectIdentifierSets. -func (s S3ObjectIdentifierSets) Len() int { +// Len returns the length of the S3ObjectIdentifiers. +func (s S3ObjectIdentifiers) Len() int { return len(s) } // Less defines the ordering of S3ObjectIdentifier instances. -func (s S3ObjectIdentifierSets) Less(i, j int) bool { +func (s S3ObjectIdentifiers) Less(i, j int) bool { return s[i].S3Key < s[j].S3Key } // Swap swaps the elements with indexes i and j. -func (s S3ObjectIdentifierSets) Swap(i, j int) { +func (s S3ObjectIdentifiers) Swap(i, j int) { s[i], s[j] = s[j], s[i] } // ToS3ObjectIdentifiers converts the S3ObjectSets to the ObjectIdentifiers. -func (s S3ObjectIdentifierSets) ToS3ObjectIdentifiers() []types.ObjectIdentifier { +func (s S3ObjectIdentifiers) ToS3ObjectIdentifiers() []types.ObjectIdentifier { ids := make([]types.ObjectIdentifier, 0, s.Len()) for _, o := range s { ids = append(ids, *o.ToAWSS3ObjectIdentifier()) @@ -389,6 +398,18 @@ func (k S3Key) IsAll() bool { } func (k S3Key) Join(key S3Key) S3Key { + if key.Empty() { + return k + } + if strings.HasPrefix(key.String(), "/") { + key = S3Key(strings.TrimPrefix(key.String(), "/")) + } + if strings.HasSuffix(key.String(), "/") { + key = S3Key(strings.TrimSuffix(key.String(), "/")) + } + if k.Empty() { + return key + } return S3Key(fmt.Sprintf("%s/%s", k.String(), key)) } diff --git a/app/domain/model/s3_test.go b/app/domain/model/s3_test.go index 8017ce0..ab84410 100644 --- a/app/domain/model/s3_test.go +++ b/app/domain/model/s3_test.go @@ -2,11 +2,19 @@ package model import ( + "bytes" "errors" + "os" "path/filepath" "runtime" + "sort" "strings" "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" ) func TestRegionString(t *testing.T) { @@ -317,6 +325,11 @@ func TestBucketValidate(t *testing.T) { b: Bucket(""), wantErr: true, }, + { + name: "failure. bucket name is too short", + b: Bucket("ab"), + wantErr: true, + }, } for _, tt := range tests { tt := tt @@ -725,3 +738,556 @@ func TestS3Key_IsAll(t *testing.T) { }) } } + +func TestNewDeleteRetryCount(t *testing.T) { + t.Parallel() + + type args struct { + i int + } + tests := []struct { + name string + args args + want DeleteObjectsRetryCount + }{ + { + name: "input is 1, NewDeleteRetryCount() returns 1", + args: args{ + i: 1, + }, + want: DeleteObjectsRetryCount(1), + }, + { + name: "input is 0, NewDeleteRetryCount() returns 0", + args: args{ + i: 0, + }, + want: DeleteObjectsRetryCount(0), + }, + { + name: "input is -1, NewDeleteRetryCount() returns 0", + args: args{ + i: -1, + }, + want: DeleteObjectsRetryCount(0), + }, + { + name: "input is over MaxS3DeleteObjectsRetryCount, NewDeleteRetryCount() returns MaxS3DeleteObjectsRetryCount", + args: args{ + i: MaxS3DeleteObjectsRetryCount + 1, + }, + want: DeleteObjectsRetryCount(MaxS3DeleteObjectsRetryCount), + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := NewDeleteRetryCount(tt.args.i); got != tt.want { + t.Errorf("NewDeleteRetryCount() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestBucket_WithProtocol(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + b Bucket + want Bucket + }{ + { + name: "If Bucket is 'abc', WithProtocol() returns 's3://abc'", + b: Bucket("abc"), + want: Bucket("s3://abc"), + }, + { + name: "If Bucket is 's3://abc', WithProtocol() returns 's3://abc'", + b: Bucket("s3://abc"), + want: Bucket("s3://abc"), + }, + { + name: "If Bucket is 's3://abc/def', WithProtocol() returns 's3://abc/def'", + b: Bucket("s3://abc/def"), + want: Bucket("s3://abc/def"), + }, + { + name: "If Bucket is '', WithProtocol() returns 's3://'", + b: Bucket(""), + want: Bucket("s3://"), + }, + { + name: "If Bucket is 's3://', WithProtocol() returns 's3://'", + b: Bucket("s3://"), + want: Bucket("s3://"), + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.b.WithProtocol(); got != tt.want { + t.Errorf("Bucket.WithProtocol() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewBucketWithoutProtocol(t *testing.T) { + t.Parallel() + type args struct { + s string + } + tests := []struct { + name string + args args + want Bucket + }{ + { + name: "If input is 's3://abc', NewBucketWithoutProtocol() returns 'abc'", + args: args{ + s: "s3://abc", + }, + want: Bucket("abc"), + }, + { + name: "If input is 's3://abc/def', NewBucketWithoutProtocol() returns 'abc/def'", + args: args{ + s: "s3://abc/def", + }, + want: Bucket("abc/def"), + }, + { + name: "If input is 'abc', NewBucketWithoutProtocol() returns 'abc'", + args: args{ + s: "abc", + }, + want: Bucket("abc"), + }, + { + name: "If input is 'abc/def', NewBucketWithoutProtocol() returns 'abc/def'", + args: args{ + s: "abc/def", + }, + want: Bucket("abc/def"), + }, + { + name: "If input is '', NewBucketWithoutProtocol() returns ''", + args: args{ + s: "", + }, + want: Bucket(""), + }, + { + name: "If input is 's3://', NewBucketWithoutProtocol() returns ''", + args: args{ + s: "s3://", + }, + want: Bucket(""), + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := NewBucketWithoutProtocol(tt.args.s); got != tt.want { + t.Errorf("NewBucketWithoutProtocol() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestBucket_Join(t *testing.T) { + t.Parallel() + + type args struct { + key S3Key + } + tests := []struct { + name string + b Bucket + args args + want Bucket + }{ + { + name: "If Bucket is 'abc' and key is 'def', Join() returns 'abc/def'", + b: Bucket("abc"), + args: args{ + S3Key("def"), + }, + want: Bucket("abc/def"), + }, + { + name: "If Bucket is 'abc' and key is 'def/ghi', Join() returns 'abc/def/ghi'", + b: Bucket("abc"), + args: args{ + S3Key("def/ghi"), + }, + want: Bucket("abc/def/ghi"), + }, + { + name: "If Bucket is 'abc' and key is '', Join() returns 'abc'", + b: Bucket("abc"), + args: args{ + S3Key(""), + }, + want: Bucket("abc"), + }, + { + name: "If Bucket is 'abc' and key is 'def/', Join() returns 'abc/def'", + b: Bucket("abc"), + args: args{ + S3Key("def/"), + }, + want: Bucket("abc/def"), + }, + { + name: "If Bucket is '' and key is 'def', Join() returns ''", + b: Bucket(""), + args: args{ + S3Key("def"), + }, + want: Bucket(""), + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.b.Join(tt.args.key); got != tt.want { + t.Errorf("Bucket.Join() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestS3ObjectIdentifiers_Len(t *testing.T) { + t.Parallel() + tests := []struct { + name string + s S3ObjectIdentifiers + want int + }{ + { + name: "If S3ObjectIdentifiers has two S3ObjectIdentifierSet, Len() returns 2", + s: S3ObjectIdentifiers{S3ObjectIdentifier{}, S3ObjectIdentifier{}}, + want: 2, + }, + { + name: "If S3ObjectIdentifiers is empty, Len() returns 0", + s: S3ObjectIdentifiers{}, + want: 0, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.s.Len(); got != tt.want { + t.Errorf("S3ObjectIdentifiers.Len() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSortS3ObjectIdentifiers(t *testing.T) { + t.Parallel() + t.Run("If S3ObjectIdentifiers has three S3ObjectIdentifierSet, sort.Sort returns sorted S3ObjectIdentifiers", func(t *testing.T) { + t.Parallel() + s := S3ObjectIdentifiers{ + S3ObjectIdentifier{ + S3Key: S3Key("ghi"), + }, + S3ObjectIdentifier{ + S3Key: S3Key("abc"), + }, + S3ObjectIdentifier{ + S3Key: S3Key("def"), + }, + } + want := S3ObjectIdentifiers{ + S3ObjectIdentifier{ + S3Key: S3Key("abc"), + }, + S3ObjectIdentifier{ + S3Key: S3Key("def"), + }, + S3ObjectIdentifier{ + S3Key: S3Key("ghi"), + }, + } + sort.Sort(s) + if diff := cmp.Diff(s, want); diff != "" { + t.Errorf("sort.Sort() mismatch (-want +got):\n%s", diff) + } + }) +} + +func TestS3ObjectIdentifiers_ToS3ObjectIdentifiers(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + s S3ObjectIdentifiers + want []types.ObjectIdentifier + }{ + { + name: "If S3ObjectIdentifiers has two S3ObjectIdentifierSet, ToS3ObjectIdentifiers() returns []types.ObjectIdentifier", + s: S3ObjectIdentifiers{ + { + S3Key: S3Key("abc"), + VersionID: VersionID("def"), + }, + { + S3Key: S3Key("ghi"), + VersionID: VersionID("jkl"), + }, + { + S3Key: S3Key("mno"), + VersionID: VersionID("pqr"), + }, + }, + want: []types.ObjectIdentifier{ + { + Key: aws.String("abc"), + VersionId: aws.String("def"), + }, + { + Key: aws.String("ghi"), + VersionId: aws.String("jkl"), + }, + { + Key: aws.String("mno"), + VersionId: aws.String("pqr"), + }, + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := tt.s.ToS3ObjectIdentifiers() + + opt := cmpopts.IgnoreUnexported(types.ObjectIdentifier{}) + if diff := cmp.Diff(got, tt.want, opt); diff != "" { + t.Errorf("S3ObjectIdentifiers.ToS3ObjectIdentifiers() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestS3Key_Join(t *testing.T) { + t.Parallel() + + type args struct { + key S3Key + } + tests := []struct { + name string + k S3Key + args args + want S3Key + }{ + { + name: "If S3Key is 'abc' and key is 'def', Join() returns 'abc/def'", + k: S3Key("abc"), + args: args{ + key: S3Key("def"), + }, + want: S3Key("abc/def"), + }, + { + name: "If S3Key is 'abc' and key is 'def/ghi', Join() returns 'abc/def/ghi'", + k: S3Key("abc"), + args: args{ + key: S3Key("def/ghi"), + }, + want: S3Key("abc/def/ghi"), + }, + { + name: "If S3Key is 'abc' and key is '', Join() returns 'abc'", + k: S3Key("abc"), + args: args{ + key: S3Key(""), + }, + want: S3Key("abc"), + }, + { + name: "If S3Key is 'abc' and key is 'def/', Join() returns 'abc/def'", + k: S3Key("abc"), + args: args{ + key: S3Key("def/"), + }, + want: S3Key("abc/def"), + }, + { + name: "If S3Key is '' and key is '/def', Join() returns 'def'", + k: S3Key(""), + args: args{ + key: S3Key("/def"), + }, + want: S3Key("def"), + }, + { + name: "If S3Key is '' and key is 'def', Join() returns 'def'", + args: args{ + key: S3Key("def"), + }, + want: S3Key("def"), + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.k.Join(tt.args.key); got != tt.want { + t.Errorf("S3Key.Join() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestS3Object_ToFile(t *testing.T) { + t.Parallel() + + t.Run("If S3Object is 'abc', ToFile() writes 'abc' to the file", func(t *testing.T) { + t.Parallel() + + want := []byte("abc") + obj := NewS3Object(want) + tmpDir := os.TempDir() + tmpFilePath := filepath.Join(tmpDir, "s3object.txt") + + if err := obj.ToFile(tmpFilePath, 0600); err != nil { + t.Fatalf("S3Object.ToFile() error = %v", err) + } + + got, err := os.ReadFile(filepath.Clean(tmpFilePath)) + if err != nil { + t.Fatalf("os.ReadFile() error = %v", err) + } + + if diff := cmp.Diff(got, want); diff != "" { + t.Errorf("S3Object.ToFile() mismatch (-want +got):\n%s", diff) + } + + if err := os.RemoveAll(tmpFilePath); err != nil { + t.Fatalf("os.RemoveAll() error = %v", err) + } + }) +} + +func TestS3Object_ContentType(t *testing.T) { + t.Parallel() + + t.Run("If S3Object is png file, ContentType() returns 'image/png'", func(t *testing.T) { + t.Parallel() + + b, err := os.ReadFile(filepath.Join("testdata", "lena.png")) + if err != nil { + t.Fatalf("os.ReadFile() error = %v", err) + } + + got := NewS3Object(b).ContentType() + want := "image/png" + if diff := cmp.Diff(got, want); diff != "" { + t.Errorf("S3Object.ContentType() mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("If S3Object is json file, ContentType() returns 'application/json'", func(t *testing.T) { + t.Parallel() + + b, err := os.ReadFile(filepath.Join("testdata", "s3policy.json")) + if err != nil { + t.Fatalf("os.ReadFile() error = %v", err) + } + + got := NewS3Object(b).ContentType() + want := "application/json" + if diff := cmp.Diff(got, want); diff != "" { + t.Errorf("S3Object.ContentType() mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("If S3Object is markdown file, ContentType() returns 'text/plain; charset=utf-8'", func(t *testing.T) { + t.Parallel() + + b, err := os.ReadFile(filepath.Join("testdata", "sample.md")) + if err != nil { + t.Fatalf("os.ReadFile() error = %v", err) + } + + got := NewS3Object(b).ContentType() + want := "text/plain; charset=utf-8" + + if diff := cmp.Diff(got, want); diff != "" { + t.Errorf("S3Object.ContentType() mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("If S3Object is empty, ContentType() returns 'text/plain'", func(t *testing.T) { + t.Parallel() + + got := NewS3Object([]byte{}).ContentType() + want := "text/plain" + + if diff := cmp.Diff(got, want); diff != "" { + t.Errorf("S3Object.ContentType() mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("If S3Object is nil, ContentType() returns 'text/plain'", func(t *testing.T) { + t.Parallel() + + got := NewS3Object(nil).ContentType() + want := "text/plain" + + if diff := cmp.Diff(got, want); diff != "" { + t.Errorf("S3Object.ContentType() mismatch (-want +got):\n%s", diff) + } + }) +} + +func TestS3Object_ContentLength(t *testing.T) { + t.Parallel() + + type fields struct { + Buffer *bytes.Buffer + } + tests := []struct { + name string + fields fields + want int64 + }{ + { + name: "If S3Object is 'abc', ContentLength() returns 3", + fields: fields{ + Buffer: bytes.NewBuffer([]byte("abc")), + }, + want: 3, + }, + { + name: "If S3Object is empty, ContentLength() returns 0", + fields: fields{ + Buffer: bytes.NewBuffer([]byte{}), + }, + want: 0, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + s := &S3Object{ + Buffer: tt.fields.Buffer, + } + if got := s.ContentLength(); got != tt.want { + t.Errorf("S3Object.ContentLength() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/app/domain/model/testdata/lena.png b/app/domain/model/testdata/lena.png new file mode 100644 index 0000000..59ef68a Binary files /dev/null and b/app/domain/model/testdata/lena.png differ diff --git a/app/domain/model/testdata/sample.md b/app/domain/model/testdata/sample.md new file mode 100644 index 0000000..5f6b797 --- /dev/null +++ b/app/domain/model/testdata/sample.md @@ -0,0 +1,5 @@ +# Sample markdown +## sample header 1 +This is comment + +## sample header 2 \ No newline at end of file diff --git a/app/domain/service/s3.go b/app/domain/service/s3.go index 6ff541b..4db0c42 100644 --- a/app/domain/service/s3.go +++ b/app/domain/service/s3.go @@ -76,7 +76,7 @@ type S3ObjectsDeleterInput struct { // Region is the region of the bucket that you want to delete. Region model.Region // S3ObjectSets is the list of the objects to delete. - S3ObjectSets model.S3ObjectIdentifierSets + S3ObjectSets model.S3ObjectIdentifiers } // S3ObjectsDeleterOutput is the output of the DeleteBucketObjects method. @@ -96,7 +96,7 @@ type S3ObjectsListerInput struct { // S3ObjectsListerOutput is the output of the ListBucketObjects method. type S3ObjectsListerOutput struct { // Objects is the list of the objects. - Objects model.S3ObjectIdentifierSets + Objects model.S3ObjectIdentifiers } // S3ObjectsLister is the interface that wraps the basic ListBucketObjects method. diff --git a/app/external/s3.go b/app/external/s3.go index e45881f..580daee 100644 --- a/app/external/s3.go +++ b/app/external/s3.go @@ -242,7 +242,7 @@ func NewS3ObjectsLister(client *s3.Client) *S3ObjectsLister { // ListS3Objects lists the objects in the bucket. func (c *S3ObjectsLister) ListS3Objects(ctx context.Context, input *service.S3ObjectsListerInput) (*service.S3ObjectsListerOutput, error) { - var objects model.S3ObjectIdentifierSets + var objects model.S3ObjectIdentifiers in := &s3.ListObjectsV2Input{ Bucket: aws.String(input.Bucket.String()), MaxKeys: aws.Int32(model.MaxS3Keys), diff --git a/app/usecase/s3.go b/app/usecase/s3.go index 9d08d6a..34371c9 100644 --- a/app/usecase/s3.go +++ b/app/usecase/s3.go @@ -46,7 +46,7 @@ type S3ObjectsListerInput struct { // S3ObjectsListerOutput is the output of the ListObjects method. type S3ObjectsListerOutput struct { // Objects is the list of the objects. - Objects model.S3ObjectIdentifierSets + Objects model.S3ObjectIdentifiers } // S3ObjectsLister is the interface that wraps the basic ListObjects method. @@ -73,7 +73,7 @@ type S3ObjectsDeleterInput struct { // Bucket is the name of the bucket that you want to delete. Bucket model.Bucket // S3ObjectSets is the list of the objects to delete. - S3ObjectSets model.S3ObjectIdentifierSets + S3ObjectSets model.S3ObjectIdentifiers } // S3ObjectsDeleterOutput is the output of the DeleteObjects method. diff --git a/cmd/subcmd/s3hub/rm.go b/cmd/subcmd/s3hub/rm.go index b3d99f2..e0cf86f 100644 --- a/cmd/subcmd/s3hub/rm.go +++ b/cmd/subcmd/s3hub/rm.go @@ -135,7 +135,7 @@ func (r *rmCmd) remove(bucket model.Bucket, key model.S3Key) error { func (r *rmCmd) removeObject(bucket model.Bucket, key model.S3Key) error { if _, err := r.S3App.S3ObjectsDeleter.DeleteS3Objects(r.ctx, &usecase.S3ObjectsDeleterInput{ Bucket: bucket, - S3ObjectSets: model.S3ObjectIdentifierSets{ + S3ObjectSets: model.S3ObjectIdentifiers{ model.S3ObjectIdentifier{ S3Key: key, },