Skip to content

Commit

Permalink
Merge pull request #8338 from OpenMined/azure_direct_conn
Browse files Browse the repository at this point in the history
Change URL generation Azure mounted files
  • Loading branch information
koenvanderveen authored Dec 15, 2023
2 parents 7448214 + 0130ab4 commit 62b6ccc
Show file tree
Hide file tree
Showing 15 changed files with 1,417 additions and 256 deletions.
865 changes: 865 additions & 0 deletions notebooks/helm/direct_azure.ipynb

Large diffs are not rendered by default.

406 changes: 198 additions & 208 deletions notebooks/helm/docker-helm-syft.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions packages/grid/seaweedfs/mount_command.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# 1 = remote config name
# 2 = azure account name
# 3 = seaweedfs bucket name
# 4 = azure container name
# 5 = azure key
echo "remote.configure -name=$1 -type=azure -azure.account_name=$2 -azure.account_key=$5" | weed shell && \
echo "s3.bucket.create -name=$3" | weed shell && \
echo "remote.mount -dir=/buckets/$3 -remote=$1/$4" | weed shell && \
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ syft =
pandas==1.5.3
docker==6.1.3
PyYAML==6.0.1

azure-storage-blob==12.19

install_requires =
%(syft)s
Expand Down
26 changes: 26 additions & 0 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,19 @@ def init_blob_storage(self, config: Optional[BlobStorageConfig] = None) -> None:
self.blob_store_config = config_
self.blob_storage_client = config_.client_type(config=config_.client_config)

# relative
from ..store.blob_storage.seaweedfs import SeaweedFSConfig

if isinstance(config, SeaweedFSConfig):
blob_storage_service = self.get_service(BlobStorageService)
remote_profiles = blob_storage_service.remote_profile_stash.get_all(
credentials=self.signing_key.verify_key, has_permission=True
).ok()
for remote_profile in remote_profiles:
self.blob_store_config.client_config.remote_profiles[
remote_profile.profile_name
] = remote_profile

def stop(self):
for consumer_list in self.queue_manager.consumers.values():
for c in consumer_list:
Expand Down Expand Up @@ -723,6 +736,13 @@ def init_stores(
document_store = document_store_config.store_type
self.document_store_config = document_store_config

# We add the python id of the current node in order
# to create one connection per Node object in MongoClientCache
# so that we avoid closing the connection from a
# different thread through the garbage collection
if isinstance(self.document_store_config, MongoStoreConfig):
self.document_store_config.client_config.node_obj_python_id = id(self)

self.document_store = document_store(
root_verify_key=self.verify_key,
store_config=document_store_config,
Expand All @@ -746,6 +766,12 @@ def init_stores(
root_verify_key=self.verify_key,
)
elif isinstance(action_store_config, MongoStoreConfig):
# We add the python id of the current node in order
# to create one connection per Node object in MongoClientCache
# so that we avoid closing the connection from a
# different thread through the garbage collection
action_store_config.client_config.node_obj_python_id = id(self)

self.action_store = MongoActionStore(
root_verify_key=self.verify_key, store_config=action_store_config
)
Expand Down
40 changes: 40 additions & 0 deletions packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,46 @@
"hash": "fd1067b0bb9a6e630a224162ed92f4746d4d5869bc104923ec48c4b9d597594c",
"action": "add"
}
},
"SeaweedSecureFilePathLocation": {
"2": {
"version": 2,
"hash": "3ca49db7536a33d5712485164e95406000df9af2aed78e9f9fa2bb2bbbb34fe6",
"action": "add"
}
},
"AzureSecureFilePathLocation": {
"1": {
"version": 1,
"hash": "1bb15f3f9d7082779f1c9f58de94011487924cb8a8c9c2ec18fd7c161c27fd0e",
"action": "add"
}
},
"RemoteConfig": {
"1": {
"version": 1,
"hash": "ad7bc4780a8ad52e14ce68601852c93d2fe07bda489809cad7cae786d2461754",
"action": "add"
}
},
"AzureRemoteConfig": {
"1": {
"version": 1,
"hash": "c05c6caa27db4e385c642536d4b0ecabc0c71e91220d2e6ce21a2761ca68a673",
"action": "add"
}
},
"BlobRetrievalByURL": {
"2": {
"version": 2,
"hash": "8059ee03016c4d74e408dad9529e877f91829672e0cc42d8cfff9c8e14058adc",
"action": "remove"
},
"3": {
"version": 3,
"hash": "0b664100ea08413ca4ef04665ca910c2cf9535539617ea4ba33687d05cdfe747",
"action": "add"
}
}
}
}
Expand Down
35 changes: 35 additions & 0 deletions packages/syft/src/syft/service/blob_storage/remote_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# relative
from ...serde.serializable import serializable
from ...store.document_store import BaseUIDStoreStash
from ...store.document_store import DocumentStore
from ...store.document_store import PartitionSettings
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SyftObject


@serializable()
class RemoteProfile(SyftObject):
__canonical_name__ = "RemoteConfig"
__version__ = SYFT_OBJECT_VERSION_1


@serializable()
class AzureRemoteProfile(RemoteProfile):
__canonical_name__ = "AzureRemoteConfig"
__version__ = SYFT_OBJECT_VERSION_1

profile_name: str # used by seaweedfs
account_name: str
account_key: str
container_name: str


