Skip to content

Commit

Permalink
Merge pull request #8570 from OpenMined/eelco/pydantic-warnings
Browse files Browse the repository at this point in the history
fix warnings in dataset transform + privateattr check
  • Loading branch information
eelcovdw authored Mar 14, 2024
2 parents e7e90cf + a632dca commit 2b2d9c2
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 10 deletions.
1 change: 1 addition & 0 deletions packages/syft/src/syft/service/action/action_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ class ActionObjectPointer:
"__repr_str__", # pydantic
"__repr_args__", # pydantic
"__post_init__", # syft
"__validate_private_attrs__", # syft
"id", # syft
"to_mongo", # syft 🟡 TODO 23: Add composeable / inheritable object passthrough attrs
"__attr_searchable__", # syft
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class UserCodeStatusCollection(SyncableSyftObject):
status_dict: dict[NodeIdentity, tuple[UserCodeStatus, str]] = {}
user_code_link: LinkedObject

def get_diffs(self, ext_obj: Any) -> list[AttrDiff]:
def syft_get_diffs(self, ext_obj: Any) -> list[AttrDiff]:
# relative
from ...service.sync.diff_state import AttrDiff

Expand Down
10 changes: 5 additions & 5 deletions packages/syft/src/syft/service/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,17 +756,17 @@ def create_and_store_twin(context: TransformContext) -> TransformContext:


def infer_shape(context: TransformContext) -> TransformContext:
if context.output is not None and context.output["shape"] is None:
if context.output is None:
raise ValueError(f"{context}'s output is None. No transformation happened")
if context.output["shape"] is None:
if context.obj is not None and not _is_action_data_empty(context.obj.mock):
context.output["shape"] = get_shape_or_len(context.obj.mock)
else:
print(f"{context}'s output is None. No transformation happened")
return context


def set_data_subjects(context: TransformContext) -> TransformContext | SyftError:
if context.output is None:
return SyftError(f"{context}'s output is None. No transformation happened")
raise ValueError(f"{context}'s output is None. No transformation happened")
if context.node is None:
return SyftError(
"f{context}'s node is None, please log in. No trasformation happened"
Expand Down Expand Up @@ -796,7 +796,7 @@ def add_default_node_uid(context: TransformContext) -> TransformContext:
if context.output["node_uid"] is None and context.node is not None:
context.output["node_uid"] = context.node.id
else:
print(f"{context}'s output is None. No transformation happened.")
raise ValueError(f"{context}'s output is None. No transformation happened")
return context


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def add_msg_creation_time(context: TransformContext) -> TransformContext:
if context.output is not None:
context.output["created_at"] = DateTime.now()
else:
print(f"{context}'s output is None. No transformation happened.")
raise ValueError(f"{context}'s output is None. No transformation happened")
return context


Expand Down
9 changes: 7 additions & 2 deletions packages/syft/src/syft/service/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,11 +441,16 @@ def apply_output(

class UserOutputPolicy(OutputPolicy):
__canonical_name__ = "UserOutputPolicy"

# Do not validate private attributes of user-defined policies, User annotations can
# contain any type and throw a NameError when resolving.
__validate_private_attrs__ = False
pass


class UserInputPolicy(InputPolicy):
__canonical_name__ = "UserInputPolicy"
__validate_private_attrs__ = False
pass


Expand Down Expand Up @@ -572,7 +577,7 @@ def generate_unique_class_name(context: TransformContext) -> TransformContext:
unique_name = f"{service_class_name}_{context.credentials}_{code_hash}"
context.output["unique_name"] = unique_name
else:
print(f"{context}'s output is None. No transformation happened.")
raise ValueError(f"{context}'s output is None. No transformation happened")

return context

Expand Down Expand Up @@ -696,7 +701,7 @@ def compile_code(context: TransformContext) -> TransformContext:
+ context.output["parsed_code"]
)
else:
print(f"{context}'s output is None. No transformation happened.")
raise ValueError(f"{context}'s output is None. No transformation happened")

return context

Expand Down
6 changes: 5 additions & 1 deletion packages/syft/src/syft/types/syft_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def make_id(cls, values: Any) -> Any:
__attr_custom_repr__: ClassVar[list[str] | None] = (
None # show these in html repr of an object
)
__validate_private_attrs__: ClassVar[bool] = True

def __syft_get_funcs__(self) -> list[tuple[str, Signature]]:
funcs = print_type_cache[type(self)]
Expand Down Expand Up @@ -577,11 +578,14 @@ def __post_init__(self) -> None:
pass

def _syft_set_validate_private_attrs_(self, **kwargs: Any) -> None:
if not self.__validate_private_attrs__:
return
# Validate and set private attributes
# https://github.com/pydantic/pydantic/issues/2105
annotations = typing.get_type_hints(self.__class__, localns=locals())
for attr, decl in self.__private_attributes__.items():
value = kwargs.get(attr, decl.get_default())
var_annotation = self.__annotations__.get(attr)
var_annotation = annotations.get(attr)
if value is not PydanticUndefined:
if var_annotation is not None:
# Otherwise validate value against the variable annotation
Expand Down

0 comments on commit 2b2d9c2

Please sign in to comment.