Skip to content

Commit

Permalink
Improved RiscZero error handling and parsing in parsing_utils.py (#233)
Browse files Browse the repository at this point in the history
Co-authored-by: Stefan Madzharov <[email protected]>
  • Loading branch information
stefanMadzharov and StefanMadzh authored Oct 18, 2024
1 parent 39c8d21 commit 1018d21
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions hydra/garaga/starknet/groth16_contract_generator/parsing_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import codecs
import dataclasses
import hashlib
import json
Expand All @@ -23,6 +24,12 @@
)


class KeyPatternNotFound(Exception):
def __init__(self, key_patterns):
super().__init__(f"No key found with patterns {key_patterns}")
self.key_patterns = key_patterns


def iterate_nested_dict(d):
for key, value in d.items():
if isinstance(value, dict):
Expand All @@ -49,7 +56,7 @@ def find_item_from_key_patterns(data: dict, key_patterns: List[str]) -> Any:
if best_match is not None:
return best_match
else:
raise ValueError(f"No key found with patterns {key_patterns}")
raise KeyPatternNotFound(key_patterns)


def try_parse_g1_point_from_key(
Expand Down Expand Up @@ -151,7 +158,7 @@ def try_parse_g2_point(point: Any, curve_id: CurveID = None) -> G2Point:
def try_guessing_curve_id_from_json(data: dict) -> CurveID:
try:
curve_id = CurveID.from_str(find_item_from_key_patterns(data, ["curve"]))
except (ValueError, KeyError):
except (ValueError, KeyError, KeyPatternNotFound):
# Try guessing the curve id from the bit size of the first found integer in the json.
x = None
for value in iterate_nested_dict(data):
Expand Down Expand Up @@ -201,7 +208,7 @@ def from_dict(data: dict) -> "Groth16VerifyingKey":
curve_id = try_guessing_curve_id_from_json(data)
try:
verifying_key = find_item_from_key_patterns(data, ["verifying_key"])
except ValueError:
except KeyPatternNotFound:
verifying_key = data
try:
return Groth16VerifyingKey(
Expand All @@ -220,7 +227,7 @@ def from_dict(data: dict) -> "Groth16VerifyingKey":
for point in find_item_from_key_patterns(verifying_key, ["ic"])
],
)
except ValueError:
except (ValueError, KeyPatternNotFound):
# Gnark case.
g1_points = find_item_from_key_patterns(verifying_key, ["g1"])
g2_points = find_item_from_key_patterns(verifying_key, ["g2"])
Expand Down Expand Up @@ -318,23 +325,26 @@ def from_dict(
curve_id = try_guessing_curve_id_from_json(data)
try:
proof = find_item_from_key_patterns(data, ["proof"])
except ValueError:
except KeyPatternNotFound:
proof = data

try:
seal = io.to_hex_str(find_item_from_key_patterns(data, ["seal"]))
image_id = io.to_hex_str(find_item_from_key_patterns(data, ["image_id"]))
journal = io.to_hex_str(find_item_from_key_patterns(data, ["journal"]))
seal = find_item_from_key_patterns(data, ["seal"])
image_id = find_item_from_key_patterns(data, ["image_id"])
journal = find_item_from_key_patterns(data, ["journal"])

print("Seal: {}\nImage_id: {}\nJournal: {}".format(seal, image_id, journal))
return Groth16Proof._from_risc0(
seal=bytes.fromhex(seal[2:]),
image_id=bytes.fromhex(image_id[2:]),
journal=bytes.fromhex(journal[2:]),
seal=codecs.decode(seal[2:].replace("\\x", ""), "hex"),
image_id=codecs.decode(image_id[2:].replace("\\x", ""), "hex"),
journal=codecs.decode(journal[2:].replace("\\x", ""), "hex"),
)
except ValueError:
pass
except KeyError:
pass
except KeyPatternNotFound:
pass
except Exception as e:
print(f"Error: {e}")
raise e
Expand All @@ -347,7 +357,11 @@ def from_dict(
else:
raise ValueError(f"Invalid public inputs format: {public_inputs}")
else:
public_inputs = find_item_from_key_patterns(data, ["public"])
try:
public_inputs = find_item_from_key_patterns(data, ["public"])
except KeyPatternNotFound as e:
print(f"Error: {e}")
raise e
return Groth16Proof(
a=try_parse_g1_point_from_key(proof, ["a"], curve_id),
b=try_parse_g2_point_from_key(proof, ["b"], curve_id),
Expand Down

0 comments on commit 1018d21

Please sign in to comment.