diff --git a/packages/syft/src/syft/protocol/data_protocol.py b/packages/syft/src/syft/protocol/data_protocol.py index 0f050949e0d..c49da061573 100644 --- a/packages/syft/src/syft/protocol/data_protocol.py +++ b/packages/syft/src/syft/protocol/data_protocol.py @@ -95,34 +95,38 @@ def _hash_to_sha256(obj_dict: Dict) -> str: def build_state(self, stop_key: Optional[str] = None) -> dict: sorted_dict = sort_dict_naturally(self.protocol_history) state_dict = defaultdict(dict) - for k, _v in sorted_dict.items(): - object_versions = sorted_dict[k]["object_versions"] + for protocol_number in sorted_dict: + object_versions = sorted_dict[protocol_number]["object_versions"] for canonical_name, versions in object_versions.items(): for version, object_metadata in versions.items(): action = object_metadata["action"] version = object_metadata["version"] hash_str = object_metadata["hash"] state_versions = state_dict[canonical_name] + state_version_hashes = [val[0] for val in state_versions.values()] if action == "add" and ( str(version) in state_versions.keys() - or hash_str in state_versions.values() + or hash_str in state_version_hashes ): raise Exception( f"Can't add {object_metadata} already in state {versions}" ) elif action == "remove" and ( str(version) not in state_versions.keys() - or hash_str not in state_versions.values() + and hash_str not in state_version_hashes ): raise Exception( f"Can't remove {object_metadata} missing from state {versions} for object {canonical_name}." ) if action == "add": - state_dict[canonical_name][str(version)] = hash_str + state_dict[canonical_name][str(version)] = ( + hash_str, + protocol_number, + ) elif action == "remove": del state_dict[canonical_name][str(version)] # stop early - if stop_key == k: + if stop_key == protocol_number: return state_dict return state_dict @@ -152,7 +156,7 @@ def diff_state(self, state: dict) -> tuple[dict, dict]: if canonical_name not in state: # new object so its an add object_diff[canonical_name][str(version)] = {} - object_diff[canonical_name][str(version)]["version"] = version + object_diff[canonical_name][str(version)]["version"] = int(version) object_diff[canonical_name][str(version)]["hash"] = hash_str object_diff[canonical_name][str(version)]["action"] = "add" continue @@ -160,33 +164,45 @@ def diff_state(self, state: dict) -> tuple[dict, dict]: versions = state[canonical_name] if ( str(version) in versions.keys() - and versions[str(version)] == hash_str + and versions[str(version)][0] == hash_str ): # already there so do nothing continue elif str(version) in versions.keys(): + is_protocol_dev = versions[str(version)][1] == "dev" + if is_protocol_dev: + # force overwrite existing object so its an add + object_diff[canonical_name][str(version)] = {} + object_diff[canonical_name][str(version)]["version"] = int( + version + ) + object_diff[canonical_name][str(version)]["hash"] = hash_str + object_diff[canonical_name][str(version)]["action"] = "add" + continue + raise Exception( f"{canonical_name} for class {cls.__name__} fqn {cls} " + f"version {version} hash has changed. " + f"{hash_str} not in {versions.values()}. " + "Is a unique __canonical_name__ for this subclass missing? " - + "If the class has changed you will need to bump the version number." + + "If the class has changed you will need to define a new class with the changes, " + + "with same __canonical_name__ and bump the __version__ number." ) else: # new object so its an add object_diff[canonical_name][str(version)] = {} - object_diff[canonical_name][str(version)]["version"] = version + object_diff[canonical_name][str(version)]["version"] = int(version) object_diff[canonical_name][str(version)]["hash"] = hash_str object_diff[canonical_name][str(version)]["action"] = "add" continue # now check for remove actions for canonical_name in state: - for version, hash_str in state[canonical_name].items(): + for version, (hash_str, _) in state[canonical_name].items(): if canonical_name not in compare_dict: # missing so its a remove object_diff[canonical_name][str(version)] = {} - object_diff[canonical_name][str(version)]["version"] = version + object_diff[canonical_name][str(version)]["version"] = int(version) object_diff[canonical_name][str(version)]["hash"] = hash_str object_diff[canonical_name][str(version)]["action"] = "remove" continue @@ -194,7 +210,7 @@ def diff_state(self, state: dict) -> tuple[dict, dict]: if str(version) not in versions.keys(): # missing so its a remove object_diff[canonical_name][str(version)] = {} - object_diff[canonical_name][str(version)]["version"] = version + object_diff[canonical_name][str(version)]["version"] = int(version) object_diff[canonical_name][str(version)]["hash"] = hash_str object_diff[canonical_name][str(version)]["action"] = "remove" continue @@ -212,7 +228,27 @@ def stage_protocol_changes(self) -> Result[SyftSuccess, SyftError]: if canonical_name not in object_versions: object_versions[canonical_name] = {} change_count += 1 - object_versions[canonical_name][version] = version_metadata + action = version_metadata["action"] + + # Allow removal of class that only been staged to dev + if ( + action == "remove" + and str(version) in object_versions[canonical_name] + ): + # Delete the whole class if only single version exists + if len(object_versions[canonical_name]) == 1: + del object_versions[canonical_name] + else: + # In case of multiple versions of the class only delete the selected + del object_versions[canonical_name][str(version)] + + else: # Add or overwrite existing data in dev + object_versions[canonical_name][str(version)] = version_metadata + + # Sort the version dict + object_versions[canonical_name] = sort_dict_naturally( + object_versions[canonical_name] + ) current_history["dev"]["object_versions"] = object_versions