Skip to content

Commit

Permalink
add migrations for SyftWorkerImage
Browse files Browse the repository at this point in the history
- define a new version for SyftWorker
  • Loading branch information
shubham3121 committed Sep 19, 2024
1 parent a2084be commit 5f049f0
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 1 deletion.
7 changes: 7 additions & 0 deletions packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@
"hash": "2e1365c5535fa51c22eef79f67dd6444789bc829c27881367e3050e06e2ffbfe",
"action": "remove"
}
},
"SyftWorker": {
"2": {
"version": 2,
"hash": "e996dabbb8ad4ff0bc5d19528077c11f73b9300d810735d367916e4e5b9149b6",
"action": "add"
}
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions packages/syft/src/syft/service/worker/worker_image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# stdlib

# stdlib
from collections.abc import Callable

# relative
from ...custom_worker.config import PrebuiltWorkerConfig
from ...custom_worker.config import WorkerConfig
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
from ...types.datetime import DateTime
from ...types.syft_migration import migrate
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SyftObject
Expand Down Expand Up @@ -88,3 +92,8 @@ def built_image_tag(self) -> str | None:
if self.is_built and self.image_identifier:
return self.image_identifier.full_name_with_tag
return None


@migrate(SyftWorkerImageV1, SyftWorkerImage)
def migrate_syft_worker_image_v1_to_v2() -> list[Callable]:
return [] # no migrations needed at data level, only unique and searchable attributes changed
53 changes: 52 additions & 1 deletion packages/syft/src/syft/service/worker/worker_pool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# stdlib
from collections.abc import Callable
from enum import Enum
from typing import Any
from typing import cast
Expand All @@ -14,12 +15,16 @@
from ...types.datetime import DateTime
from ...types.errors import SyftException
from ...types.result import as_result
from ...types.syft_migration import migrate
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SyftObject
from ...types.syft_object import short_uid
from ...types.transforms import TransformContext
from ...types.uid import UID
from ..response import SyftError
from .worker_image import SyftWorkerImage
from .worker_image import SyftWorkerImageV1


@serializable(canonical_name="WorkerStatus", version=1)
Expand All @@ -44,7 +49,7 @@ class WorkerHealth(Enum):


@serializable()
class SyftWorker(SyftObject):
class SyftWorkerV1(SyftObject):
__canonical_name__ = "SyftWorker"
__version__ = SYFT_OBJECT_VERSION_1

Expand All @@ -60,6 +65,36 @@ class SyftWorker(SyftObject):
"created_at",
]

id: UID
name: str
container_id: str | None = None
created_at: DateTime = DateTime.now()
healthcheck: WorkerHealth | None = None
status: WorkerStatus
image: SyftWorkerImageV1 | None = None
worker_pool_name: str
consumer_state: ConsumerState = ConsumerState.DETACHED
job_id: UID | None = None
to_be_deleted: bool = False


@serializable()
class SyftWorker(SyftObject):
__canonical_name__ = "SyftWorker"
__version__ = SYFT_OBJECT_VERSION_2

__attr_unique__ = ["name"]
__attr_searchable__ = ["name", "container_id", "to_be_deleted"]
__repr_attrs__ = [
"name",
"container_id",
"image",
"status",
"healthcheck",
"worker_pool_name",
"created_at",
]

id: UID
name: str
container_id: str | None = None
Expand Down Expand Up @@ -283,3 +318,19 @@ def _get_worker_container_status(
container_status,
SyftError(message=f"Unknown container status: {container_status}"),
)


def migrate_worker_image_v1_to_v2(context: TransformContext) -> TransformContext:
old_image = context["image"]
if isinstance(old_image, SyftWorkerImageV1):
new_image = old_image.migrate_to(
version=SYFT_OBJECT_VERSION_2,
context=context.to_server_context(),
)
context["image"] = new_image
return context


@migrate(SyftWorkerV1, SyftWorker)
def migrate_worker_v1_to_v2() -> list[Callable]:
return [migrate_worker_image_v1_to_v2]

0 comments on commit 5f049f0

Please sign in to comment.