Skip to content

Commit

Permalink
Move final exp tests to proper test
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime committed Aug 2, 2024
1 parent f00a0d1 commit cbc444e
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 64 deletions.
22 changes: 0 additions & 22 deletions hydra/hints/bls.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,25 +64,3 @@ def get_root_and_scaling_factor_bls(mlo: E12) -> tuple[E12, E12]:
shift = x**s
root = (shift * mlo) ** e
return root, shift


if __name__ == "__main__":
import random

from hydra.hints.multi_miller_witness import get_miller_loop_output

random.seed(0)
for i in range(5):
# Test a correct case where final_exp(miller loop output) == 1
f = get_miller_loop_output(curve_id=CurveID.BLS12_381, will_be_one=True)
root, w_full = get_root_and_scaling_factor_bls(f)
assert f**h == ONE, f"f^h!=1"
assert f * w_full == root**lam, f"f * w_full!= root**lam"

# Test a wrong case where final_exp(miller loop output) != 1
f = get_miller_loop_output(curve_id=CurveID.BLS12_381, will_be_one=False)
root, w_full = get_root_and_scaling_factor_bls(f)
assert f**h != ONE, f"f^h==1"
assert f * w_full != root**lam, f"f * w_full == root**lam although f^h!=1"

print(f"{i}-th check ok")
17 changes: 16 additions & 1 deletion hydra/hints/multi_miller_witness.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ def get_final_exp_witness(curve_id: int, f: E12) -> tuple[E12, E12]:
raise ValueError(f"Curve ID {curve_id} not supported")


def get_lambda(curve_id: CurveID) -> int:
x = CURVES[curve_id.value].x
q = CURVES[curve_id.value].p
if curve_id == CurveID.BN254:
λ = (
6 * x + 2 + q - q**2 + q**3
) # https://eprint.iacr.org/2008/096.pdf See section 4 for BN curves.
return λ
elif curve_id == CurveID.BLS12_381:
λ = -x + q
return λ
else:
raise ValueError(f"Curve ID {curve_id} not supported")


def get_m_dash_root(f: E12) -> E12:
assert f.curve_id == CurveID.BN254.value

Expand Down Expand Up @@ -152,7 +167,7 @@ def get_rth_root(f: E12) -> E12:
h = (CURVES[f.curve_id].p ** 12 - 1) // r
r_inv = pow(r, -1, h)
res = f**r_inv
assert res**r == f, "res**r should be f"
# assert res**r == f, "res**r should be f"
return res


