Skip to content

Commit

Permalink
Merge pull request #8629 from OpenMined/fix_hash_forwardref
Browse files Browse the repository at this point in the history
Changed hash computation for Pydantic2 and ForwardRef
  • Loading branch information
shubham3121 authored Mar 28, 2024
2 parents 2a90040 + 9550748 commit a3be8a2
Show file tree
Hide file tree
Showing 13 changed files with 1,761 additions and 1,718 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pr-tests-hagrid.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ jobs:
if: steps.changes.outputs.hagrid == 'true'
run: |
bandit -r hagrid
safety check -i 42923 -i 54229 -i 54230 -i 54230 -i 54229 -i 62044 -i 65213
safety check -i 42923 -i 54229 -i 54230 -i 54230 -i 54229 -i 62044 -i 65213 -i 54564
- name: Run normal tests
if: steps.changes.outputs.hagrid == 'true'
Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/pr-tests-stack.yml
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,11 @@ jobs:
run: |
pip install --upgrade tox tox-uv==1.5.1
- name: Run syft backend base image building test
if: steps.changes.outputs.stack == 'true'
timeout-minutes: 60
run: |
tox -e backend.test.basecpu
# - name: Run syft backend base image building test
# if: steps.changes.outputs.stack == 'true'
# timeout-minutes: 60
# run: |
# tox -e backend.test.basecpu

pr-tests-notebook-stack:
strategy:
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/pr-tests-syft.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ jobs:
python-version: ["3.12"]
deployment-type: ["python"]
notebook-paths: ["tutorials"]
bump-version: ["False"]
include:
- python-version: "3.11"
os: "ubuntu-latest"
Expand All @@ -119,6 +120,11 @@ jobs:
os: "ubuntu-latest"
deployment-type: "python"
notebook-paths: "tutorials"
- python-version: "3.12"
os: "ubuntu-latest"
deployment-type: "python"
notebook-paths: "tutorials"
bump-version: "True"

