Skip to content

Commit

Permalink
Merge branch 'dev' into autosplat-warning
Browse files Browse the repository at this point in the history
  • Loading branch information
kiendang authored Aug 11, 2024
2 parents 2a2966e + 62e6630 commit 4b288f4
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 43 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ repos:
name: "mypy: syft"
always_run: true
files: "^packages/syft/src/syft/"
exclude: "packages/syft/src/syft/types/dicttuple.py"
args: [
"--follow-imports=skip",
"--ignore-missing-imports",
Expand Down
14 changes: 14 additions & 0 deletions packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@
"hash": "cf6c1cb55d569af9823d8541ca038806bd350450a919345244ed4f432a099f34",
"action": "add"
}
},
"DatasetPageView": {
"2": {
"version": 2,
"hash": "be1ca6dcd0b3aa0481ce5dce737e78432d06a78ad0c701aaf136be407c798352",
"action": "add"
}
},
"JobItem": {
"2": {
"version": 2,
"hash": "b087d0c62b7d304c6ca80e4fb0e8a7f2a444be8f8cba57490dc09aeb98033105",
"action": "add"
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions packages/syft/src/syft/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,7 @@ def add_queueitem_to_queue(
action=action,
requested_by=user_id,
job_type=job_type,
endpoint=queue_item.kwargs.get("path", None),
)

# 🟡 TODO 36: Needs distributed lock
Expand Down
12 changes: 10 additions & 2 deletions packages/syft/src/syft/service/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ...types.dicttuple import DictTuple
from ...types.syft_object import PartialSyftObject
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SyftObject
from ...types.transforms import TransformContext
from ...types.transforms import generate_id
Expand Down Expand Up @@ -596,7 +597,15 @@ def _check_asset_must_contain_mock(asset_list: list[CreateAsset]) -> None:

@serializable()
class DatasetPageView(SyftObject):
# version
__canonical_name__ = "DatasetPageView"
__version__ = SYFT_OBJECT_VERSION_2

datasets: DictTuple[str, Dataset]
total: int


@serializable()
class DatasetPageViewV1(SyftObject):
__canonical_name__ = "DatasetPageView"
__version__ = SYFT_OBJECT_VERSION_1

Expand All @@ -606,7 +615,6 @@ class DatasetPageView(SyftObject):

@serializable()
class CreateDataset(Dataset):
# version
__canonical_name__ = "CreateDataset"
__version__ = SYFT_OBJECT_VERSION_1
asset_list: list[CreateAsset] = []
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/service/dataset/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def search(
name: str,
page_size: int | None = 0,
page_index: int | None = 0,
) -> DatasetPageView | SyftError:
) -> DatasetPageView | DictTuple[str, Dataset] | SyftError:
"""Search a Dataset by name"""
results = self.get_all(context)

Expand Down
46 changes: 42 additions & 4 deletions packages/syft/src/syft/service/job/job_stash.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# stdlib
from collections.abc import Callable
from datetime import datetime
from datetime import timedelta
from datetime import timezone
Expand Down Expand Up @@ -31,9 +32,12 @@
from ...store.document_store import UIDPartitionKey
from ...types.datetime import DateTime
from ...types.datetime import format_timedelta
from ...types.syft_migration import migrate
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SyftObject
from ...types.syncable_object import SyncableSyftObject
from ...types.transforms import make_set_default
from ...types.uid import UID
from ...util.markdown import as_markdown_code
from ...util.telemetry import instrument
Expand Down Expand Up @@ -86,7 +90,7 @@ def __str__(self) -> str:
@serializable()
class Job(SyncableSyftObject):
__canonical_name__ = "JobItem"
__version__ = SYFT_OBJECT_VERSION_1
__version__ = SYFT_OBJECT_VERSION_2

id: UID
server_uid: UID
Expand All @@ -107,6 +111,8 @@ class Job(SyncableSyftObject):
user_code_id: UID | None = None
requested_by: UID | None = None
job_type: JobType = JobType.JOB
# used by JobType.TWINAPIJOB
endpoint: str | None = None

__attr_searchable__ = [
"parent_job_id",
Expand Down Expand Up @@ -452,9 +458,8 @@ def summary_html(self) -> str:

try:
# type_html = f'<div class="label {self.type_badge_class()}">{self.object_type_name.upper()}</div>'
description_html = (
f"<span class='syncstate-description'>{self.user_code_name}</span>"
)
job_name = self.user_code_name or self.endpoint or "Job"
description_html = f"<span class='syncstate-description'>{job_name}</span>"
worker_summary = ""
if self.job_worker_id:
worker_copy_button = CopyIDButton(
Expand Down Expand Up @@ -931,3 +936,36 @@ def get_by_user_code_id(
)

return self.query_all(credentials=credentials, qks=qks)


@serializable()
class JobV1(SyncableSyftObject):
__canonical_name__ = "JobItem"
__version__ = SYFT_OBJECT_VERSION_1

id: UID
server_uid: UID
result: Any | None = None
resolved: bool = False
status: JobStatus = JobStatus.CREATED
log_id: UID | None = None
parent_job_id: UID | None = None
n_iters: int | None = 0
current_iter: int | None = None
creation_time: str | None = Field(
default_factory=lambda: str(datetime.now(tz=timezone.utc))
)
action: Action | None = None
job_pid: int | None = None
job_worker_id: UID | None = None
updated_at: DateTime | None = None
user_code_id: UID | None = None
requested_by: UID | None = None
job_type: JobType = JobType.JOB


@migrate(JobV1, Job)
def migrate_job_update_v1_current() -> list[Callable]:
return [
make_set_default("endpoint", None),
]
9 changes: 9 additions & 0 deletions packages/syft/src/syft/service/sync/diff_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from ..job.job_stash import JobType
from ..log.log import SyftLog
from ..output.output_service import ExecutionOutput
from ..policy.policy import Constant
from ..request.request import Request
from ..response import SyftError
from ..response import SyftSuccess
Expand Down Expand Up @@ -367,6 +368,14 @@ def repr_attr_dict(self, side: str) -> dict[str, Any]:
for attr in repr_attrs:
value = getattr(obj, attr)
res[attr] = value

# if there are constants in UserCode input policy, add to repr
# type ignores since mypy thinks the code is unreachable for some reason
if isinstance(obj, UserCode) and obj.input_policy_init_kwargs is not None: # type: ignore
for input_policy_kwarg in obj.input_policy_init_kwargs.values(): # type: ignore
for input_val in input_policy_kwarg.values():
if isinstance(input_val, Constant):
res[input_val.kw] = input_val.val
return res

def diff_attributes_str(self, side: str) -> str:
Expand Down
Loading

0 comments on commit 4b288f4

Please sign in to comment.