Expand Down
41 changes: 0 additions & 41 deletions hydra/precompiled_circuits/multi_pairing_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,44 +440,3 @@ def get_pairing_check_input(
return c_input[:-6], M
else:
return c_input, None


if __name__ == "__main__":

def test_mpcheck(curve_id: CurveID, n_pairs: int, include_m: bool = False):
c = MultiPairingCheckCircuit(
name="mock", curve_id=curve_id.value, n_pairs=n_pairs
)
circuit_input, m = get_pairing_check_input(
curve_id, n_pairs, include_m=include_m
)
c.write_p_and_q_raw(circuit_input)
M = c.write_elements(m, WriteOps.INPUT) if m is not None else None
c.multi_pairing_check(n_pairs, M)
c.finalize_circuit()

def total_cost(c):
summ = c.summarize()
summ["total_steps_cost"] = (
summ["MULMOD"] * 8
+ summ["ADDMOD"] * 4
+ summ["ASSERT_EQ"] * 2
+ summ["POSEIDON"] * 17
+ summ["RLC"] * 28
)
return summ

print(total_cost(c))
print(
f"Test {curve_id.name} {n_pairs=} {'with m' if include_m else 'without m'} passed"
)
print(f"n_eq: {c.accumulate_poly_instructions[0].n}")
print(
f"Q max degree: {max([q.degree() for q in c.accumulate_poly_instructions[0].Qis])}"
)

for curve_id in [CurveID.BN254, CurveID.BLS12_381]:
for n_pairs in [2, 3]:
print(f"Testing {curve_id.name} {n_pairs=}")
test_mpcheck(curve_id, n_pairs)
test_mpcheck(curve_id, n_pairs, include_m=True)
92 changes: 92 additions & 0 deletions tests/hydra/hints/test_final_exp_witness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest
import random
from hydra.hints.multi_miller_witness import (
get_miller_loop_output,
get_final_exp_witness,
get_lambda,
)
from hydra.hints.tower_backup import E12, E6
from hydra.precompiled_circuits.multi_pairing_check import (
MultiPairingCheckCircuit,
get_pairing_check_input,
WriteOps,
get_max_Q_degree,
)
from hydra.definitions import CurveID, CURVES, get_sparsity


@pytest.mark.parametrize("seed", range(5))
@pytest.mark.parametrize("curve_id", [CurveID.BN254, CurveID.BLS12_381])
def test_final_exp_witness(seed, curve_id):
random.seed(seed)
ONE = E12.one(curve_id.value)
λ = get_lambda(curve_id)
q = CURVES[curve_id.value].p
r = CURVES[curve_id.value].n
h = (q**12 - 1) // r

# Test correct case
f_correct = get_miller_loop_output(curve_id=curve_id, will_be_one=True)
root_correct, w_full_correct = get_final_exp_witness(curve_id.value, f_correct)

e6_subfield = E12(
[E6.random(curve_id.value), E6.zero(curve_id.value)], curve_id.value
)
scaling_factor_sparsity = get_sparsity(e6_subfield.to_direct())
scaling_factor = w_full_correct.to_direct()
# Assert sparsity is correct: for every index where the sparsity is 0, the coefficient must 0 in scaling factor
for i in range(len(scaling_factor_sparsity)):
if scaling_factor_sparsity[i] == 0:
assert scaling_factor[i].value == 0
# Therefore scaling factor lies in Fp6

assert f_correct**h == ONE, "f^h should equal 1 for correct case"
assert (
f_correct * w_full_correct == root_correct**λ
), "f * w_full should equal root**λ for correct case"

# Test incorrect case
f_incorrect = get_miller_loop_output(curve_id=curve_id, will_be_one=False)
root_incorrect, w_full_incorrect = get_final_exp_witness(
curve_id.value, f_incorrect
)

assert f_incorrect**h != ONE, "f^h should not equal 1 for incorrect case"
assert (
f_incorrect * w_full_incorrect != root_incorrect**λ
), "f * w_full should not equal root**λ for incorrect case"

print(f"{seed}-th check ok")


@pytest.mark.parametrize("curve_id", [CurveID.BN254, CurveID.BLS12_381])
@pytest.mark.parametrize("n_pairs", [2, 3, 4, 5])
@pytest.mark.parametrize("include_m", [False, True])
def test_mpcheck(curve_id: CurveID, n_pairs: int, include_m: bool):
c = MultiPairingCheckCircuit(name="mock", curve_id=curve_id.value, n_pairs=n_pairs)
circuit_input, m = get_pairing_check_input(curve_id, n_pairs, include_m=include_m)
c.write_p_and_q_raw(circuit_input)
M = c.write_elements(m, WriteOps.INPUT) if m is not None else None
c.multi_pairing_check(n_pairs, M) # Check done implicitely here
c.finalize_circuit()

def total_cost(c):
summ = c.summarize()
summ["total_steps_cost"] = (
summ["MULMOD"] * 8
+ summ["ADDMOD"] * 4
+ summ["ASSERT_EQ"] * 2
+ summ["POSEIDON"] * 17
+ summ["RLC"] * 28
)
return summ

cost = total_cost(c)
q_max_degree = max([q.degree() for q in c.accumulate_poly_instructions[0].Qis])

# Assertions
assert q_max_degree <= get_max_Q_degree(curve_id.value, n_pairs)

print(f"\nTest {curve_id.name} {n_pairs=} {'with m' if include_m else 'without m'}")
print(f"Total cost: {cost}")
print(f"Q max degree: {q_max_degree}")

0 comments on commit cbc444e

Please sign in to comment.