Skip to content

Commit

Permalink
Merge branch 'dev' into shubham/fix-kaniko-build
Browse files Browse the repository at this point in the history
  • Loading branch information
shubham3121 authored Aug 12, 2024
2 parents ef517cf + f859665 commit a61089f
Show file tree
Hide file tree
Showing 14 changed files with 347 additions and 133 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
certifi>=2024.7.4 # not directly required, pinned by Snyk to avoid a vulnerability
idna>=3.7 # not directly required, pinned by Snyk to avoid a vulnerability
ipython==8.10.0
jinja2>=3.1.4 # not directly required, pinned by Snyk to avoid a vulnerability
markupsafe==2.0.1
Expand All @@ -12,4 +13,3 @@ sphinx-code-include==1.1.1
sphinx-copybutton==0.4.0
sphinx-panels==0.6.0
urllib3>=2.2.2 # not directly required, pinned by Snyk to avoid a vulnerability
idna>=3.7 # not directly required, pinned by Snyk to avoid a vulnerability
65 changes: 37 additions & 28 deletions packages/syft/src/syft/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import types
from typing import Any
from typing import TYPE_CHECKING
from typing import _GenericAlias
from typing import cast
from typing import get_args
from typing import get_origin
Expand All @@ -19,12 +18,9 @@
from nacl.exceptions import BadSignatureError
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import EmailStr
from pydantic import TypeAdapter
from result import OkErr
from result import Result
from typeguard import TypeCheckError
from typeguard import check_type

# relative
from ..abstract_server import AbstractServer
Expand All @@ -47,6 +43,8 @@
from ..service.response import SyftSuccess
from ..service.service import UserLibConfigRegistry
from ..service.service import UserServiceConfigRegistry
from ..service.service import _format_signature
from ..service.service import _signature_error_message
from ..service.user.user_roles import ServiceRole
from ..service.warnings import APIEndpointWarning
from ..service.warnings import WarningContext
Expand Down Expand Up @@ -97,6 +95,23 @@ def _has_config_dict(t: Any) -> bool:
)


_config_dict = ConfigDict(arbitrary_types_allowed=True)


def _check_type(v: object, t: Any) -> Any:
# TypeAdapter only accepts `config` arg if `t` does not
# already contain a ConfigDict
# i.e model_config in BaseModel and __pydantic_config__ in
# other types.
type_adapter = (
TypeAdapter(t, config=_config_dict)
if not _has_config_dict(t)
else TypeAdapter(t)
)

return type_adapter.validate_python(v)


class APIRegistry:
__api_registry__: dict[tuple, SyftAPI] = OrderedDict()

