diff --git a/api-put-object-fan-out.go b/api-put-object-fan-out.go index 0ae9142e1..3023b949c 100644 --- a/api-put-object-fan-out.go +++ b/api-put-object-fan-out.go @@ -85,7 +85,10 @@ func (c *Client) PutObjectFanOut(ctx context.Context, bucket string, fanOutData policy.SetEncryption(fanOutReq.SSE) // Set checksum headers if any. - policy.SetChecksum(fanOutReq.Checksum) + err := policy.SetChecksum(fanOutReq.Checksum) + if err != nil { + return nil, err + } url, formData, err := c.PresignedPostPolicy(ctx, policy) if err != nil { diff --git a/functional_tests.go b/functional_tests.go index c10b19f3d..43383d134 100644 --- a/functional_tests.go +++ b/functional_tests.go @@ -160,7 +160,7 @@ func logError(testName, function string, args map[string]interface{}, startTime } else { logFailure(testName, function, args, startTime, alert, message, err) if !isRunOnFail() { - panic(err) + panic(fmt.Sprintf("Test failed with message: %s, err: %v", message, err)) } } } @@ -2032,7 +2032,7 @@ func testPutObjectWithChecksums() { h := test.cs.Hasher() h.Reset() - // Test with Wrong CRC. + // Test with a bad CRC - we haven't called h.Write(b), so this is a checksum of empty data meta[test.cs.Key()] = base64.StdEncoding.EncodeToString(h.Sum(nil)) args["metadata"] = meta args["range"] = "false" @@ -2638,7 +2638,6 @@ func testTrailingChecksums() { test.ChecksumCRC32C = hashMultiPart(b, int(test.PO.PartSize), test.hasher) // Set correct CRC. - // c.TraceOn(os.Stderr) resp, err := c.PutObject(context.Background(), bucketName, objectName, bytes.NewReader(b), int64(bufSize), test.PO) if err != nil { logError(testName, function, args, startTime, "", "PutObject failed", err) @@ -2690,6 +2689,8 @@ func testTrailingChecksums() { delete(args, "metadata") } + + logSuccess(testName, function, args, startTime) } // Test PutObject with custom checksums. @@ -5146,50 +5147,22 @@ func testPresignedPostPolicy() { return } - // Save the data - _, err = c.PutObject(context.Background(), bucketName, objectName, bytes.NewReader(buf), int64(len(buf)), minio.PutObjectOptions{ContentType: "binary/octet-stream"}) - if err != nil { - logError(testName, function, args, startTime, "", "PutObject failed", err) - return - } - policy := minio.NewPostPolicy() - - if err := policy.SetBucket(""); err == nil { - logError(testName, function, args, startTime, "", "SetBucket did not fail for invalid conditions", err) - return - } - if err := policy.SetKey(""); err == nil { - logError(testName, function, args, startTime, "", "SetKey did not fail for invalid conditions", err) - return - } - if err := policy.SetExpires(time.Date(1, time.January, 1, 0, 0, 0, 0, time.UTC)); err == nil { - logError(testName, function, args, startTime, "", "SetExpires did not fail for invalid conditions", err) - return - } - if err := policy.SetContentType(""); err == nil { - logError(testName, function, args, startTime, "", "SetContentType did not fail for invalid conditions", err) - return - } - if err := policy.SetContentLengthRange(1024*1024, 1024); err == nil { - logError(testName, function, args, startTime, "", "SetContentLengthRange did not fail for invalid conditions", err) - return - } - if err := policy.SetUserMetadata("", ""); err == nil { - logError(testName, function, args, startTime, "", "SetUserMetadata did not fail for invalid conditions", err) - return - } - policy.SetBucket(bucketName) policy.SetKey(objectName) policy.SetExpires(time.Now().UTC().AddDate(0, 0, 10)) // expires in 10 days policy.SetContentType("binary/octet-stream") policy.SetContentLengthRange(10, 1024*1024) policy.SetUserMetadata(metadataKey, metadataValue) + policy.SetContentEncoding("gzip") // Add CRC32C checksum := minio.ChecksumCRC32C.ChecksumBytes(buf) - policy.SetChecksum(checksum) + err = policy.SetChecksum(checksum) + if err != nil { + logError(testName, function, args, startTime, "", "SetChecksum failed", err) + return + } args["policy"] = policy.String() @@ -5285,7 +5258,7 @@ func testPresignedPostPolicy() { expectedLocation := scheme + os.Getenv(serverEndpoint) + "/" + bucketName + "/" + objectName expectedLocationBucketDNS := scheme + bucketName + "." + os.Getenv(serverEndpoint) + "/" + objectName - if !strings.Contains(expectedLocation, "s3.amazonaws.com/") { + if !strings.Contains(expectedLocation, ".amazonaws.com/") { // Test when not against AWS S3. if val, ok := res.Header["Location"]; ok { if val[0] != expectedLocation && val[0] != expectedLocationBucketDNS { @@ -5297,9 +5270,194 @@ func testPresignedPostPolicy() { return } } - want := checksum.Encoded() - if got := res.Header.Get("X-Amz-Checksum-Crc32c"); got != want { - logError(testName, function, args, startTime, "", fmt.Sprintf("Want checksum %q, got %q", want, got), nil) + wantChecksumCrc32c := checksum.Encoded() + if got := res.Header.Get("X-Amz-Checksum-Crc32c"); got != wantChecksumCrc32c { + logError(testName, function, args, startTime, "", fmt.Sprintf("Want checksum %q, got %q", wantChecksumCrc32c, got), nil) + return + } + + // Ensure that when we subsequently GetObject, the checksum is returned + gopts := minio.GetObjectOptions{Checksum: true} + r, err := c.GetObject(context.Background(), bucketName, objectName, gopts) + if err != nil { + logError(testName, function, args, startTime, "", "GetObject failed", err) + return + } + st, err := r.Stat() + if err != nil { + logError(testName, function, args, startTime, "", "Stat failed", err) + return + } + if st.ChecksumCRC32C != wantChecksumCrc32c { + logError(testName, function, args, startTime, "", fmt.Sprintf("Want checksum %s, got %s", wantChecksumCrc32c, st.ChecksumCRC32C), nil) + return + } + + logSuccess(testName, function, args, startTime) +} + +// testPresignedPostPolicyWrongFile tests that when we have a policy with a checksum, we cannot POST the wrong file +func testPresignedPostPolicyWrongFile() { + // initialize logging params + startTime := time.Now() + testName := getFuncName() + function := "PresignedPostPolicy(policy)" + args := map[string]interface{}{ + "policy": "", + } + + c, err := NewClient(ClientConfig{}) + if err != nil { + logError(testName, function, args, startTime, "", "MinIO client object creation failed", err) + return + } + + // Generate a new random bucket name. + bucketName := randString(60, rand.NewSource(time.Now().UnixNano()), "minio-go-test-") + + // Make a new bucket in 'us-east-1' (source bucket). + err = c.MakeBucket(context.Background(), bucketName, minio.MakeBucketOptions{Region: "us-east-1"}) + if err != nil { + logError(testName, function, args, startTime, "", "MakeBucket failed", err) + return + } + + defer cleanupBucket(bucketName, c) + + // Generate 33K of data. + reader := getDataReader("datafile-33-kB") + defer reader.Close() + + objectName := randString(60, rand.NewSource(time.Now().UnixNano()), "") + // Azure requires the key to not start with a number + metadataKey := randString(60, rand.NewSource(time.Now().UnixNano()), "user") + metadataValue := randString(60, rand.NewSource(time.Now().UnixNano()), "") + + buf, err := io.ReadAll(reader) + if err != nil { + logError(testName, function, args, startTime, "", "ReadAll failed", err) + return + } + + policy := minio.NewPostPolicy() + policy.SetBucket(bucketName) + policy.SetKey(objectName) + policy.SetExpires(time.Now().UTC().AddDate(0, 0, 10)) // expires in 10 days + policy.SetContentType("binary/octet-stream") + policy.SetContentLengthRange(10, 1024*1024) + policy.SetUserMetadata(metadataKey, metadataValue) + + // Add CRC32C of the 33kB file that the policy will explicitly allow. + checksum := minio.ChecksumCRC32C.ChecksumBytes(buf) + err = policy.SetChecksum(checksum) + if err != nil { + logError(testName, function, args, startTime, "", "SetChecksum failed", err) + return + } + + args["policy"] = policy.String() + + presignedPostPolicyURL, formData, err := c.PresignedPostPolicy(context.Background(), policy) + if err != nil { + logError(testName, function, args, startTime, "", "PresignedPostPolicy failed", err) + return + } + + // At this stage, we have a policy that allows us to upload datafile-33-kB. + // Test that uploading datafile-10-kB, with a different checksum, fails as expected + filePath := getMintDataDirFilePath("datafile-10-kB") + if filePath == "" { + // Make a temp file with 10 KB data. + file, err := os.CreateTemp(os.TempDir(), "PresignedPostPolicyTest") + if err != nil { + logError(testName, function, args, startTime, "", "TempFile creation failed", err) + return + } + if _, err = io.Copy(file, getDataReader("datafile-10-kB")); err != nil { + logError(testName, function, args, startTime, "", "Copy failed", err) + return + } + if err = file.Close(); err != nil { + logError(testName, function, args, startTime, "", "File Close failed", err) + return + } + filePath = file.Name() + } + fileReader := getDataReader("datafile-10-kB") + defer fileReader.Close() + buf10k, err := io.ReadAll(fileReader) + if err != nil { + logError(testName, function, args, startTime, "", "ReadAll failed", err) + return + } + otherChecksum := minio.ChecksumCRC32C.ChecksumBytes(buf10k) + + var formBuf bytes.Buffer + writer := multipart.NewWriter(&formBuf) + for k, v := range formData { + if k == "x-amz-checksum-crc32c" { + v = otherChecksum.Encoded() + } + writer.WriteField(k, v) + } + + // Add file to post request + f, err := os.Open(filePath) + defer f.Close() + if err != nil { + logError(testName, function, args, startTime, "", "File open failed", err) + return + } + w, err := writer.CreateFormFile("file", filePath) + if err != nil { + logError(testName, function, args, startTime, "", "CreateFormFile failed", err) + return + } + _, err = io.Copy(w, f) + if err != nil { + logError(testName, function, args, startTime, "", "Copy failed", err) + return + } + writer.Close() + + httpClient := &http.Client{ + Timeout: 30 * time.Second, + Transport: createHTTPTransport(), + } + args["url"] = presignedPostPolicyURL.String() + + req, err := http.NewRequest(http.MethodPost, presignedPostPolicyURL.String(), bytes.NewReader(formBuf.Bytes())) + if err != nil { + logError(testName, function, args, startTime, "", "HTTP request failed", err) + return + } + + req.Header.Set("Content-Type", writer.FormDataContentType()) + + // Make the POST request with the form data. + res, err := httpClient.Do(req) + if err != nil { + logError(testName, function, args, startTime, "", "HTTP request failed", err) + return + } + defer res.Body.Close() + if res.StatusCode != http.StatusForbidden { + logError(testName, function, args, startTime, "", "HTTP request unexpected status", errors.New(res.Status)) + return + } + + // Read the response body, ensure it has checksum failure message + resBody, err := io.ReadAll(res.Body) + if err != nil { + logError(testName, function, args, startTime, "", "ReadAll failed", err) + return + } + + // Normalize the response body, because S3 uses quotes around the policy condition components + // in the error message, MinIO does not. + resBodyStr := strings.ReplaceAll(string(resBody), `"`, "") + if !strings.Contains(resBodyStr, "Policy Condition failed: [eq, $x-amz-checksum-crc32c, aHnJMw==]") { + logError(testName, function, args, startTime, "", "Unexpected response body", errors.New(resBodyStr)) return } @@ -8581,7 +8739,7 @@ func testEncryptedCopyObjectWrapper(c *minio.Client, bucketName string, sseSrc, dstEncryption = sseDst } // 3. get copied object and check if content is equal - coreClient := minio.Core{c} + coreClient := minio.Core{Client: c} reader, _, _, err := coreClient.GetObject(context.Background(), bucketName, "dstObject", minio.GetObjectOptions{ServerSideEncryption: dstEncryption}) if err != nil { logError(testName, function, args, startTime, "", "GetObject failed", err) @@ -8697,7 +8855,6 @@ func testUnencryptedToSSECCopyObject() { bucketName := randString(60, rand.NewSource(time.Now().UnixNano()), "minio-go-test-") sseDst := encrypt.DefaultPBKDF([]byte("correct horse battery staple"), []byte(bucketName+"dstObject")) - // c.TraceOn(os.Stderr) testEncryptedCopyObjectWrapper(c, bucketName, nil, sseDst) } @@ -8719,7 +8876,6 @@ func testUnencryptedToSSES3CopyObject() { var sseSrc encrypt.ServerSide sseDst := encrypt.NewSSE() - // c.TraceOn(os.Stderr) testEncryptedCopyObjectWrapper(c, bucketName, sseSrc, sseDst) } @@ -8740,7 +8896,6 @@ func testUnencryptedToUnencryptedCopyObject() { bucketName := randString(60, rand.NewSource(time.Now().UnixNano()), "minio-go-test-") var sseSrc, sseDst encrypt.ServerSide - // c.TraceOn(os.Stderr) testEncryptedCopyObjectWrapper(c, bucketName, sseSrc, sseDst) } @@ -8762,7 +8917,6 @@ func testEncryptedSSECToSSECCopyObject() { sseSrc := encrypt.DefaultPBKDF([]byte("correct horse battery staple"), []byte(bucketName+"srcObject")) sseDst := encrypt.DefaultPBKDF([]byte("correct horse battery staple"), []byte(bucketName+"dstObject")) - // c.TraceOn(os.Stderr) testEncryptedCopyObjectWrapper(c, bucketName, sseSrc, sseDst) } @@ -8784,7 +8938,6 @@ func testEncryptedSSECToSSES3CopyObject() { sseSrc := encrypt.DefaultPBKDF([]byte("correct horse battery staple"), []byte(bucketName+"srcObject")) sseDst := encrypt.NewSSE() - // c.TraceOn(os.Stderr) testEncryptedCopyObjectWrapper(c, bucketName, sseSrc, sseDst) } @@ -8806,7 +8959,6 @@ func testEncryptedSSECToUnencryptedCopyObject() { sseSrc := encrypt.DefaultPBKDF([]byte("correct horse battery staple"), []byte(bucketName+"srcObject")) var sseDst encrypt.ServerSide - // c.TraceOn(os.Stderr) testEncryptedCopyObjectWrapper(c, bucketName, sseSrc, sseDst) } @@ -8828,7 +8980,6 @@ func testEncryptedSSES3ToSSECCopyObject() { sseSrc := encrypt.NewSSE() sseDst := encrypt.DefaultPBKDF([]byte("correct horse battery staple"), []byte(bucketName+"dstObject")) - // c.TraceOn(os.Stderr) testEncryptedCopyObjectWrapper(c, bucketName, sseSrc, sseDst) } @@ -8850,7 +9001,6 @@ func testEncryptedSSES3ToSSES3CopyObject() { sseSrc := encrypt.NewSSE() sseDst := encrypt.NewSSE() - // c.TraceOn(os.Stderr) testEncryptedCopyObjectWrapper(c, bucketName, sseSrc, sseDst) } @@ -8872,7 +9022,6 @@ func testEncryptedSSES3ToUnencryptedCopyObject() { sseSrc := encrypt.NewSSE() var sseDst encrypt.ServerSide - // c.TraceOn(os.Stderr) testEncryptedCopyObjectWrapper(c, bucketName, sseSrc, sseDst) } @@ -8894,7 +9043,6 @@ func testEncryptedCopyObjectV2() { sseSrc := encrypt.DefaultPBKDF([]byte("correct horse battery staple"), []byte(bucketName+"srcObject")) sseDst := encrypt.DefaultPBKDF([]byte("correct horse battery staple"), []byte(bucketName+"dstObject")) - // c.TraceOn(os.Stderr) testEncryptedCopyObjectWrapper(c, bucketName, sseSrc, sseDst) } @@ -10619,7 +10767,6 @@ func testUserMetadataCopying() { return } - // c.TraceOn(os.Stderr) testUserMetadataCopyingWrapper(c) } @@ -10790,7 +10937,6 @@ func testUserMetadataCopyingV2() { return } - // c.TraceOn(os.Stderr) testUserMetadataCopyingWrapper(c) } @@ -13637,6 +13783,7 @@ func main() { testGetObjectReadAtFunctional() testGetObjectReadAtWhenEOFWasReached() testPresignedPostPolicy() + testPresignedPostPolicyWrongFile() testCopyObject() testComposeObjectErrorCases() testCompose10KSources() diff --git a/post-policy.go b/post-policy.go index b5414af29..26bf441b5 100644 --- a/post-policy.go +++ b/post-policy.go @@ -85,7 +85,7 @@ func (p *PostPolicy) SetExpires(t time.Time) error { // SetKey - Sets an object name for the policy based upload. func (p *PostPolicy) SetKey(key string) error { - if strings.TrimSpace(key) == "" || key == "" { + if strings.TrimSpace(key) == "" { return errInvalidArgument("Object name is empty.") } policyCond := policyCondition{ @@ -118,7 +118,7 @@ func (p *PostPolicy) SetKeyStartsWith(keyStartsWith string) error { // SetBucket - Sets bucket at which objects will be uploaded to. func (p *PostPolicy) SetBucket(bucketName string) error { - if strings.TrimSpace(bucketName) == "" || bucketName == "" { + if strings.TrimSpace(bucketName) == "" { return errInvalidArgument("Bucket name is empty.") } policyCond := policyCondition{ @@ -135,7 +135,7 @@ func (p *PostPolicy) SetBucket(bucketName string) error { // SetCondition - Sets condition for credentials, date and algorithm func (p *PostPolicy) SetCondition(matchType, condition, value string) error { - if strings.TrimSpace(value) == "" || value == "" { + if strings.TrimSpace(value) == "" { return errInvalidArgument("No value specified for condition") } @@ -156,7 +156,7 @@ func (p *PostPolicy) SetCondition(matchType, condition, value string) error { // SetTagging - Sets tagging for the object for this policy based upload. func (p *PostPolicy) SetTagging(tagging string) error { - if strings.TrimSpace(tagging) == "" || tagging == "" { + if strings.TrimSpace(tagging) == "" { return errInvalidArgument("No tagging specified.") } _, err := tags.ParseObjectXML(strings.NewReader(tagging)) @@ -178,7 +178,7 @@ func (p *PostPolicy) SetTagging(tagging string) error { // SetContentType - Sets content-type of the object for this policy // based upload. func (p *PostPolicy) SetContentType(contentType string) error { - if strings.TrimSpace(contentType) == "" || contentType == "" { + if strings.TrimSpace(contentType) == "" { return errInvalidArgument("No content type specified.") } policyCond := policyCondition{ @@ -211,7 +211,7 @@ func (p *PostPolicy) SetContentTypeStartsWith(contentTypeStartsWith string) erro // SetContentDisposition - Sets content-disposition of the object for this policy func (p *PostPolicy) SetContentDisposition(contentDisposition string) error { - if strings.TrimSpace(contentDisposition) == "" || contentDisposition == "" { + if strings.TrimSpace(contentDisposition) == "" { return errInvalidArgument("No content disposition specified.") } policyCond := policyCondition{ @@ -226,6 +226,23 @@ func (p *PostPolicy) SetContentDisposition(contentDisposition string) error { return nil } +// SetContentEncoding - Sets content-encoding of the object for this policy +func (p *PostPolicy) SetContentEncoding(contentEncoding string) error { + if strings.TrimSpace(contentEncoding) == "" { + return errInvalidArgument("No content encoding specified.") + } + policyCond := policyCondition{ + matchType: "eq", + condition: "$Content-Encoding", + value: contentEncoding, + } + if err := p.addNewPolicy(policyCond); err != nil { + return err + } + p.formData["Content-Encoding"] = contentEncoding + return nil +} + // SetContentLengthRange - Set new min and max content length // condition for all incoming uploads. func (p *PostPolicy) SetContentLengthRange(minLen, maxLen int64) error { @@ -246,7 +263,7 @@ func (p *PostPolicy) SetContentLengthRange(minLen, maxLen int64) error { // SetSuccessActionRedirect - Sets the redirect success url of the object for this policy // based upload. func (p *PostPolicy) SetSuccessActionRedirect(redirect string) error { - if strings.TrimSpace(redirect) == "" || redirect == "" { + if strings.TrimSpace(redirect) == "" { return errInvalidArgument("Redirect is empty") } policyCond := policyCondition{ @@ -264,7 +281,7 @@ func (p *PostPolicy) SetSuccessActionRedirect(redirect string) error { // SetSuccessStatusAction - Sets the status success code of the object for this policy // based upload. func (p *PostPolicy) SetSuccessStatusAction(status string) error { - if strings.TrimSpace(status) == "" || status == "" { + if strings.TrimSpace(status) == "" { return errInvalidArgument("Status is empty") } policyCond := policyCondition{ @@ -282,10 +299,10 @@ func (p *PostPolicy) SetSuccessStatusAction(status string) error { // SetUserMetadata - Set user metadata as a key/value couple. // Can be retrieved through a HEAD request or an event. func (p *PostPolicy) SetUserMetadata(key, value string) error { - if strings.TrimSpace(key) == "" || key == "" { + if strings.TrimSpace(key) == "" { return errInvalidArgument("Key is empty") } - if strings.TrimSpace(value) == "" || value == "" { + if strings.TrimSpace(value) == "" { return errInvalidArgument("Value is empty") } headerName := fmt.Sprintf("x-amz-meta-%s", key) @@ -304,7 +321,7 @@ func (p *PostPolicy) SetUserMetadata(key, value string) error { // SetUserMetadataStartsWith - Set how an user metadata should starts with. // Can be retrieved through a HEAD request or an event. func (p *PostPolicy) SetUserMetadataStartsWith(key, value string) error { - if strings.TrimSpace(key) == "" || key == "" { + if strings.TrimSpace(key) == "" { return errInvalidArgument("Key is empty") } headerName := fmt.Sprintf("x-amz-meta-%s", key) @@ -321,11 +338,29 @@ func (p *PostPolicy) SetUserMetadataStartsWith(key, value string) error { } // SetChecksum sets the checksum of the request. -func (p *PostPolicy) SetChecksum(c Checksum) { +func (p *PostPolicy) SetChecksum(c Checksum) error { if c.IsSet() { p.formData[amzChecksumAlgo] = c.Type.String() p.formData[c.Type.Key()] = c.Encoded() + + policyCond := policyCondition{ + matchType: "eq", + condition: fmt.Sprintf("$%s", amzChecksumAlgo), + value: c.Type.String(), + } + if err := p.addNewPolicy(policyCond); err != nil { + return err + } + policyCond = policyCondition{ + matchType: "eq", + condition: fmt.Sprintf("$%s", c.Type.Key()), + value: c.Encoded(), + } + if err := p.addNewPolicy(policyCond); err != nil { + return err + } } + return nil } // SetEncryption - sets encryption headers for POST API diff --git a/post-policy_test.go b/post-policy_test.go new file mode 100644 index 000000000..c105e053c --- /dev/null +++ b/post-policy_test.go @@ -0,0 +1,459 @@ +/* + * MinIO Go Library for Amazon S3 Compatible Cloud Storage + * Copyright 2015-2023 MinIO, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package minio + +import ( + "strings" + "testing" + "time" + + "github.com/minio/minio-go/v7/pkg/encrypt" +) + +func TestPostPolicySetExpires(t *testing.T) { + tests := []struct { + name string + input time.Time + wantErr bool + wantResult string + }{ + { + name: "valid time", + input: time.Date(2023, time.March, 2, 15, 4, 5, 0, time.UTC), + wantErr: false, + wantResult: "2023-03-02T15:04:05", + }, + { + name: "time before 1970", + input: time.Date(1, time.January, 1, 0, 0, 0, 0, time.UTC), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + err := pp.SetExpires(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err) + } + + if tt.wantResult != "" { + result := pp.String() + if !strings.Contains(result, tt.wantResult) { + t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result) + } + } + }) + } +} + +func TestPostPolicySetKey(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + wantResult string + }{ + { + name: "valid key", + input: "my-object", + wantResult: `"eq","$key","my-object"`, + }, + { + name: "empty key", + input: "", + wantErr: true, + }, + { + name: "key with spaces", + input: " ", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + err := pp.SetKey(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err) + } + + if tt.wantResult != "" { + result := pp.String() + if !strings.Contains(result, tt.wantResult) { + t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result) + } + } + }) + } +} + +func TestPostPolicySetKeyStartsWith(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "valid key prefix", + input: "my-prefix/", + want: `["starts-with","$key","my-prefix/"]`, + }, + { + name: "empty prefix (allow any key)", + input: "", + want: `["starts-with","$key",""]`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + err := pp.SetKeyStartsWith(tt.input) + if err != nil { + t.Errorf("%s: want no error, got: %v", tt.name, err) + } + + if tt.want != "" { + result := pp.String() + if !strings.Contains(result, tt.want) { + t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.want, result) + } + } + }) + } +} + +func TestPostPolicySetBucket(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + wantResult string + }{ + { + name: "valid bucket", + input: "my-bucket", + wantResult: `"eq","$bucket","my-bucket"`, + }, + { + name: "empty bucket", + input: "", + wantErr: true, + }, + { + name: "bucket with spaces", + input: " ", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + err := pp.SetBucket(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err) + } + + if tt.wantResult != "" { + result := pp.String() + if !strings.Contains(result, tt.wantResult) { + t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result) + } + } + }) + } +} + +func TestPostPolicySetCondition(t *testing.T) { + tests := []struct { + name string + matchType string + condition string + value string + wantErr bool + wantResult string + }{ + { + name: "valid eq condition", + matchType: "eq", + condition: "X-Amz-Date", + value: "20210324T000000Z", + wantResult: `"eq","$X-Amz-Date","20210324T000000Z"`, + }, + { + name: "empty value", + matchType: "eq", + condition: "X-Amz-Date", + value: "", + wantErr: true, + }, + { + name: "invalid condition", + matchType: "eq", + condition: "Invalid-Condition", + value: "somevalue", + wantErr: true, + }, + { + name: "valid starts-with condition", + matchType: "starts-with", + condition: "X-Amz-Credential", + value: "my-access-key", + wantResult: `"starts-with","$X-Amz-Credential","my-access-key"`, + }, + { + name: "empty condition", + matchType: "eq", + condition: "", + value: "somevalue", + wantErr: true, + }, + { + name: "empty matchType", + matchType: "", + condition: "X-Amz-Date", + value: "somevalue", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + err := pp.SetCondition(tt.matchType, tt.condition, tt.value) + if (err != nil) != tt.wantErr { + t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err) + } + + if tt.wantResult != "" { + result := pp.String() + if !strings.Contains(result, tt.wantResult) { + t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result) + } + } + }) + } +} + +func TestPostPolicySetTagging(t *testing.T) { + tests := []struct { + name string + tagging string + wantErr bool + wantResult string + }{ + { + name: "valid tagging", + tagging: `key1value1`, + wantResult: `"eq","$tagging","key1value1"`, + }, + { + name: "empty tagging", + tagging: "", + wantErr: true, + }, + { + name: "whitespace tagging", + tagging: " ", + wantErr: true, + }, + { + name: "invalid XML", + tagging: `key1value1`, + wantErr: true, + }, + { + name: "invalid schema", + tagging: ``, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + err := pp.SetTagging(tt.tagging) + if (err != nil) != tt.wantErr { + t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err) + } + + if tt.wantResult != "" { + result := pp.String() + if !strings.Contains(result, tt.wantResult) { + t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result) + } + } + }) + } +} + +func TestPostPolicySetUserMetadata(t *testing.T) { + tests := []struct { + name string + key string + value string + wantErr bool + wantResult string + }{ + { + name: "valid metadata", + key: "user-key", + value: "user-value", + wantResult: `"eq","$x-amz-meta-user-key","user-value"`, + }, + { + name: "empty key", + key: "", + value: "somevalue", + wantErr: true, + }, + { + name: "empty value", + key: "user-key", + value: "", + wantErr: true, + }, + { + name: "key with spaces", + key: " ", + value: "somevalue", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + err := pp.SetUserMetadata(tt.key, tt.value) + if (err != nil) != tt.wantErr { + t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err) + } + + if tt.wantResult != "" { + result := pp.String() + if !strings.Contains(result, tt.wantResult) { + t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result) + } + } + }) + } +} + +func TestPostPolicySetChecksum(t *testing.T) { + tests := []struct { + name string + checksum Checksum + wantErr bool + wantResult string + }{ + { + name: "valid checksum SHA256", + checksum: ChecksumSHA256.ChecksumBytes([]byte("somerandomdata")), + wantResult: `[["eq","$x-amz-checksum-algorithm","SHA256"],["eq","$x-amz-checksum-sha256","29/7Qm/iMzZ1O3zMbO0luv6mYWyS6JIqPYV9lc8w1PA="]]`, + }, + { + name: "valid checksum CRC32", + checksum: ChecksumCRC32.ChecksumBytes([]byte("somerandomdata")), + wantResult: `[["eq","$x-amz-checksum-algorithm","CRC32"],["eq","$x-amz-checksum-crc32","7sOPnw=="]]`, + }, + { + name: "empty checksum", + checksum: Checksum{}, + wantResult: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + err := pp.SetChecksum(tt.checksum) + if (err != nil) != tt.wantErr { + t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err) + } + + if tt.wantResult != "" { + result := pp.String() + if !strings.Contains(result, tt.wantResult) { + t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result) + } + } + }) + } +} + +func TestPostPolicySetEncryption(t *testing.T) { + tests := []struct { + name string + sseType string + keyID string + want map[string]string + }{ + { + name: "SSE-S3 encryption", + sseType: "SSE-S3", + keyID: "my-key-id", + want: map[string]string{ + "X-Amz-Server-Side-Encryption": "aws:kms", + "X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id": "my-key-id", + }, + }, + { + name: "SSE-C encryption with Key ID", + sseType: "SSE-C", + keyID: "my-key-id", + want: map[string]string{ + "X-Amz-Server-Side-Encryption-Customer-Key": "bXktc2VjcmV0LWtleTEyMzQ1Njc4OTBhYmNkZWZnaGk=", + "X-Amz-Server-Side-Encryption-Customer-Key-Md5": "T1mefJwyXBH43sRtfEgRZQ==", + "X-Amz-Server-Side-Encryption-Customer-Algorithm": "AES256", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + var sse encrypt.ServerSide + var err error + if tt.sseType == "SSE-S3" { + sse, err = encrypt.NewSSEKMS(tt.keyID, nil) + if err != nil { + t.Fatalf("Failed to create SSE-KMS: %v", err) + } + } else if tt.sseType == "SSE-C" { + sse, err = encrypt.NewSSEC([]byte("my-secret-key1234567890abcdefghi")) + if err != nil { + t.Fatalf("Failed to create SSE-C: %v", err) + } + } else { + t.Fatalf("Unknown SSE type: %s", tt.sseType) + } + + pp.SetEncryption(sse) + + for k, v := range tt.want { + if pp.formData[k] != v { + t.Errorf("%s: want %s: %s, got: %s", tt.name, k, v, pp.formData[k]) + } + } + }) + } +} diff --git a/retry-continous.go b/retry-continous.go index 0b92611b8..81fcf16f1 100644 --- a/retry-continous.go +++ b/retry-continous.go @@ -39,7 +39,7 @@ func (c *Client) newRetryTimerContinous(baseSleep, maxSleep time.Duration, jitte if attempt > maxAttempt { attempt = maxAttempt } - // sleep = random_between(0, min(cap, base * 2 ** attempt)) + // sleep = random_between(0, min(maxSleep, base * 2 ** attempt)) sleep := baseSleep * time.Duration(1< maxSleep { sleep = maxSleep diff --git a/retry.go b/retry.go index 15f4dca4f..4cc45920c 100644 --- a/retry.go +++ b/retry.go @@ -59,7 +59,7 @@ func (c *Client) newRetryTimer(ctx context.Context, maxRetry int, baseSleep, max jitter = MaxJitter } - // sleep = random_between(0, min(cap, base * 2 ** attempt)) + // sleep = random_between(0, min(maxSleep, base * 2 ** attempt)) sleep := baseSleep * time.Duration(1< maxSleep { sleep = maxSleep