diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 2070713710c..caeaf450e23 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -232,6 +232,7 @@ class ActionObjectPointer: "__repr_str__", # pydantic "__repr_args__", # pydantic "__post_init__", # syft + "__validate_private_attrs__", # syft "id", # syft "to_mongo", # syft 🟡 TODO 23: Add composeable / inheritable object passthrough attrs "__attr_searchable__", # syft diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 6c8bc77c876..062dbc2b424 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -118,7 +118,7 @@ class UserCodeStatusCollection(SyncableSyftObject): status_dict: dict[NodeIdentity, tuple[UserCodeStatus, str]] = {} user_code_link: LinkedObject - def get_diffs(self, ext_obj: Any) -> list[AttrDiff]: + def syft_get_diffs(self, ext_obj: Any) -> list[AttrDiff]: # relative from ...service.sync.diff_state import AttrDiff diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 6a971bbfb1a..daf92ecdbcb 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -756,17 +756,17 @@ def create_and_store_twin(context: TransformContext) -> TransformContext: def infer_shape(context: TransformContext) -> TransformContext: - if context.output is not None and context.output["shape"] is None: + if context.output is None: + raise ValueError(f"{context}'s output is None. No transformation happened") + if context.output["shape"] is None: if context.obj is not None and not _is_action_data_empty(context.obj.mock): context.output["shape"] = get_shape_or_len(context.obj.mock) - else: - print(f"{context}'s output is None. No transformation happened") return context def set_data_subjects(context: TransformContext) -> TransformContext | SyftError: if context.output is None: - return SyftError(f"{context}'s output is None. No transformation happened") + raise ValueError(f"{context}'s output is None. No transformation happened") if context.node is None: return SyftError( "f{context}'s node is None, please log in. No trasformation happened" @@ -796,7 +796,7 @@ def add_default_node_uid(context: TransformContext) -> TransformContext: if context.output["node_uid"] is None and context.node is not None: context.output["node_uid"] = context.node.id else: - print(f"{context}'s output is None. No transformation happened.") + raise ValueError(f"{context}'s output is None. No transformation happened") return context diff --git a/packages/syft/src/syft/service/notification/notifications.py b/packages/syft/src/syft/service/notification/notifications.py index e708993a7e0..6df1716ed4a 100644 --- a/packages/syft/src/syft/service/notification/notifications.py +++ b/packages/syft/src/syft/service/notification/notifications.py @@ -151,7 +151,7 @@ def add_msg_creation_time(context: TransformContext) -> TransformContext: if context.output is not None: context.output["created_at"] = DateTime.now() else: - print(f"{context}'s output is None. No transformation happened.") + raise ValueError(f"{context}'s output is None. No transformation happened") return context diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index 78b3b436765..d0f8b2f7ce2 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -441,11 +441,16 @@ def apply_output( class UserOutputPolicy(OutputPolicy): __canonical_name__ = "UserOutputPolicy" + + # Do not validate private attributes of user-defined policies, User annotations can + # contain any type and throw a NameError when resolving. + __validate_private_attrs__ = False pass class UserInputPolicy(InputPolicy): __canonical_name__ = "UserInputPolicy" + __validate_private_attrs__ = False pass @@ -572,7 +577,7 @@ def generate_unique_class_name(context: TransformContext) -> TransformContext: unique_name = f"{service_class_name}_{context.credentials}_{code_hash}" context.output["unique_name"] = unique_name else: - print(f"{context}'s output is None. No transformation happened.") + raise ValueError(f"{context}'s output is None. No transformation happened") return context @@ -696,7 +701,7 @@ def compile_code(context: TransformContext) -> TransformContext: + context.output["parsed_code"] ) else: - print(f"{context}'s output is None. No transformation happened.") + raise ValueError(f"{context}'s output is None. No transformation happened") return context diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index d19ae10c6ac..cbce6600589 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -419,6 +419,7 @@ def make_id(cls, values: Any) -> Any: __attr_custom_repr__: ClassVar[list[str] | None] = ( None # show these in html repr of an object ) + __validate_private_attrs__: ClassVar[bool] = True def __syft_get_funcs__(self) -> list[tuple[str, Signature]]: funcs = print_type_cache[type(self)] @@ -577,11 +578,14 @@ def __post_init__(self) -> None: pass def _syft_set_validate_private_attrs_(self, **kwargs: Any) -> None: + if not self.__validate_private_attrs__: + return # Validate and set private attributes # https://github.com/pydantic/pydantic/issues/2105 + annotations = typing.get_type_hints(self.__class__, localns=locals()) for attr, decl in self.__private_attributes__.items(): value = kwargs.get(attr, decl.get_default()) - var_annotation = self.__annotations__.get(attr) + var_annotation = annotations.get(attr) if value is not PydanticUndefined: if var_annotation is not None: # Otherwise validate value against the variable annotation