Skip to content

Commit

Permalink
Merge branch 'main' into pypi-wheels-config
Browse files Browse the repository at this point in the history
  • Loading branch information
raugfer committed Aug 19, 2024
2 parents 75c91b8 + d2011dc commit b1857a8
Show file tree
Hide file tree
Showing 24 changed files with 2,103 additions and 193 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/fustat.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
run: |
source venv/bin/activate && ./tools/make/fustat_format_check.sh
- name: Compile cairo files
run: source venv/bin/activate && make build
run: source venv/bin/activate && make clean && ./tools/make/build.sh tests
- name: Run fustat programs
run: |
source venv/bin/activate
Expand Down
4 changes: 2 additions & 2 deletions hydra/garaga/algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,5 +1032,5 @@ def validate_degrees(self, msm_size: int) -> bool:
assert degrees["b"]["denominator"] <= msm_size + 5
return True

def print_as_sage_poly(self, var: str = "x") -> str:
return f"(({self.b.numerator.print_as_sage_poly(var)}) / ({self.b.denominator.print_as_sage_poly(var)}) * y + ({self.a.numerator.print_as_sage_poly(var)} / ({self.a.denominator.print_as_sage_poly(var)})"
def print_as_sage_poly(self, var: str = "x", as_hex: bool = False) -> str:
return f"(({self.b.numerator.print_as_sage_poly(var, as_hex)}) / ({self.b.denominator.print_as_sage_poly(var, as_hex)}) * y + ({self.a.numerator.print_as_sage_poly(var, as_hex)} / ({self.a.denominator.print_as_sage_poly(var, as_hex)})"
3 changes: 3 additions & 0 deletions hydra/garaga/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,9 @@ class G1Point:
y: int
curve_id: CurveID

def __str__(self) -> str:
return f"G1Point({hex(self.x)}, {hex(self.y)}) on curve {self.curve_id}"

def __hash__(self):
return hash((self.x, self.y, self.curve_id))

Expand Down
82 changes: 68 additions & 14 deletions hydra/garaga/hints/ecip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import functools
from dataclasses import dataclass

import garaga_rs

