diff --git a/sky/data/storage.py b/sky/data/storage.py index 5dce3f0a0d8..c3ccb3dfc67 100644 --- a/sky/data/storage.py +++ b/sky/data/storage.py @@ -354,7 +354,8 @@ def from_metadata(cls, metadata: StoreMetadata, **override_args): metadata.is_sky_managed), sync_on_reconstruction=override_args.get('sync_on_reconstruction', True), - # backward compatibility + # Backward compatibility + # TODO: remove the hasattr check after v0.11.0 _bucket_sub_path=override_args.get( '_bucket_sub_path', metadata._bucket_sub_path # pylint: disable=protected-access @@ -1462,6 +1463,8 @@ def batch_aws_rsync(self, set to True, the directory is created in the bucket root and contents are uploaded to it. """ + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') def get_file_sync_command(base_dir_path, file_names): includes = ' '.join([ @@ -1469,8 +1472,6 @@ def get_file_sync_command(base_dir_path, file_names): for file_name in file_names ]) base_dir_path = shlex.quote(base_dir_path) - sub_path = (f'/{self._bucket_sub_path}' - if self._bucket_sub_path else '') sync_command = ('aws s3 sync --no-follow-symlinks --exclude="*" ' f'{includes} {base_dir_path} ' f's3://{self.name}{sub_path}') @@ -1485,8 +1486,6 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): for file_name in excluded_list ]) src_dir_path = shlex.quote(src_dir_path) - sub_path = (f'/{self._bucket_sub_path}' - if self._bucket_sub_path else '') sync_command = (f'aws s3 sync --no-follow-symlinks {excludes} ' f'{src_dir_path} ' f's3://{self.name}{sub_path}/{dest_dir_name}') @@ -1500,7 +1499,7 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): log_path = sky_logging.generate_tmp_logging_file_path( _STORAGE_LOG_FILE_NAME) - sync_path = f'{source_message} -> s3://{self.name}/' + sync_path = f'{source_message} -> s3://{self.name}{sub_path}/' with rich_utils.safe_status( ux_utils.spinner_message(f'Syncing {sync_path}', log_path=log_path)): @@ -1959,11 +1958,13 @@ def batch_gsutil_cp(self, copy_list = '\n'.join( os.path.abspath(os.path.expanduser(p)) for p in source_path_list) gsutil_alias, alias_gen = data_utils.get_gsutil_command() + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = (f'{alias_gen}; echo "{copy_list}" | {gsutil_alias} ' - f'cp -e -n -r -I gs://{self.name}') + f'cp -e -n -r -I gs://{self.name}{sub_path}') log_path = sky_logging.generate_tmp_logging_file_path( _STORAGE_LOG_FILE_NAME) - sync_path = f'{source_message} -> gs://{self.name}/' + sync_path = f'{source_message} -> gs://{self.name}{sub_path}/' with rich_utils.safe_status( ux_utils.spinner_message(f'Syncing {sync_path}', log_path=log_path)): @@ -1995,13 +1996,13 @@ def batch_gsutil_rsync(self, set to True, the directory is created in the bucket root and contents are uploaded to it. """ + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') def get_file_sync_command(base_dir_path, file_names): sync_format = '|'.join(file_names) gsutil_alias, alias_gen = data_utils.get_gsutil_command() base_dir_path = shlex.quote(base_dir_path) - sub_path = (f'/{self._bucket_sub_path}' - if self._bucket_sub_path else '') sync_command = (f'{alias_gen}; {gsutil_alias} ' f'rsync -e -x \'^(?!{sync_format}$).*\' ' f'{base_dir_path} gs://{self.name}{sub_path}') @@ -2014,8 +2015,6 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): excludes = '|'.join(excluded_list) gsutil_alias, alias_gen = data_utils.get_gsutil_command() src_dir_path = shlex.quote(src_dir_path) - sub_path = (f'/{self._bucket_sub_path}' - if self._bucket_sub_path else '') sync_command = (f'{alias_gen}; {gsutil_alias} ' f'rsync -e -r -x \'({excludes})\' {src_dir_path} ' f'gs://{self.name}{sub_path}/{dest_dir_name}') @@ -2029,7 +2028,7 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): log_path = sky_logging.generate_tmp_logging_file_path( _STORAGE_LOG_FILE_NAME) - sync_path = f'{source_message} -> gs://{self.name}/' + sync_path = f'{source_message} -> gs://{self.name}{sub_path}/' with rich_utils.safe_status( ux_utils.spinner_message(f'Syncing {sync_path}', log_path=log_path)): @@ -2307,15 +2306,24 @@ def from_metadata(cls, metadata: AbstractStore.StoreMetadata, An instance of AzureBlobStore. """ assert isinstance(metadata, AzureBlobStore.AzureBlobStoreMetadata) - return cls(name=override_args.get('name', metadata.name), - storage_account_name=override_args.get( - 'storage_account', metadata.storage_account_name), - source=override_args.get('source', metadata.source), - region=override_args.get('region', metadata.region), - is_sky_managed=override_args.get('is_sky_managed', - metadata.is_sky_managed), - sync_on_reconstruction=override_args.get( - 'sync_on_reconstruction', True)) + # TODO: this needs to be kept in sync with the abstract + # AbstractStore.from_metadata. + return cls( + name=override_args.get('name', metadata.name), + storage_account_name=override_args.get( + 'storage_account', metadata.storage_account_name), + source=override_args.get('source', metadata.source), + region=override_args.get('region', metadata.region), + is_sky_managed=override_args.get('is_sky_managed', + metadata.is_sky_managed), + sync_on_reconstruction=override_args.get('sync_on_reconstruction', + True), + # Backward compatibility + # TODO: remove the hasattr check after v0.11.0 + _bucket_sub_path=override_args.get( + '_bucket_sub_path', + metadata._bucket_sub_path # pylint: disable=protected-access + ) if hasattr(metadata, '_bucket_sub_path') else None) def get_metadata(self) -> AzureBlobStoreMetadata: return self.AzureBlobStoreMetadata( @@ -2795,6 +2803,8 @@ def batch_az_blob_sync(self, set to True, the directory is created in the bucket root and contents are uploaded to it. """ + container_path = (f'{self.container_name}/{self._bucket_sub_path}' + if self._bucket_sub_path else self.container_name) def get_file_sync_command(base_dir_path, file_names) -> str: # shlex.quote is not used for file_names as 'az storage blob sync' @@ -2803,8 +2813,6 @@ def get_file_sync_command(base_dir_path, file_names) -> str: includes_list = ';'.join(file_names) includes = f'--include-pattern "{includes_list}"' base_dir_path = shlex.quote(base_dir_path) - container_path = (f'{self.container_name}/{self._bucket_sub_path}' - if self._bucket_sub_path else self.container_name) sync_command = (f'az storage blob sync ' f'--account-name {self.storage_account_name} ' f'--account-key {self.storage_account_key} ' @@ -2822,18 +2830,17 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: [file_name.rstrip('*') for file_name in excluded_list]) excludes = f'--exclude-path "{excludes_list}"' src_dir_path = shlex.quote(src_dir_path) - container_path = (f'{self.container_name}/{self._bucket_sub_path}' - if self._bucket_sub_path else - f'{self.container_name}') if dest_dir_name: - container_path = f'{container_path}/{dest_dir_name}' + dest_dir_name = f'/{dest_dir_name}' + else: + dest_dir_name = '' sync_command = (f'az storage blob sync ' f'--account-name {self.storage_account_name} ' f'--account-key {self.storage_account_key} ' f'{excludes} ' '--delete-destination false ' f'--source {src_dir_path} ' - f'--container {container_path}') + f'--container {container_path}{dest_dir_name}') return sync_command # Generate message for upload @@ -2844,7 +2851,7 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: source_message = source_path_list[0] container_endpoint = data_utils.AZURE_CONTAINER_URL.format( storage_account_name=self.storage_account_name, - container_name=self.name) + container_name=container_path) log_path = sky_logging.generate_tmp_logging_file_path( _STORAGE_LOG_FILE_NAME) sync_path = f'{source_message} -> {container_endpoint}/' @@ -3238,6 +3245,8 @@ def batch_aws_rsync(self, set to True, the directory is created in the bucket root and contents are uploaded to it. """ + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') def get_file_sync_command(base_dir_path, file_names): includes = ' '.join([ @@ -3246,8 +3255,6 @@ def get_file_sync_command(base_dir_path, file_names): ]) endpoint_url = cloudflare.create_endpoint() base_dir_path = shlex.quote(base_dir_path) - sub_path = (f'/{self._bucket_sub_path}' - if self._bucket_sub_path else '') sync_command = ('AWS_SHARED_CREDENTIALS_FILE=' f'{cloudflare.R2_CREDENTIALS_PATH} ' 'aws s3 sync --no-follow-symlinks --exclude="*" ' @@ -3267,8 +3274,6 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): ]) endpoint_url = cloudflare.create_endpoint() src_dir_path = shlex.quote(src_dir_path) - sub_path = (f'/{self._bucket_sub_path}' - if self._bucket_sub_path else '') sync_command = ('AWS_SHARED_CREDENTIALS_FILE=' f'{cloudflare.R2_CREDENTIALS_PATH} ' f'aws s3 sync --no-follow-symlinks {excludes} ' @@ -3286,7 +3291,7 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): log_path = sky_logging.generate_tmp_logging_file_path( _STORAGE_LOG_FILE_NAME) - sync_path = f'{source_message} -> r2://{self.name}/' + sync_path = f'{source_message} -> r2://{self.name}{sub_path}/' with rich_utils.safe_status( ux_utils.spinner_message(f'Syncing {sync_path}', log_path=log_path)): @@ -3710,6 +3715,8 @@ def batch_ibm_rsync(self, set to True, the directory is created in the bucket root and contents are uploaded to it. """ + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: """returns an rclone command that copies a complete folder @@ -3731,8 +3738,6 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: # .git directory is excluded from the sync # wrapping src_dir_path with "" to support path with spaces src_dir_path = shlex.quote(src_dir_path) - sub_path = (f'/{self._bucket_sub_path}' - if self._bucket_sub_path else '') sync_command = ( 'rclone copy --exclude ".git/*" ' f'{src_dir_path} ' @@ -3763,8 +3768,6 @@ def get_file_sync_command(base_dir_path, file_names) -> str: for file_name in file_names ]) base_dir_path = shlex.quote(base_dir_path) - sub_path = (f'/{self._bucket_sub_path}' - if self._bucket_sub_path else '') sync_command = ( 'rclone copy ' f'{includes} {base_dir_path} ' @@ -3779,7 +3782,8 @@ def get_file_sync_command(base_dir_path, file_names) -> str: log_path = sky_logging.generate_tmp_logging_file_path( _STORAGE_LOG_FILE_NAME) - sync_path = f'{source_message} -> cos://{self.region}/{self.name}/' + sync_path = ( + f'{source_message} -> cos://{self.region}/{self.name}{sub_path}/') with rich_utils.safe_status( ux_utils.spinner_message(f'Syncing {sync_path}', log_path=log_path)): @@ -4178,15 +4182,21 @@ def batch_oci_rsync(self, set to True, the directory is created in the bucket root and contents are uploaded to it. """ + sub_path = (f'{self._bucket_sub_path}/' + if self._bucket_sub_path else '') @oci.with_oci_env def get_file_sync_command(base_dir_path, file_names): includes = ' '.join( [f'--include "{file_name}"' for file_name in file_names]) + prefix_arg = '' + if sub_path: + prefix_arg = f'--object-prefix "{sub_path.strip("/")}"' sync_command = ( 'oci os object bulk-upload --no-follow-symlinks --overwrite ' f'--bucket-name {self.name} --namespace-name {self.namespace} ' f'--region {self.region} --src-dir "{base_dir_path}" ' + f'{prefix_arg} ' f'{includes}') return sync_command @@ -4207,7 +4217,8 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): sync_command = ( 'oci os object bulk-upload --no-follow-symlinks --overwrite ' f'--bucket-name {self.name} --namespace-name {self.namespace} ' - f'--region {self.region} --object-prefix "{dest_dir_name}" ' + f'--region {self.region} ' + f'--object-prefix "{sub_path}{dest_dir_name}" ' f'--src-dir "{src_dir_path}" {excludes}') return sync_command @@ -4220,7 +4231,7 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): log_path = sky_logging.generate_tmp_logging_file_path( _STORAGE_LOG_FILE_NAME) - sync_path = f'{source_message} -> oci://{self.name}/' + sync_path = f'{source_message} -> oci://{self.name}/{sub_path}' with rich_utils.safe_status( ux_utils.spinner_message(f'Syncing {sync_path}', log_path=log_path)): diff --git a/tests/smoke_tests/test_mount_and_storage.py b/tests/smoke_tests/test_mount_and_storage.py index c7f3e356c0d..13c3e118f75 100644 --- a/tests/smoke_tests/test_mount_and_storage.py +++ b/tests/smoke_tests/test_mount_and_storage.py @@ -854,16 +854,18 @@ def yield_storage_object( persistent=persistent, mode=mode, _bucket_sub_path=_bucket_sub_path) - yield storage_obj - handle = global_user_state.get_handle_from_storage_name( - storage_obj.name) - if handle: - # If handle exists, delete manually - # TODO(romilb): This is potentially risky - if the delete method has - # bugs, this can cause resource leaks. Ideally we should manually - # eject storage from global_user_state and delete the bucket using - # boto3 directly. - storage_obj.delete() + try: + yield storage_obj + finally: + handle = global_user_state.get_handle_from_storage_name( + storage_obj.name) + if handle: + # If handle exists, delete manually + # TODO(romilb): This is potentially risky - if the delete method has + # bugs, this can cause resource leaks. Ideally we should manually + # eject storage from global_user_state and delete the bucket using + # boto3 directly. + storage_obj.delete() @pytest.fixture def tmp_scratch_storage_obj(self, tmp_bucket_name): @@ -881,17 +883,19 @@ def tmp_multiple_scratch_storage_obj(self): timestamp = str(time.time()).replace('.', '') store_obj = storage_lib.Storage(name=f'sky-test-{timestamp}') storage_mult_obj.append(store_obj) - yield storage_mult_obj - for storage_obj in storage_mult_obj: - handle = global_user_state.get_handle_from_storage_name( - storage_obj.name) - if handle: - # If handle exists, delete manually - # TODO(romilb): This is potentially risky - if the delete method has - # bugs, this can cause resource leaks. Ideally we should manually - # eject storage from global_user_state and delete the bucket using - # boto3 directly. - storage_obj.delete() + try: + yield storage_mult_obj + finally: + for storage_obj in storage_mult_obj: + handle = global_user_state.get_handle_from_storage_name( + storage_obj.name) + if handle: + # If handle exists, delete manually + # TODO(romilb): This is potentially risky - if the delete method has + # bugs, this can cause resource leaks. Ideally we should manually + # eject storage from global_user_state and delete the bucket using + # boto3 directly. + storage_obj.delete() @pytest.fixture def tmp_multiple_custom_source_storage_obj(self): @@ -907,12 +911,14 @@ def tmp_multiple_custom_source_storage_obj(self): store_obj = storage_lib.Storage(name=f'sky-test-{timestamp}', source=src_path) storage_mult_obj.append(store_obj) - yield storage_mult_obj - for storage_obj in storage_mult_obj: - handle = global_user_state.get_handle_from_storage_name( - storage_obj.name) - if handle: - storage_obj.delete() + try: + yield storage_mult_obj + finally: + for storage_obj in storage_mult_obj: + handle = global_user_state.get_handle_from_storage_name( + storage_obj.name) + if handle: + storage_obj.delete() @pytest.fixture def tmp_local_storage_obj(self, tmp_bucket_name, tmp_source): @@ -1099,7 +1105,14 @@ def test_bucket_sub_path(self, tmp_local_storage_obj_with_sub_path, store_type): # Creates a new bucket with a local source, uploads files to it # and deletes it. - tmp_local_storage_obj_with_sub_path.add_store(store_type) + region_kwargs = {} + if store_type == storage_lib.StoreType.AZURE: + # We have to specify the region for Azure storage, as the default + # Azure storage account is in centralus region. + region_kwargs['region'] = 'centralus' + + tmp_local_storage_obj_with_sub_path.add_store(store_type, + **region_kwargs) # Check files under bucket and filter by prefix files = self.list_all_files(store_type, @@ -1412,7 +1425,13 @@ def test_upload_to_existing_bucket(self, ext_bucket_fixture, request, # sky) and verifies that files are written. bucket_name, _ = request.getfixturevalue(ext_bucket_fixture) storage_obj = storage_lib.Storage(name=bucket_name, source=tmp_source) - storage_obj.add_store(store_type) + region_kwargs = {} + if store_type == storage_lib.StoreType.AZURE: + # We have to specify the region for Azure storage, as the default + # Azure storage account is in centralus region. + region_kwargs['region'] = 'centralus' + + storage_obj.add_store(store_type, **region_kwargs) # Check if tmp_source/tmp-file exists in the bucket using aws cli out = subprocess.check_output(self.cli_ls_cmd(store_type, bucket_name), @@ -1458,7 +1477,13 @@ def test_copy_mount_existing_storage(self, def test_list_source(self, tmp_local_list_storage_obj, store_type): # Uses a list in the source field to specify a file and a directory to # be uploaded to the storage object. - tmp_local_list_storage_obj.add_store(store_type) + region_kwargs = {} + if store_type == storage_lib.StoreType.AZURE: + # We have to specify the region for Azure storage, as the default + # Azure storage account is in centralus region. + region_kwargs['region'] = 'centralus' + + tmp_local_list_storage_obj.add_store(store_type, **region_kwargs) # Check if tmp-file exists in the bucket root using cli out = subprocess.check_output(self.cli_ls_cmd( @@ -1513,7 +1538,13 @@ def test_excluded_file_cloud_storage_upload_copy(self, gitignore_structure, tmp_gitignore_storage_obj): # tests if files included in .gitignore and .git/info/exclude are # excluded from being transferred to Storage - tmp_gitignore_storage_obj.add_store(store_type) + region_kwargs = {} + if store_type == storage_lib.StoreType.AZURE: + # We have to specify the region for Azure storage, as the default + # Azure storage account is in centralus region. + region_kwargs['region'] = 'centralus' + + tmp_gitignore_storage_obj.add_store(store_type, **region_kwargs) upload_file_name = 'included' # Count the number of files with the given file name up_cmd = self.cli_count_name_in_bucket(store_type, \