diff --git a/hydra/garaga/starknet/groth16_contract_generator/parsing_utils.py b/hydra/garaga/starknet/groth16_contract_generator/parsing_utils.py index a1353ad3..f062096f 100644 --- a/hydra/garaga/starknet/groth16_contract_generator/parsing_utils.py +++ b/hydra/garaga/starknet/groth16_contract_generator/parsing_utils.py @@ -1,3 +1,4 @@ +import codecs import dataclasses import hashlib import json @@ -23,6 +24,12 @@ ) +class KeyPatternNotFound(Exception): + def __init__(self, key_patterns): + super().__init__(f"No key found with patterns {key_patterns}") + self.key_patterns = key_patterns + + def iterate_nested_dict(d): for key, value in d.items(): if isinstance(value, dict): @@ -49,7 +56,7 @@ def find_item_from_key_patterns(data: dict, key_patterns: List[str]) -> Any: if best_match is not None: return best_match else: - raise ValueError(f"No key found with patterns {key_patterns}") + raise KeyPatternNotFound(key_patterns) def try_parse_g1_point_from_key( @@ -151,7 +158,7 @@ def try_parse_g2_point(point: Any, curve_id: CurveID = None) -> G2Point: def try_guessing_curve_id_from_json(data: dict) -> CurveID: try: curve_id = CurveID.from_str(find_item_from_key_patterns(data, ["curve"])) - except (ValueError, KeyError): + except (ValueError, KeyError, KeyPatternNotFound): # Try guessing the curve id from the bit size of the first found integer in the json. x = None for value in iterate_nested_dict(data): @@ -201,7 +208,7 @@ def from_dict(data: dict) -> "Groth16VerifyingKey": curve_id = try_guessing_curve_id_from_json(data) try: verifying_key = find_item_from_key_patterns(data, ["verifying_key"]) - except ValueError: + except KeyPatternNotFound: verifying_key = data try: return Groth16VerifyingKey( @@ -220,7 +227,7 @@ def from_dict(data: dict) -> "Groth16VerifyingKey": for point in find_item_from_key_patterns(verifying_key, ["ic"]) ], ) - except ValueError: + except (ValueError, KeyPatternNotFound): # Gnark case. g1_points = find_item_from_key_patterns(verifying_key, ["g1"]) g2_points = find_item_from_key_patterns(verifying_key, ["g2"]) @@ -318,23 +325,26 @@ def from_dict( curve_id = try_guessing_curve_id_from_json(data) try: proof = find_item_from_key_patterns(data, ["proof"]) - except ValueError: + except KeyPatternNotFound: proof = data try: - seal = io.to_hex_str(find_item_from_key_patterns(data, ["seal"])) - image_id = io.to_hex_str(find_item_from_key_patterns(data, ["image_id"])) - journal = io.to_hex_str(find_item_from_key_patterns(data, ["journal"])) + seal = find_item_from_key_patterns(data, ["seal"]) + image_id = find_item_from_key_patterns(data, ["image_id"]) + journal = find_item_from_key_patterns(data, ["journal"]) + print("Seal: {}\nImage_id: {}\nJournal: {}".format(seal, image_id, journal)) return Groth16Proof._from_risc0( - seal=bytes.fromhex(seal[2:]), - image_id=bytes.fromhex(image_id[2:]), - journal=bytes.fromhex(journal[2:]), + seal=codecs.decode(seal[2:].replace("\\x", ""), "hex"), + image_id=codecs.decode(image_id[2:].replace("\\x", ""), "hex"), + journal=codecs.decode(journal[2:].replace("\\x", ""), "hex"), ) except ValueError: pass except KeyError: pass + except KeyPatternNotFound: + pass except Exception as e: print(f"Error: {e}") raise e @@ -347,7 +357,11 @@ def from_dict( else: raise ValueError(f"Invalid public inputs format: {public_inputs}") else: - public_inputs = find_item_from_key_patterns(data, ["public"]) + try: + public_inputs = find_item_from_key_patterns(data, ["public"]) + except KeyPatternNotFound as e: + print(f"Error: {e}") + raise e return Groth16Proof( a=try_parse_g1_point_from_key(proof, ["a"], curve_id), b=try_parse_g2_point_from_key(proof, ["b"], curve_id),