diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index c40677221e2..8842e27e0dc 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -17,7 +17,10 @@ # third party from nacl.exceptions import BadSignatureError +from pydantic import BaseModel +from pydantic import ConfigDict from pydantic import EmailStr +from pydantic import TypeAdapter from result import OkErr from result import Result from typeguard import check_type @@ -85,6 +88,14 @@ IPYNB_BACKGROUND_PREFIXES = ["_ipy", "_repr", "__ipython", "__pydantic"] +def _has_config_dict(t: Any) -> bool: + return ( + isinstance(t, type) + and issubclass(t, BaseModel) + or hasattr(t, "__pydantic_config__") + ) + + class APIRegistry: __api_registry__: dict[tuple, SyftAPI] = OrderedDict() @@ -1321,30 +1332,25 @@ def validate_callable_args_and_kwargs( t = index_syft_by_module_name(param.annotation) else: t = param.annotation - msg = None - try: - if t is not inspect.Parameter.empty: - if isinstance(t, _GenericAlias) and type(None) in t.__args__: - success = False - for v in t.__args__: - if issubclass(v, EmailStr): - v = str - try: - check_type(value, v) # raises Exception - success = True - break # only need one to match - except Exception: # nosec - pass - if not success: - raise TypeError() - else: - check_type(value, t) # raises Exception - except TypeError: - _type_str = getattr(t, "__name__", str(t)) - msg = f"`{key}` must be of type `{_type_str}` not `{type(value).__name__}`" - if msg: - return SyftError(message=msg) + if t is not inspect.Parameter.empty: + try: + config_kw = ( + {"config": ConfigDict(arbitrary_types_allowed=True)} + if not _has_config_dict(t) + else {} + ) + + # TypeAdapter only accepts `config` arg if `t` does not + # already contain a ConfigDict + # i.e model_config in BaseModel and __pydantic_config__ in + # other types. + TypeAdapter(t, **config_kw).validate_python(value) + except Exception: + _type_str = getattr(t, "__name__", str(t)) + return SyftError( + message=f"`{key}` must be of type `{_type_str}` not `{type(value).__name__}`" + ) _valid_kwargs[key] = value