Skip to content

Commit

Permalink
msm tests with fixes and neg_3 hint support
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime committed Jul 25, 2024
1 parent 4e6be8f commit 805d35d
Show file tree
Hide file tree
Showing 20 changed files with 56,064 additions and 6,638 deletions.
44 changes: 42 additions & 2 deletions hydra/algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,46 @@ def __eq__(self, other: object) -> bool:
return self.value == other
raise TypeError(f"Cannot compare PyFelt and {type(other)}")

def __lt__(self, other: PyFelt | int) -> bool:
if isinstance(other, PyFelt):
return self.value < other.value
if isinstance(other, int):
return self.value < other
raise TypeError(f"Cannot compare PyFelt and {type(other)}")

def __le__(self, other: PyFelt | int) -> bool:
if isinstance(other, PyFelt):
return self.value <= other.value
if isinstance(other, int):
return self.value <= other
raise TypeError(f"Cannot compare PyFelt and {type(other)}")

def __gt__(self, other: PyFelt | int) -> bool:
if isinstance(other, PyFelt):
return self.value > other.value
if isinstance(other, int):
return self.value > other
raise TypeError(f"Cannot compare PyFelt and {type(other)}")

def __ge__(self, other: PyFelt | int) -> bool:
if isinstance(other, PyFelt):
return self.value >= other.value
if isinstance(other, int):
return self.value >= other
raise TypeError(f"Cannot compare PyFelt and {type(other)}")

def __rlt__(self, left: int) -> bool:
return left < self.value

def __rle__(self, left: int) -> bool:
return left <= self.value

def __rgt__(self, left: int) -> bool:
return left > self.value

def __rge__(self, left: int) -> bool:
return left >= self.value

def __ne__(self, other: object) -> bool:
return not self.__eq__(other)

Expand Down Expand Up @@ -290,11 +330,11 @@ def __truediv__(self, other):
return quo

def __floordiv__(self, other: "Polynomial") -> "Polynomial":
quo, rem = Polynomial.__divmod__(self, other)
quo, _ = Polynomial.__divmod__(self, other)
return quo

def __mod__(self, other: "Polynomial") -> "Polynomial":
quo, rem = Polynomial.__divmod__(self, other)
_, rem = Polynomial.__divmod__(self, other)
return rem

def __divmod__(self, denominator: "Polynomial") -> tuple[Polynomial, Polynomial]:
Expand Down
36 changes: 16 additions & 20 deletions hydra/definitions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from hydra.algebra import Polynomial, BaseField, PyFelt, ModuloCircuitElement
from hydra.hints.io import bigint_split, int_to_u384
from hydra.hints.io import bigint_split, int_to_u384, int_to_u256

from starkware.python.math_utils import ec_safe_mult, EcInfinity, ec_safe_add
from dataclasses import dataclass
Expand Down Expand Up @@ -84,20 +84,14 @@ def to_cairo_zero(self) -> str:
return code

def to_cairo_one(self) -> str:
code = f"const {self.cairo_zero_namespace_name}:Curve = \n"
p = bigint_split(self.p, N_LIMBS, BASE)
n = bigint_split(self.n, N_LIMBS, BASE)
a = bigint_split(self.a, N_LIMBS, BASE)
b = bigint_split(self.b, N_LIMBS, BASE)
g = bigint_split(self.fp_generator, N_LIMBS, BASE)
min_one = bigint_split(-1 % self.p, N_LIMBS, BASE)
code = f"const {self.cairo_zero_namespace_name.upper()}:Curve = \n"
code += "Curve {\n"
code += f"p:u384{{limb0: {hex(p[0])}, limb1: {hex(p[1])}, limb2: {hex(p[2])}, limb3: {hex(p[3])}}},\n"
code += f"n:u384{{limb0: {hex(n[0])}, limb1: {hex(n[1])}, limb2: {hex(n[2])}, limb3: {hex(n[3])}}},\n"
code += f"a:u384{{limb0: {hex(a[0])}, limb1: {hex(a[1])}, limb2: {hex(a[2])}, limb3: {hex(a[3])}}},\n"
code += f"b:u384{{limb0: {hex(b[0])}, limb1: {hex(b[1])}, limb2: {hex(b[2])}, limb3: {hex(b[3])}}},\n"
code += f"g:u384{{limb0: {hex(g[0])}, limb1: {hex(g[1])}, limb2: {hex(g[2])}, limb3: {hex(g[3])}}},\n"
code += f"min_one:u384{{limb0: {hex(min_one[0])}, limb1: {hex(min_one[1])}, limb2: {hex(min_one[2])}, limb3: {hex(min_one[3])}}},\n"
code += f"p:{int_to_u384(self.p)},\n"
code += f"n:{int_to_u256(self.n)},\n"
code += f"a:{int_to_u384(self.a)},\n"
code += f"b:{int_to_u384(self.b)},\n"
code += f"g:{int_to_u384(self.fp_generator)},\n"
code += f"min_one:{int_to_u384(-1%self.p)},\n"
code += "};\n"
return code

