diff --git a/snakemake_storage_plugin_azure/__init__.py b/snakemake_storage_plugin_azure/__init__.py index c7f8b7a..e89831c 100644 --- a/snakemake_storage_plugin_azure/__init__.py +++ b/snakemake_storage_plugin_azure/__init__.py @@ -1,5 +1,6 @@ import os from dataclasses import dataclass, field +from pathlib import Path from typing import Iterable, List, Optional from urllib.parse import urlparse @@ -9,9 +10,19 @@ EnvironmentCredential, ManagedIdentityCredential, ) -from azure.storage.blob import BlobClient, BlobServiceClient, ContainerClient +from azure.storage.blob import ( + BlobClient, + BlobProperties, + BlobServiceClient, + ContainerClient, +) +from snakemake_interface_common.exceptions import WorkflowError from snakemake_interface_storage_plugins.common import Operation -from snakemake_interface_storage_plugins.io import IOCacheStorageInterface, Mtime +from snakemake_interface_storage_plugins.io import ( + IOCacheStorageInterface, + Mtime, + get_constant_prefix, +) from snakemake_interface_storage_plugins.settings import StorageProviderSettingsBase from snakemake_interface_storage_plugins.storage_object import ( StorageObjectGlob, @@ -82,7 +93,7 @@ def __post_init__(self): # use mock storage credential for tests test_credential = os.getenv("AZURITE_CONNECTION_STRING") if test_credential: - self.blob_account_client = BlobServiceClient.from_connection_string( + self.bsc: BlobServiceClient = BlobServiceClient.from_connection_string( test_credential ) else: @@ -94,7 +105,7 @@ def __post_init__(self): ManagedIdentityCredential(), EnvironmentCredential(), ) - self.blob_account_client = BlobServiceClient( + self.bsc = BlobServiceClient( endpoint_url, credential=ChainedTokenCredential(*credential_chain) ) @@ -179,9 +190,7 @@ def list_objects(self) -> Iterable[str]: This is optional and can raise a NotImplementedError() instead. """ - cc = self.blob_account_client.get_container_client( - self.get_storage_container_name() - ) + cc = self.bsc.get_container_client(self.get_storage_container_name()) return [o for o in cc.list_blob_names()] @@ -199,22 +208,22 @@ def __post_init__(self): # This is optional and can be removed if not needed. # Alternatively, you can e.g. prepare a connection to your storage backend here. # and set additional attributes. - self.blob_account_client: BlobServiceClient = self.provider.blob_account_client + self.bsc: BlobServiceClient = self.provider.bsc if self.is_valid_query(): parsed = urlparse(self.query) self.container_name = parsed.netloc self.blob_path = parsed.path.lstrip("/") self._local_suffix = self._local_suffix_from_key(self.blob_path) + self._is_dir = None def container_client(self) -> ContainerClient: """Return initialized ContainerClient.""" - return self.blob_account_client.get_container_client(self.container_name) + return self.bsc.get_container_client(self.container_name) - def blob_client(self) -> BlobClient: + def blob_client(self, blob_path=None) -> BlobClient: """Return initialized BlobClient.""" - return self.blob_account_client.get_blob_client( - self.container_name, self.blob_path - ) + path = blob_path if blob_path else self.blob_path + return self.bsc.get_blob_client(self.container_name, path) async def inventory(self, cache: IOCacheStorageInterface): """From this file, try to find as much existence and modification date @@ -241,6 +250,21 @@ async def inventory(self, cache: IOCacheStorageInterface): cache.size[key] = o.size cache.exists_remote[key] = True + def is_dir(self): + """Return True if the query has blobs under it's prefix.""" + if self._is_dir is None: + self._is_dir = any(self.get_prefix_blobs()) + return self._is_dir + + def get_prefix_blobs(self) -> Iterable[BlobProperties]: + """Return an iterator of objects in the storage that match the query prefix.""" + prefix = self.blob_path + "/" + return ( + item + for item in self.container_client().list_blobs(name_starts_with=prefix) + if item.name != prefix + ) + def get_inventory_parent(self) -> Optional[str]: """Return the parent directory of this object.""" # this is optional and can be left as is @@ -262,25 +286,45 @@ def cleanup(self): # provided by snakemake-interface-storage-plugins. @retry_decorator def exists(self) -> bool: - """Return True if the object exists.""" + """Return True if the object exists, or if the container exists and the path + is a directory, otherwise false""" if not self.container_client().exists(): return False - else: - return self.blob_client().exists() + + # the blob is a directory + if self.container_client().exists() and self._is_dir: + return True + + return self.blob_client().exists() @retry_decorator def mtime(self) -> float: """Returns the modification time.""" + if self.is_dir(): + return max( + item.last_modified.timestamp() for item in self.get_prefix_blobs() + ) return self.blob_client().get_blob_properties().last_modified.timestamp() @retry_decorator def size(self) -> int: """Returns the size in bytes.""" + if self.is_dir(): + return sum(item.size for item in self.get_prefix_blobs()) return self.blob_client().get_blob_properties().size @retry_decorator def retrieve_object(self): - self.download_blob_from_storage() + if self.is_dir(): + self.local_path().mkdir(parents=True, exist_ok=True) + for item in self.get_prefix_blobs(): + name = item.name[len(self.blob_path) :].lstrip("/") + local_path = self.local_path() / name + local_path.parent.mkdir(parents=True, exist_ok=True) + self.download_blob_from_storage(item.name, local_path) + else: + self.download_blob_from_storage() + # Ensure that the object is accessible locally under self.local_path() if not self.local_path().exists(): raise FileNotFoundError( @@ -294,27 +338,46 @@ def store_object(self): """ Stores the local object in cloud storage. + If the local object is a directory, the directory is uploaded to the storage. + If the storage container does not exist, it is created. This check creates the dependency that one must provide a credential with container create permissions. """ if not self.container_client().exists(): - self.blob_account_client.create_container(self.container_name) - - # Ensure that the object is stored at the location - # specified by self.local_path(). - if self.local_path().exists(): - self.upload_blob_to_storage() - - def upload_blob_to_storage(self): - """Uploads the blob to storage, opening a connection and streaming the bytes.""" - with open(self.local_path(), "rb") as data: - self.blob_client().upload_blob(data, overwrite=True) + self.bsc.create_container(self.container_name) + + if self.local_path().is_dir(): + self._is_dir = True + for item in self.local_path().rglob("*"): + if item.is_file(): + path = Path(self.blob_path / item.relative_to(self.local_path())) + self.upload_blob_to_storage(item, path) + else: + # Ensure that the object is stored at the location + # specified by self.local_path(). + if self.local_path().exists(): + self.upload_blob_to_storage(self.local_path(), self.local_path()) + + def upload_blob_to_storage(self, local_path: Path = None, remote_path: Path = None): + """Uploads the file at local_path to blob to storage location remote_path, + if the file exists, opening a connection and streaming the bytes.""" + if not local_path.exists(): + raise FileNotFoundError(f"File {local_path} not found.") + + with open(str(local_path), "rb") as data: + self.blob_client(blob_path=str(remote_path)).upload_blob( + data, overwrite=True + ) - def download_blob_from_storage(self): + def download_blob_from_storage( + self, blob_path: str = None, local_path: Path = None + ): """Downloads the blob from storage, opening connection and streaming the bytes.""" - with open(self.local_path(), "wb") as data: - data.write(self.blob_client().download_blob().readall()) + file_path = self.local_path() if local_path is None else local_path + blob_path = self.blob_path if blob_path is None else blob_path + with open(str(file_path), "wb") as data: + data.write(self.blob_client(blob_path=blob_path).download_blob().readall()) @retry_decorator def remove(self): @@ -329,4 +392,13 @@ def list_candidate_matches(self) -> Iterable[str]: """Return a list of candidate matches in the storage for the query.""" # This is used by glob_wildcards() to find matches for wildcards in the query. # The method has to return concretized queries without any remaining wildcards. - ... + prefix = get_constant_prefix(self.query) + if prefix.startswith(self.container_name): + prefix = prefix[len(self.container_name) :] + return (item.key for item in self.get_prefix_blobs(prefix=prefix)) + else: + raise WorkflowError( + "S3 storage object {self.query} cannot be used to list matching " + "objects because bucket name contains a wildcard, which is not " + "supported." + ) diff --git a/tests/tests.py b/tests/tests.py index d74f277..480262a 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -17,7 +17,10 @@ class TestStorageNoSettings(TestStorageBase): __test__ = True - retrieve_only = True + retrieve_only = False + store_only = False + delete = True + files_only = False def get_query_not_existing(self, tmp_path) -> str: container = uuid.uuid4().hex