from garaga.algebra import Fp2, FunctionFelt, Polynomial, PyFelt, RationalFunction, T
from garaga.definitions import CURVES, CurveID, G1Point, G2Point, get_base_field
from garaga.hints.neg_3 import (
Expand Down Expand Up @@ -86,7 +88,7 @@ def rhs_compute(x: PyFelt | Fp2) -> PyFelt | Fp2:


def zk_ecip_hint(
Bs: list[G1Point] | list[G2Point], scalars: list[int]
Bs: list[G1Point] | list[G2Point], scalars: list[int], use_rust: bool = True
) -> tuple[G1Point | G2Point, FunctionFelt[T]]:
"""
Inputs:
Expand All @@ -102,12 +104,39 @@ def zk_ecip_hint(
"""
assert len(Bs) == len(scalars)

dss = construct_digit_vectors(scalars)
Q, Ds = ecip_functions(Bs, dss)
dlogs = [dlog(D) for D in Ds]
sum_dlog = dlogs[0]
for i in range(1, len(dlogs)):
sum_dlog = sum_dlog + (-3) ** i * dlogs[i]
ec_group_class = get_ec_group_class_from_ec_point(Bs[0])
if ec_group_class == G1Point and use_rust:
pts = []
c_id = Bs[0].curve_id
if c_id == CurveID.BLS12_381:
nb = 48
else:
nb = 32
for pt in Bs:
pts.extend([pt.x.to_bytes(nb, "big"), pt.y.to_bytes(nb, "big")])
field_type = get_field_type_from_ec_point(Bs[0])
field = get_base_field(c_id.value, field_type)

q, a_num, a_den, b_num, b_den = garaga_rs.zk_ecip_hint(pts, scalars, c_id.value)

a_num = [field(int(f, 16)) for f in a_num] if len(a_num) > 0 else [field.zero()]
a_den = [field(int(f, 16)) for f in a_den] if len(a_den) > 0 else [field.one()]
b_num = [field(int(f, 16)) for f in b_num] if len(b_num) > 0 else [field.zero()]
b_den = [field(int(f, 16)) for f in b_den] if len(b_den) > 0 else [field.one()]

Q = G1Point(int(q[0], 16), int(q[1], 16), c_id)
sum_dlog = FunctionFelt(
RationalFunction(Polynomial(a_num), Polynomial(a_den)),
RationalFunction(Polynomial(b_num), Polynomial(b_den)),
)
else:
dss = construct_digit_vectors(scalars)
Q, Ds = ecip_functions(Bs, dss)
dlogs = [dlog(D) for D in Ds]
sum_dlog = dlogs[0]
field = get_base_field(Q.curve_id.value, PyFelt)
for i in range(1, len(dlogs)):
sum_dlog = sum_dlog + (-3) ** i * dlogs[i]
return Q, sum_dlog


Expand All @@ -117,10 +146,11 @@ def verify_ecip(
Q: G1Point | G2Point = None,
sum_dlog: FunctionFelt[T] = None,
A0: G1Point | G2Point = None,
use_rust: bool = True,
) -> bool:
# Prover :
if Q is None or sum_dlog is None:
Q, sum_dlog = zk_ecip_hint(Bs, scalars)
Q, sum_dlog = zk_ecip_hint(Bs, scalars, use_rust)
else:
Q = Q
sum_dlog = sum_dlog
Expand Down Expand Up @@ -411,8 +441,8 @@ def construct_function(Ps: list[G1Point] | list[G2Point]) -> FF:
raise EmptyListOfPoints(
"Cannot construct a function from an empty list of points"
)

xs = [(P, line(P, -P)) for P in Ps]

while len(xs) != 1:
xs2 = []

Expand All @@ -425,7 +455,10 @@ def construct_function(Ps: list[G1Point] | list[G2Point]) -> FF:
for n in range(0, len(xs) // 2):
(A, aNum) = xs[2 * n]
(B, bNum) = xs[2 * n + 1]
num = (aNum * bNum * line(A, B)).reduce()
aNum_bNum = aNum * bNum
line_AB = line(A, B)
product = aNum_bNum * line_AB
num = product.reduce()
den = (line(A, -A) * line(B, -B)).to_poly()
D = num.div_by_poly(den)
xs2.append((A.add(B), D))
Expand All @@ -437,7 +470,7 @@ def construct_function(Ps: list[G1Point] | list[G2Point]) -> FF:

assert xs[0][0].is_infinity()

return D.normalize()
return xs[-1][1].normalize()


def row_function(
Expand Down Expand Up @@ -471,7 +504,6 @@ def ecip_functions(
Bs: list[G1Point] | list[G2Point], dss: list[list[int]]
) -> tuple[G1Point | G2Point, list[FF]]:
dss.reverse()

ec_group_class = G1Point if isinstance(Bs[0], G1Point) else G2Point
Q = ec_group_class.infinity(Bs[0].curve_id)
Ds = []
Expand Down Expand Up @@ -504,6 +536,7 @@ def dlog(d: FF) -> FunctionFelt:

d: FF = d.reduce()
assert len(d.coeffs) == 2, f"D has {len(d.coeffs)} coeffs: {d.coeffs}"

Dx = FF([d[0].differentiate(), d[1].differentiate()], d.curve_id)
Dy: Polynomial = d[1] # B(x)

Expand Down Expand Up @@ -573,7 +606,7 @@ def print_ff(ff: FF):
string = ""
coeffs = ff.coeffs
for i, p in enumerate(coeffs[::-1]):
coeff_str = p.print_as_sage_poly(var_name=f"x")
coeff_str = p.print_as_sage_poly(var_name=f"x", as_hex=True)

if i == len(coeffs) - 1:
if coeff_str == "":
Expand Down Expand Up @@ -635,4 +668,25 @@ def build_cairo1_tests_derive_ec_point_from_X(x: int, curve_id: CurveID, idx: in
# print(f"Average number of roots: {average_n_roots / n}")
# print(f"Max number of roots: {max_n_roots}")

verify_ecip([G1Point.gen_random_point(CurveID.SECP256K1)], scalars=[-1])
# verify_ecip([G1Point.gen_random_point(CurveID.SECP256K1)], scalars=[-1])

import time

order = CURVES[CurveID.BN254.value].n
n = 50
Bs = [G1Point.gen_random_point(CurveID.BN254) for _ in range(n)]
ss = [random.randint(1, order) for _ in range(n)]

t0 = time.time()
ZZ = zk_ecip_hint(Bs, ss, use_rust=False)
time_taken_py = time.time() - t0
print(f"Time taken py : {time_taken_py}")

t0 = time.time()
ZZ_rs = zk_ecip_hint(Bs, ss, use_rust=True)
time_taken_rs = time.time() - t0
print(f"Time taken rs : {time_taken_rs}")

print(f"Ratio: {time_taken_py / time_taken_rs}")

assert ZZ == ZZ_rs
1 change: 1 addition & 0 deletions hydra/garaga/hints/neg_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def neg_3_base_le(scalar: int) -> list[int]:
# For remainder 1 and 0, no change is required
digits.append(remainder)
scalar = -(scalar // 3) # divide by -3 for the next digit

return digits


Expand Down
28 changes: 24 additions & 4 deletions hydra/garaga/starknet/groth16_contract_generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class ECIP_OPS_CLASS_HASH(Enum):
MAINNET = None
SEPOLIA = 0x0245A22F0FE79CFF1A66622EF3A1B545DD109E159C7138CC3B6C224782DBD6D8
SEPOLIA = 0x03D917FCAF6737E3110800D8A29E534FFB885DF71313909F5D14D1D203345F06


def precompute_lines_from_vk(vk: Groth16VerifyingKey) -> StructArray:
Expand Down Expand Up @@ -81,7 +81,8 @@ def gen_groth16_verifier(
use garaga::ec_ops::{{G1PointTrait, G2PointTrait, ec_safe_add}};
use super::{{N_PUBLIC_INPUTS, vk, ic, precomputed_lines}};
const ECIP_OPS_CLASS_HASH: felt252 = {ecip_class_hash.value};
const ECIP_OPS_CLASS_HASH: felt252 = {hex(ecip_class_hash.value)};
use starknet::ContractAddress;
#[storage]
struct Storage {{}}
Expand All @@ -95,6 +96,8 @@ def gen_groth16_verifier(
small_Q: E12DMulQuotient,
msm_hint: Array<felt252>,
) -> bool {{
// DO NOT EDIT THIS FUNCTION UNLESS YOU KNOW WHAT YOU ARE DOING.
// ONLY EDIT THE process_public_inputs FUNCTION BELOW.
groth16_proof.a.assert_on_curve({curve_id.value});
groth16_proof.b.assert_on_curve({curve_id.value});
groth16_proof.c.assert_on_curve({curve_id.value});
Expand Down Expand Up @@ -128,15 +131,32 @@ def gen_groth16_verifier(
}}
}};
// Perform the pairing check.
multi_pairing_check_{curve_id.name.lower()}_3P_2F_with_extra_miller_loop_result(
let check = multi_pairing_check_{curve_id.name.lower()}_3P_2F_with_extra_miller_loop_result(
G1G2Pair {{ p: vk_x, q: vk.gamma_g2 }},
G1G2Pair {{ p: groth16_proof.c, q: vk.delta_g2 }},
G1G2Pair {{ p: groth16_proof.a.negate({curve_id.value}), q: groth16_proof.b }},
vk.alpha_beta_miller_loop_result,
precomputed_lines.span(),
mpcheck_hint,
small_Q
)
);
if check == true {{
self
.process_public_inputs(
starknet::get_caller_address(), groth16_proof.public_inputs
);
return true;
}} else {{
return false;
}}
}}
}}
#[generate_trait]
impl InternalFunctions of InternalFunctionsTrait {{
fn process_public_inputs(
ref self: ContractState, user: ContractAddress, public_inputs: Span<u256>,
) {{ // Process the public inputs with respect to the caller address (user).
// Update the storage, emit events, call other contracts, etc.
}}
}}
}}
Expand Down
1 change: 1 addition & 0 deletions maturin.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cd tools/garaga_rs && maturin develop --release
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ mod Groth16VerifierBLS12_381 {
use super::{N_PUBLIC_INPUTS, vk, ic, precomputed_lines};

const ECIP_OPS_CLASS_HASH: felt252 =
1027657496336879269758359542067106285349450528473286930334547874574577161944;
0x3d917fcaf6737e3110800d8a29e534ffb885df71313909f5d14d1d203345f06;
use starknet::ContractAddress;

#[storage]
struct Storage {}
Expand All @@ -39,6 +40,8 @@ mod Groth16VerifierBLS12_381 {
small_Q: E12DMulQuotient,
msm_hint: Array<felt252>,
) -> bool {
// DO NOT EDIT THIS FUNCTION UNLESS YOU KNOW WHAT YOU ARE DOING.
// ONLY EDIT THE process_public_inputs FUNCTION BELOW.
groth16_proof.a.assert_on_curve(1);
groth16_proof.b.assert_on_curve(1);
groth16_proof.c.assert_on_curve(1);
Expand Down Expand Up @@ -72,15 +75,32 @@ mod Groth16VerifierBLS12_381 {
}
};
// Perform the pairing check.
multi_pairing_check_bls12_381_3P_2F_with_extra_miller_loop_result(
let check = multi_pairing_check_bls12_381_3P_2F_with_extra_miller_loop_result(
G1G2Pair { p: vk_x, q: vk.gamma_g2 },
G1G2Pair { p: groth16_proof.c, q: vk.delta_g2 },
G1G2Pair { p: groth16_proof.a.negate(1), q: groth16_proof.b },
vk.alpha_beta_miller_loop_result,
precomputed_lines.span(),
mpcheck_hint,
small_Q
)
);
if check == true {
self
.process_public_inputs(
starknet::get_caller_address(), groth16_proof.public_inputs
);
return true;
} else {
return false;
}
}
}
#[generate_trait]
impl InternalFunctions of InternalFunctionsTrait {
fn process_public_inputs(
ref self: ContractState, user: ContractAddress, public_inputs: Span<u256>,
) { // Process the public inputs with respect to the caller address (user).
// Update the storage, emit events, call other contracts, etc.
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ mod Groth16VerifierBN254 {
use super::{N_PUBLIC_INPUTS, vk, ic, precomputed_lines};

const ECIP_OPS_CLASS_HASH: felt252 =
1027657496336879269758359542067106285349450528473286930334547874574577161944;
0x3d917fcaf6737e3110800d8a29e534ffb885df71313909f5d14d1d203345f06;
use starknet::ContractAddress;

#[storage]
struct Storage {}
Expand All @@ -39,6 +40,8 @@ mod Groth16VerifierBN254 {
small_Q: E12DMulQuotient,
msm_hint: Array<felt252>,
) -> bool {
// DO NOT EDIT THIS FUNCTION UNLESS YOU KNOW WHAT YOU ARE DOING.
// ONLY EDIT THE process_public_inputs FUNCTION BELOW.
groth16_proof.a.assert_on_curve(0);
groth16_proof.b.assert_on_curve(0);
groth16_proof.c.assert_on_curve(0);
Expand Down Expand Up @@ -72,15 +75,32 @@ mod Groth16VerifierBN254 {
}
};
// Perform the pairing check.
multi_pairing_check_bn254_3P_2F_with_extra_miller_loop_result(
let check = multi_pairing_check_bn254_3P_2F_with_extra_miller_loop_result(
G1G2Pair { p: vk_x, q: vk.gamma_g2 },
G1G2Pair { p: groth16_proof.c, q: vk.delta_g2 },
G1G2Pair { p: groth16_proof.a.negate(0), q: groth16_proof.b },
vk.alpha_beta_miller_loop_result,
precomputed_lines.span(),
mpcheck_hint,
small_Q
)
);
if check == true {
self
.process_public_inputs(
starknet::get_caller_address(), groth16_proof.public_inputs
);
return true;
} else {
return false;
}
}
}
#[generate_trait]
impl InternalFunctions of InternalFunctionsTrait {
fn process_public_inputs(
ref self: ContractState, user: ContractAddress, public_inputs: Span<u256>,
) { // Process the public inputs with respect to the caller address (user).
// Update the storage, emit events, call other contracts, etc.
}
}
}
Expand Down
Loading

0 comments on commit b1857a8

Please sign in to comment.