Skip to content

Commit

Permalink
Merge pull request #8642 from OpenMined/fix-partial-object
Browse files Browse the repository at this point in the history
Add Empty annotation during class creation for subclasses of PartialModelMetaclass
  • Loading branch information
shubham3121 authored Apr 1, 2024
2 parents 197a4f7 + f8b40af commit cd21efb
Show file tree
Hide file tree
Showing 15 changed files with 118 additions and 58 deletions.
6 changes: 6 additions & 0 deletions packages/syft/src/syft/protocol/data_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def diff_state(self, state: dict) -> tuple[dict, dict]:
+ "Is a unique __canonical_name__ for this subclass missing? "
+ "If the class has changed you will need to define a new class with the changes, "
+ "with same __canonical_name__ and bump the __version__ number."
+ f"{cls.model_fields}"
)

if get_dev_mode() or self.raise_exception:
Expand Down Expand Up @@ -503,6 +504,11 @@ def has_dev(self) -> bool:
return True
return False

def reset_dev_protocol(self) -> None:
if self.has_dev:
del self.protocol_history["dev"]
self.save_history(self.protocol_history)


def get_data_protocol(raise_exception: bool = False) -> DataProtocol:
return DataProtocol(
Expand Down
52 changes: 52 additions & 0 deletions packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,57 @@
},
"4": {
"release_name": "0.8.5.post1.json"
},
"dev": {
"object_versions": {
"PartialSyftObject": {
"2": {
"version": 2,
"hash": "4182684fe6b6a09901f79ebbbef533044725d7c330722ffe01f4e6d8cc81e0ae",
"action": "remove"
},
"3": {
"version": 3,
"hash": "3db9ad277ff4f29041379da7adb31dfcffa0487c30bcdc6f57a932d7509f7a26",
"action": "add"
}
},
"UserUpdate": {
"3": {
"version": 3,
"hash": "2a2feb8f1b5b57bf9dec3bea3874a2b77dbc1be88d0ceb2f120c92a7af5f7ec8",
"action": "remove"
},
"4": {
"version": 4,
"hash": "a4313c229e8b374d748292b5a12093328cecb0653e03e17d98b4bedb6d0728cd",
"action": "add"
}
},
"UserSearch": {
"2": {
"version": 2,
"hash": "529a5874946f4b8c1a1fa74034000db8fc3a348e488a80c1f02d8ed1cc8aec3a",
"action": "remove"
},
"3": {
"version": 3,
"hash": "3f0c4f9117702f76b70645c3b7c2a514982f33f8d3203030a4ca3628e7120408",
"action": "add"
}
},
"NodeSettingsUpdate": {
"2": {
"version": 2,
"hash": "88775f18141f0eb29342566bdd199c359a13db0a0125e3b8386b10dbf11ab32e",
"action": "remove"
},
"3": {
"version": 3,
"hash": "bcdf3cab07728978a49b483ff046380fb3614904d2f8b5b7239fd0434e6c0465",
"action": "add"
}
}
}
}
}
2 changes: 1 addition & 1 deletion packages/syft/src/syft/protocol/releases/0.8.5.post1.json
Original file line number Diff line number Diff line change
Expand Up @@ -1668,4 +1668,4 @@
}
}
}
}
}
3 changes: 2 additions & 1 deletion packages/syft/src/syft/service/action/action_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ...types.datetime import DateTime
from ...types.syft_object import PartialSyftObject
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SYFT_OBJECT_VERSION_3
from ...types.syft_object import SyftObject
from ...types.uid import UID
from .action_object import Action
Expand Down Expand Up @@ -108,7 +109,7 @@ def __repr__(self) -> str:
@serializable()
class NodeActionDataUpdate(PartialSyftObject):
__canonical_name__ = "NodeActionDataUpdate"
__version__ = SYFT_OBJECT_VERSION_2
__version__ = SYFT_OBJECT_VERSION_3

id: UID
type: NodeType
Expand Down
3 changes: 1 addition & 2 deletions packages/syft/src/syft/service/settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from ...node.credentials import SyftVerifyKey
from ...serde.serializable import serializable
from ...types.syft_object import PartialSyftObject
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SYFT_OBJECT_VERSION_3
from ...types.syft_object import SyftObject
from ...types.uid import UID
Expand All @@ -15,7 +14,7 @@
@serializable()
class NodeSettingsUpdate(PartialSyftObject):
__canonical_name__ = "NodeSettingsUpdate"
__version__ = SYFT_OBJECT_VERSION_2
__version__ = SYFT_OBJECT_VERSION_3