Expand Down Expand Up @@ -246,8 +240,8 @@ def NAF(x):
line_function_sparsity=None,
final_exp_cofactor=None,
fp_generator=6,
Gx=0,
Gy=0,
Gx=0x2AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD245A,
Gy=0x20AE19A1B8A086B4E01EDD2C7748D14C923D4D7E6D7C61B229E9C5A27ECED3D9,
),
}

Expand Down Expand Up @@ -276,7 +270,9 @@ def is_generator(g: int, p: int) -> bool:
return True


def get_base_field(curve_id: int) -> BaseField:
def get_base_field(curve_id: int | CurveID) -> BaseField:
if isinstance(curve_id, CurveID):
curve_id = curve_id.value
return BaseField(CURVES[curve_id].p)


Expand Down Expand Up @@ -630,13 +626,13 @@ def replace_consecutive_zeros(lst):
i = 0
while i < len(lst):
if i < len(lst) - 1 and lst[i] == 0 and lst[i + 1] == 0:
result.append(3) # Replace consecutive zeros with 3
result.append(3) # Replace consecutive zeros with 3
i += 2
elif lst[i] == -1:
result.append(2) # Replace -1 with 2
result.append(2) # Replace -1 with 2
i += 1
else:
result.append(lst[i])
result.append(lst[i])
i += 1
return result

Expand Down
17 changes: 9 additions & 8 deletions hydra/hints/ecip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from starkware.python.math_utils import is_quad_residue, sqrt as sqrt_mod_p
from hydra.poseidon_transcript import hades_permutation
from hydra.hints.io import int_to_u384, int_array_to_u384_array
from hydra.hints.neg_3 import construct_digit_vectors


def derive_ec_point_from_X(
x: PyFelt | int, curve_id: CurveID
) -> tuple[PyFelt, list[PyFelt]]:
) -> tuple[PyFelt, PyFelt, list[PyFelt]]:
field = get_base_field(curve_id.value)
if isinstance(x, int):
x = field(x)
Expand All @@ -37,18 +38,16 @@ def derive_ec_point_from_X(
)
attempt += 1

y = sqrt_mod_p(rhs.value, field.p)
assert field(y) ** 2 == rhs
return x, y, g_rhs_roots
y = field(sqrt_mod_p(rhs.value, field.p))
assert y**2 == rhs
return x, y, [field(r) for r in g_rhs_roots]


