diff --git a/s3transfer/crt.py b/s3transfer/crt.py index 2954477e..aa541843 100644 --- a/s3transfer/crt.py +++ b/s3transfer/crt.py @@ -40,6 +40,7 @@ from s3transfer.constants import MB from s3transfer.exceptions import TransferNotDoneError from s3transfer.futures import BaseTransferFuture, BaseTransferMeta +from s3transfer.manager import TransferManager from s3transfer.utils import ( CallArgs, OSUtils, @@ -181,6 +182,14 @@ def _get_crt_throughput_target_gbps(provided_throughput_target_bytes=None): class CRTTransferManager: + ALLOWED_DOWNLOAD_ARGS = TransferManager.ALLOWED_DOWNLOAD_ARGS + ALLOWED_UPLOAD_ARGS = TransferManager.ALLOWED_UPLOAD_ARGS + ALLOWED_DELETE_ARGS = TransferManager.ALLOWED_DELETE_ARGS + + VALIDATE_SUPPORTED_BUCKET_VALUES = True + + _UNSUPPORTED_BUCKET_PATTERNS = TransferManager._UNSUPPORTED_BUCKET_PATTERNS + def __init__(self, crt_s3_client, crt_request_serializer, osutil=None): """A transfer manager interface for Amazon S3 on CRT s3 client. @@ -226,6 +235,8 @@ def download( extra_args = {} if subscribers is None: subscribers = {} + self._validate_all_known_args(extra_args, self.ALLOWED_DOWNLOAD_ARGS) + self._validate_if_bucket_supported(bucket) callargs = CallArgs( bucket=bucket, key=key, @@ -240,6 +251,8 @@ def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None): extra_args = {} if subscribers is None: subscribers = {} + self._validate_all_known_args(extra_args, self.ALLOWED_UPLOAD_ARGS) + self._validate_if_bucket_supported(bucket) self._validate_checksum_algorithm_supported(extra_args) callargs = CallArgs( bucket=bucket, @@ -255,6 +268,8 @@ def delete(self, bucket, key, extra_args=None, subscribers=None): extra_args = {} if subscribers is None: subscribers = {} + self._validate_all_known_args(extra_args, self.ALLOWED_DELETE_ARGS) + self._validate_if_bucket_supported(bucket) callargs = CallArgs( bucket=bucket, key=key, @@ -266,6 +281,27 @@ def delete(self, bucket, key, extra_args=None, subscribers=None): def shutdown(self, cancel=False): self._shutdown(cancel) + def _validate_if_bucket_supported(self, bucket): + # s3 high level operations don't support some resources + # (eg. S3 Object Lambda) only direct API calls are available + # for such resources + if self.VALIDATE_SUPPORTED_BUCKET_VALUES: + for resource, pattern in self._UNSUPPORTED_BUCKET_PATTERNS.items(): + match = pattern.match(bucket) + if match: + raise ValueError( + f'TransferManager methods do not support {resource} ' + 'resource. Use direct client calls instead.' + ) + + def _validate_all_known_args(self, actual, allowed): + for kwarg in actual: + if kwarg not in allowed: + raise ValueError( + f"Invalid extra_args key '{kwarg}', " + f"must be one of: {', '.join(allowed)}" + ) + def _validate_checksum_algorithm_supported(self, extra_args): checksum_algorithm = extra_args.get('ChecksumAlgorithm') if checksum_algorithm is None: diff --git a/tests/functional/test_crt.py b/tests/functional/test_crt.py index 352e5854..9a3afa75 100644 --- a/tests/functional/test_crt.py +++ b/tests/functional/test_crt.py @@ -373,6 +373,28 @@ def test_upload_throws_error_for_unsupported_checksum(self): [self.record_subscriber], ) + def test_upload_throws_error_for_unsupported_arg(self): + with self.assertRaisesRegex( + ValueError, "Invalid extra_args key 'ContentMD5'" + ): + self.transfer_manager.upload( + self.filename, + self.bucket, + self.key, + {'ContentMD5': '938c2cc0dcc05f2b68c4287040cfcf71'}, + [self.record_subscriber], + ) + + def test_upload_throws_error_on_s3_object_lambda_resource(self): + s3_object_lambda_arn = ( + 'arn:aws:s3-object-lambda:us-west-2:123456789012:' + 'accesspoint:my-accesspoint' + ) + with self.assertRaisesRegex(ValueError, 'methods do not support'): + self.transfer_manager.upload( + self.filename, s3_object_lambda_arn, self.key + ) + def test_upload_with_s3express(self): future = self.transfer_manager.upload( self.filename, @@ -489,6 +511,18 @@ def test_download_to_nonseekable_stream(self): underlying_stream.getvalue(), self.expected_download_content ) + def test_download_throws_error_for_unsupported_arg(self): + with self.assertRaisesRegex( + ValueError, "Invalid extra_args key 'Range'" + ): + self.transfer_manager.download( + self.bucket, + self.key, + self.filename, + {'Range': 'bytes:0-1023'}, + [self.record_subscriber], + ) + def test_download_with_s3express(self): future = self.transfer_manager.download( self.s3express_bucket, @@ -526,6 +560,17 @@ def test_delete(self): ) self._assert_subscribers_called(future) + def test_delete_throws_error_for_unsupported_arg(self): + with self.assertRaisesRegex( + ValueError, "Invalid extra_args key 'BypassGovernanceRetention'" + ): + self.transfer_manager.delete( + self.bucket, + self.key, + {'BypassGovernanceRetention': True}, + [self.record_subscriber], + ) + def test_delete_with_s3express(self): future = self.transfer_manager.delete( self.s3express_bucket, self.key, {}, [self.record_subscriber]