From 95dadf8d9f1349db723c02a6f9c7362d0577666f Mon Sep 17 00:00:00 2001 From: Joel Diaz Date: Wed, 11 Nov 2020 10:31:57 -0500 Subject: [PATCH] UPSTREAM: docker/distribution: 3296: allow pointing to an AWS config file as a parameter for the s3 driver Recognize a new parameter when setting up the AWS client so that a generic AWS config file can be used instead of having to specify AWS access and secret keys. This should allow someone to use different authentication methods beyond just access key, secret key (and optionally session token). Using the current supported auth methods a valid file would look like: ``` [default] aws_access_key_id = AKMYAWSACCCESSKEYID aws_secret_access_key = myawssecretaccesskey ``` But you can also specify alternative auth methods: ``` [default] role_arn = arn:aws:iam:ACCOUNT_NUM:role/ROLE_NAME web_identity_token_file = /path/to/token ``` Signed-off-by: Tiger Kaovilai --- registry/storage/driver/s3-aws/s3.go | 235 +++++++++++------- registry/storage/driver/s3-aws/s3_test.go | 165 ++++++------ registry/storage/driver/storagedriver.go | 30 ++- .../v3/registry/storage/driver/s3-aws/s3.go | 235 +++++++++++------- .../registry/storage/driver/storagedriver.go | 30 ++- 5 files changed, 428 insertions(+), 267 deletions(-) diff --git a/registry/storage/driver/s3-aws/s3.go b/registry/storage/driver/s3-aws/s3.go index 7e0c48650d2..5d18d91bcd5 100644 --- a/registry/storage/driver/s3-aws/s3.go +++ b/registry/storage/driver/s3-aws/s3.go @@ -15,9 +15,9 @@ import ( "bytes" "context" "crypto/tls" + "errors" "fmt" "io" - "io/ioutil" "math" "net/http" "path/filepath" @@ -36,7 +36,6 @@ import ( "github.com/aws/aws-sdk-go/service/s3" dcontext "github.com/distribution/distribution/v3/context" - "github.com/distribution/distribution/v3/registry/client/transport" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver/base" "github.com/distribution/distribution/v3/registry/storage/driver/factory" @@ -93,7 +92,7 @@ var validRegions = map[string]struct{}{} // validObjectACLs contains known s3 object Acls var validObjectACLs = map[string]struct{}{} -//DriverParameters A struct that encapsulates all of the driver parameters after all values have been set +// DriverParameters A struct that encapsulates all of the driver parameters after all values have been set type DriverParameters struct { AccessKey string SecretKey string @@ -118,6 +117,8 @@ type DriverParameters struct { SessionToken string UseDualStack bool Accelerate bool + VirtualHostedStyle bool + CredentialsConfigPath string } func init() { @@ -197,6 +198,11 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) { secretKey = "" } + credentialsConfigPath := parameters["credentialsconfigpath"] + if credentialsConfigPath == nil { + credentialsConfigPath = "" + } + regionEndpoint := parameters["regionendpoint"] if regionEndpoint == nil { regionEndpoint = "" @@ -417,6 +423,23 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) { return nil, fmt.Errorf("the multipartcombinesmallpart parameter should be a boolean") } + virtualHostedStyleBool := false + virtualHostedStyle := parameters["virtualhostedstyle"] + switch virtualHostedStyle := virtualHostedStyle.(type) { + case string: + b, err := strconv.ParseBool(virtualHostedStyle) + if err != nil { + return nil, fmt.Errorf("the virtualHostedStyle parameter should be a boolean") + } + virtualHostedStyleBool = b + case bool: + virtualHostedStyleBool = virtualHostedStyle + case nil: + // do nothing + default: + return nil, fmt.Errorf("the virtualHostedStyle parameter should be a boolean") + } + sessionToken := "" accelerateBool := false @@ -460,6 +483,8 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) { fmt.Sprint(sessionToken), useDualStackBool, accelerateBool, + virtualHostedStyleBool, + fmt.Sprint(credentialsConfigPath), } return New(params) @@ -503,6 +528,12 @@ func New(params DriverParameters) (*Driver, error) { return nil, fmt.Errorf("on Amazon S3 this storage driver can only be used with v4 authentication") } + // Makes no sense to provide access/secret key and the location of a + // config file with credentials. + if (params.AccessKey != "" || params.SecretKey != "") && params.CredentialsConfigPath != "" { + return nil, fmt.Errorf("cannot set both access/secret key and credentials file path") + } + awsConfig := aws.NewConfig() if params.AccessKey != "" && params.SecretKey != "" { @@ -515,6 +546,9 @@ func New(params DriverParameters) (*Driver, error) { } if params.RegionEndpoint != "" { + if !params.VirtualHostedStyle { + awsConfig.WithS3ForcePathStyle(true) + } awsConfig.WithEndpoint(params.RegionEndpoint) awsConfig.WithS3ForcePathStyle(params.ForcePathStyle) } @@ -522,32 +556,35 @@ func New(params DriverParameters) (*Driver, error) { awsConfig.WithS3UseAccelerate(params.Accelerate) awsConfig.WithRegion(params.Region) awsConfig.WithDisableSSL(!params.Secure) - if params.UseDualStack { - awsConfig.UseDualStackEndpoint = endpoints.DualStackEndpointStateEnabled - } + awsConfig.WithUseDualStack(params.UseDualStack) - if params.UserAgent != "" || params.SkipVerify { - httpTransport := http.DefaultTransport - if params.SkipVerify { - httpTransport = &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - } - if params.UserAgent != "" { - awsConfig.WithHTTPClient(&http.Client{ - Transport: transport.NewTransport(httpTransport, transport.NewHeaderRequestModifier(http.Header{http.CanonicalHeaderKey("User-Agent"): []string{params.UserAgent}})), - }) - } else { - awsConfig.WithHTTPClient(&http.Client{ - Transport: transport.NewTransport(httpTransport), - }) + if params.SkipVerify { + httpTransport := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } + awsConfig.WithHTTPClient(&http.Client{ + Transport: httpTransport, + }) } - sess, err := session.NewSession(awsConfig) + sessionOptions := session.Options{ + Config: *awsConfig, + } + if params.CredentialsConfigPath != "" { + sessionOptions.SharedConfigState = session.SharedConfigEnable + sessionOptions.SharedConfigFiles = []string{ + params.CredentialsConfigPath, + } + } + sess, err := session.NewSessionWithOptions(sessionOptions) if err != nil { return nil, fmt.Errorf("failed to create new session with aws config: %v", err) } + + if params.UserAgent != "" { + sess.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler(params.UserAgent)) + } + s3obj := s3.New(sess) // enable S3 compatible signature v2 signing instead @@ -606,7 +643,7 @@ func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) { if err != nil { return nil, err } - return ioutil.ReadAll(reader) + return io.ReadAll(reader) } // PutContent stores the []byte content at a location designated by "path". @@ -632,10 +669,9 @@ func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.Read Key: aws.String(d.s3Path(path)), Range: aws.String("bytes=" + strconv.FormatInt(offset, 10) + "-"), }) - if err != nil { if s3Err, ok := err.(awserr.Error); ok && s3Err.Code() == "InvalidRange" { - return ioutil.NopCloser(bytes.NewReader(nil)), nil + return io.NopCloser(bytes.NewReader(nil)), nil } return nil, parseError(path, err) @@ -696,7 +732,7 @@ func (d *driver) Writer(ctx context.Context, path string, appendParam bool) (sto return nil, parseError(path, err) } allParts = append(allParts, partsList.Parts...) - for *resp.IsTruncated { + for *partsList.IsTruncated { partsList, err = d.S3.ListParts(&s3.ListPartsInput{ Bucket: aws.String(d.Bucket), Key: aws.String(key), @@ -923,54 +959,71 @@ func (d *driver) copy(ctx context.Context, sourcePath string, destPath string) e return err } -func min(a, b int) int { - if a < b { - return a - } - return b -} - // Delete recursively deletes all objects stored at "path" and its subpaths. // We must be careful since S3 does not guarantee read after delete consistency func (d *driver) Delete(ctx context.Context, path string) error { s3Objects := make([]*s3.ObjectIdentifier, 0, listMax) - - // manually add the given path if it's a file - stat, err := d.Stat(ctx, path) - if err != nil { - return err - } - if stat != nil && !stat.IsDir() { - path := d.s3Path(path) - s3Objects = append(s3Objects, &s3.ObjectIdentifier{ - Key: &path, - }) - } - - // list objects under the given path as a subpath (suffix with slash "/") - s3Path := d.s3Path(path) + "/" + s3Path := d.s3Path(path) listObjectsInput := &s3.ListObjectsV2Input{ Bucket: aws.String(d.Bucket), Prefix: aws.String(s3Path), } -ListLoop: + for { // list all the objects resp, err := d.S3.ListObjectsV2(listObjectsInput) // resp.Contents can only be empty on the first call // if there were no more results to return after the first call, resp.IsTruncated would have been false - // and the loop would be exited without recalling ListObjects + // and the loop would exit without recalling ListObjects if err != nil || len(resp.Contents) == 0 { - break ListLoop + return storagedriver.PathNotFoundError{Path: path} } for _, key := range resp.Contents { + // Skip if we encounter a key that is not a subpath (so that deleting "/a" does not delete "/ab"). + if len(*key.Key) > len(s3Path) && (*key.Key)[len(s3Path)] != '/' { + continue + } s3Objects = append(s3Objects, &s3.ObjectIdentifier{ Key: key.Key, }) } + // Delete objects only if the list is not empty, otherwise S3 API returns a cryptic error + if len(s3Objects) > 0 { + // NOTE: according to AWS docs https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html + // by default the response returns up to 1,000 key names. The response _might_ contain fewer keys but it will never contain more. + // 10000 keys is coincidentally (?) also the max number of keys that can be deleted in a single Delete operation, so we'll just smack + // Delete here straight away and reset the object slice when successful. + resp, err := d.S3.DeleteObjects(&s3.DeleteObjectsInput{ + Bucket: aws.String(d.Bucket), + Delete: &s3.Delete{ + Objects: s3Objects, + Quiet: aws.Bool(false), + }, + }) + if err != nil { + return err + } + + if len(resp.Errors) > 0 { + // NOTE: AWS SDK s3.Error does not implement error interface which + // is pretty intensely sad, so we have to do away with this for now. + errs := make([]error, 0, len(resp.Errors)) + for _, err := range resp.Errors { + errs = append(errs, errors.New(err.String())) + } + return storagedriver.Errors{ + DriverName: driverName, + Errs: errs, + } + } + } + // NOTE: we don't want to reallocate + // the slice so we simply "reset" it + s3Objects = s3Objects[:0] + // resp.Contents must have at least one element or we would have returned not found listObjectsInput.StartAfter = resp.Contents[len(resp.Contents)-1].Key @@ -981,35 +1034,17 @@ ListLoop: } } - total := len(s3Objects) - if total == 0 { - return storagedriver.PathNotFoundError{Path: path} - } - - // need to chunk objects into groups of 1000 per s3 restrictions - for i := 0; i < total; i += 1000 { - _, err := d.S3.DeleteObjects(&s3.DeleteObjectsInput{ - Bucket: aws.String(d.Bucket), - Delete: &s3.Delete{ - Objects: s3Objects[i:min(i+1000, total)], - Quiet: aws.Bool(false), - }, - }) - if err != nil { - return err - } - } return nil } // URLFor returns a URL which may be used to retrieve the content stored at the given path. // May return an UnsupportedMethodErr in certain StorageDriver implementations. func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { - methodString := "GET" + methodString := http.MethodGet method, ok := options["method"] if ok { methodString, ok = method.(string) - if !ok || (methodString != "GET" && methodString != "HEAD") { + if !ok || (methodString != http.MethodGet && methodString != http.MethodHead) { return "", storagedriver.ErrUnsupportedMethod{} } } @@ -1026,12 +1061,12 @@ func (d *driver) URLFor(ctx context.Context, path string, options map[string]int var req *request.Request switch methodString { - case "GET": + case http.MethodGet: req, _ = d.S3.GetObjectRequest(&s3.GetObjectInput{ Bucket: aws.String(d.Bucket), Key: aws.String(d.s3Path(path)), }) - case "HEAD": + case http.MethodHead: req, _ = d.S3.HeadObjectRequest(&s3.HeadObjectInput{ Bucket: aws.String(d.Bucket), Key: aws.String(d.s3Path(path)), @@ -1077,7 +1112,7 @@ func (d *driver) doWalk(parentCtx context.Context, objectCount *int64, path, pre // the most recent skip directory to avoid walking over undesirable files prevSkipDir string ) - prevDir = prefix + path + prevDir = strings.Replace(path, d.s3Path(""), prefix, 1) listObjectsInput := &s3.ListObjectsV2Input{ Bucket: aws.String(d.Bucket), @@ -1166,16 +1201,22 @@ func (d *driver) doWalk(parentCtx context.Context, objectCount *int64, path, pre // directoryDiff finds all directories that are not in common between // the previous and current paths in sorted order. // -// Eg 1 directoryDiff("/path/to/folder", "/path/to/folder/folder/file") -// => [ "/path/to/folder/folder" ], -// Eg 2 directoryDiff("/path/to/folder/folder1", "/path/to/folder/folder2/file") -// => [ "/path/to/folder/folder2" ] -// Eg 3 directoryDiff("/path/to/folder/folder1/file", "/path/to/folder/folder2/file") -// => [ "/path/to/folder/folder2" ] -// Eg 4 directoryDiff("/path/to/folder/folder1/file", "/path/to/folder/folder2/folder1/file") -// => [ "/path/to/folder/folder2", "/path/to/folder/folder2/folder1" ] -// Eg 5 directoryDiff("/", "/path/to/folder/folder/file") -// => [ "/path", "/path/to", "/path/to/folder", "/path/to/folder/folder" ], +// # Examples +// +// directoryDiff("/path/to/folder", "/path/to/folder/folder/file") +// // => [ "/path/to/folder/folder" ] +// +// directoryDiff("/path/to/folder/folder1", "/path/to/folder/folder2/file") +// // => [ "/path/to/folder/folder2" ] +// +// directoryDiff("/path/to/folder/folder1/file", "/path/to/folder/folder2/file") +// // => [ "/path/to/folder/folder2" ] +// +// directoryDiff("/path/to/folder/folder1/file", "/path/to/folder/folder2/folder1/file") +// // => [ "/path/to/folder/folder2", "/path/to/folder/folder2/folder1" ] +// +// directoryDiff("/", "/path/to/folder/folder/file") +// // => [ "/path", "/path/to", "/path/to/folder", "/path/to/folder/folder" ] func directoryDiff(prev, current string) []string { var paths []string @@ -1351,7 +1392,7 @@ func (w *writer) Write(p []byte) (int, error) { } defer resp.Body.Close() w.parts = nil - w.readyPart, err = ioutil.ReadAll(resp.Body) + w.readyPart, err = io.ReadAll(resp.Body) if err != nil { return 0, err } @@ -1463,6 +1504,30 @@ func (w *writer) Commit() error { }) } + // This is an edge case when we are trying to upload an empty chunk of data using + // a MultiPart upload. As a result we are trying to complete the MultipartUpload + // with an empty slice of `completedUploadedParts` which will always lead to 400 + // being returned from S3 See: https://docs.aws.amazon.com/sdk-for-go/api/service/s3/#CompletedMultipartUpload + // Solution: we upload an empty i.e. 0 byte part as a single part and then append it + // to the completedUploadedParts slice used to complete the Multipart upload. + if len(w.parts) == 0 { + resp, err := w.driver.S3.UploadPart(&s3.UploadPartInput{ + Bucket: aws.String(w.driver.Bucket), + Key: aws.String(w.key), + PartNumber: aws.Int64(1), + UploadId: aws.String(w.uploadID), + Body: bytes.NewReader(nil), + }) + if err != nil { + return err + } + + completedUploadedParts = append(completedUploadedParts, &s3.CompletedPart{ + ETag: resp.ETag, + PartNumber: aws.Int64(1), + }) + } + sort.Sort(completedUploadedParts) _, err = w.driver.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{ diff --git a/registry/storage/driver/s3-aws/s3_test.go b/registry/storage/driver/s3-aws/s3_test.go index 74a3226aab6..80d4284f40e 100644 --- a/registry/storage/driver/s3-aws/s3_test.go +++ b/registry/storage/driver/s3-aws/s3_test.go @@ -4,7 +4,6 @@ import ( "bytes" "errors" "fmt" - "io/ioutil" "math/rand" "os" "path" @@ -27,27 +26,34 @@ import ( // Hook up gocheck into the "go test" runner. func Test(t *testing.T) { check.TestingT(t) } -var s3DriverConstructor func(rootDirectory, storageClass string) (*Driver, error) -var skipS3 func() string +var ( + s3DriverConstructor func(rootDirectory, storageClass string) (*Driver, error) + skipS3 func() string +) func init() { - accessKey := os.Getenv("AWS_ACCESS_KEY") - secretKey := os.Getenv("AWS_SECRET_KEY") - bucket := os.Getenv("S3_BUCKET") - encrypt := os.Getenv("S3_ENCRYPT") - keyID := os.Getenv("S3_KEY_ID") - secure := os.Getenv("S3_SECURE") - skipVerify := os.Getenv("S3_SKIP_VERIFY") - v4Auth := os.Getenv("S3_V4_AUTH") - region := os.Getenv("AWS_REGION") - objectACL := os.Getenv("S3_OBJECT_ACL") - root, err := ioutil.TempDir("", "driver-") - regionEndpoint := os.Getenv("REGION_ENDPOINT") - forcePathStyle := os.Getenv("AWS_S3_FORCE_PATH_STYLE") - sessionToken := os.Getenv("AWS_SESSION_TOKEN") - useDualStack := os.Getenv("S3_USE_DUALSTACK") - combineSmallPart := os.Getenv("MULTIPART_COMBINE_SMALL_PART") - accelerate := os.Getenv("S3_ACCELERATE") + var ( + accessKey = os.Getenv("AWS_ACCESS_KEY") + secretKey = os.Getenv("AWS_SECRET_KEY") + bucket = os.Getenv("S3_BUCKET") + encrypt = os.Getenv("S3_ENCRYPT") + keyID = os.Getenv("S3_KEY_ID") + secure = os.Getenv("S3_SECURE") + skipVerify = os.Getenv("S3_SKIP_VERIFY") + v4Auth = os.Getenv("S3_V4_AUTH") + region = os.Getenv("AWS_REGION") + objectACL = os.Getenv("S3_OBJECT_ACL") + regionEndpoint = os.Getenv("REGION_ENDPOINT") + forcePathStyle = os.Getenv("AWS_S3_FORCE_PATH_STYLE") + sessionToken = os.Getenv("AWS_SESSION_TOKEN") + useDualStack = os.Getenv("S3_USE_DUALSTACK") + combineSmallPart = os.Getenv("MULTIPART_COMBINE_SMALL_PART") + accelerate = os.Getenv("S3_ACCELERATE") + virtualHostedStyle = os.Getenv("S3_VIRTUAL_HOSTED_STYLE") + credentialsConfigPath = os.Getenv("AWS_SHARED_CREDENTIALS_FILE") + ) + + root, err := os.MkdirTemp("", "driver-") if err != nil { panic(err) } @@ -114,6 +120,14 @@ func init() { } } + virtualHostedStyleBool := true + if virtualHostedStyle != "" { + virtualHostedStyleBool, err = strconv.ParseBool(virtualHostedStyle) + if err != nil { + return nil, err + } + } + parameters := DriverParameters{ accessKey, secretKey, @@ -138,6 +152,8 @@ func init() { sessionToken, useDualStackBool, accelerateBool, + virtualHostedStyleBool, + credentialsConfigPath, } return New(parameters) @@ -161,12 +177,7 @@ func TestEmptyRootList(t *testing.T) { t.Skip(skipS3()) } - validRoot, err := ioutil.TempDir("", "driver-") - if err != nil { - t.Fatalf("unexpected error creating temporary directory: %v", err) - } - defer os.Remove(validRoot) - + validRoot := t.TempDir() rootedDriver, err := s3DriverConstructor(validRoot, s3.StorageClassStandard) if err != nil { t.Fatalf("unexpected error creating rooted driver: %v", err) @@ -199,9 +210,9 @@ func TestEmptyRootList(t *testing.T) { } keys, _ = slashRootDriver.List(ctx, "/") - for _, path := range keys { - if !storagedriver.PathRegexp.MatchString(path) { - t.Fatalf("unexpected string in path: %q != %q", path, storagedriver.PathRegexp) + for _, p := range keys { + if !storagedriver.PathRegexp.MatchString(p) { + t.Fatalf("unexpected string in path: %q != %q", p, storagedriver.PathRegexp) } } } @@ -244,12 +255,7 @@ func TestStorageClass(t *testing.T) { t.Skip(skipS3()) } - rootDir, err := ioutil.TempDir("", "driver-") - if err != nil { - t.Fatalf("unexpected error creating temporary directory: %v", err) - } - defer os.Remove(rootDir) - + rootDir := t.TempDir() contents := []byte("contents") ctx := context.Background() for _, storageClass := range s3StorageClasses { @@ -302,13 +308,9 @@ func TestDelete(t *testing.T) { t.Skip(skipS3()) } - rootDir, err := ioutil.TempDir("", "driver-") - if err != nil { - t.Fatalf("unexpected error creating temporary directory: %v", err) - } - defer os.Remove(rootDir) + rootDir := t.TempDir() - driver, err := s3DriverConstructor(rootDir, s3.StorageClassStandard) + drvr, err := s3DriverConstructor(rootDir, s3.StorageClassStandard) if err != nil { t.Fatalf("unexpected error creating driver with standard storage: %v", err) } @@ -343,7 +345,7 @@ func TestDelete(t *testing.T) { return false } - var objs = []string{ + objs := []string{ "/file1", "/file1-2", "/file1/2", @@ -411,40 +413,40 @@ func TestDelete(t *testing.T) { } // objects to skip auto-created test case - var skipCase = map[string]bool{ + skipCase := map[string]bool{ // special case where deleting "/file1" also deletes "/file1/2" is tested explicitly "/file1": true, } // create a test case for each file - for _, path := range objs { - if skipCase[path] { + for _, p := range objs { + if skipCase[p] { continue } tcs = append(tcs, testCase{ - name: fmt.Sprintf("delete path:'%s'", path), - delete: path, - expected: []string{path}, + name: fmt.Sprintf("delete path:'%s'", p), + delete: p, + expected: []string{p}, }) } init := func() []string { // init file structure matching objs var created []string - for _, path := range objs { - err := driver.PutContent(context.Background(), path, []byte("content "+path)) + for _, p := range objs { + err := drvr.PutContent(context.Background(), p, []byte("content "+p)) if err != nil { - fmt.Printf("unable to init file %s: %s\n", path, err) + fmt.Printf("unable to init file %s: %s\n", p, err) continue } - created = append(created, path) + created = append(created, p) } return created } cleanup := func(objs []string) { var lastErr error - for _, path := range objs { - err := driver.Delete(context.Background(), path) + for _, p := range objs { + err := drvr.Delete(context.Background(), p) if err != nil { switch err.(type) { case storagedriver.PathNotFoundError: @@ -463,7 +465,7 @@ func TestDelete(t *testing.T) { t.Run(tc.name, func(t *testing.T) { objs := init() - err := driver.Delete(context.Background(), tc.delete) + err := drvr.Delete(context.Background(), tc.delete) if tc.err != nil { if err == nil { @@ -491,7 +493,7 @@ func TestDelete(t *testing.T) { return false } for _, path := range objs { - stat, err := driver.Stat(context.Background(), path) + stat, err := drvr.Stat(context.Background(), path) if err != nil { switch err.(type) { case storagedriver.PathNotFoundError: @@ -525,18 +527,14 @@ func TestWalk(t *testing.T) { t.Skip(skipS3()) } - rootDir, err := ioutil.TempDir("", "driver-") - if err != nil { - t.Fatalf("unexpected error creating temporary directory: %v", err) - } - defer os.Remove(rootDir) + rootDir := t.TempDir() - driver, err := s3DriverConstructor(rootDir, s3.StorageClassStandard) + drvr, err := s3DriverConstructor(rootDir, s3.StorageClassStandard) if err != nil { t.Fatalf("unexpected error creating driver with standard storage: %v", err) } - var fileset = []string{ + fileset := []string{ "/file1", "/folder1/file1", "/folder2/file1", @@ -547,22 +545,22 @@ func TestWalk(t *testing.T) { // create file structure matching fileset above var created []string - for _, path := range fileset { - err := driver.PutContent(context.Background(), path, []byte("content "+path)) + for _, p := range fileset { + err := drvr.PutContent(context.Background(), p, []byte("content "+p)) if err != nil { - fmt.Printf("unable to create file %s: %s\n", path, err) + fmt.Printf("unable to create file %s: %s\n", p, err) continue } - created = append(created, path) + created = append(created, p) } // cleanup defer func() { var lastErr error - for _, path := range created { - err := driver.Delete(context.Background(), path) + for _, p := range created { + err := drvr.Delete(context.Background(), p) if err != nil { - _ = fmt.Errorf("cleanup failed for path %s: %s", path, err) + _ = fmt.Errorf("cleanup failed for path %s: %s", p, err) lastErr = err } } @@ -664,7 +662,7 @@ func TestWalk(t *testing.T) { tc.from = "/" } t.Run(tc.name, func(t *testing.T) { - err := driver.Walk(context.Background(), tc.from, func(fileInfo storagedriver.FileInfo) error { + err := drvr.Walk(context.Background(), tc.from, func(fileInfo storagedriver.FileInfo) error { walked = append(walked, fileInfo.Path()) return tc.fn(fileInfo) }) @@ -684,12 +682,7 @@ func TestOverThousandBlobs(t *testing.T) { t.Skip(skipS3()) } - rootDir, err := ioutil.TempDir("", "driver-") - if err != nil { - t.Fatalf("unexpected error creating temporary directory: %v", err) - } - defer os.Remove(rootDir) - + rootDir := t.TempDir() standardDriver, err := s3DriverConstructor(rootDir, s3.StorageClassStandard) if err != nil { t.Fatalf("unexpected error creating driver with standard storage: %v", err) @@ -717,12 +710,7 @@ func TestMoveWithMultipartCopy(t *testing.T) { t.Skip(skipS3()) } - rootDir, err := ioutil.TempDir("", "driver-") - if err != nil { - t.Fatalf("unexpected error creating temporary directory: %v", err) - } - defer os.Remove(rootDir) - + rootDir := t.TempDir() d, err := s3DriverConstructor(rootDir, s3.StorageClassStandard) if err != nil { t.Fatalf("unexpected error creating driver: %v", err) @@ -771,12 +759,7 @@ func TestListObjectsV2(t *testing.T) { t.Skip(skipS3()) } - rootDir, err := ioutil.TempDir("", "driver-") - if err != nil { - t.Fatalf("unexpected error creating temporary directory: %v", err) - } - defer os.Remove(rootDir) - + rootDir := t.TempDir() d, err := s3DriverConstructor(rootDir, s3.StorageClassStandard) if err != nil { t.Fatalf("unexpected error creating driver: %v", err) @@ -789,8 +772,8 @@ func TestListObjectsV2(t *testing.T) { for i := 0; i < n; i++ { filePaths = append(filePaths, fmt.Sprintf("%s/%d", prefix, i)) } - for _, path := range filePaths { - if err := d.PutContent(ctx, path, []byte(path)); err != nil { + for _, p := range filePaths { + if err := d.PutContent(ctx, p, []byte(p)); err != nil { t.Fatalf("unexpected error putting content: %v", err) } } diff --git a/registry/storage/driver/storagedriver.go b/registry/storage/driver/storagedriver.go index 9a9b9a8f4e1..d573e6176df 100644 --- a/registry/storage/driver/storagedriver.go +++ b/registry/storage/driver/storagedriver.go @@ -17,14 +17,14 @@ type Version string // Major returns the major (primary) component of a version. func (version Version) Major() uint { - majorPart := strings.Split(string(version), ".")[0] + majorPart, _, _ := strings.Cut(string(version), ".") major, _ := strconv.ParseUint(majorPart, 10, 0) return uint(major) } // Minor returns the minor (secondary) component of a version. func (version Version) Minor() uint { - minorPart := strings.Split(string(version), ".")[1] + _, minorPart, _ := strings.Cut(string(version), ".") minor, _ := strconv.ParseUint(minorPart, 10, 0) return uint(minor) } @@ -66,7 +66,7 @@ type StorageDriver interface { Stat(ctx context.Context, path string) (FileInfo, error) // List returns a list of the objects that are direct descendants of the - //given path. + // given path. List(ctx context.Context, path string) ([]string, error) // Move moves an object stored at sourcePath to destPath, removing the @@ -169,3 +169,27 @@ type Error struct { func (err Error) Error() string { return fmt.Sprintf("%s: %s", err.DriverName, err.Enclosed) } + +// Errors provides the envelope for multiple errors +// for use within the storagedriver implementations. +type Errors struct { + DriverName string + Errs []error +} + +var _ error = Errors{} + +func (e Errors) Error() string { + switch len(e.Errs) { + case 0: + return "" + case 1: + return e.Errs[0].Error() + default: + msg := "errors:\n" + for _, err := range e.Errs { + msg += err.Error() + "\n" + } + return msg + } +} diff --git a/vendor/github.com/distribution/distribution/v3/registry/storage/driver/s3-aws/s3.go b/vendor/github.com/distribution/distribution/v3/registry/storage/driver/s3-aws/s3.go index 7e0c48650d2..5d18d91bcd5 100644 --- a/vendor/github.com/distribution/distribution/v3/registry/storage/driver/s3-aws/s3.go +++ b/vendor/github.com/distribution/distribution/v3/registry/storage/driver/s3-aws/s3.go @@ -15,9 +15,9 @@ import ( "bytes" "context" "crypto/tls" + "errors" "fmt" "io" - "io/ioutil" "math" "net/http" "path/filepath" @@ -36,7 +36,6 @@ import ( "github.com/aws/aws-sdk-go/service/s3" dcontext "github.com/distribution/distribution/v3/context" - "github.com/distribution/distribution/v3/registry/client/transport" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver/base" "github.com/distribution/distribution/v3/registry/storage/driver/factory" @@ -93,7 +92,7 @@ var validRegions = map[string]struct{}{} // validObjectACLs contains known s3 object Acls var validObjectACLs = map[string]struct{}{} -//DriverParameters A struct that encapsulates all of the driver parameters after all values have been set +// DriverParameters A struct that encapsulates all of the driver parameters after all values have been set type DriverParameters struct { AccessKey string SecretKey string @@ -118,6 +117,8 @@ type DriverParameters struct { SessionToken string UseDualStack bool Accelerate bool + VirtualHostedStyle bool + CredentialsConfigPath string } func init() { @@ -197,6 +198,11 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) { secretKey = "" } + credentialsConfigPath := parameters["credentialsconfigpath"] + if credentialsConfigPath == nil { + credentialsConfigPath = "" + } + regionEndpoint := parameters["regionendpoint"] if regionEndpoint == nil { regionEndpoint = "" @@ -417,6 +423,23 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) { return nil, fmt.Errorf("the multipartcombinesmallpart parameter should be a boolean") } + virtualHostedStyleBool := false + virtualHostedStyle := parameters["virtualhostedstyle"] + switch virtualHostedStyle := virtualHostedStyle.(type) { + case string: + b, err := strconv.ParseBool(virtualHostedStyle) + if err != nil { + return nil, fmt.Errorf("the virtualHostedStyle parameter should be a boolean") + } + virtualHostedStyleBool = b + case bool: + virtualHostedStyleBool = virtualHostedStyle + case nil: + // do nothing + default: + return nil, fmt.Errorf("the virtualHostedStyle parameter should be a boolean") + } + sessionToken := "" accelerateBool := false @@ -460,6 +483,8 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) { fmt.Sprint(sessionToken), useDualStackBool, accelerateBool, + virtualHostedStyleBool, + fmt.Sprint(credentialsConfigPath), } return New(params) @@ -503,6 +528,12 @@ func New(params DriverParameters) (*Driver, error) { return nil, fmt.Errorf("on Amazon S3 this storage driver can only be used with v4 authentication") } + // Makes no sense to provide access/secret key and the location of a + // config file with credentials. + if (params.AccessKey != "" || params.SecretKey != "") && params.CredentialsConfigPath != "" { + return nil, fmt.Errorf("cannot set both access/secret key and credentials file path") + } + awsConfig := aws.NewConfig() if params.AccessKey != "" && params.SecretKey != "" { @@ -515,6 +546,9 @@ func New(params DriverParameters) (*Driver, error) { } if params.RegionEndpoint != "" { + if !params.VirtualHostedStyle { + awsConfig.WithS3ForcePathStyle(true) + } awsConfig.WithEndpoint(params.RegionEndpoint) awsConfig.WithS3ForcePathStyle(params.ForcePathStyle) } @@ -522,32 +556,35 @@ func New(params DriverParameters) (*Driver, error) { awsConfig.WithS3UseAccelerate(params.Accelerate) awsConfig.WithRegion(params.Region) awsConfig.WithDisableSSL(!params.Secure) - if params.UseDualStack { - awsConfig.UseDualStackEndpoint = endpoints.DualStackEndpointStateEnabled - } + awsConfig.WithUseDualStack(params.UseDualStack) - if params.UserAgent != "" || params.SkipVerify { - httpTransport := http.DefaultTransport - if params.SkipVerify { - httpTransport = &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - } - if params.UserAgent != "" { - awsConfig.WithHTTPClient(&http.Client{ - Transport: transport.NewTransport(httpTransport, transport.NewHeaderRequestModifier(http.Header{http.CanonicalHeaderKey("User-Agent"): []string{params.UserAgent}})), - }) - } else { - awsConfig.WithHTTPClient(&http.Client{ - Transport: transport.NewTransport(httpTransport), - }) + if params.SkipVerify { + httpTransport := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } + awsConfig.WithHTTPClient(&http.Client{ + Transport: httpTransport, + }) } - sess, err := session.NewSession(awsConfig) + sessionOptions := session.Options{ + Config: *awsConfig, + } + if params.CredentialsConfigPath != "" { + sessionOptions.SharedConfigState = session.SharedConfigEnable + sessionOptions.SharedConfigFiles = []string{ + params.CredentialsConfigPath, + } + } + sess, err := session.NewSessionWithOptions(sessionOptions) if err != nil { return nil, fmt.Errorf("failed to create new session with aws config: %v", err) } + + if params.UserAgent != "" { + sess.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler(params.UserAgent)) + } + s3obj := s3.New(sess) // enable S3 compatible signature v2 signing instead @@ -606,7 +643,7 @@ func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) { if err != nil { return nil, err } - return ioutil.ReadAll(reader) + return io.ReadAll(reader) } // PutContent stores the []byte content at a location designated by "path". @@ -632,10 +669,9 @@ func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.Read Key: aws.String(d.s3Path(path)), Range: aws.String("bytes=" + strconv.FormatInt(offset, 10) + "-"), }) - if err != nil { if s3Err, ok := err.(awserr.Error); ok && s3Err.Code() == "InvalidRange" { - return ioutil.NopCloser(bytes.NewReader(nil)), nil + return io.NopCloser(bytes.NewReader(nil)), nil } return nil, parseError(path, err) @@ -696,7 +732,7 @@ func (d *driver) Writer(ctx context.Context, path string, appendParam bool) (sto return nil, parseError(path, err) } allParts = append(allParts, partsList.Parts...) - for *resp.IsTruncated { + for *partsList.IsTruncated { partsList, err = d.S3.ListParts(&s3.ListPartsInput{ Bucket: aws.String(d.Bucket), Key: aws.String(key), @@ -923,54 +959,71 @@ func (d *driver) copy(ctx context.Context, sourcePath string, destPath string) e return err } -func min(a, b int) int { - if a < b { - return a - } - return b -} - // Delete recursively deletes all objects stored at "path" and its subpaths. // We must be careful since S3 does not guarantee read after delete consistency func (d *driver) Delete(ctx context.Context, path string) error { s3Objects := make([]*s3.ObjectIdentifier, 0, listMax) - - // manually add the given path if it's a file - stat, err := d.Stat(ctx, path) - if err != nil { - return err - } - if stat != nil && !stat.IsDir() { - path := d.s3Path(path) - s3Objects = append(s3Objects, &s3.ObjectIdentifier{ - Key: &path, - }) - } - - // list objects under the given path as a subpath (suffix with slash "/") - s3Path := d.s3Path(path) + "/" + s3Path := d.s3Path(path) listObjectsInput := &s3.ListObjectsV2Input{ Bucket: aws.String(d.Bucket), Prefix: aws.String(s3Path), } -ListLoop: + for { // list all the objects resp, err := d.S3.ListObjectsV2(listObjectsInput) // resp.Contents can only be empty on the first call // if there were no more results to return after the first call, resp.IsTruncated would have been false - // and the loop would be exited without recalling ListObjects + // and the loop would exit without recalling ListObjects if err != nil || len(resp.Contents) == 0 { - break ListLoop + return storagedriver.PathNotFoundError{Path: path} } for _, key := range resp.Contents { + // Skip if we encounter a key that is not a subpath (so that deleting "/a" does not delete "/ab"). + if len(*key.Key) > len(s3Path) && (*key.Key)[len(s3Path)] != '/' { + continue + } s3Objects = append(s3Objects, &s3.ObjectIdentifier{ Key: key.Key, }) } + // Delete objects only if the list is not empty, otherwise S3 API returns a cryptic error + if len(s3Objects) > 0 { + // NOTE: according to AWS docs https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html + // by default the response returns up to 1,000 key names. The response _might_ contain fewer keys but it will never contain more. + // 10000 keys is coincidentally (?) also the max number of keys that can be deleted in a single Delete operation, so we'll just smack + // Delete here straight away and reset the object slice when successful. + resp, err := d.S3.DeleteObjects(&s3.DeleteObjectsInput{ + Bucket: aws.String(d.Bucket), + Delete: &s3.Delete{ + Objects: s3Objects, + Quiet: aws.Bool(false), + }, + }) + if err != nil { + return err + } + + if len(resp.Errors) > 0 { + // NOTE: AWS SDK s3.Error does not implement error interface which + // is pretty intensely sad, so we have to do away with this for now. + errs := make([]error, 0, len(resp.Errors)) + for _, err := range resp.Errors { + errs = append(errs, errors.New(err.String())) + } + return storagedriver.Errors{ + DriverName: driverName, + Errs: errs, + } + } + } + // NOTE: we don't want to reallocate + // the slice so we simply "reset" it + s3Objects = s3Objects[:0] + // resp.Contents must have at least one element or we would have returned not found listObjectsInput.StartAfter = resp.Contents[len(resp.Contents)-1].Key @@ -981,35 +1034,17 @@ ListLoop: } } - total := len(s3Objects) - if total == 0 { - return storagedriver.PathNotFoundError{Path: path} - } - - // need to chunk objects into groups of 1000 per s3 restrictions - for i := 0; i < total; i += 1000 { - _, err := d.S3.DeleteObjects(&s3.DeleteObjectsInput{ - Bucket: aws.String(d.Bucket), - Delete: &s3.Delete{ - Objects: s3Objects[i:min(i+1000, total)], - Quiet: aws.Bool(false), - }, - }) - if err != nil { - return err - } - } return nil } // URLFor returns a URL which may be used to retrieve the content stored at the given path. // May return an UnsupportedMethodErr in certain StorageDriver implementations. func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { - methodString := "GET" + methodString := http.MethodGet method, ok := options["method"] if ok { methodString, ok = method.(string) - if !ok || (methodString != "GET" && methodString != "HEAD") { + if !ok || (methodString != http.MethodGet && methodString != http.MethodHead) { return "", storagedriver.ErrUnsupportedMethod{} } } @@ -1026,12 +1061,12 @@ func (d *driver) URLFor(ctx context.Context, path string, options map[string]int var req *request.Request switch methodString { - case "GET": + case http.MethodGet: req, _ = d.S3.GetObjectRequest(&s3.GetObjectInput{ Bucket: aws.String(d.Bucket), Key: aws.String(d.s3Path(path)), }) - case "HEAD": + case http.MethodHead: req, _ = d.S3.HeadObjectRequest(&s3.HeadObjectInput{ Bucket: aws.String(d.Bucket), Key: aws.String(d.s3Path(path)), @@ -1077,7 +1112,7 @@ func (d *driver) doWalk(parentCtx context.Context, objectCount *int64, path, pre // the most recent skip directory to avoid walking over undesirable files prevSkipDir string ) - prevDir = prefix + path + prevDir = strings.Replace(path, d.s3Path(""), prefix, 1) listObjectsInput := &s3.ListObjectsV2Input{ Bucket: aws.String(d.Bucket), @@ -1166,16 +1201,22 @@ func (d *driver) doWalk(parentCtx context.Context, objectCount *int64, path, pre // directoryDiff finds all directories that are not in common between // the previous and current paths in sorted order. // -// Eg 1 directoryDiff("/path/to/folder", "/path/to/folder/folder/file") -// => [ "/path/to/folder/folder" ], -// Eg 2 directoryDiff("/path/to/folder/folder1", "/path/to/folder/folder2/file") -// => [ "/path/to/folder/folder2" ] -// Eg 3 directoryDiff("/path/to/folder/folder1/file", "/path/to/folder/folder2/file") -// => [ "/path/to/folder/folder2" ] -// Eg 4 directoryDiff("/path/to/folder/folder1/file", "/path/to/folder/folder2/folder1/file") -// => [ "/path/to/folder/folder2", "/path/to/folder/folder2/folder1" ] -// Eg 5 directoryDiff("/", "/path/to/folder/folder/file") -// => [ "/path", "/path/to", "/path/to/folder", "/path/to/folder/folder" ], +// # Examples +// +// directoryDiff("/path/to/folder", "/path/to/folder/folder/file") +// // => [ "/path/to/folder/folder" ] +// +// directoryDiff("/path/to/folder/folder1", "/path/to/folder/folder2/file") +// // => [ "/path/to/folder/folder2" ] +// +// directoryDiff("/path/to/folder/folder1/file", "/path/to/folder/folder2/file") +// // => [ "/path/to/folder/folder2" ] +// +// directoryDiff("/path/to/folder/folder1/file", "/path/to/folder/folder2/folder1/file") +// // => [ "/path/to/folder/folder2", "/path/to/folder/folder2/folder1" ] +// +// directoryDiff("/", "/path/to/folder/folder/file") +// // => [ "/path", "/path/to", "/path/to/folder", "/path/to/folder/folder" ] func directoryDiff(prev, current string) []string { var paths []string @@ -1351,7 +1392,7 @@ func (w *writer) Write(p []byte) (int, error) { } defer resp.Body.Close() w.parts = nil - w.readyPart, err = ioutil.ReadAll(resp.Body) + w.readyPart, err = io.ReadAll(resp.Body) if err != nil { return 0, err } @@ -1463,6 +1504,30 @@ func (w *writer) Commit() error { }) } + // This is an edge case when we are trying to upload an empty chunk of data using + // a MultiPart upload. As a result we are trying to complete the MultipartUpload + // with an empty slice of `completedUploadedParts` which will always lead to 400 + // being returned from S3 See: https://docs.aws.amazon.com/sdk-for-go/api/service/s3/#CompletedMultipartUpload + // Solution: we upload an empty i.e. 0 byte part as a single part and then append it + // to the completedUploadedParts slice used to complete the Multipart upload. + if len(w.parts) == 0 { + resp, err := w.driver.S3.UploadPart(&s3.UploadPartInput{ + Bucket: aws.String(w.driver.Bucket), + Key: aws.String(w.key), + PartNumber: aws.Int64(1), + UploadId: aws.String(w.uploadID), + Body: bytes.NewReader(nil), + }) + if err != nil { + return err + } + + completedUploadedParts = append(completedUploadedParts, &s3.CompletedPart{ + ETag: resp.ETag, + PartNumber: aws.Int64(1), + }) + } + sort.Sort(completedUploadedParts) _, err = w.driver.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{ diff --git a/vendor/github.com/distribution/distribution/v3/registry/storage/driver/storagedriver.go b/vendor/github.com/distribution/distribution/v3/registry/storage/driver/storagedriver.go index 9a9b9a8f4e1..d573e6176df 100644 --- a/vendor/github.com/distribution/distribution/v3/registry/storage/driver/storagedriver.go +++ b/vendor/github.com/distribution/distribution/v3/registry/storage/driver/storagedriver.go @@ -17,14 +17,14 @@ type Version string // Major returns the major (primary) component of a version. func (version Version) Major() uint { - majorPart := strings.Split(string(version), ".")[0] + majorPart, _, _ := strings.Cut(string(version), ".") major, _ := strconv.ParseUint(majorPart, 10, 0) return uint(major) } // Minor returns the minor (secondary) component of a version. func (version Version) Minor() uint { - minorPart := strings.Split(string(version), ".")[1] + _, minorPart, _ := strings.Cut(string(version), ".") minor, _ := strconv.ParseUint(minorPart, 10, 0) return uint(minor) } @@ -66,7 +66,7 @@ type StorageDriver interface { Stat(ctx context.Context, path string) (FileInfo, error) // List returns a list of the objects that are direct descendants of the - //given path. + // given path. List(ctx context.Context, path string) ([]string, error) // Move moves an object stored at sourcePath to destPath, removing the @@ -169,3 +169,27 @@ type Error struct { func (err Error) Error() string { return fmt.Sprintf("%s: %s", err.DriverName, err.Enclosed) } + +// Errors provides the envelope for multiple errors +// for use within the storagedriver implementations. +type Errors struct { + DriverName string + Errs []error +} + +var _ error = Errors{} + +func (e Errors) Error() string { + switch len(e.Errs) { + case 0: + return "" + case 1: + return e.Errs[0].Error() + default: + msg := "errors:\n" + for _, err := range e.Errs { + msg += err.Error() + "\n" + } + return msg + } +}