def zk_ecip_hint(
Bs: list[G1Point], dss: list[list[int]]
) -> tuple[G1Point, FunctionFelt]:
def zk_ecip_hint(Bs: list[G1Point], scalars: list[int]) -> tuple[G1Point, FunctionFelt]:
"""
Inputs:
- Bs: list of points on the curve
- dss: list of digits of the points in Bs (obtained from scalars using hints.neg3.construct_digit_vectors)
- scalars: list of scalars
Returns:
- Q: MSM of Bs by scalars contained in dss matrix
- sum_dlog: sum of the logarithmic derivatives of the functions in Ds
Expand All @@ -57,6 +56,8 @@ def zk_ecip_hint(
Partial Ref : https://gist.github.com/Liam-Eagen/666d0771f4968adccd6087465b8c5bd4
Full algo verifying it available in tests/benchmarks.py::test_msm_n_points
"""
assert len(Bs) == len(scalars)
dss = construct_digit_vectors(scalars)
Q, Ds = ecip_functions(Bs, dss)
dlogs = [dlog(D) for D in Ds]
sum_dlog = dlogs[0]
Expand Down
35 changes: 25 additions & 10 deletions hydra/hints/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,17 @@ def bigint_split(

def int_to_u384(x: int | PyFelt) -> str:
limbs = bigint_split(x, 4, 2**96)
return f"u384{{limb0:{limbs[0]}, limb1:{limbs[1]}, limb2:{limbs[2]}, limb3:{limbs[3]}}}"
return f"u384{{limb0:{hex(limbs[0])}, limb1:{hex(limbs[1])}, limb2:{hex(limbs[2])}, limb3:{hex(limbs[3])}}}"


def int_to_u256(x: int | PyFelt) -> str:
assert 0 <= x < 2**256, f"Value {x} is too large to fit in a u256"
limbs = bigint_split(x, 2, 2**128)
return f"u256{{low:{hex(limbs[0])}, high:{hex(limbs[1])}}}"


def int_array_to_u256_array(x: list) -> str:
return f"array![{', '.join([int_to_u256(i) for i in x])}]"


def int_array_to_u384_array(x: list) -> str:
Expand Down Expand Up @@ -154,20 +164,25 @@ def fill_uint256(x: int, ids: object):


def padd_function_felt(
f: FunctionFelt, n: int
f: FunctionFelt, n: int, py_felt: bool = False
) -> tuple[list[int], list[int], list[int], list[int]]:
a_num = f.a.numerator.get_value_coeffs()
a_den = f.a.denominator.get_value_coeffs()
b_num = f.b.numerator.get_value_coeffs()
b_den = f.b.denominator.get_value_coeffs()
a_num = f.a.numerator.get_coeffs() if py_felt else f.a.numerator.get_value_coeffs()
a_den = (
f.a.denominator.get_coeffs() if py_felt else f.a.denominator.get_value_coeffs()
)
b_num = f.b.numerator.get_coeffs() if py_felt else f.b.numerator.get_value_coeffs()
b_den = (
f.b.denominator.get_coeffs() if py_felt else f.b.denominator.get_value_coeffs()
)
assert len(a_num) <= n + 1
assert len(a_den) <= n + 2
assert len(b_num) <= n + 2
assert len(b_den) <= n + 5
a_num = a_num + [0] * (n + 1 - len(a_num))
a_den = a_den + [0] * (n + 2 - len(a_den))
b_num = b_num + [0] * (n + 2 - len(b_num))
b_den = b_den + [0] * (n + 5 - len(b_den))
zero = [f.a.numerator.field.zero()] if py_felt else [0]
a_num = a_num + zero * (n + 1 - len(a_num))
a_den = a_den + zero * (n + 2 - len(a_den))
b_num = b_num + zero * (n + 2 - len(b_num))
b_den = b_den + zero * (n + 5 - len(b_den))
return (a_num, a_den, b_num, b_den)


Expand Down
38 changes: 28 additions & 10 deletions hydra/modulo_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,15 +382,33 @@ def write_struct(
self,
struct: Cairo1SerializableStruct,
write_source: WriteOps = WriteOps.INPUT,
) -> list[ModuloCircuitElement]:
assert all(
type(elmt) == PyFelt for elmt in struct.elmts
), f"Expected PyFelt, got {type(struct.elmts)}"
self.input_structs.append(struct)
if len(struct) == 1:
return self.write_element(struct.elmts[0], write_source)
else:
return self.write_elements(struct.elmts, write_source)
) -> (
list[ModuloCircuitElement]
| list[list[ModuloCircuitInstruction]]
| ModuloCircuitElement
):
all_pyfelt = all(type(elmt) == PyFelt for elmt in struct.elmts)
all_cairo1serializablestruct = all(
isinstance(elmt, Cairo1SerializableStruct) for elmt in struct.elmts
)
assert (
all_pyfelt or all_cairo1serializablestruct
), f"Expected list of PyFelt or Cairo1SerializableStruct, got {[type(elmt) for elmt in struct.elmts]}"

if all_pyfelt:
self.input_structs.append(struct)
if len(struct) == 1:
return self.write_element(struct.elmts[0], write_source)
else:
return self.write_elements(struct.elmts, write_source)
elif all_cairo1serializablestruct:
result = [self.write_struct(elmt, write_source) for elmt in struct.elmts]
# Ensure only the larger struct is appended
self.input_structs = [
s for s in self.input_structs if s not in struct.elmts
]
self.input_structs.append(struct)
return result

def write_elements(
self, elmts: list[PyFelt], operation: WriteOps, sparsity: list[int] = None
Expand Down Expand Up @@ -927,7 +945,7 @@ def compile_circuit_cairo_1(
signature_input = f"{','.join([x.serialize_input_signature() for x in self.input_structs])}"
else:
raise ValueError(
f"Input structs must have the same number of elements as the input: {len(self.input_structs)=} != {len(self.input)=}"
f"Input structs must have the same number of elements as the input: {sum([len(x) for x in self.input_structs])=} != {len(self.input)=}"
)
else:
signature_input = f"mut input: Array<u384>"
Expand Down
Loading

0 comments on commit 805d35d

Please sign in to comment.