Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Storage] Azure bucket sub directory is ignored if the bucket previously exists #4706

Merged
merged 6 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 53 additions & 42 deletions sky/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ def from_metadata(cls, metadata: StoreMetadata, **override_args):
metadata.is_sky_managed),
sync_on_reconstruction=override_args.get('sync_on_reconstruction',
True),
# backward compatibility
# Backward compatibility
# TODO: remove the hasattr check after v0.11.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we saying that backward compatibility is removed after v0.11.0?
For example, v0.6.0 won't be compatible with v0.11.0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we have a convention that keep backward compatibilty for 3 major releases.

_bucket_sub_path=override_args.get(
'_bucket_sub_path',
metadata._bucket_sub_path # pylint: disable=protected-access
Expand Down Expand Up @@ -1462,15 +1463,15 @@ def batch_aws_rsync(self,
set to True, the directory is created in the bucket root and
contents are uploaded to it.
"""
sub_path = (f'/{self._bucket_sub_path}'
if self._bucket_sub_path else '')

def get_file_sync_command(base_dir_path, file_names):
includes = ' '.join([
f'--include {shlex.quote(file_name)}'
for file_name in file_names
])
base_dir_path = shlex.quote(base_dir_path)
sub_path = (f'/{self._bucket_sub_path}'
if self._bucket_sub_path else '')
sync_command = ('aws s3 sync --no-follow-symlinks --exclude="*" '
f'{includes} {base_dir_path} '
f's3://{self.name}{sub_path}')
Expand All @@ -1485,8 +1486,6 @@ def get_dir_sync_command(src_dir_path, dest_dir_name):
for file_name in excluded_list
])
src_dir_path = shlex.quote(src_dir_path)
sub_path = (f'/{self._bucket_sub_path}'
if self._bucket_sub_path else '')
sync_command = (f'aws s3 sync --no-follow-symlinks {excludes} '
f'{src_dir_path} '
f's3://{self.name}{sub_path}/{dest_dir_name}')
Expand All @@ -1500,7 +1499,7 @@ def get_dir_sync_command(src_dir_path, dest_dir_name):

log_path = sky_logging.generate_tmp_logging_file_path(
_STORAGE_LOG_FILE_NAME)
sync_path = f'{source_message} -> s3://{self.name}/'
sync_path = f'{source_message} -> s3://{self.name}{sub_path}/'
with rich_utils.safe_status(
ux_utils.spinner_message(f'Syncing {sync_path}',
log_path=log_path)):
Expand Down Expand Up @@ -1959,11 +1958,13 @@ def batch_gsutil_cp(self,
copy_list = '\n'.join(
os.path.abspath(os.path.expanduser(p)) for p in source_path_list)
gsutil_alias, alias_gen = data_utils.get_gsutil_command()
sub_path = (f'/{self._bucket_sub_path}'
if self._bucket_sub_path else '')
sync_command = (f'{alias_gen}; echo "{copy_list}" | {gsutil_alias} '
f'cp -e -n -r -I gs://{self.name}')
f'cp -e -n -r -I gs://{self.name}{sub_path}')
log_path = sky_logging.generate_tmp_logging_file_path(
_STORAGE_LOG_FILE_NAME)
sync_path = f'{source_message} -> gs://{self.name}/'
sync_path = f'{source_message} -> gs://{self.name}{sub_path}/'
with rich_utils.safe_status(
ux_utils.spinner_message(f'Syncing {sync_path}',
log_path=log_path)):
Expand Down Expand Up @@ -1995,13 +1996,13 @@ def batch_gsutil_rsync(self,
set to True, the directory is created in the bucket root and
contents are uploaded to it.
"""
sub_path = (f'/{self._bucket_sub_path}'
if self._bucket_sub_path else '')

def get_file_sync_command(base_dir_path, file_names):
sync_format = '|'.join(file_names)
gsutil_alias, alias_gen = data_utils.get_gsutil_command()
base_dir_path = shlex.quote(base_dir_path)
sub_path = (f'/{self._bucket_sub_path}'
if self._bucket_sub_path else '')
sync_command = (f'{alias_gen}; {gsutil_alias} '
f'rsync -e -x \'^(?!{sync_format}$).*\' '
f'{base_dir_path} gs://{self.name}{sub_path}')
Expand All @@ -2014,8 +2015,6 @@ def get_dir_sync_command(src_dir_path, dest_dir_name):
excludes = '|'.join(excluded_list)
gsutil_alias, alias_gen = data_utils.get_gsutil_command()
src_dir_path = shlex.quote(src_dir_path)
sub_path = (f'/{self._bucket_sub_path}'
if self._bucket_sub_path else '')
sync_command = (f'{alias_gen}; {gsutil_alias} '
f'rsync -e -r -x \'({excludes})\' {src_dir_path} '
f'gs://{self.name}{sub_path}/{dest_dir_name}')
Expand All @@ -2029,7 +2028,7 @@ def get_dir_sync_command(src_dir_path, dest_dir_name):

log_path = sky_logging.generate_tmp_logging_file_path(
_STORAGE_LOG_FILE_NAME)
sync_path = f'{source_message} -> gs://{self.name}/'
sync_path = f'{source_message} -> gs://{self.name}{sub_path}/'
with rich_utils.safe_status(
ux_utils.spinner_message(f'Syncing {sync_path}',
log_path=log_path)):
Expand Down Expand Up @@ -2307,15 +2306,24 @@ def from_metadata(cls, metadata: AbstractStore.StoreMetadata,
An instance of AzureBlobStore.
"""
assert isinstance(metadata, AzureBlobStore.AzureBlobStoreMetadata)
return cls(name=override_args.get('name', metadata.name),
storage_account_name=override_args.get(
'storage_account', metadata.storage_account_name),
source=override_args.get('source', metadata.source),
region=override_args.get('region', metadata.region),
is_sky_managed=override_args.get('is_sky_managed',
metadata.is_sky_managed),
sync_on_reconstruction=override_args.get(
'sync_on_reconstruction', True))
# TODO: this needs to be kept in sync with the abstract
# AbstractStore.from_metadata.
return cls(
name=override_args.get('name', metadata.name),
storage_account_name=override_args.get(
'storage_account', metadata.storage_account_name),
source=override_args.get('source', metadata.source),
region=override_args.get('region', metadata.region),
is_sky_managed=override_args.get('is_sky_managed',
metadata.is_sky_managed),
sync_on_reconstruction=override_args.get('sync_on_reconstruction',
True),
# Backward compatibility
# TODO: remove the hasattr check after v0.11.0
_bucket_sub_path=override_args.get(
'_bucket_sub_path',
metadata._bucket_sub_path # pylint: disable=protected-access
) if hasattr(metadata, '_bucket_sub_path') else None)

def get_metadata(self) -> AzureBlobStoreMetadata:
return self.AzureBlobStoreMetadata(
Expand Down Expand Up @@ -2795,6 +2803,8 @@ def batch_az_blob_sync(self,
set to True, the directory is created in the bucket root and
contents are uploaded to it.
"""
container_path = (f'{self.container_name}/{self._bucket_sub_path}'
if self._bucket_sub_path else self.container_name)

def get_file_sync_command(base_dir_path, file_names) -> str:
# shlex.quote is not used for file_names as 'az storage blob sync'
Expand All @@ -2803,8 +2813,6 @@ def get_file_sync_command(base_dir_path, file_names) -> str:
includes_list = ';'.join(file_names)
includes = f'--include-pattern "{includes_list}"'
base_dir_path = shlex.quote(base_dir_path)
container_path = (f'{self.container_name}/{self._bucket_sub_path}'
if self._bucket_sub_path else self.container_name)
sync_command = (f'az storage blob sync '
f'--account-name {self.storage_account_name} '
f'--account-key {self.storage_account_key} '
Expand All @@ -2822,18 +2830,17 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str:
[file_name.rstrip('*') for file_name in excluded_list])
excludes = f'--exclude-path "{excludes_list}"'
src_dir_path = shlex.quote(src_dir_path)
container_path = (f'{self.container_name}/{self._bucket_sub_path}'
if self._bucket_sub_path else
f'{self.container_name}')
if dest_dir_name:
container_path = f'{container_path}/{dest_dir_name}'
dest_dir_name = f'/{dest_dir_name}'
else:
dest_dir_name = ''
sync_command = (f'az storage blob sync '
f'--account-name {self.storage_account_name} '
f'--account-key {self.storage_account_key} '
f'{excludes} '
'--delete-destination false '
f'--source {src_dir_path} '
f'--container {container_path}')
f'--container {container_path}{dest_dir_name}')
return sync_command

# Generate message for upload
Expand All @@ -2844,7 +2851,7 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str:
source_message = source_path_list[0]
container_endpoint = data_utils.AZURE_CONTAINER_URL.format(
storage_account_name=self.storage_account_name,
container_name=self.name)
container_name=container_path)
log_path = sky_logging.generate_tmp_logging_file_path(
_STORAGE_LOG_FILE_NAME)
sync_path = f'{source_message} -> {container_endpoint}/'
Expand Down Expand Up @@ -3238,6 +3245,8 @@ def batch_aws_rsync(self,
set to True, the directory is created in the bucket root and
contents are uploaded to it.
"""
sub_path = (f'/{self._bucket_sub_path}'
if self._bucket_sub_path else '')

def get_file_sync_command(base_dir_path, file_names):
includes = ' '.join([
Expand All @@ -3246,8 +3255,6 @@ def get_file_sync_command(base_dir_path, file_names):
])
endpoint_url = cloudflare.create_endpoint()
base_dir_path = shlex.quote(base_dir_path)
sub_path = (f'/{self._bucket_sub_path}'
if self._bucket_sub_path else '')
sync_command = ('AWS_SHARED_CREDENTIALS_FILE='
f'{cloudflare.R2_CREDENTIALS_PATH} '
'aws s3 sync --no-follow-symlinks --exclude="*" '
Expand All @@ -3267,8 +3274,6 @@ def get_dir_sync_command(src_dir_path, dest_dir_name):
])
endpoint_url = cloudflare.create_endpoint()
src_dir_path = shlex.quote(src_dir_path)
sub_path = (f'/{self._bucket_sub_path}'
if self._bucket_sub_path else '')
sync_command = ('AWS_SHARED_CREDENTIALS_FILE='
f'{cloudflare.R2_CREDENTIALS_PATH} '
f'aws s3 sync --no-follow-symlinks {excludes} '
Expand All @@ -3286,7 +3291,7 @@ def get_dir_sync_command(src_dir_path, dest_dir_name):

log_path = sky_logging.generate_tmp_logging_file_path(
_STORAGE_LOG_FILE_NAME)
sync_path = f'{source_message} -> r2://{self.name}/'
sync_path = f'{source_message} -> r2://{self.name}{sub_path}/'
with rich_utils.safe_status(
ux_utils.spinner_message(f'Syncing {sync_path}',
log_path=log_path)):
Expand Down Expand Up @@ -3710,6 +3715,8 @@ def batch_ibm_rsync(self,
set to True, the directory is created in the bucket root and
contents are uploaded to it.
"""
sub_path = (f'/{self._bucket_sub_path}'
if self._bucket_sub_path else '')

def get_dir_sync_command(src_dir_path, dest_dir_name) -> str:
"""returns an rclone command that copies a complete folder
Expand All @@ -3731,8 +3738,6 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str:
# .git directory is excluded from the sync
# wrapping src_dir_path with "" to support path with spaces
src_dir_path = shlex.quote(src_dir_path)
sub_path = (f'/{self._bucket_sub_path}'
if self._bucket_sub_path else '')
sync_command = (
'rclone copy --exclude ".git/*" '
f'{src_dir_path} '
Expand Down Expand Up @@ -3763,8 +3768,6 @@ def get_file_sync_command(base_dir_path, file_names) -> str:
for file_name in file_names
])
base_dir_path = shlex.quote(base_dir_path)
sub_path = (f'/{self._bucket_sub_path}'
if self._bucket_sub_path else '')
sync_command = (
'rclone copy '
f'{includes} {base_dir_path} '
Expand All @@ -3779,7 +3782,8 @@ def get_file_sync_command(base_dir_path, file_names) -> str:

log_path = sky_logging.generate_tmp_logging_file_path(
_STORAGE_LOG_FILE_NAME)
sync_path = f'{source_message} -> cos://{self.region}/{self.name}/'
sync_path = (
f'{source_message} -> cos://{self.region}/{self.name}{sub_path}/')
with rich_utils.safe_status(
ux_utils.spinner_message(f'Syncing {sync_path}',
log_path=log_path)):
Expand Down Expand Up @@ -4178,15 +4182,21 @@ def batch_oci_rsync(self,
set to True, the directory is created in the bucket root and
contents are uploaded to it.
"""
sub_path = (f'{self._bucket_sub_path}/'
if self._bucket_sub_path else '')

@oci.with_oci_env
def get_file_sync_command(base_dir_path, file_names):
includes = ' '.join(
[f'--include "{file_name}"' for file_name in file_names])
prefix_arg = ''
if sub_path:
prefix_arg = f'--object-prefix "{sub_path.strip("/")}"'
sync_command = (
'oci os object bulk-upload --no-follow-symlinks --overwrite '
f'--bucket-name {self.name} --namespace-name {self.namespace} '
f'--region {self.region} --src-dir "{base_dir_path}" '
f'{prefix_arg} '
f'{includes}')

return sync_command
Expand All @@ -4207,7 +4217,8 @@ def get_dir_sync_command(src_dir_path, dest_dir_name):
sync_command = (
'oci os object bulk-upload --no-follow-symlinks --overwrite '
f'--bucket-name {self.name} --namespace-name {self.namespace} '
f'--region {self.region} --object-prefix "{dest_dir_name}" '
f'--region {self.region} '
f'--object-prefix "{sub_path}{dest_dir_name}" '
f'--src-dir "{src_dir_path}" {excludes}')

return sync_command
Expand All @@ -4220,7 +4231,7 @@ def get_dir_sync_command(src_dir_path, dest_dir_name):

log_path = sky_logging.generate_tmp_logging_file_path(
_STORAGE_LOG_FILE_NAME)
sync_path = f'{source_message} -> oci://{self.name}/'
sync_path = f'{source_message} -> oci://{self.name}/{sub_path}'
with rich_utils.safe_status(
ux_utils.spinner_message(f'Syncing {sync_path}',
log_path=log_path)):
Expand Down
Loading