Expand Down Expand Up @@ -1308,7 +1323,10 @@ def validate_callable_args_and_kwargs(
for key, value in kwargs.items():
if key not in signature.parameters:
return SyftError(
message=f"""Invalid parameter: `{key}`. Valid Parameters: {list(signature.parameters)}"""
message=(
f"Invalid parameter: `{key}`.\n"
f"{_signature_error_message(_format_signature(signature))}"
)
)
param = signature.parameters[key]
if isinstance(param.annotation, str):
Expand All @@ -1320,21 +1338,15 @@ def validate_callable_args_and_kwargs(

if t is not inspect.Parameter.empty:
try:
config_kw = (
{"config": ConfigDict(arbitrary_types_allowed=True)}
if not _has_config_dict(t)
else {}
)

# TypeAdapter only accepts `config` arg if `t` does not
# already contain a ConfigDict
# i.e model_config in BaseModel and __pydantic_config__ in
# other types.
TypeAdapter(t, **config_kw).validate_python(value)
except Exception:
_check_type(value, t)
except ValueError:
_type_str = getattr(t, "__name__", str(t))

return SyftError(
message=f"`{key}` must be of type `{_type_str}` not `{type(value).__name__}`"
message=(
f"`{key}` must be of type `{_type_str}` not `{type(value).__name__}`\n"
f"{_signature_error_message(_format_signature(signature))}"
)
)

_valid_kwargs[key] = value
Expand All @@ -1353,15 +1365,8 @@ def validate_callable_args_and_kwargs(
msg = None
try:
if t is not inspect.Parameter.empty:
if isinstance(t, _GenericAlias) and type(None) in t.__args__:
for v in t.__args__:
if issubclass(v, EmailStr):
v = str
check_type(arg, v) # raises Exception
break # only need one to match
else:
check_type(arg, t) # raises Exception
except TypeCheckError:
_check_type(arg, t)
except ValueError:
t_arg = type(arg)
if (
autoreload_enabled()
Expand All @@ -1372,7 +1377,11 @@ def validate_callable_args_and_kwargs(
pass
else:
_type_str = getattr(t, "__name__", str(t))
msg = f"Arg is `{arg}`. \nIt must be of type `{_type_str}`, not `{type(arg).__name__}`"

msg = (
f"Arg is `{arg}`. \nIt must be of type `{_type_str}`, not `{type(arg).__name__}`\n"
f"{_signature_error_message(_format_signature(signature))}"
)

if msg:
return SyftError(message=msg)
Expand Down
32 changes: 17 additions & 15 deletions packages/syft/src/syft/service/action/action_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,21 +913,23 @@ def syft_lineage_id(self) -> LineageID:

@model_validator(mode="before")
@classmethod
def __check_action_data(cls, values: dict) -> dict:
v = values.get("syft_action_data_cache")
if values.get("syft_action_data_type", None) is None:
values["syft_action_data_type"] = type(v)
if not isinstance(v, ActionDataEmpty):
if inspect.isclass(v):
values["syft_action_data_repr_"] = truncate_str(repr_cls(v))
else:
values["syft_action_data_repr_"] = truncate_str(
v._repr_markdown_()
if v is not None and hasattr(v, "_repr_markdown_")
else v.__repr__()
)
values["syft_action_data_str_"] = truncate_str(str(v))
values["syft_has_bool_attr"] = hasattr(v, "__bool__")
def __check_action_data(cls, values: Any) -> dict:
if isinstance(values, dict):
v = values.get("syft_action_data_cache")
if values.get("syft_action_data_type", None) is None:
values["syft_action_data_type"] = type(v)
if not isinstance(v, ActionDataEmpty):
if inspect.isclass(v):
values["syft_action_data_repr_"] = truncate_str(repr_cls(v))
else:
values["syft_action_data_repr_"] = truncate_str(
v._repr_markdown_()
if v is not None and hasattr(v, "_repr_markdown_")
else v.__repr__()
)
values["syft_action_data_str_"] = truncate_str(str(v))
values["syft_has_bool_attr"] = hasattr(v, "__bool__")

return values

@property
Expand Down
84 changes: 73 additions & 11 deletions packages/syft/src/syft/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,22 @@
# stdlib
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Iterable
from copy import deepcopy
import functools
from functools import partial
from functools import reduce
import inspect
from inspect import Parameter
import logging
import operator
import types
import typing
from typing import Any
from typing import TYPE_CHECKING

# third party
from pydantic import ValidationError
from result import Ok
from result import OkErr
from typing_extensions import Self
Expand All @@ -34,6 +40,9 @@
from ..server.credentials import SyftVerifyKey
from ..store.document_store import DocumentStore
from ..store.linked_obj import LinkedObject
from ..types.syft_metaclass import Empty
from ..types.syft_metaclass import EmptyType
from ..types.syft_object import EXCLUDED_FROM_SIGNATURE
from ..types.syft_object import SYFT_OBJECT_VERSION_1
from ..types.syft_object import SyftBaseObject
from ..types.syft_object import SyftObject
Expand Down Expand Up @@ -260,16 +269,59 @@ def deconstruct_param(param: inspect.Parameter) -> dict[str, Any]:


def types_for_autosplat(signature: Signature, autosplat: list[str]) -> dict[str, type]:
autosplat_types = {}
for k, v in signature.parameters.items():
if k in autosplat:
autosplat_types[k] = v.annotation
return autosplat_types
return {k: v.annotation for k, v in signature.parameters.items() if k in autosplat}


def _check_empty_union(x: Any) -> bool:
return isinstance(
x, typing._UnionGenericAlias | types.UnionType
) and EmptyType in typing.get_args(x)


def _check_empty_parameter(p: Parameter) -> bool:
return _check_empty_union(p.annotation) and p.default is Empty


def _make_union_type(args: Iterable) -> types.UnionType:
return reduce(operator.or_, args)


def _replace_empty_parameter(p: Parameter) -> Parameter:
return Parameter(
name=p.name,
default="optional",
annotation=_make_union_type(
t for t in typing.get_args(p.annotation) if t is not EmptyType
),
kind=p.kind,
)


def _format_signature(s: inspect.Signature) -> inspect.Signature:
params = (
(_replace_empty_parameter(p) if _check_empty_parameter(p) else p)
for p in s.parameters.values()
)

return inspect.Signature(
parameters=params,
return_annotation=inspect.Signature.empty,
)


_SIGNATURE_ERROR_MESSAGE = (
"Please provide the correct arguments to the method according to this signature"
)


def _signature_error_message(s: inspect.Signature) -> str:
return f"{_SIGNATURE_ERROR_MESSAGE}\n{s}"


def reconstruct_args_kwargs(
signature: Signature,
autosplat: list[str],
expanded_signature: Signature,
args: tuple[Any, ...],
kwargs: dict[Any, str],
) -> tuple[tuple[Any, ...], dict[str, Any]]:
Expand All @@ -282,7 +334,13 @@ def reconstruct_args_kwargs(
for key in keys:
if key in kwargs:
init_kwargs[key] = kwargs.pop(key)
autosplat_objs[autosplat_key] = autosplat_type(**init_kwargs)
try:
autosplat_objs[autosplat_key] = autosplat_type(**init_kwargs)
except ValidationError:
raise TypeError(
f"Invalid argument(s) provided. "
f"{_signature_error_message(_format_signature(expanded_signature))}"
)

final_kwargs = {}
for param_key, param in signature.parameters.items():
Expand All @@ -293,7 +351,10 @@ def reconstruct_args_kwargs(
elif not isinstance(param.default, type(Parameter.empty)):
final_kwargs[param_key] = param.default
else:
raise Exception(f"Missing {param_key} not in kwargs.")
raise TypeError(
f"Missing argument {param_key}."
f"{_signature_error_message(_format_signature(expanded_signature))}"
)

if "context" in kwargs:
final_kwargs["context"] = kwargs["context"]
Expand All @@ -320,7 +381,7 @@ def expand_signature(signature: Signature, autosplat: list[str]) -> Signature:

# Reorder the parameter based on if they have default value or not
new_params = sorted(
new_mapping.values(),
(v for k, v in new_mapping.items() if k not in EXCLUDED_FROM_SIGNATURE),
key=lambda param: param.default is param.empty,
reverse=True,
)
Expand Down Expand Up @@ -354,6 +415,9 @@ def wrapper(func: Any) -> Callable:

input_signature = deepcopy(signature)

if autosplat is not None and len(autosplat) > 0:
signature = expand_signature(signature=input_signature, autosplat=autosplat)

@functools.wraps(func)
def _decorator(self: Any, *args: Any, **kwargs: Any) -> Callable:
communication_protocol = kwargs.pop("communication_protocol", None)
Expand All @@ -366,6 +430,7 @@ def _decorator(self: Any, *args: Any, **kwargs: Any) -> Callable:
args, kwargs = reconstruct_args_kwargs(
signature=input_signature,
autosplat=autosplat,
expanded_signature=signature,
args=args,
kwargs=kwargs,
)
Expand All @@ -386,9 +451,6 @@ def _decorator(self: Any, *args: Any, **kwargs: Any) -> Callable:
attach_attribute_to_syft_object(result=result, attr_dict=attrs_to_attach)
return result

if autosplat is not None and len(autosplat) > 0:
signature = expand_signature(signature=input_signature, autosplat=autosplat)

config = ServiceConfig(
public_path=_path if path is None else path,
private_path=_path,
Expand Down
13 changes: 3 additions & 10 deletions packages/syft/src/syft/service/user/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from bcrypt import hashpw
from pydantic import EmailStr
from pydantic import ValidationError
from pydantic import field_validator

# relative
from ...client.api import APIRegistry
Expand Down Expand Up @@ -188,13 +187,6 @@ class UserUpdate(PartialSyftObject):
__canonical_name__ = "UserUpdate"
__version__ = SYFT_OBJECT_VERSION_1

@field_validator("role", mode="before")
@classmethod
def str_to_role(cls, v: Any) -> Any:
if isinstance(v, str) and hasattr(ServiceRole, v.upper()):
return getattr(ServiceRole, v.upper())
return v

email: EmailStr
name: str
role: ServiceRole # make sure role cant be set without uid
Expand Down Expand Up @@ -344,14 +336,15 @@ def update(
)
if api is None:
return SyftError(message=f"You must login to {self.server_uid}")
user_update = UserUpdate(

result = api.services.user.update(
uid=self.id,
name=name,
institution=institution,
website=website,
role=role,
mock_execution_permission=mock_execution_permission,
)
result = api.services.user.update(uid=self.id, **user_update)

if isinstance(result, SyftError):
return result
Expand Down
Loading

0 comments on commit a61089f

Please sign in to comment.