Skip to content

Commit

Permalink
update transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
koenvanderveen committed Dec 15, 2023
1 parent a8eadf8 commit e96d6d2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
19 changes: 16 additions & 3 deletions packages/syft/src/syft/store/blob_storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from ...types.syft_object import SyftObject
from ...types.transforms import drop
from ...types.transforms import make_set_default
from ...types.transforms import str_url_to_grid_url
from ...types.uid import UID

DEFAULT_TIMEOUT = 10
Expand Down Expand Up @@ -109,7 +110,7 @@ def _read_data(self, **kwargs):


@migrate(BlobRetrieval, BlobRetrievalV1)
def downgrade_blobretrival_v2_to_v1():
def downgrade_blobretrieval_v2_to_v1():
return [
drop(["syft_blob_storage_entry_id", "file_size"]),
]
Expand Down Expand Up @@ -258,21 +259,33 @@ def _read_data(self, stream=False, chunk_size=DEFAULT_CHUNK_SIZE):
return SyftError(message=f"Failed to retrieve with Error: {e}")


@migrate(BlobRetrievalByURL, BlobRetrievalByURLV1)
@migrate(BlobRetrievalByURLV2, BlobRetrievalByURLV1)
def downgrade_blobretrivalbyurl_v2_to_v1():
return [
drop(["syft_blob_storage_entry_id", "file_size"]),
]


@migrate(BlobRetrievalByURLV1, BlobRetrievalByURL)
@migrate(BlobRetrievalByURLV1, BlobRetrievalByURLV2)
def upgrade_blobretrivalbyurl_v1_to_v2():
return [
make_set_default("syft_blob_storage_entry_id", None),
make_set_default("file_size", 1),
]


@migrate(BlobRetrievalByURL, BlobRetrievalByURLV2)
def downgrade_blobretrivalbyurl_v3_to_v2():
return [
str_url_to_grid_url,
]


@migrate(BlobRetrievalByURLV2, BlobRetrievalByURL)
def upgrade_blobretrivalbyurl_v2_to_v3():
return []


@serializable()
class BlobDeposit(SyftObject):
__canonical_name__ = "BlobDeposit"
Expand Down
7 changes: 7 additions & 0 deletions packages/syft/src/syft/types/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ def validate_email(context: TransformContext) -> TransformContext:
return context


def str_url_to_grid_url(context: TransformContext) -> TransformContext:
url = context.output.get("url", None)
if url is not None and isinstance(url, str):
context.output["url"] = GridURL.from_url(str)
return context


def add_credentials_for_key(key: str) -> Callable:
def add_credentials(context: TransformContext) -> TransformContext:
context.output[key] = context.credentials
Expand Down

0 comments on commit e96d6d2

Please sign in to comment.