runs-on: ${{ matrix.os }}
steps:
Expand Down Expand Up @@ -183,6 +189,7 @@ jobs:
env:
ORCHESTRA_DEPLOYMENT_TYPE: "${{ matrix.deployment-type }}"
TEST_NOTEBOOK_PATHS: "${{ matrix.notebook-paths }}"
BUMP_VERSION: "${{ matrix.bump-version }}"
with:
timeout_seconds: 2400
max_attempts: 3
Expand Down
18 changes: 6 additions & 12 deletions notebooks/api/0.8/06-multiple-code-requests.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,8 @@
"outputs": [],
"source": [
"datasets = ds_client.datasets.search(name=\"My Sample Dataset - II\")\n",
"dataset_ptr2 = datasets[0]"
"dataset_ptr2 = datasets[0]\n",
"dataset_ptr2"
]
},
{
Expand All @@ -489,7 +490,8 @@
"outputs": [],
"source": [
"# Validate if input policy is violated\n",
"sum_ptr = ds_client.code.calculate_sum(data=dataset_ptr2.assets[0])"
"sum_ptr = ds_client.code.calculate_sum(data=dataset_ptr2.assets[0])\n",
"sum_ptr"
]
},
{
Expand All @@ -499,7 +501,7 @@
"metadata": {},
"outputs": [],
"source": [
"assert isinstance(sum_ptr, sy.SyftError), sum_ptr"
"assert isinstance(sum_ptr, sy.SyftError), (sum_ptr, str(dataset_ptr2.assets[0]))"
]
},
{
Expand Down Expand Up @@ -547,17 +549,9 @@
},
"outputs": [],
"source": [
"if node.node_type.value == \"python\":\n",
"if node.deployment_type.value in [\"python\", \"single_container\"]:\n",
" node.land()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "40",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
12 changes: 6 additions & 6 deletions notebooks/api/0.8/10-container-images.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@
"metadata": {},
"outputs": [],
"source": [
"custom_dockerfile_str = f\"\"\"\n",
"FROM openmined/grid-backend:{syft_base_worker_tag}\n",
"custom_dockerfile_str = \"\"\"\n",
"FROM openmined/grid-backend:0.8.5-beta.10\n",
"\n",
"RUN pip install pydicom\n",
"\n",
Expand Down Expand Up @@ -1108,8 +1108,8 @@
"metadata": {},
"outputs": [],
"source": [
"custom_dockerfile_str_2 = f\"\"\"\n",
"FROM openmined/grid-backend:{syft_base_worker_tag}\n",
"custom_dockerfile_str_2 = \"\"\"\n",
"FROM openmined/grid-backend:0.8.5-beta.10\n",
"\n",
"RUN pip install opendp\n",
"\"\"\".strip()\n",
Expand Down Expand Up @@ -1260,8 +1260,8 @@
"metadata": {},
"outputs": [],
"source": [
"custom_dockerfile_str_3 = f\"\"\"\n",
"FROM openmined/grid-backend:{syft_base_worker_tag}\n",
"custom_dockerfile_str_3 = \"\"\"\n",
"FROM openmined/grid-backend:0.8.5-beta.10\n",
"\n",
"RUN pip install recordlinkage\n",
"\"\"\".strip()\n",
Expand Down
2 changes: 1 addition & 1 deletion packages/hagrid/hagrid/orchestra.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,8 @@ def launch(
in_memory_workers: bool = True,
) -> NodeHandle | None:
NodeType = ImportFromSyft.import_node_type()
os.environ["DEV_MODE"] = str(dev_mode)
if dev_mode is True:
os.environ["DEV_MODE"] = "True"
thread_workers = True

# syft 0.8.1
Expand Down
7 changes: 1 addition & 6 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
from ..types.uid import UID
from ..util.experimental_flags import flags
from ..util.telemetry import instrument
from ..util.util import get_dev_mode
from ..util.util import get_env
from ..util.util import get_queue_address
from ..util.util import random_name
Expand Down Expand Up @@ -178,10 +179,6 @@ def get_default_root_password() -> str | None:
return get_env(DEFAULT_ROOT_PASSWORD, "changethis") # nosec


def get_dev_mode() -> bool:
return str_to_bool(get_env("DEV_MODE", "False"))


def get_enable_warnings() -> bool:
return str_to_bool(get_env("ENABLE_WARNINGS", "False"))

Expand Down Expand Up @@ -1427,8 +1424,6 @@ def get_unauthed_context(
return UnauthedServiceContext(node=self, login_credentials=login_credentials)

def create_initial_settings(self, admin_email: str) -> NodeSettingsV2 | None:
if self.name is None:
self.name = random_name()
try:
settings_stash = SettingsStash(store=self.document_store)
if self.signing_key is None:
Expand Down
58 changes: 51 additions & 7 deletions packages/syft/src/syft/protocol/data_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import os
from pathlib import Path
import re
from types import UnionType
import typing
from typing import Any
import warnings

# third party
from packaging.version import parse
Expand All @@ -24,6 +27,7 @@
from ..service.response import SyftSuccess
from ..types.dicttuple import DictTuple
from ..types.syft_object import SyftBaseObject
from ..util.util import get_dev_mode

PROTOCOL_STATE_FILENAME = "protocol_version.json"
PROTOCOL_TYPE = str | int
Expand Down Expand Up @@ -53,9 +57,36 @@ def protocol_release_dir() -> Path:
return data_protocol_dir() / "releases"


def handle_union_type_klass_name(type_klass_name: str) -> str:
if type_klass_name == typing.Union.__name__:
return UnionType.__name__
return type_klass_name


def handle_annotation_repr_(annotation: type) -> str:
"""Handle typing representation."""
origin = typing.get_origin(annotation)
args = typing.get_args(annotation)
if origin and args:
args_repr = ", ".join(getattr(arg, "__name__", str(arg)) for arg in args)
origin_repr = getattr(origin, "__name__", str(origin))

# Handle typing.Union and types.UnionType
origin_repr = handle_union_type_klass_name(origin_repr)
return f"{origin_repr}: [{args_repr}]"
elif args:
args_repr = ", ".join(
getattr(arg, "__name__", str(arg)) for arg in sorted(args)
)
return args_repr
else:
return repr(annotation)


class DataProtocol:
def __init__(self, filename: str) -> None:
def __init__(self, filename: str, raise_exception: bool = False) -> None:
self.file_path = data_protocol_dir() / filename
self.raise_exception = raise_exception
self.load_state()

def load_state(self) -> None:
Expand All @@ -67,8 +98,12 @@ def load_state(self) -> None:
@staticmethod
def _calculate_object_hash(klass: type[SyftBaseObject]) -> str:
# TODO: this depends on what is marked as serde

# Rebuild the model to ensure that the fields are up to date
# and any ForwardRef are resolved
klass.model_rebuild()
field_data = {
field: repr(field_info.annotation)
field: handle_annotation_repr_(field_info.rebuild_annotation())
for field, field_info in sorted(
klass.model_fields.items(), key=itemgetter(0)
)
Expand Down Expand Up @@ -211,14 +246,20 @@ def diff_state(self, state: dict) -> tuple[dict, dict]:
object_diff[canonical_name][str(version)]["action"] = "add"
continue

raise Exception(
error_msg = (
f"{canonical_name} for class {cls.__name__} fqn {cls} "
+ f"version {version} hash has changed. "
+ f"{hash_str} not in {versions.values()}. "
+ "Is a unique __canonical_name__ for this subclass missing? "
+ "If the class has changed you will need to define a new class with the changes, "
+ "with same __canonical_name__ and bump the __version__ number."
)

if get_dev_mode() or self.raise_exception:
raise Exception(error_msg)
else:
warnings.warn(error_msg, stacklevel=1, category=UserWarning)
break
else:
# new object so its an add
object_diff[canonical_name][str(version)] = {}
Expand Down Expand Up @@ -463,17 +504,20 @@ def has_dev(self) -> bool:
return False


def get_data_protocol() -> DataProtocol:
return DataProtocol(filename=data_protocol_file_name())
def get_data_protocol(raise_exception: bool = False) -> DataProtocol:
return DataProtocol(
filename=data_protocol_file_name(),
raise_exception=raise_exception,
)


def stage_protocol_changes() -> Result[SyftSuccess, SyftError]:
data_protocol = get_data_protocol()
data_protocol = get_data_protocol(raise_exception=True)
return data_protocol.stage_protocol_changes()


def bump_protocol_version() -> Result[SyftSuccess, SyftError]:
data_protocol = get_data_protocol()
data_protocol = get_data_protocol(raise_exception=True)
return data_protocol.bump_protocol_version()


Expand Down
Loading

0 comments on commit a3be8a2

Please sign in to comment.