From 5e136c250d7a2fc7d33876aa75da208198975f76 Mon Sep 17 00:00:00 2001 From: Kaiyuan Eric Chen Date: Thu, 13 Feb 2025 16:10:35 -0800 Subject: [PATCH 1/2] [Docs] Add vector database tutorial to documentation (#4713) * add vdb * add vdb to ai gallery * fix warning --- README.md | 3 ++- .../_gallery_original/applications/vector_database.md | 1 + docs/source/_gallery_original/index.rst | 1 + docs/source/docs/index.rst | 2 +- examples/vector_database/README.md | 10 +++++----- 5 files changed, 10 insertions(+), 7 deletions(-) create mode 120000 docs/source/_gallery_original/applications/vector_database.md diff --git a/README.md b/README.md index 8a3361f9f41..ce3ccac8606 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ ---- :fire: *News* :fire: +- [Jan 2025] Prepare and Serve Large-Scale Image Search with **Vector Database**: [**blog post**](https://blog.skypilot.co/large-scale-vector-database/) [**example**](./examples/vector_database/) - [Jan 2025] Launch and Serve **[DeepSeek-R1](https://github.com/deepseek-ai/DeepSeek-R1)** and **[Janus](https://github.com/deepseek-ai/DeepSeek-Janus)** on Kubernetes or Any Cloud: [**R1 example**](./llm/deepseek-r1/) and [**Janus example**](./llm/deepseek-janus/) - [Oct 2024] :tada: **SkyPilot crossed 1M+ downloads** :tada:: Thank you to our community! [**Twitter/X**](https://x.com/skypilot_org/status/1844770841718067638) - [Sep 2024] Point, Launch and Serve **Llama 3.2** on Kubernetes or Any Cloud: [**example**](./llm/llama-3_2/) @@ -187,7 +188,7 @@ Runnable examples: - [LocalGPT](./llm/localgpt) - [Falcon](./llm/falcon) - Add yours here & see more in [`llm/`](./llm)! -- Framework examples: [PyTorch DDP](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_distributed_torch.yaml), [DeepSpeed](./examples/deepspeed-multinode/sky.yaml), [JAX/Flax on TPU](https://github.com/skypilot-org/skypilot/blob/master/examples/tpu/tpuvm_mnist.yaml), [Stable Diffusion](https://github.com/skypilot-org/skypilot/tree/master/examples/stable_diffusion), [Detectron2](https://github.com/skypilot-org/skypilot/blob/master/examples/detectron2_docker.yaml), [Distributed](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_distributed_tf_app.py) [TensorFlow](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_app_storage.yaml), [Ray Train](examples/distributed_ray_train/ray_train.yaml), [NeMo](https://github.com/skypilot-org/skypilot/blob/master/examples/nemo/), [programmatic grid search](https://github.com/skypilot-org/skypilot/blob/master/examples/huggingface_glue_imdb_grid_search_app.py), [Docker](https://github.com/skypilot-org/skypilot/blob/master/examples/docker/echo_app.yaml), [Cog](https://github.com/skypilot-org/skypilot/blob/master/examples/cog/), [Unsloth](https://github.com/skypilot-org/skypilot/blob/master/examples/unsloth/unsloth.yaml), [Ollama](https://github.com/skypilot-org/skypilot/blob/master/llm/ollama), [llm.c](https://github.com/skypilot-org/skypilot/tree/master/llm/gpt-2), [Airflow](./examples/airflow/training_workflow) and [many more (`examples/`)](./examples). +- Framework examples: [Vector Database](./examples/vector_database/), [PyTorch DDP](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_distributed_torch.yaml), [DeepSpeed](./examples/deepspeed-multinode/sky.yaml), [JAX/Flax on TPU](https://github.com/skypilot-org/skypilot/blob/master/examples/tpu/tpuvm_mnist.yaml), [Stable Diffusion](https://github.com/skypilot-org/skypilot/tree/master/examples/stable_diffusion), [Detectron2](https://github.com/skypilot-org/skypilot/blob/master/examples/detectron2_docker.yaml), [Distributed](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_distributed_tf_app.py) [TensorFlow](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_app_storage.yaml), [Ray Train](examples/distributed_ray_train/ray_train.yaml), [NeMo](https://github.com/skypilot-org/skypilot/blob/master/examples/nemo/), [programmatic grid search](https://github.com/skypilot-org/skypilot/blob/master/examples/huggingface_glue_imdb_grid_search_app.py), [Docker](https://github.com/skypilot-org/skypilot/blob/master/examples/docker/echo_app.yaml), [Cog](https://github.com/skypilot-org/skypilot/blob/master/examples/cog/), [Unsloth](https://github.com/skypilot-org/skypilot/blob/master/examples/unsloth/unsloth.yaml), [Ollama](https://github.com/skypilot-org/skypilot/blob/master/llm/ollama), [llm.c](https://github.com/skypilot-org/skypilot/tree/master/llm/gpt-2), [Airflow](./examples/airflow/training_workflow) and [many more (`examples/`)](./examples). Case Studies and Integrations: [Community Spotlights](https://blog.skypilot.co/community/) diff --git a/docs/source/_gallery_original/applications/vector_database.md b/docs/source/_gallery_original/applications/vector_database.md new file mode 120000 index 00000000000..ebcd50df736 --- /dev/null +++ b/docs/source/_gallery_original/applications/vector_database.md @@ -0,0 +1 @@ +../../../../examples/vector_database/README.md \ No newline at end of file diff --git a/docs/source/_gallery_original/index.rst b/docs/source/_gallery_original/index.rst index e049a4ad322..8e0d0b16c35 100644 --- a/docs/source/_gallery_original/index.rst +++ b/docs/source/_gallery_original/index.rst @@ -50,6 +50,7 @@ Contents :maxdepth: 1 :caption: Applications + Image Vector Database Tabby: Coding Assistant LocalGPT: Chat with PDF diff --git a/docs/source/docs/index.rst b/docs/source/docs/index.rst index 2e9ca6859c6..ea5d6c6c18e 100644 --- a/docs/source/docs/index.rst +++ b/docs/source/docs/index.rst @@ -108,7 +108,7 @@ Runnable examples: * `LocalGPT `_ * Add yours here & see more in `llm/ `_! -* Framework examples: `PyTorch DDP `_, `DeepSpeed `_, `JAX/Flax on TPU `_, `Stable Diffusion `_, `Detectron2 `_, `Distributed `_ `TensorFlow `_, `NeMo `_, `programmatic grid search `_, `Docker `_, `Cog `_, `Unsloth `_, `Ollama `_, `llm.c `__, `Airflow `_ and `many more `_. +* Framework examples: `Vector Database `_, `PyTorch DDP `_, `DeepSpeed `_, `JAX/Flax on TPU `_, `Stable Diffusion `_, `Detectron2 `_, `Distributed `_ `TensorFlow `_, `NeMo `_, `programmatic grid search `_, `Docker `_, `Cog `_, `Unsloth `_, `Ollama `_, `llm.c `__, `Airflow `_ and `many more `_. Case Studies and Integrations: `Community Spotlights `_ diff --git a/examples/vector_database/README.md b/examples/vector_database/README.md index f127d2c176e..20581cc421c 100644 --- a/examples/vector_database/README.md +++ b/examples/vector_database/README.md @@ -4,7 +4,7 @@ VectorDB with SkyPilot

-### Large-Scale Image Search +## Large-Scale Image Search As the volume of image data grows, the need for efficient and powerful search methods becomes critical. Traditional keyword-based or metadata-based search often fails to capture the full semantic meaning in images. A vector database enables semantic search: you can find images that conceptually match a query (e.g., "a photo of a cloud") rather than relying on textual tags. In particular: @@ -17,7 +17,7 @@ SkyPilot streamlines the process of running such large-scale jobs in the cloud. Please find the complete blog post [here](https://blog.skypilot.co/large-scale-vector-database/) -### Step 0: Set Up The Environment +## Step 0: Set Up The Environment Install the following Prerequisites: * SkyPilot: Make sure you have SkyPilot installed and `sky check` should succeed. Refer to [SkyPilot’s documentation](https://docs.skypilot.co/en/latest/getting-started/installation.html) for instructions. * Hugging Face Token: To download dataset from Hugging Face Hub, you will need your token. Follow the steps below to configure your token. @@ -28,7 +28,7 @@ HF_TOKEN=hf_xxxxx ``` or set up the environment variable `HF_TOKEN`. -### Step 1: Compute Vectors from Image Data with OpenAI CLIP +## Step 1: Compute Vectors from Image Data with OpenAI CLIP You need to convert images into vector representations (embeddings) so they can be stored in a vector database. Models like [CLIP by OpenAI](https://openai.com/index/clip/) learn powerful representations that map images and text into the same embedding space. This allows for semantic similarity calculations, making queries like “a photo of a cloud” match relevant images. Use the following command to launch a job that processes your image dataset and computes the CLIP embeddings: @@ -51,7 +51,7 @@ You can also use `sky jobs queue` and `sky jobs dashboard` to see the status of SkyPilot Dashboard

-### Step 2: Construct the Vector Database from Computed Embeddings +## Step 2: Construct the Vector Database from Computed Embeddings Once you have the image embeddings, you need a specialized engine to perform rapid similarity searches at scale. In this example, we use [ChromaDB](https://docs.trychroma.com/getting-started) to store and query the embeddings. This step ingests the embeddings from Step 1 into a vector database to enable real-time or near real-time search over millions of vectors. To construct the database from embeddings: @@ -68,7 +68,7 @@ Processing batches: 100%|██████████| 1/1 [00:02<00:00, 2.39 Processing files: 100%|██████████| 12/12 [00:05<00:00, 2.04it/s]/1 [00:00 Date: Thu, 13 Feb 2025 17:39:06 -0800 Subject: [PATCH 2/2] [Storage] Azure bucket sub directory is ignored if the bucket previously exists (#4706) * Set sub path for azure storage * Full path for the ux * fix variable reference * minor * enforce azure region for bucket * Delete storage even when failure happens --- sky/data/storage.py | 95 ++++++++++++--------- tests/smoke_tests/test_mount_and_storage.py | 93 +++++++++++++------- 2 files changed, 115 insertions(+), 73 deletions(-) diff --git a/sky/data/storage.py b/sky/data/storage.py index 5dce3f0a0d8..c3ccb3dfc67 100644 --- a/sky/data/storage.py +++ b/sky/data/storage.py @@ -354,7 +354,8 @@ def from_metadata(cls, metadata: StoreMetadata, **override_args): metadata.is_sky_managed), sync_on_reconstruction=override_args.get('sync_on_reconstruction', True), - # backward compatibility + # Backward compatibility + # TODO: remove the hasattr check after v0.11.0 _bucket_sub_path=override_args.get( '_bucket_sub_path', metadata._bucket_sub_path # pylint: disable=protected-access @@ -1462,6 +1463,8 @@ def batch_aws_rsync(self, set to True, the directory is created in the bucket root and contents are uploaded to it. """ + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') def get_file_sync_command(base_dir_path, file_names): includes = ' '.join([ @@ -1469,8 +1472,6 @@ def get_file_sync_command(base_dir_path, file_names): for file_name in file_names ]) base_dir_path = shlex.quote(base_dir_path) - sub_path = (f'/{self._bucket_sub_path}' - if self._bucket_sub_path else '') sync_command = ('aws s3 sync --no-follow-symlinks --exclude="*" ' f'{includes} {base_dir_path} ' f's3://{self.name}{sub_path}') @@ -1485,8 +1486,6 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): for file_name in excluded_list ]) src_dir_path = shlex.quote(src_dir_path) - sub_path = (f'/{self._bucket_sub_path}' - if self._bucket_sub_path else '') sync_command = (f'aws s3 sync --no-follow-symlinks {excludes} ' f'{src_dir_path} ' f's3://{self.name}{sub_path}/{dest_dir_name}') @@ -1500,7 +1499,7 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): log_path = sky_logging.generate_tmp_logging_file_path( _STORAGE_LOG_FILE_NAME) - sync_path = f'{source_message} -> s3://{self.name}/' + sync_path = f'{source_message} -> s3://{self.name}{sub_path}/' with rich_utils.safe_status( ux_utils.spinner_message(f'Syncing {sync_path}', log_path=log_path)): @@ -1959,11 +1958,13 @@ def batch_gsutil_cp(self, copy_list = '\n'.join( os.path.abspath(os.path.expanduser(p)) for p in source_path_list) gsutil_alias, alias_gen = data_utils.get_gsutil_command() + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = (f'{alias_gen}; echo "{copy_list}" | {gsutil_alias} ' - f'cp -e -n -r -I gs://{self.name}') + f'cp -e -n -r -I gs://{self.name}{sub_path}') log_path = sky_logging.generate_tmp_logging_file_path( _STORAGE_LOG_FILE_NAME) - sync_path = f'{source_message} -> gs://{self.name}/' + sync_path = f'{source_message} -> gs://{self.name}{sub_path}/' with rich_utils.safe_status( ux_utils.spinner_message(f'Syncing {sync_path}', log_path=log_path)): @@ -1995,13 +1996,13 @@ def batch_gsutil_rsync(self, set to True, the directory is created in the bucket root and contents are uploaded to it. """ + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') def get_file_sync_command(base_dir_path, file_names): sync_format = '|'.join(file_names) gsutil_alias, alias_gen = data_utils.get_gsutil_command() base_dir_path = shlex.quote(base_dir_path) - sub_path = (f'/{self._bucket_sub_path}' - if self._bucket_sub_path else '') sync_command = (f'{alias_gen}; {gsutil_alias} ' f'rsync -e -x \'^(?!{sync_format}$).*\' ' f'{base_dir_path} gs://{self.name}{sub_path}') @@ -2014,8 +2015,6 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): excludes = '|'.join(excluded_list) gsutil_alias, alias_gen = data_utils.get_gsutil_command() src_dir_path = shlex.quote(src_dir_path) - sub_path = (f'/{self._bucket_sub_path}' - if self._bucket_sub_path else '') sync_command = (f'{alias_gen}; {gsutil_alias} ' f'rsync -e -r -x \'({excludes})\' {src_dir_path} ' f'gs://{self.name}{sub_path}/{dest_dir_name}') @@ -2029,7 +2028,7 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): log_path = sky_logging.generate_tmp_logging_file_path( _STORAGE_LOG_FILE_NAME) - sync_path = f'{source_message} -> gs://{self.name}/' + sync_path = f'{source_message} -> gs://{self.name}{sub_path}/' with rich_utils.safe_status( ux_utils.spinner_message(f'Syncing {sync_path}', log_path=log_path)): @@ -2307,15 +2306,24 @@ def from_metadata(cls, metadata: AbstractStore.StoreMetadata, An instance of AzureBlobStore. """ assert isinstance(metadata, AzureBlobStore.AzureBlobStoreMetadata) - return cls(name=override_args.get('name', metadata.name), - storage_account_name=override_args.get( - 'storage_account', metadata.storage_account_name), - source=override_args.get('source', metadata.source), - region=override_args.get('region', metadata.region), - is_sky_managed=override_args.get('is_sky_managed', - metadata.is_sky_managed), - sync_on_reconstruction=override_args.get( - 'sync_on_reconstruction', True)) + # TODO: this needs to be kept in sync with the abstract + # AbstractStore.from_metadata. + return cls( + name=override_args.get('name', metadata.name), + storage_account_name=override_args.get( + 'storage_account', metadata.storage_account_name), + source=override_args.get('source', metadata.source), + region=override_args.get('region', metadata.region), + is_sky_managed=override_args.get('is_sky_managed', + metadata.is_sky_managed), + sync_on_reconstruction=override_args.get('sync_on_reconstruction', + True), + # Backward compatibility + # TODO: remove the hasattr check after v0.11.0 + _bucket_sub_path=override_args.get( + '_bucket_sub_path', + metadata._bucket_sub_path # pylint: disable=protected-access + ) if hasattr(metadata, '_bucket_sub_path') else None) def get_metadata(self) -> AzureBlobStoreMetadata: return self.AzureBlobStoreMetadata( @@ -2795,6 +2803,8 @@ def batch_az_blob_sync(self, set to True, the directory is created in the bucket root and contents are uploaded to it. """ + container_path = (f'{self.container_name}/{self._bucket_sub_path}' + if self._bucket_sub_path else self.container_name) def get_file_sync_command(base_dir_path, file_names) -> str: # shlex.quote is not used for file_names as 'az storage blob sync' @@ -2803,8 +2813,6 @@ def get_file_sync_command(base_dir_path, file_names) -> str: includes_list = ';'.join(file_names) includes = f'--include-pattern "{includes_list}"' base_dir_path = shlex.quote(base_dir_path) - container_path = (f'{self.container_name}/{self._bucket_sub_path}' - if self._bucket_sub_path else self.container_name) sync_command = (f'az storage blob sync ' f'--account-name {self.storage_account_name} ' f'--account-key {self.storage_account_key} ' @@ -2822,18 +2830,17 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: [file_name.rstrip('*') for file_name in excluded_list]) excludes = f'--exclude-path "{excludes_list}"' src_dir_path = shlex.quote(src_dir_path) - container_path = (f'{self.container_name}/{self._bucket_sub_path}' - if self._bucket_sub_path else - f'{self.container_name}') if dest_dir_name: - container_path = f'{container_path}/{dest_dir_name}' + dest_dir_name = f'/{dest_dir_name}' + else: + dest_dir_name = '' sync_command = (f'az storage blob sync ' f'--account-name {self.storage_account_name} ' f'--account-key {self.storage_account_key} ' f'{excludes} ' '--delete-destination false ' f'--source {src_dir_path} ' - f'--container {container_path}') + f'--container {container_path}{dest_dir_name}') return sync_command # Generate message for upload @@ -2844,7 +2851,7 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: source_message = source_path_list[0] container_endpoint = data_utils.AZURE_CONTAINER_URL.format( storage_account_name=self.storage_account_name, - container_name=self.name) + container_name=container_path) log_path = sky_logging.generate_tmp_logging_file_path( _STORAGE_LOG_FILE_NAME) sync_path = f'{source_message} -> {container_endpoint}/' @@ -3238,6 +3245,8 @@ def batch_aws_rsync(self, set to True, the directory is created in the bucket root and contents are uploaded to it. """ + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') def get_file_sync_command(base_dir_path, file_names): includes = ' '.join([ @@ -3246,8 +3255,6 @@ def get_file_sync_command(base_dir_path, file_names): ]) endpoint_url = cloudflare.create_endpoint() base_dir_path = shlex.quote(base_dir_path) - sub_path = (f'/{self._bucket_sub_path}' - if self._bucket_sub_path else '') sync_command = ('AWS_SHARED_CREDENTIALS_FILE=' f'{cloudflare.R2_CREDENTIALS_PATH} ' 'aws s3 sync --no-follow-symlinks --exclude="*" ' @@ -3267,8 +3274,6 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): ]) endpoint_url = cloudflare.create_endpoint() src_dir_path = shlex.quote(src_dir_path) - sub_path = (f'/{self._bucket_sub_path}' - if self._bucket_sub_path else '') sync_command = ('AWS_SHARED_CREDENTIALS_FILE=' f'{cloudflare.R2_CREDENTIALS_PATH} ' f'aws s3 sync --no-follow-symlinks {excludes} ' @@ -3286,7 +3291,7 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): log_path = sky_logging.generate_tmp_logging_file_path( _STORAGE_LOG_FILE_NAME) - sync_path = f'{source_message} -> r2://{self.name}/' + sync_path = f'{source_message} -> r2://{self.name}{sub_path}/' with rich_utils.safe_status( ux_utils.spinner_message(f'Syncing {sync_path}', log_path=log_path)): @@ -3710,6 +3715,8 @@ def batch_ibm_rsync(self, set to True, the directory is created in the bucket root and contents are uploaded to it. """ + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: """returns an rclone command that copies a complete folder @@ -3731,8 +3738,6 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: # .git directory is excluded from the sync # wrapping src_dir_path with "" to support path with spaces src_dir_path = shlex.quote(src_dir_path) - sub_path = (f'/{self._bucket_sub_path}' - if self._bucket_sub_path else '') sync_command = ( 'rclone copy --exclude ".git/*" ' f'{src_dir_path} ' @@ -3763,8 +3768,6 @@ def get_file_sync_command(base_dir_path, file_names) -> str: for file_name in file_names ]) base_dir_path = shlex.quote(base_dir_path) - sub_path = (f'/{self._bucket_sub_path}' - if self._bucket_sub_path else '') sync_command = ( 'rclone copy ' f'{includes} {base_dir_path} ' @@ -3779,7 +3782,8 @@ def get_file_sync_command(base_dir_path, file_names) -> str: log_path = sky_logging.generate_tmp_logging_file_path( _STORAGE_LOG_FILE_NAME) - sync_path = f'{source_message} -> cos://{self.region}/{self.name}/' + sync_path = ( + f'{source_message} -> cos://{self.region}/{self.name}{sub_path}/') with rich_utils.safe_status( ux_utils.spinner_message(f'Syncing {sync_path}', log_path=log_path)): @@ -4178,15 +4182,21 @@ def batch_oci_rsync(self, set to True, the directory is created in the bucket root and contents are uploaded to it. """ + sub_path = (f'{self._bucket_sub_path}/' + if self._bucket_sub_path else '') @oci.with_oci_env def get_file_sync_command(base_dir_path, file_names): includes = ' '.join( [f'--include "{file_name}"' for file_name in file_names]) + prefix_arg = '' + if sub_path: + prefix_arg = f'--object-prefix "{sub_path.strip("/")}"' sync_command = ( 'oci os object bulk-upload --no-follow-symlinks --overwrite ' f'--bucket-name {self.name} --namespace-name {self.namespace} ' f'--region {self.region} --src-dir "{base_dir_path}" ' + f'{prefix_arg} ' f'{includes}') return sync_command @@ -4207,7 +4217,8 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): sync_command = ( 'oci os object bulk-upload --no-follow-symlinks --overwrite ' f'--bucket-name {self.name} --namespace-name {self.namespace} ' - f'--region {self.region} --object-prefix "{dest_dir_name}" ' + f'--region {self.region} ' + f'--object-prefix "{sub_path}{dest_dir_name}" ' f'--src-dir "{src_dir_path}" {excludes}') return sync_command @@ -4220,7 +4231,7 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): log_path = sky_logging.generate_tmp_logging_file_path( _STORAGE_LOG_FILE_NAME) - sync_path = f'{source_message} -> oci://{self.name}/' + sync_path = f'{source_message} -> oci://{self.name}/{sub_path}' with rich_utils.safe_status( ux_utils.spinner_message(f'Syncing {sync_path}', log_path=log_path)): diff --git a/tests/smoke_tests/test_mount_and_storage.py b/tests/smoke_tests/test_mount_and_storage.py index c7f3e356c0d..13c3e118f75 100644 --- a/tests/smoke_tests/test_mount_and_storage.py +++ b/tests/smoke_tests/test_mount_and_storage.py @@ -854,16 +854,18 @@ def yield_storage_object( persistent=persistent, mode=mode, _bucket_sub_path=_bucket_sub_path) - yield storage_obj - handle = global_user_state.get_handle_from_storage_name( - storage_obj.name) - if handle: - # If handle exists, delete manually - # TODO(romilb): This is potentially risky - if the delete method has - # bugs, this can cause resource leaks. Ideally we should manually - # eject storage from global_user_state and delete the bucket using - # boto3 directly. - storage_obj.delete() + try: + yield storage_obj + finally: + handle = global_user_state.get_handle_from_storage_name( + storage_obj.name) + if handle: + # If handle exists, delete manually + # TODO(romilb): This is potentially risky - if the delete method has + # bugs, this can cause resource leaks. Ideally we should manually + # eject storage from global_user_state and delete the bucket using + # boto3 directly. + storage_obj.delete() @pytest.fixture def tmp_scratch_storage_obj(self, tmp_bucket_name): @@ -881,17 +883,19 @@ def tmp_multiple_scratch_storage_obj(self): timestamp = str(time.time()).replace('.', '') store_obj = storage_lib.Storage(name=f'sky-test-{timestamp}') storage_mult_obj.append(store_obj) - yield storage_mult_obj - for storage_obj in storage_mult_obj: - handle = global_user_state.get_handle_from_storage_name( - storage_obj.name) - if handle: - # If handle exists, delete manually - # TODO(romilb): This is potentially risky - if the delete method has - # bugs, this can cause resource leaks. Ideally we should manually - # eject storage from global_user_state and delete the bucket using - # boto3 directly. - storage_obj.delete() + try: + yield storage_mult_obj + finally: + for storage_obj in storage_mult_obj: + handle = global_user_state.get_handle_from_storage_name( + storage_obj.name) + if handle: + # If handle exists, delete manually + # TODO(romilb): This is potentially risky - if the delete method has + # bugs, this can cause resource leaks. Ideally we should manually + # eject storage from global_user_state and delete the bucket using + # boto3 directly. + storage_obj.delete() @pytest.fixture def tmp_multiple_custom_source_storage_obj(self): @@ -907,12 +911,14 @@ def tmp_multiple_custom_source_storage_obj(self): store_obj = storage_lib.Storage(name=f'sky-test-{timestamp}', source=src_path) storage_mult_obj.append(store_obj) - yield storage_mult_obj - for storage_obj in storage_mult_obj: - handle = global_user_state.get_handle_from_storage_name( - storage_obj.name) - if handle: - storage_obj.delete() + try: + yield storage_mult_obj + finally: + for storage_obj in storage_mult_obj: + handle = global_user_state.get_handle_from_storage_name( + storage_obj.name) + if handle: + storage_obj.delete() @pytest.fixture def tmp_local_storage_obj(self, tmp_bucket_name, tmp_source): @@ -1099,7 +1105,14 @@ def test_bucket_sub_path(self, tmp_local_storage_obj_with_sub_path, store_type): # Creates a new bucket with a local source, uploads files to it # and deletes it. - tmp_local_storage_obj_with_sub_path.add_store(store_type) + region_kwargs = {} + if store_type == storage_lib.StoreType.AZURE: + # We have to specify the region for Azure storage, as the default + # Azure storage account is in centralus region. + region_kwargs['region'] = 'centralus' + + tmp_local_storage_obj_with_sub_path.add_store(store_type, + **region_kwargs) # Check files under bucket and filter by prefix files = self.list_all_files(store_type, @@ -1412,7 +1425,13 @@ def test_upload_to_existing_bucket(self, ext_bucket_fixture, request, # sky) and verifies that files are written. bucket_name, _ = request.getfixturevalue(ext_bucket_fixture) storage_obj = storage_lib.Storage(name=bucket_name, source=tmp_source) - storage_obj.add_store(store_type) + region_kwargs = {} + if store_type == storage_lib.StoreType.AZURE: + # We have to specify the region for Azure storage, as the default + # Azure storage account is in centralus region. + region_kwargs['region'] = 'centralus' + + storage_obj.add_store(store_type, **region_kwargs) # Check if tmp_source/tmp-file exists in the bucket using aws cli out = subprocess.check_output(self.cli_ls_cmd(store_type, bucket_name), @@ -1458,7 +1477,13 @@ def test_copy_mount_existing_storage(self, def test_list_source(self, tmp_local_list_storage_obj, store_type): # Uses a list in the source field to specify a file and a directory to # be uploaded to the storage object. - tmp_local_list_storage_obj.add_store(store_type) + region_kwargs = {} + if store_type == storage_lib.StoreType.AZURE: + # We have to specify the region for Azure storage, as the default + # Azure storage account is in centralus region. + region_kwargs['region'] = 'centralus' + + tmp_local_list_storage_obj.add_store(store_type, **region_kwargs) # Check if tmp-file exists in the bucket root using cli out = subprocess.check_output(self.cli_ls_cmd( @@ -1513,7 +1538,13 @@ def test_excluded_file_cloud_storage_upload_copy(self, gitignore_structure, tmp_gitignore_storage_obj): # tests if files included in .gitignore and .git/info/exclude are # excluded from being transferred to Storage - tmp_gitignore_storage_obj.add_store(store_type) + region_kwargs = {} + if store_type == storage_lib.StoreType.AZURE: + # We have to specify the region for Azure storage, as the default + # Azure storage account is in centralus region. + region_kwargs['region'] = 'centralus' + + tmp_gitignore_storage_obj.add_store(store_type, **region_kwargs) upload_file_name = 'included' # Count the number of files with the given file name up_cmd = self.cli_count_name_in_bucket(store_type, \