id: UID
name: str
Expand Down
5 changes: 3 additions & 2 deletions packages/syft/src/syft/service/user/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ...types.syft_object import PartialSyftObject
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SYFT_OBJECT_VERSION_3
from ...types.syft_object import SYFT_OBJECT_VERSION_4
from ...types.syft_object import SyftObject
from ...types.transforms import TransformContext
from ...types.transforms import drop
Expand Down Expand Up @@ -116,7 +117,7 @@ def check_pwd(password: str, hashed_password: str) -> bool:
@serializable()
class UserUpdate(PartialSyftObject):
__canonical_name__ = "UserUpdate"
__version__ = SYFT_OBJECT_VERSION_3
__version__ = SYFT_OBJECT_VERSION_4

@field_validator("role", mode="before")
@classmethod
Expand Down Expand Up @@ -158,7 +159,7 @@ class UserCreate(SyftObject):
@serializable()
class UserSearch(PartialSyftObject):
__canonical_name__ = "UserSearch"
__version__ = SYFT_OBJECT_VERSION_2
__version__ = SYFT_OBJECT_VERSION_3

id: UID
email: EmailStr
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/service/user/user_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def update(

edits_non_role_attrs = any(
getattr(user_update, attr) is not Empty
for attr in user_update.dict()
for attr in user_update.to_dict()
if attr != "role"
)

Expand Down
14 changes: 11 additions & 3 deletions packages/syft/src/syft/types/syft_metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,20 @@ class Empty(metaclass=EmptyType):


class PartialModelMetaclass(ModelMetaclass):
def __call__(cls: type[_T], *args: Any, **kwargs: Any) -> _T:
def __new__(
mcs,
cls_name: str,
bases: tuple[type[Any], ...],
namespace: dict[str, Any],
*args: Any,
**kwargs: Any,
) -> type:
cls = super().__new__(mcs, cls_name, bases, namespace, *args, **kwargs)

for field_info in cls.model_fields.values():
if field_info.annotation is not None and field_info.is_required():
field_info.annotation = field_info.annotation | EmptyType
field_info.default = Empty

cls.model_rebuild(force=True)

return super().__call__(*args, **kwargs) # type: ignore[misc]
return cls
33 changes: 12 additions & 21 deletions packages/syft/src/syft/types/syft_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from typing import Union
from typing import get_args
from typing import get_origin
import warnings

# third party
import pandas as pd
Expand Down Expand Up @@ -554,25 +553,17 @@ def to(self, projection: type, context: Context | None = None) -> Any:
def to_dict(
self, exclude_none: bool = False, exclude_empty: bool = False
) -> dict[str, Any]:
warnings.warn(
"`SyftObject.to_dict` is deprecated and will be removed in a future version",
PendingDeprecationWarning,
stacklevel=2,
)
# 🟡 TODO 18: Remove to_dict and replace usage with transforms etc
if not exclude_none and not exclude_empty:
return self.dict()
else:
new_dict = {}
for k, v in dict(self).items():
# exclude dynamically added syft attributes
if k in DYNAMIC_SYFT_ATTRIBUTES:
continue
if exclude_empty and v is not Empty:
new_dict[k] = v
if exclude_none and v is not None:
new_dict[k] = v
return new_dict
new_dict = {}
for k, v in dict(self).items():
# exclude dynamically added syft attributes
if k in DYNAMIC_SYFT_ATTRIBUTES:
continue
if exclude_empty and v is Empty:
continue
if exclude_none and v is None:
continue
new_dict[k] = v
return new_dict

def __post_init__(self) -> None:
pass
Expand Down Expand Up @@ -957,7 +948,7 @@ class PartialSyftObject(SyftObject, metaclass=PartialModelMetaclass):
"""Syft Object to which partial arguments can be provided."""

__canonical_name__ = "PartialSyftObject"
__version__ = SYFT_OBJECT_VERSION_2
__version__ = SYFT_OBJECT_VERSION_3

def __iter__(self) -> TupleGenerator:
yield from ((k, v) for k, v in super().__iter__() if v is not Empty)
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def stage_protocol(protocol_file: Path):
stage_protocol_changes()
# bump_protocol_version()
yield dp.protocol_history
dp.revert_latest_protocol()
dp.reset_dev_protocol()
dp.save_history(dp.protocol_history)

# Cleanup release dir, remove unused released files
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/tests/syft/action_graph/action_graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_node_action_data_update() -> None:
assert len(node_action_data_update.to_dict(exclude_empty=True)) == 1
assert (
node_action_data_update.to_dict(exclude_empty=False)
== node_action_data_update.dict()
== node_action_data_update.to_dict()
)


Expand Down
4 changes: 2 additions & 2 deletions packages/syft/tests/syft/settings/settings_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ def mock_stash_get_all(root_verify_key) -> Ok:
assert response.is_ok() is True
assert len(response.ok()) == len(mock_stash_get_all_output)
assert (
updated_settings.model_dump() == new_settings.model_dump()
updated_settings.to_dict() == new_settings.to_dict()
) # the first settings is updated
assert (
not_updated_settings.model_dump() == settings.model_dump()
not_updated_settings.to_dict() == settings.to_dict()
) # the second settings is not updated


Expand Down
2 changes: 1 addition & 1 deletion packages/syft/tests/syft/stores/base_stash_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class MockStash(BaseUIDStoreStash):


def get_object_values(obj: SyftObject) -> tuple[Any]:
return tuple(obj.dict().values())
return tuple(obj.to_dict().values())


def add_mock_object(root_verify_key, stash: MockStash, obj: MockObject) -> MockObject:
Expand Down
20 changes: 11 additions & 9 deletions packages/syft/tests/syft/users/user_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def mock_get_by_email(credentials: SyftVerifyKey, email: str) -> Ok:
return Ok(None)

expected_user = guest_create_user.to(User)
expected_output = expected_user.to(UserView)
expected_output: UserView = expected_user.to(UserView)
expected_output.syft_client_verify_key = authed_context.credentials
expected_output.syft_node_location = authed_context.node.id

def mock_set(
credentials: SyftVerifyKey,
Expand Down Expand Up @@ -168,7 +170,7 @@ def mock_get_by_uid(credentials: SyftVerifyKey, uid: UID) -> Ok:
monkeypatch.setattr(user_service.stash, "get_by_uid", mock_get_by_uid)
response = user_service.view(authed_context, uid_to_view)
assert isinstance(response, UserView)
assert response.model_dump() == expected_output.model_dump()
assert response.to_dict() == expected_output.to_dict()


def test_userservice_get_all_success(
Expand All @@ -189,7 +191,7 @@ def mock_get_all(credentials: SyftVerifyKey) -> Ok:
assert isinstance(response, list)
assert len(response) == len(expected_output)
assert all(
r.model_dump() == expected.model_dump()
r.to_dict() == expected.to_dict()
for r, expected in zip(response, expected_output)
)

Expand Down Expand Up @@ -230,24 +232,24 @@ def mock_find_all(credentials: SyftVerifyKey, **kwargs) -> Ok | Err:
response = user_service.search(authed_context, id=guest_user.id)
assert isinstance(response, list)
assert all(
r.model_dump() == expected.model_dump()
r.to_dict() == expected.to_dict()
for r, expected in zip(response, expected_output)
)
# assert response.model_dump() == expected_output.model_dump()
# assert response.to_dict() == expected_output.to_dict()

# Search via email
response = user_service.search(authed_context, email=guest_user.email)
assert isinstance(response, list)
assert all(
r.model_dump() == expected.model_dump()
r.to_dict() == expected.to_dict()
for r, expected in zip(response, expected_output)
)

# Search via name
response = user_service.search(authed_context, name=guest_user.name)
assert isinstance(response, list)
assert all(
r.model_dump() == expected.model_dump()
r.to_dict() == expected.to_dict()
for r, expected in zip(response, expected_output)
)

Expand All @@ -258,7 +260,7 @@ def mock_find_all(credentials: SyftVerifyKey, **kwargs) -> Ok | Err:
)
assert isinstance(response, list)
assert all(
r.model_dump() == expected.model_dump()
r.to_dict() == expected.to_dict()
for r, expected in zip(response, expected_output)
)

Expand All @@ -268,7 +270,7 @@ def mock_find_all(credentials: SyftVerifyKey, **kwargs) -> Ok | Err:
)
assert isinstance(response, list)
assert all(
r.model_dump() == expected.model_dump()
r.to_dict() == expected.to_dict()
for r, expected in zip(response, expected_output)
)

Expand Down
Loading

0 comments on commit cd21efb

Please sign in to comment.