@serializable()
class RemoteProfileStash(BaseUIDStoreStash):
object_type = RemoteProfile
settings: PartitionSettings = PartitionSettings(
name=RemoteProfile.__canonical_name__, object_type=RemoteProfile
)

def __init__(self, store: DocumentStore) -> None:
super().__init__(store=store)
49 changes: 39 additions & 10 deletions packages/syft/src/syft/service/blob_storage/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from ...store.blob_storage.seaweedfs import SeaweedFSBlobDeposit
from ...store.document_store import DocumentStore
from ...store.document_store import UIDPartitionKey
from ...types.blob_storage import AzureSecureFilePathLocation
from ...types.blob_storage import BlobFileType
from ...types.blob_storage import BlobStorageEntry
from ...types.blob_storage import BlobStorageMetadata
from ...types.blob_storage import CreateBlobStorageEntry
from ...types.blob_storage import SecureFilePathLocation
from ...types.uid import UID
from ..context import AuthedServiceContext
from ..response import SyftError
Expand All @@ -28,6 +28,8 @@
from ..service import TYPE_TO_SERVICE
from ..service import service_method
from ..user.user_roles import GUEST_ROLE_LEVEL
from .remote_profile import AzureRemoteProfile
from .remote_profile import RemoteProfileStash
from .stash import BlobStorageStash

BlobDepositType = Union[OnDiskBlobDeposit, SeaweedFSBlobDeposit]
Expand All @@ -37,10 +39,12 @@
class BlobStorageService(AbstractService):
store: DocumentStore
stash: BlobStorageStash
remote_profile_stash: RemoteProfileStash

def __init__(self, store: DocumentStore) -> None:
self.store = store
self.stash = BlobStorageStash(store=store)
self.remote_profile_stash = RemoteProfileStash(store=store)

@service_method(path="blob_storage.get_all", name="get_all")
def get_all_blob_storage_entries(
Expand All @@ -61,34 +65,53 @@ def mount_azure(
bucket_name: str,
):
# stdlib
import sys

# TODO: fix arguments

remote_name = f"{account_name}{container_name}"
remote_name = "".join(ch for ch in remote_name if ch.isalnum())
args_dict = {
"account_name": account_name,
"account_key": account_key,
"container_name": container_name,
"remote_name": f"{account_name}{container_name}",
"remote_name": remote_name,
"bucket_name": bucket_name,
}

new_profile = AzureRemoteProfile(
profile_name=remote_name,
account_name=account_name,
account_key=account_key,
container_name=container_name,
)
res = self.remote_profile_stash.set(context.credentials, new_profile)
if res.is_err():
return SyftError(message=res.value)
remote_profile = res.ok()
seaweed_config = context.node.blob_storage_client.config
# we cache this here such that we can use it when reading a file from azure
# from the remote_name
seaweed_config.remote_profiles[remote_name] = remote_profile

# TODO: possible wrap this in try catch
cfg = context.node.blob_store_config.client_config
init_request = requests.post(url=cfg.mount_url, json=args_dict) # nosec
print(init_request.content)
# TODO check return code

print(bucket_name, file=sys.stderr)

res = context.node.blob_storage_client.connect().client.list_objects(
Bucket=bucket_name
)
print(res)
# stdlib
objects = res["Contents"]
file_sizes = [object["Size"] for object in objects]
file_paths = [object["Key"] for object in objects]
secure_file_paths = [
SecureFilePathLocation(path=file_path) for file_path in file_paths
AzureSecureFilePathLocation(
path=file_path,
azure_profile_name=remote_name,
bucket_name=bucket_name,
)
for file_path in file_paths
]

for sfp, file_size in zip(secure_file_paths, file_sizes):
Expand All @@ -112,12 +135,17 @@ def get_files_from_bucket(self, context: AuthedServiceContext, bucket_name: str)
return result
bse_list = result.ok()
# stdlib
import sys

print(bse_list, file=sys.stderr)
blob_files = []
for bse in bse_list:
self.stash.set(obj=bse, credentials=context.credentials)
# We create an empty ActionObject and set its blob_storage_entry_id to bse.id
# such that we can call reload_cache which creates
# the BlobRetrieval (user needs permission to do this)
# This could be a BlobRetrievalByURL that creates a BlobFile
# and then sets it in the cache (it does not contain the data, only the BlobFile).
# In the client, when reading the file, we will creates **another**, blobretrieval
# object to read the actual data
blob_file = ActionObject.empty()
blob_file.syft_blob_storage_entry_id = bse.id
blob_file.syft_client_verify_key = context.credentials
Expand Down Expand Up @@ -146,6 +174,7 @@ def get_blob_storage_metadata_by_uid(
return blob_storage_entry.to(BlobStorageMetadata)
return SyftError(message=result.err())

# TODO: replace name with `create_blob_retrieval`
@service_method(
path="blob_storage.read",
name="read",
Expand Down
5 changes: 4 additions & 1 deletion packages/syft/src/syft/service/queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,10 @@ def handle_message(message: bytes):

credentials = queue_item.syft_client_verify_key

job_item = worker.job_stash.get_by_uid(credentials, queue_item.job_id).ok()
res = worker.job_stash.get_by_uid(credentials, queue_item.job_id)
if res.is_err():
raise Exception(res.value)
job_item = res.ok()

queue_item.status = Status.PROCESSING
queue_item.node_uid = worker.id
Expand Down
Loading

0 comments on commit 62b6ccc

Please sign in to comment.