Skip to content

Commit

Permalink
Merge pull request #8193 from OpenMined/fix-protocol-bug
Browse files Browse the repository at this point in the history
Fix protocol version bug
  • Loading branch information
shubham3121 authored Oct 29, 2023
2 parents 08f0658 + db811a2 commit 258d72d
Showing 1 changed file with 50 additions and 14 deletions.
64 changes: 50 additions & 14 deletions packages/syft/src/syft/protocol/data_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,34 +95,38 @@ def _hash_to_sha256(obj_dict: Dict) -> str:
def build_state(self, stop_key: Optional[str] = None) -> dict:
sorted_dict = sort_dict_naturally(self.protocol_history)
state_dict = defaultdict(dict)
for k, _v in sorted_dict.items():
object_versions = sorted_dict[k]["object_versions"]
for protocol_number in sorted_dict:
object_versions = sorted_dict[protocol_number]["object_versions"]
for canonical_name, versions in object_versions.items():
for version, object_metadata in versions.items():
action = object_metadata["action"]
version = object_metadata["version"]
hash_str = object_metadata["hash"]
state_versions = state_dict[canonical_name]
state_version_hashes = [val[0] for val in state_versions.values()]
if action == "add" and (
str(version) in state_versions.keys()
or hash_str in state_versions.values()
or hash_str in state_version_hashes
):
raise Exception(
f"Can't add {object_metadata} already in state {versions}"
)
elif action == "remove" and (
str(version) not in state_versions.keys()
or hash_str not in state_versions.values()
and hash_str not in state_version_hashes
):
raise Exception(
f"Can't remove {object_metadata} missing from state {versions} for object {canonical_name}."
)
if action == "add":
state_dict[canonical_name][str(version)] = hash_str
state_dict[canonical_name][str(version)] = (
hash_str,
protocol_number,
)
elif action == "remove":
del state_dict[canonical_name][str(version)]
# stop early
if stop_key == k:
if stop_key == protocol_number:
return state_dict
return state_dict

Expand Down Expand Up @@ -152,49 +156,61 @@ def diff_state(self, state: dict) -> tuple[dict, dict]:
if canonical_name not in state:
# new object so its an add
object_diff[canonical_name][str(version)] = {}
object_diff[canonical_name][str(version)]["version"] = version
object_diff[canonical_name][str(version)]["version"] = int(version)
object_diff[canonical_name][str(version)]["hash"] = hash_str
object_diff[canonical_name][str(version)]["action"] = "add"
continue

versions = state[canonical_name]
if (
str(version) in versions.keys()
and versions[str(version)] == hash_str
and versions[str(version)][0] == hash_str
):
# already there so do nothing
continue
elif str(version) in versions.keys():
is_protocol_dev = versions[str(version)][1] == "dev"
if is_protocol_dev:
# force overwrite existing object so its an add
object_diff[canonical_name][str(version)] = {}
object_diff[canonical_name][str(version)]["version"] = int(
version
)
object_diff[canonical_name][str(version)]["hash"] = hash_str
object_diff[canonical_name][str(version)]["action"] = "add"
continue

raise Exception(
f"{canonical_name} for class {cls.__name__} fqn {cls} "
+ f"version {version} hash has changed. "
+ f"{hash_str} not in {versions.values()}. "
+ "Is a unique __canonical_name__ for this subclass missing? "
+ "If the class has changed you will need to bump the version number."
+ "If the class has changed you will need to define a new class with the changes, "
+ "with same __canonical_name__ and bump the __version__ number."
)
else:
# new object so its an add
object_diff[canonical_name][str(version)] = {}
object_diff[canonical_name][str(version)]["version"] = version
object_diff[canonical_name][str(version)]["version"] = int(version)
object_diff[canonical_name][str(version)]["hash"] = hash_str
object_diff[canonical_name][str(version)]["action"] = "add"
continue

# now check for remove actions
for canonical_name in state:
for version, hash_str in state[canonical_name].items():
for version, (hash_str, _) in state[canonical_name].items():
if canonical_name not in compare_dict:
# missing so its a remove
object_diff[canonical_name][str(version)] = {}
object_diff[canonical_name][str(version)]["version"] = version
object_diff[canonical_name][str(version)]["version"] = int(version)
object_diff[canonical_name][str(version)]["hash"] = hash_str
object_diff[canonical_name][str(version)]["action"] = "remove"
continue
versions = compare_dict[canonical_name]
if str(version) not in versions.keys():
# missing so its a remove
object_diff[canonical_name][str(version)] = {}
object_diff[canonical_name][str(version)]["version"] = version
object_diff[canonical_name][str(version)]["version"] = int(version)
object_diff[canonical_name][str(version)]["hash"] = hash_str
object_diff[canonical_name][str(version)]["action"] = "remove"
continue
Expand All @@ -212,7 +228,27 @@ def stage_protocol_changes(self) -> Result[SyftSuccess, SyftError]:
if canonical_name not in object_versions:
object_versions[canonical_name] = {}
change_count += 1
object_versions[canonical_name][version] = version_metadata
action = version_metadata["action"]

# Allow removal of class that only been staged to dev
if (
action == "remove"
and str(version) in object_versions[canonical_name]
):
# Delete the whole class if only single version exists
if len(object_versions[canonical_name]) == 1:
del object_versions[canonical_name]
else:
# In case of multiple versions of the class only delete the selected
del object_versions[canonical_name][str(version)]

else: # Add or overwrite existing data in dev
object_versions[canonical_name][str(version)] = version_metadata

# Sort the version dict
object_versions[canonical_name] = sort_dict_naturally(
object_versions[canonical_name]
)

current_history["dev"]["object_versions"] = object_versions

Expand Down

0 comments on commit 258d72d

Please sign in to comment.