Skip to content

Commit

Permalink
directory support
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevc committed Aug 4, 2024
1 parent cfd0650 commit 7862d94
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 33 deletions.
136 changes: 104 additions & 32 deletions snakemake_storage_plugin_azure/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterable, List, Optional
from urllib.parse import urlparse

Expand All @@ -9,9 +10,19 @@
EnvironmentCredential,
ManagedIdentityCredential,
)
from azure.storage.blob import BlobClient, BlobServiceClient, ContainerClient
from azure.storage.blob import (
BlobClient,
BlobProperties,
BlobServiceClient,
ContainerClient,
)
from snakemake_interface_common.exceptions import WorkflowError
from snakemake_interface_storage_plugins.common import Operation
from snakemake_interface_storage_plugins.io import IOCacheStorageInterface, Mtime
from snakemake_interface_storage_plugins.io import (
IOCacheStorageInterface,
Mtime,
get_constant_prefix,
)
from snakemake_interface_storage_plugins.settings import StorageProviderSettingsBase
from snakemake_interface_storage_plugins.storage_object import (
StorageObjectGlob,
Expand Down Expand Up @@ -82,7 +93,7 @@ def __post_init__(self):
# use mock storage credential for tests
test_credential = os.getenv("AZURITE_CONNECTION_STRING")
if test_credential:
self.blob_account_client = BlobServiceClient.from_connection_string(
self.bsc: BlobServiceClient = BlobServiceClient.from_connection_string(
test_credential
)
else:
Expand All @@ -94,7 +105,7 @@ def __post_init__(self):
ManagedIdentityCredential(),
EnvironmentCredential(),
)
self.blob_account_client = BlobServiceClient(
self.bsc = BlobServiceClient(
endpoint_url, credential=ChainedTokenCredential(*credential_chain)
)

Expand Down Expand Up @@ -179,9 +190,7 @@ def list_objects(self) -> Iterable[str]:
This is optional and can raise a NotImplementedError() instead.
"""
cc = self.blob_account_client.get_container_client(
self.get_storage_container_name()
)
cc = self.bsc.get_container_client(self.get_storage_container_name())
return [o for o in cc.list_blob_names()]


Expand All @@ -199,22 +208,22 @@ def __post_init__(self):
# This is optional and can be removed if not needed.
# Alternatively, you can e.g. prepare a connection to your storage backend here.
# and set additional attributes.
self.blob_account_client: BlobServiceClient = self.provider.blob_account_client
self.bsc: BlobServiceClient = self.provider.bsc
if self.is_valid_query():
parsed = urlparse(self.query)
self.container_name = parsed.netloc
self.blob_path = parsed.path.lstrip("/")
self._local_suffix = self._local_suffix_from_key(self.blob_path)
self._is_dir = None

def container_client(self) -> ContainerClient:
"""Return initialized ContainerClient."""
return self.blob_account_client.get_container_client(self.container_name)
return self.bsc.get_container_client(self.container_name)

def blob_client(self) -> BlobClient:
def blob_client(self, blob_path=None) -> BlobClient:
"""Return initialized BlobClient."""
return self.blob_account_client.get_blob_client(
self.container_name, self.blob_path
)
path = blob_path if blob_path else self.blob_path
return self.bsc.get_blob_client(self.container_name, path)

async def inventory(self, cache: IOCacheStorageInterface):
"""From this file, try to find as much existence and modification date
Expand All @@ -241,6 +250,21 @@ async def inventory(self, cache: IOCacheStorageInterface):
cache.size[key] = o.size
cache.exists_remote[key] = True

def is_dir(self):
"""Return True if the query has blobs under it's prefix."""
if self._is_dir is None:
self._is_dir = any(self.get_prefix_blobs())
return self._is_dir

def get_prefix_blobs(self) -> Iterable[BlobProperties]:
"""Return an iterator of objects in the storage that match the query prefix."""
prefix = self.blob_path + "/"
return (
item
for item in self.container_client().list_blobs(name_starts_with=prefix)
if item.name != prefix
)

def get_inventory_parent(self) -> Optional[str]:
"""Return the parent directory of this object."""
# this is optional and can be left as is
Expand All @@ -262,25 +286,45 @@ def cleanup(self):
# provided by snakemake-interface-storage-plugins.
@retry_decorator
def exists(self) -> bool:
"""Return True if the object exists."""
"""Return True if the object exists, or if the container exists and the path
is a directory, otherwise false"""
if not self.container_client().exists():
return False
else:
return self.blob_client().exists()

# the blob is a directory
if self.container_client().exists() and self._is_dir:
return True

return self.blob_client().exists()

@retry_decorator
def mtime(self) -> float:
"""Returns the modification time."""
if self.is_dir():
return max(
item.last_modified.timestamp() for item in self.get_prefix_blobs()
)
return self.blob_client().get_blob_properties().last_modified.timestamp()

@retry_decorator
def size(self) -> int:
"""Returns the size in bytes."""
if self.is_dir():
return sum(item.size for item in self.get_prefix_blobs())
return self.blob_client().get_blob_properties().size

@retry_decorator
def retrieve_object(self):
self.download_blob_from_storage()
if self.is_dir():
self.local_path().mkdir(parents=True, exist_ok=True)
for item in self.get_prefix_blobs():
name = item.name[len(self.blob_path) :].lstrip("/")
local_path = self.local_path() / name
local_path.parent.mkdir(parents=True, exist_ok=True)
self.download_blob_from_storage(item.name, local_path)
else:
self.download_blob_from_storage()

# Ensure that the object is accessible locally under self.local_path()
if not self.local_path().exists():
raise FileNotFoundError(
Expand All @@ -294,27 +338,46 @@ def store_object(self):
"""
Stores the local object in cloud storage.
If the local object is a directory, the directory is uploaded to the storage.
If the storage container does not exist, it is created. This check creates the
dependency that one must provide a credential with container create permissions.
"""
if not self.container_client().exists():
self.blob_account_client.create_container(self.container_name)

# Ensure that the object is stored at the location
# specified by self.local_path().
if self.local_path().exists():
self.upload_blob_to_storage()

def upload_blob_to_storage(self):
"""Uploads the blob to storage, opening a connection and streaming the bytes."""
with open(self.local_path(), "rb") as data:
self.blob_client().upload_blob(data, overwrite=True)
self.bsc.create_container(self.container_name)

if self.local_path().is_dir():
self._is_dir = True
for item in self.local_path().rglob("*"):
if item.is_file():
path = Path(self.blob_path / item.relative_to(self.local_path()))
self.upload_blob_to_storage(item, path)
else:
# Ensure that the object is stored at the location
# specified by self.local_path().
if self.local_path().exists():
self.upload_blob_to_storage(self.local_path(), self.local_path())

def upload_blob_to_storage(self, local_path: Path = None, remote_path: Path = None):
"""Uploads the file at local_path to blob to storage location remote_path,
if the file exists, opening a connection and streaming the bytes."""
if not local_path.exists():
raise FileNotFoundError(f"File {local_path} not found.")

with open(str(local_path), "rb") as data:
self.blob_client(blob_path=str(remote_path)).upload_blob(
data, overwrite=True
)

def download_blob_from_storage(self):
def download_blob_from_storage(
self, blob_path: str = None, local_path: Path = None
):
"""Downloads the blob from storage,
opening connection and streaming the bytes."""
with open(self.local_path(), "wb") as data:
data.write(self.blob_client().download_blob().readall())
file_path = self.local_path() if local_path is None else local_path
blob_path = self.blob_path if blob_path is None else blob_path
with open(str(file_path), "wb") as data:
data.write(self.blob_client(blob_path=blob_path).download_blob().readall())

@retry_decorator
def remove(self):
Expand All @@ -329,4 +392,13 @@ def list_candidate_matches(self) -> Iterable[str]:
"""Return a list of candidate matches in the storage for the query."""
# This is used by glob_wildcards() to find matches for wildcards in the query.
# The method has to return concretized queries without any remaining wildcards.
...
prefix = get_constant_prefix(self.query)
if prefix.startswith(self.container_name):
prefix = prefix[len(self.container_name) :]
return (item.key for item in self.get_prefix_blobs(prefix=prefix))
else:
raise WorkflowError(
"S3 storage object {self.query} cannot be used to list matching "
"objects because bucket name contains a wildcard, which is not "
"supported."
)
5 changes: 4 additions & 1 deletion tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

class TestStorageNoSettings(TestStorageBase):
__test__ = True
retrieve_only = True
retrieve_only = False
store_only = False
delete = True
files_only = False

def get_query_not_existing(self, tmp_path) -> str:
container = uuid.uuid4().hex
Expand Down

0 comments on commit 7862d94

Please sign in to comment.