diff --git a/archive_tmp/bn254/pairing_final_exp.py b/archive_tmp/bn254/pairing_final_exp.py index 140d297d..1376cf93 100644 --- a/archive_tmp/bn254/pairing_final_exp.py +++ b/archive_tmp/bn254/pairing_final_exp.py @@ -1,8 +1,7 @@ -from algebra import Polynomial -from algebra import PyFelt, BaseField +from src.algebra import Polynomial +from src.algebra import PyFelt, BaseField from tools.extension_trick import ( gnark_to_v, - gnark_to_v_bigint3, flatten, neg_e6, v_to_gnark, @@ -294,10 +293,10 @@ def frobenius_torus(x: list): 5722266937896532885780051958958348231143373700109372999374820235121374419868, ) # 1 / v^((p-1)/2) res = flatten(mul_e6((t0, t1, t2), (v0, (0, 0), (0, 0)))) - res_bigint3 = gnark_to_v_bigint3([split(x) for x in res]) + # res_bigint3 = gnark_to_v_bigint3([split(x) for x in res]) res = gnark_to_v(res) - return res, res_bigint3 + return res def frobenius_cube_torus(x: list): diff --git a/archive_tmp/bn254/pairing_final_exp_circuit_generation.py b/archive_tmp/bn254/pairing_final_exp_circuit_generation.py deleted file mode 100644 index d6fc4ad5..00000000 --- a/archive_tmp/bn254/pairing_final_exp_circuit_generation.py +++ /dev/null @@ -1,326 +0,0 @@ -from algebra import Polynomial -from algebra import PyFelt, BaseField -from tools.extension_trick import ( - gnark_to_v, - flatten, - neg_e6, - v_to_gnark, - pack_e6, - flatten, - div_e6, - mul_e6, - mul_e2, - inv_e12, - mul_e12, - pack_e12, -) -from src.extension_field_modulo_circuit import ( - ExtensionFieldModuloCircuit, - ModuloElement, - EuclideanPolyAccumulator, -) -from src.hints.extf_mul import nondeterministic_extension_field_mul_divmod -from definitions import BN254_ID -from hints.io import bigint_split -from poseidon_transcript import CairoPoseidonTranscript -from dataclasses import dataclass - -p = 0x30644E72E131A029B85045B68181585D97816A916871CA8D3C208C16D87CFD47 -BASE = 2**96 -DEGREE = 3 -N_LIMBS = 4 -STARK = 3618502788666131213697322783095070105623107215331596699973092056135872020481 -field = BaseField(p) - - -def to_fp6(x: list) -> Polynomial: - return Polynomial([PyFelt(xi, field) for xi in x]) - - -def mul_torus( - y1: list, y2: list, continuable_hash: int, y1_bigint3=None, y2_bigint3=None -): - num_min_v, continuable_hash = mul_trick_e6( - y1, y2, continuable_hash, x_bigint3=y1_bigint3, y_bigint3=y2_bigint3 - ) - num_min_v[1] = num_min_v[1] + 1 - - num = num_min_v - den = [y1i + y2i for y1i, y2i in zip(y1, y2)] - if y1_bigint3 is None: - y1_bigint3 = [split(x) for x in y1] - if y2_bigint3 is None: - y2_bigint3 = [split(x) for x in y2] - - den_bigint3 = [ - (y1i[0] + y2i[0], y1i[1] + y2i[1], y1i[2] + y2i[2]) - for y1i, y2i in zip(y1_bigint3, y2_bigint3) - ] - res, continuable_hash = div_trick_e6( - num, den, continuable_hash, y_bigint3=den_bigint3 - ) - return res, continuable_hash - - -def div_trick_e6( - x: list[ModuloElement], - y: list[ModuloElement], - continuable_hash: int, -) -> (list, int): - x_gnark, y_gnark = pack_e6(v_to_gnark(x)), pack_e6(v_to_gnark(y)) - div = flatten(div_e6(x_gnark, y_gnark)) - div = gnark_to_v(div) - check, h = mul_trick_e6( - y, - div, - continuable_hash, - ) - assert x == check, f"{x} != {check}" - return div, h - - -def expt_torus(x: list, continuable_hash: int) -> (list, int): - t3, continuable_hash = square_torus(x, continuable_hash) - t5, continuable_hash = square_torus(t3, continuable_hash) - result, continuable_hash = square_torus(t5, continuable_hash) - t0, continuable_hash = square_torus(result, continuable_hash) - t2, continuable_hash = mul_torus(x, t0, continuable_hash) - t0, continuable_hash = mul_torus(t3, t2, continuable_hash) - t1, continuable_hash = mul_torus(x, t0, continuable_hash) - t4, continuable_hash = mul_torus(result, t2, continuable_hash) - t6, continuable_hash = square_torus(t2, continuable_hash) - t1, continuable_hash = mul_torus(t0, t1, continuable_hash) - t0, continuable_hash = mul_torus(t3, t1, continuable_hash) - t6, continuable_hash = n_square_torus(t6, 6, continuable_hash) - t5, continuable_hash = mul_torus(t5, t6, continuable_hash) - t5, continuable_hash = mul_torus(t4, t5, continuable_hash) - t5, continuable_hash = n_square_torus(t5, 7, continuable_hash) - t4, continuable_hash = mul_torus(t4, t5, continuable_hash) - t4, continuable_hash = n_square_torus(t4, 8, continuable_hash) - t4, continuable_hash = mul_torus(t0, t4, continuable_hash) - t3, continuable_hash = mul_torus(t3, t4, continuable_hash) - t3, continuable_hash = n_square_torus(t3, 6, continuable_hash) - t2, continuable_hash = mul_torus(t2, t3, continuable_hash) - t2, continuable_hash = n_square_torus(t2, 8, continuable_hash) - t2, continuable_hash = mul_torus(t0, t2, continuable_hash) - t2, continuable_hash = n_square_torus(t2, 6, continuable_hash) - t2, continuable_hash = mul_torus(t0, t2, continuable_hash) - t2, continuable_hash = n_square_torus(t2, 10, continuable_hash) - t1, continuable_hash = mul_torus(t1, t2, continuable_hash) - t1, continuable_hash = n_square_torus(t1, 6, continuable_hash) - t0, continuable_hash = mul_torus(t0, t1, continuable_hash) - z, continuable_hash = mul_torus(result, t0, continuable_hash) - return z, continuable_hash - - -def n_square_torus(x: list, n: int, continuable_hash: int): - if n == 0: - return x, continuable_hash - else: - x, continuable_hash = square_torus(x, continuable_hash) - return n_square_torus(x, n - 1, continuable_hash) - - -def square_torus(x: list[ModuloElement], circuit: ExtensionFieldModuloCircuit): - # x_gnark = pack_e6(v_to_gnark(x)) - # sq = [int(x) for x in gnark_to_v(flatten(square_torus_e6(x_gnark)))] - v_tmp = [(2 * sq_i - x_i) % p for sq_i, x_i in zip(sq, x)] - - x_poly = to_fp6(v_tmp) - y_poly = to_fp6(x) - z_poly = x_poly * y_poly - - z_polyq = z_poly // irreducible_poly - z_polyr = z_poly % irreducible_poly - z_polyq_coeffs = z_polyq.get_coeffs() - # print(f"z_polyq_coeffs={z_polyq_coeffs}") - z_polyr_coeffs = z_polyr.get_coeffs() - # print(f"z_polyr_coeffs={z_polyr_coeffs}") - - return sq, h - - -def frobenius_square_torus(x: list): - x_fr2 = [ - x[0] * 2203960485148121921418603742825762020974279258880205651967 % p, - x[1] - * 21888242871839275220042445260109153167277707414472061641714758635765020556617 - % p, - x[2] - * 21888242871839275222246405745257275088696311157297823662689037894645226208582 - % p, - x[3] * 2203960485148121921418603742825762020974279258880205651967 % p, - x[4] - * 21888242871839275220042445260109153167277707414472061641714758635765020556617 - % p, - x[5] - * 21888242871839275222246405745257275088696311157297823662689037894645226208582 - % p, - ] - return x_fr2 - - -def frobenius_torus(x: list): - x_gnark = pack_e6(v_to_gnark(x)) - t0 = (x_gnark[0][0], -x_gnark[0][1] % p) - t1 = (x_gnark[1][0], -x_gnark[1][1] % p) - t2 = (x_gnark[2][0], -x_gnark[2][1] % p) - t1 = mul_e2( - t1, - ( - 21575463638280843010398324269430826099269044274347216827212613867836435027261, - 10307601595873709700152284273816112264069230130616436755625194854815875713954, - ), # (1,9)^(2*(p-1)/6) - ) - t2 = mul_e2( - t2, - ( - 2581911344467009335267311115468803099551665605076196740867805258568234346338, - 19937756971775647987995932169929341994314640652964949448313374472400716661030, - ), # (1,9)^(4*(p-1)/6) - ) - v0 = ( - 18566938241244942414004596690298913868373833782006617400804628704885040364344, - 5722266937896532885780051958958348231143373700109372999374820235121374419868, - ) # 1 / v^((p-1)/2) - res = flatten(mul_e6((t0, t1, t2), (v0, (0, 0), (0, 0)))) - res_bigint3 = gnark_to_v_bigint3([split(x) for x in res]) - - res = gnark_to_v(res) - return res, res_bigint3 - - -def frobenius_cube_torus(x: list): - x_gnark = pack_e6(v_to_gnark(x)) - t0 = (x_gnark[0][0], -x_gnark[0][1] % p) - t1 = (x_gnark[1][0], -x_gnark[1][1] % p) - t2 = (x_gnark[2][0], -x_gnark[2][1] % p) - t1 = mul_e2( - t1, - ( - 3772000881919853776433695186713858239009073593817195771773381919316419345261, - 2236595495967245188281701248203181795121068902605861227855261137820944008926, - ), # (1,9)^(2*(p^3-1)/6) - ) - t2 = mul_e2( - t2, - ( - 5324479202449903542726783395506214481928257762400643279780343368557297135718, - 16208900380737693084919495127334387981393726419856888799917914180988844123039, - ), # (1,9)^(4*(p^3-1)/6) - ) - v0 = ( - 10190819375481120917420622822672549775783927716138318623895010788866272024264, - 303847389135065887422783454877609941456349188919719272345083954437860409601, - ) # 1 / v^((p^3-1)/2) - res = flatten(mul_e6((t0, t1, t2), (v0, (0, 0), (0, 0)))) - res_bigint3 = gnark_to_v_bigint3([split(x) for x in res]) - res = gnark_to_v(res) - return res, res_bigint3 - - -def inverse_torus(x: list): - return [-xi % p for xi in x] - - -def decompress_torus(x: ((int, int), (int, int), (int, int))): - num = (x, ((1, 0), (0, 0), (0, 0))) - den = (x, ((-1 % p, 0), (0, 0), (0, 0))) - res = pack_e12(inv_e12(den[0], den[1])) - res = mul_e12(num, res) - return res - - -def final_exponentiation( - z: (((int, int), (int, int), (int, int)), ((int, int), (int, int), (int, int))), - unsafe: bool, - continuable_hash: int = int.from_bytes(b"GaragaBN254FinalExp", "big"), -): - circuit = ExtensionFieldModuloCircuit("BN254_final_exp", BN254_ID, 6, field) - - MIN_9 = circuit.write_element(field(-9 % p)) - MIN_ONE = circuit.write_element(field(-1 % p)) - - c_num = circuit.write_elements( - [PyFelt(z[0][i][j], field) for i in range(3) for j in range(2)] - ) - if unsafe: - z_c1 = circuit.write_elements( - [PyFelt(z[1][i][j], field) for i in range(3) for j in range(2)] - ) - else: - if z[1] == ((0, 0), (0, 0), (0, 0)): - selector1 = 1 - z_c1 = circuit.write_elements( - [ - field.one(), - field.zero(), - field.zero(), - field.zero(), - field.zero(), - field.zero(), - ] - ) - else: - selector1 = 0 - z_c1 = circuit.write_elements( - [PyFelt(z[1][i][j], field) for i in range(3) for j in range(2)] - ) - - c_num_full = [ - circuit.mul(MIN_ONE, circuit.add(c_num[0], circuit.mul(MIN_9, c_num[1]))), - circuit.mul(MIN_ONE, circuit.add(c_num[2], circuit.mul(MIN_9, c_num[3]))), - circuit.mul(MIN_ONE, circuit.add(c_num[4], circuit.mul(MIN_9, c_num[5]))), - circuit.mul(MIN_ONE, c_num[1]), - circuit.mul(MIN_ONE, c_num[3]), - circuit.mul(MIN_ONE, c_num[5]), - ] - z_c1_full = [ - circuit.add(z_c1[0], circuit.mul(MIN_9, z_c1[1])), - circuit.add(z_c1[2], circuit.mul(MIN_9, z_c1[3])), - circuit.add(z_c1[4], circuit.mul(MIN_9, z_c1[5])), - z_c1[1], - z_c1[3], - z_c1[5], - ] - - Z_fake = PyFelt(42) - Z_fake = circuit.write_element(Z_fake) - return circuit, continuable_hash - - -if __name__ == "__main__": - x = [ - 15631577932152315104652445523700417040601500707877284609546312920354446056447, - 1274881022144191920838043222130710344172476924365725732436425248566978625605, - 14374765490310691286872600100687989211994071432725749506715026469291207213364, - 19232683452852686150799946178434694116955802884971349389480427332156028484678, - 4711060662209480322403082802390043737109415216436721343938907246739585294619, - 12628528420035269572171509623830053865991813551619118245630623189571187704212, - 6132046658265970172317265843030970288646178101127187503319861429480398294166, - 696877141756131447795834834192003128716698847022516178077777960435426094082, - 19968037526512504126402565293093453753511856148614571257107664150629413134903, - 19711115225256248898674588007895864056457997172157519591556283079102178159639, - 4264731731400846354398198898948247059528185839861404225131520284631392266215, - 3153660797904284033741194851243498835351306539671786555576214661552094399141, - ] - z = [ - 17264119758069723980713015158403419364912226240334615592005620718956030922389, - 1300711225518851207585954685848229181392358478699795190245709208408267917898, - 8894217292938489450175280157304813535227569267786222825147475294561798790624, - 1829859855596098509359522796979920150769875799037311140071969971193843357227, - 4968700049505451466697923764727215585075098085662966862137174841375779106779, - 12814315002058128940449527172080950701976819591738376253772993495204862218736, - 4233474252585134102088637248223601499779641130562251948384759786370563844606, - 9420544134055737381096389798327244442442230840902787283326002357297404128074, - 13457906610892676317612909831857663099224588803620954529514857102808143524905, - 5122435115068592725432309312491733755581898052459744089947319066829791570839, - 8891987925005301465158626530377582234132838601606565363865129986128301774627, - 440796048150724096437130979851431985500142692666486515369083499585648077975, - ] - x = pack_e12(x) - c, continuable_hash = final_exponentiation(x, True) - # print(f"f = {f}") - # print(f"z = {z}") - # print(f"hash={continuable_hash}") - # assert pack_e12(z) == f diff --git a/contracts/cairo_bn254/bn254_precompiles.cairo b/contracts/cairo_bn254/bn254_precompiles.cairo deleted file mode 100644 index 50ea5d04..00000000 --- a/contracts/cairo_bn254/bn254_precompiles.cairo +++ /dev/null @@ -1,87 +0,0 @@ -%lang starknet - -// Starkware dependencies. -from starkware.cairo.common.cairo_builtins import HashBuiltin -from starkware.cairo.common.cairo_secp.bigint import BigInt3 -from starkware.starknet.common.syscalls import get_caller_address - -// Project dependencies. -from openzeppelin.upgrades.library import Proxy - -// Local dependencies. -from src.bn254.g1 import G1PointFull -from contracts.cairo_bn254.library import BN254Precompiles, PairingInput - -@external -func initializer{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() { - let (owner) = get_caller_address(); - Proxy.initializer(owner); - return (); -} - -@view -func ecAdd{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( - a: G1PointFull, b: G1PointFull -) -> (res: G1PointFull) { - let (res) = BN254Precompiles.ec_add(a, b); - return (res=res); -} - -@view -func ecMul{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( - a: G1PointFull, s: BigInt3 -) -> (res: G1PointFull) { - alloc_locals; - let (res) = BN254Precompiles.ec_mul(a, s); - return (res=res); -} - -@view -func ecPairing{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( - input_len: felt, input: PairingInput* -) -> (res: felt) { - let (res) = BN254Precompiles.ec_pairing(input_len, input); - return (res=res); -} - -// -// Proxy administration -// - -// @notice Return the current implementation hash. -// @return implementation The implementation class hash. -@view -func getImplementationHash{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> ( - implementation: felt -) { - return Proxy.get_implementation_hash(); -} - -// @notice Return the current admin address. -// @return admin The admin address. -@view -func getAdmin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> (admin: felt) { - return Proxy.get_admin(); -} - -// @notice Upgrade the contract to the new implementation. -// @dev This function is only callable by the admin. -// @param new_implementation The new implementation class hash. -@external -func upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( - new_implementation: felt -) { - Proxy.assert_only_admin(); - Proxy._set_implementation_hash(new_implementation); - return (); -} - -// @notice Transfer admin rights to a new admin. -// @dev This function is only callable by the admin. -// @param new_admin The new admin address. -@external -func setAdmin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(new_admin: felt) { - Proxy.assert_only_admin(); - Proxy._set_admin(new_admin); - return (); -} diff --git a/contracts/cairo_bn254/library.cairo b/contracts/cairo_bn254/library.cairo deleted file mode 100644 index 140a13f7..00000000 --- a/contracts/cairo_bn254/library.cairo +++ /dev/null @@ -1,107 +0,0 @@ -%lang starknet - -from starkware.cairo.common.math import assert_nn, assert_le -from starkware.cairo.common.cairo_builtins import HashBuiltin -from starkware.cairo.common.cairo_secp.bigint import BigInt3 -from starkware.cairo.common.registers import get_fp_and_pc - -from src.bn254.g1 import G1Point, G1PointFull, g1 -from src.bn254.g2 import G2Point -from src.bn254.pairing import miller_loop, final_exponentiation -from src.bn254.towers.e2 import E2 -from src.bn254.towers.e12 import E12, e12 - -struct E2Full { - a0: BigInt3, - a1: BigInt3, -} - -struct G2PointFull { - x: E2Full, - y: E2Full, -} - -struct PairingInput { - p: G1PointFull, - q: G2PointFull, -} - -namespace BN254Precompiles { - // - // Views - // - - func ec_add{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( - a: G1PointFull, b: G1PointFull - ) -> (res: G1PointFull) { - alloc_locals; - let (res) = g1.add_full(a, b); - return (res=res); - } - - func ec_mul{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( - a: G1PointFull, s: BigInt3 - ) -> (res: G1PointFull) { - alloc_locals; - let (__fp__, _) = get_fp_and_pc(); - local a_: G1Point = G1Point(x=&a.x, y=&a.y); - let (tmp) = g1.scalar_mul(&a_, s); - let res = G1PointFull(x=[tmp.x], y=[tmp.y]); - return (res=res); - } - - func ec_pairing{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( - input_len: felt, input: PairingInput* - ) -> (res: felt) { - alloc_locals; - with_attr error_message("Garaga bn254: PairingInput cannot be empty.") { - assert_nn(input_len); - assert_le(1, input_len); - } - local unsafe_final_exp; - if (input_len == 1) { - unsafe_final_exp = 1; - } else { - unsafe_final_exp = 0; - } - - let one = e12.one(); - let pairing_value = pairing(input_len, input, one, unsafe_final_exp); - let (res) = verify_pairing(pairing_value); - return (res=res); - } - - // - // Internal functions. - // - - func verify_pairing{range_check_ptr}(pairing_value: E12*) -> (bool: felt) { - let one = e12.one(); - let tmp = e12.sub(pairing_value, one); - let res = e12.is_zero(tmp); - return (bool=res); - } - - func pairing{range_check_ptr}( - input_len: felt, input: PairingInput*, acc: E12*, unsafe_final_exp: felt - ) -> E12* { - alloc_locals; - let (__fp__, _) = get_fp_and_pc(); - - if (input_len == 0) { - return final_exponentiation(acc, unsafe_final_exp); - } - - let p: G1PointFull = input[0].p; - let q: G2PointFull = input[0].q; - - local p_: G1Point* = new G1Point(x=&p.x, y=&p.y); - local q_: G2Point* = new G2Point( - x=new E2(a0=&q.x.a0, a1=&q.x.a1), y=new E2(a0=&q.y.a0, a1=&q.y.a1) - ); - - let tmp = miller_loop(p_, q_); - let new_acc: E12* = e12.mul(tmp, acc); - return pairing(input_len - 1, input + PairingInput.SIZE, new_acc, unsafe_final_exp); - } -} diff --git a/contracts/lib/cairo_contracts b/contracts/lib/cairo_contracts deleted file mode 160000 index 70cbd05e..00000000 --- a/contracts/lib/cairo_contracts +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 70cbd05ed24ccd147f24b18c638dbd6e7fea88bb diff --git a/src/algebra.py b/src/algebra.py index b40d5b2c..c9c9584f 100644 --- a/src/algebra.py +++ b/src/algebra.py @@ -1,20 +1,6 @@ from dataclasses import dataclass -@dataclass(slots=True) -class BaseField: - p: int - - def __call__(self, integer): - return PyFelt(integer % self.p, self.p) - - def zero(self): - return PyFelt(0, self.p) - - def one(self): - return PyFelt(1, self.p) - - @dataclass(slots=True, frozen=True) class PyFelt: """ @@ -62,10 +48,14 @@ def __mul__(self, right): return PyFelt((self.value * right) % self.p, self.p) return NotImplemented + def __rmul__(self, left): + return self.__mul__(left) + def __inv__(self): return PyFelt(pow(self.value, -1, self.p), self.p) def __truediv__(self, right): + assert type(self) == type(right), f"Cannot divide {type(self)} by {type(right)}" return self * right.__inv__() def __pow__(self, exponent): @@ -87,9 +77,6 @@ def __radd__(self, left): def __rsub__(self, left): return -self.__sub__(left) - def __rmul__(self, left): - return self.__mul__(left) - def __rtruediv__(self, left): return self.__inv__().__mul__(left) @@ -97,6 +84,20 @@ def __rpow__(self, left): return PyFelt(pow(left, self.value, self.p), self.p) +@dataclass(slots=True) +class BaseField: + p: int + + def __call__(self, integer): + return PyFelt(integer % self.p, self.p) + + def zero(self): + return PyFelt(0, self.p) + + def one(self): + return PyFelt(1, self.p) + + @dataclass(slots=True, frozen=True) class ModuloCircuitElement: """ @@ -173,6 +174,15 @@ def __init__( ) self.field = BaseField(self.p) + def __repr__(self): + return f"Polynomial({[x.value for x in self.get_coeffs()]})" + + def __getitem__(self, i): + try: + return self.coefficients[i].value + except IndexError: + return 0 + def degree(self): if self.coefficients == []: return -1 @@ -189,8 +199,13 @@ def get_coeffs(self) -> list[PyFelt]: coeffs = self.coefficients.copy() while len(coeffs) > 0 and coeffs[-1] == 0: coeffs.pop() + if coeffs == []: + return [self.field.zero()] return coeffs + def get_value_coeffs(self) -> list[int]: + return [c.value for c in self.get_coeffs()] + def __add__(self, other): if self.degree() == -1: return other @@ -221,7 +236,7 @@ def __mul__(self, other): ) if self.coefficients == [] or other.coefficients == []: - return Polynomial([]) + return Polynomial([self.field.zero()]) zero = self.field.zero() buf = [zero] * (len(self.coefficients) + len(other.coefficients) - 1) for i in range(len(self.coefficients)): @@ -229,13 +244,15 @@ def __mul__(self, other): continue # optimization for sparse polynomials for j in range(len(other.coefficients)): buf[i + j] = buf[i + j] + self.coefficients[i] * other.coefficients[j] - return Polynomial(buf, raw_init=True) + res = Polynomial(Polynomial(buf).get_coeffs(), raw_init=True) + return res def __rmul__(self, other): return self.__mul__(other) def __truediv__(self, other): - quo, rem = Polynomial.divide(self, other) + quo, rem = Polynomial.__divmod__(self, other) + print(quo, rem) assert ( rem.is_zero() ), "cannot perform polynomial division because remainder is not zero" @@ -249,11 +266,11 @@ def __mod__(self, other): quo, rem = Polynomial.__divmod__(self, other) return rem - def __divmod__(self, denominator): + def __divmod__(self, denominator: "Polynomial"): if denominator.degree() == -1: return None if self.degree() < denominator.degree(): - return (Polynomial([]), self) + return (Polynomial([PyFelt(0, self.p)]), self) field = self.field remainder = Polynomial([n for n in self.coefficients]) quotient_coefficients = [ @@ -296,9 +313,6 @@ def is_zero(self): return True return False - def __str__(self): - return "[" + ",".join(s.__str__() for s in self.coefficients) + "]" - def leading_coefficient(self): return self.coefficients[self.degree()] @@ -306,7 +320,7 @@ def is_zero(self): if self.coefficients == []: return True for c in self.coefficients: - if not c.is_zero(): + if c != 0: return False return True @@ -320,10 +334,61 @@ def evaluate(self, point): def __pow__(self, exponent): if exponent == 0: - return Polynomial([self.coefficients[0].field.one()]) - acc = Polynomial([self.coefficients[0].field.one()]) + return Polynomial([self.field.one()]) + acc = Polynomial([self.field.one()]) for i in reversed(range(len(bin(exponent)[2:]))): acc = acc * acc if (1 << i) & exponent != 0: acc = acc * self return acc + + def pow(self, exponent: int, modulo_poly: "Polynomial"): + if exponent == 0: + return Polynomial([PyFelt(1, self.coefficients[0].p)]) + acc = Polynomial([PyFelt(1, self.coefficients[0].p)]) + for i in reversed(range(len(bin(exponent)[2:]))): + acc = acc * acc % modulo_poly + if (1 << i) & exponent != 0: + acc = (acc * self) % modulo_poly + return acc % modulo_poly + + @staticmethod + def xgcd(x, y): + """ + Extended Euclidean Algorithm for polynomials. + + This method computes the extended greatest common divisor (GCD) of two polynomials x and y. + It returns a tuple of three elements: (a, b, g) such that a * x + b * y = g, where g is the + greatest common divisor of x and y. This is particularly useful in contexts like + computational algebra or number theory where the coefficients of the polynomials are in a field. + + Parameters: + x (Polynomial): The first polynomial. + y (Polynomial): The second polynomial. + + Returns: + tuple: A tuple (a, b, g) where: + a (Polynomial): A polynomial such that a * x + b * y = g. + b (Polynomial): A polynomial such that a * x + b * y = g. + g (Polynomial): The greatest common divisor of x and y. + """ + one = Polynomial([x.field.one()]) + zero = Polynomial([x.field.zero()]) + old_r, r = (x, y) + old_s, s = (one, zero) + old_t, t = (zero, one) + + while not r.is_zero(): + quotient = old_r // r + old_r, r = (r, old_r - quotient * r) + old_s, s = (s, old_s - quotient * s) + old_t, t = (t, old_t - quotient * t) + + lcinv = old_r.coefficients[old_r.degree()].__inv__() + + # a, b, g + return ( + Polynomial([c * lcinv for c in old_s.coefficients]), + Polynomial([c * lcinv for c in old_t.coefficients]), + Polynomial([c * lcinv for c in old_r.coefficients]), + ) diff --git a/src/benchmarks.py b/src/benchmarks.py index 4b282358..3e495e51 100644 --- a/src/benchmarks.py +++ b/src/benchmarks.py @@ -1,7 +1,7 @@ from src.definitions import STARK, CurveID, CURVES, PyFelt, Curve from random import randint from src.extension_field_modulo_circuit import ExtensionFieldModuloCircuit -from src.precompiled_circuits.final_exp import FinalExpCircuit +from src.precompiled_circuits.final_exp import FinalExpTorusCircuit, test_final_exp def test_extf_mul_circuit_amortized(curve_id: CurveID, extension_degree: int): @@ -54,7 +54,7 @@ def test_square_torus_amortized(curve_id: CurveID, extension_degree: int): curve: Curve = CURVES[curve_id.value] p = curve.p X = [PyFelt(randint(0, p - 1), p) for _ in range(extension_degree)] - circuit = FinalExpCircuit( + circuit = FinalExpTorusCircuit( f"{curve_id.name}_square_torus_amortized_Fp{extension_degree}", curve_id=curve_id.value, extension_degree=extension_degree, @@ -69,7 +69,7 @@ def test_mul_torus_amortized(curve_id: CurveID, extension_degree: int): curve: Curve = CURVES[curve_id.value] p = curve.p X = [PyFelt(randint(0, p - 1), p) for _ in range(extension_degree)] - circuit = FinalExpCircuit( + circuit = FinalExpTorusCircuit( f"{curve_id.name}_mul_torus_amortized_Fp{extension_degree}", curve_id=curve_id.value, extension_degree=extension_degree, @@ -80,16 +80,28 @@ def test_mul_torus_amortized(curve_id: CurveID, extension_degree: int): return circuit.summarize() +def test_final_exp_circuit(curve_id: CurveID): + part1, part2 = test_final_exp(curve_id) + summ1 = part1.summarize() + summ1["circuit"] = summ1["circuit"] + "_pt_I" + summ2 = part2.summarize() + summ2["circuit"] = summ2["circuit"] + "_pt_II" + return summ1, summ2 + + if __name__ == "__main__": data = [] - for curveID in [CurveID.BLS12_381]: - data.append(test_extf_mul_circuit_amortized(curveID, 6)) - data.append(test_extf_mul_circuit_full(curveID, 6)) - data.append(test_extf_mul_circuit_amortized(curveID, 12)) - data.append(test_extf_mul_circuit_full(curveID, 12)) - data.append(test_square_torus_amortized(curveID, 6)) - data.append(test_extf_square_amortized(curveID, 12)) - data.append(test_mul_torus_amortized(curveID, 6)) + for curveID in [CurveID.BN254, CurveID.BLS12_381]: + # data.append(test_extf_mul_circuit_amortized(curveID, 6)) + # data.append(test_extf_mul_circuit_full(curveID, 6)) + # data.append(test_extf_mul_circuit_amortized(curveID, 12)) + # data.append(test_extf_mul_circuit_full(curveID, 12)) + # data.append(test_square_torus_amortized(curveID, 6)) + # data.append(test_extf_square_amortized(curveID, 12)) + # data.append(test_mul_torus_amortized(curveID, 6)) + summ1, summ2 = test_final_exp_circuit(curveID) + data.append(summ1) + data.append(summ2) import pandas as pd diff --git a/src/definitions.cairo b/src/definitions.cairo index 8ec50ee0..ea91f4c2 100644 --- a/src/definitions.cairo +++ b/src/definitions.cairo @@ -49,17 +49,56 @@ func get_P(curve_id: felt) -> (prime: UInt384) { } } -func get_final_exp_circuit(curve_id: felt) -> (add_offsets_ptr: felt*, mul_offsets_ptr: felt*) { - if (curve_id == bls.CURVE_ID) { - return (0, 0); - } else { - if (curve_id == bn.CURVE_ID) { - return (0, 0); - } else { - return (0, 0); - } - } -} +// func get_final_exp_circuit(curve_id: felt) -> ( +// constants_ptr: felt*, +// add_offsets_ptr: felt*, +// mul_offsets_ptr: felt*, +// left_assert_eq_offsets_ptr: felt*, +// right_assert_eq_offsets_ptr: felt*, +// poseidon_indexes_ptr: felt*, +// constants_ptr_len: felt, +// add_mod_n: felt, +// mul_mod_n: felt, +// commitments_len: felt, +// assert_eq_len: felt, +// N_Euclidean_equations: felt, +// ) { +// if (curve_id == bls.CURVE_ID) { +// return ( +// cast(0, felt*), +// cast(0, felt*), +// cast(0, felt*), +// cast(0, felt*), +// cast(0, felt*), +// cast(0, felt*), +// 0, +// 0, +// 0, +// 0, +// 0, +// 0, +// ); +// } else { +// if (curve_id == bn.CURVE_ID) { +// return get_GaragaBN254FinalExp_non_interactive_circuit(); +// } else { +// return ( +// cast(0, felt*), +// cast(0, felt*), +// cast(0, felt*), +// cast(0, felt*), +// cast(0, felt*), +// cast(0, felt*), +// 0, +// 0, +// 0, +// 0, +// 0, +// 0, +// ); +// } +// } +// } // Base for UInt384 / BigInt4 const BASE = 2 ** 96; @@ -102,7 +141,7 @@ struct ExtFCircuitInfo { } func zero_E12D() -> E12D { - return E12D( + let res = E12D( UInt384(0, 0, 0, 0), UInt384(0, 0, 0, 0), UInt384(0, 0, 0, 0), @@ -116,10 +155,11 @@ func zero_E12D() -> E12D { UInt384(0, 0, 0, 0), UInt384(0, 0, 0, 0), ); + return res; } func one_E12D() -> E12D { - return E12D( + let res = E12D( UInt384(0, 0, 0, 0), UInt384(1, 0, 0, 0), UInt384(0, 0, 0, 0), @@ -133,4 +173,5 @@ func one_E12D() -> E12D { UInt384(0, 0, 0, 0), UInt384(0, 0, 0, 0), ); + return res; } diff --git a/src/definitions.py b/src/definitions.py index 8c893bca..533ee017 100644 --- a/src/definitions.py +++ b/src/definitions.py @@ -1,5 +1,4 @@ -from src.algebra import Polynomial -from src.algebra import BaseField, PyFelt, ModuloCircuitElement +from src.algebra import Polynomial, BaseField, PyFelt, ModuloCircuitElement from dataclasses import dataclass from enum import Enum @@ -19,31 +18,46 @@ class CurveID(Enum): class Curve: id: int p: int + n: int # order irreducible_polys: dict[int, list[int]] - nr_a0: int + nr_a0: int # E2 non residue nr_a1: int + a: int # y^2 = x^3 + ax + b + b: int + b20: int + b21: int # E2: b is (b20, b21) CURVES = { BN254_ID: Curve( id=BN254_ID, p=0x30644E72E131A029B85045B68181585D97816A916871CA8D3C208C16D87CFD47, + n=0x30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001, irreducible_polys={ 6: [82, 0, 0, -18, 0, 0, 1], 12: [82, 0, 0, 0, 0, 0, -18, 0, 0, 0, 0, 0, 1], }, nr_a0=9, nr_a1=1, + a=0, + b=3, + b20=0x2B149D40CEB8AAAE81BE18991BE06AC3B5B4C5E559DBEFA33267E6DC24A138E5, + b21=0x9713B03AF0FED4CD2CAFADEED8FDF4A74FA084E52D1852E4A2BD0685C315D2, ), BLS12_381_ID: Curve( id=BLS12_381_ID, p=0x1A0111EA397FE69A4B1BA7B6434BACD764774B84F38512BF6730D2A0F6B0F6241EABFFFEB153FFFFB9FEFFFFFFFFAAAB, + n=0x73EDA753299D7D483339D80809A1D80553BDA402FFFE5BFEFFFFFFFF00000001, irreducible_polys={ 6: [2, 0, 0, -2, 0, 0, 1], 12: [2, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 1], }, nr_a0=1, nr_a1=1, + a=0, + b=4, + b20=4, + b21=4, ), } @@ -53,7 +67,7 @@ def get_base_field(curve_id: int) -> BaseField: def get_irreducible_poly(curve_id: int, extension_degree: int) -> Polynomial: - field = BaseField(CURVES[curve_id].p) + field = get_base_field(curve_id) return Polynomial( coefficients=[ field(x) for x in CURVES[curve_id].irreducible_polys[extension_degree] @@ -62,6 +76,61 @@ def get_irreducible_poly(curve_id: int, extension_degree: int) -> Polynomial: ) +@dataclass(frozen=True) +class G1Point: + """ + Represents a point on G1, the group of rational points on an elliptic curve over the base field. + """ + + x: int + y: int + curve_id: CurveID + + def __post_init__(self): + if not self.is_on_curve(): + raise ValueError(f"Point {self} is not on the curve") + + def is_on_curve(self) -> bool: + """ + Check if the point is on the curve using the curve equation y^2 = x^3 + ax + b. + """ + a = CURVES[self.curve_id.value].a + b = CURVES[self.curve_id.value].b + p = CURVES[self.curve_id.value].p + lhs = self.y**2 % p + rhs = (self.x**3 + a * self.x + b) % p + return lhs == rhs + + +@dataclass(frozen=True) +class G2Point: + """ + Represents a point on G2, the group of rational points on an elliptic curve over an extension field. + """ + + x: tuple[int, int] + y: tuple[int, int] + curve_id: CurveID + + def __post_init__(self): + if not self.is_on_curve(): + raise ValueError("Point is not on the curve") + + def is_on_curve(self) -> bool: + """ + Check if the point is on the curve using the curve equation y^2 = x^3 + ax + b in the extension field. + """ + from src.hints.tower_backup import E2 + + a = CURVES[self.curve_id.value].a + + p = CURVES[self.curve_id.value].p + b = E2(CURVES[self.curve_id.value].b20, CURVES[self.curve_id.value].b21, p) + y = E2(*self.y, p) + x = E2(*self.x, p) + return y**2 == x**3 + a * x + b + + # v^6 - 18v^3 + 82 # w^12 - 18w^6 + 82 # v^6 - 2v^3 + 2 @@ -172,10 +241,123 @@ def DT12(X: list[PyFelt], curve_id: int) -> list[PyFelt]: ] +def get_p_powers_of_V(curve_id: int, extension_degree: int, k: int) -> list[Polynomial]: + """ + Computes V^(i*p^k) for i in range(extension_degree), where V is the polynomial V(X) = X. + + Args: + curve_id (int): Identifier for the curve. + extension_degree (int): Degree of the field extension. + k (int): Exponent in p^k, must be 1, 2, or 3. + + Returns: + list[Polynomial]: List of polynomials representing V^(i*p^k) for i in range(extension_degree). + """ + assert k in [1, 2, 3], f"Supported k values are 1, 2, 3. Received: {k}" + + field = BaseField(CURVES[curve_id].p) + irr = get_irreducible_poly(curve_id, extension_degree) + + V = Polynomial( + [field.zero() if i != 1 else field.one() for i in range(extension_degree)] + ) + + V_pow = [V.pow(i * field.p**k, irr) for i in range(extension_degree)] + + return V_pow + + +def get_V_torus_powers(curve_id: int, extension_degree: int, k: int) -> Polynomial: + """ + Computes 1/V^((p^k - 1) // 2) where V is the polynomial V(X) = X. + This is used to compute the Frobenius automorphism in the Torus. + + Args: + curve_id (int): Identifier for the curve. + extension_degree (int): Degree of the field extension. + k (int): Exponent in p^k, must be 1, 2, or 3. + + Returns: + list[Polynomial]: List of polynomials representing V^(i*p^k) for i in range(extension_degree). + """ + assert k in [1, 2, 3], f"Supported k values are 1, 2, 3. Received: {k}" + + field = BaseField(CURVES[curve_id].p) + irr = get_irreducible_poly(curve_id, extension_degree) + + V = Polynomial( + [field.zero() if i != 1 else field.one() for i in range(extension_degree)] + ) + + V_pow = V.pow((field.p**k - 1) // 2, irr) + inverse, _, _ = Polynomial.xgcd(V_pow, irr) + return inverse + + +def frobenius( + F: list[PyFelt], V_pow: list[Polynomial], p: int, frob_power: int, irr: Polynomial +) -> Polynomial: + """ + Applies the Frobenius automorphism to a polynomial in a direct extension field. + + Args: + F (list[PyFelt]): Coefficients of the polynomial. + V_pow (list[Polynomial]): Precomputed powers of V. + p (int): Prime number of the base field. + frob_power (int): Power of the Frobenius automorphism. + irr (Polynomial): Irreducible polynomial for the field extension. + + Returns: + Polynomial: Result of applying Frobenius automorphism. + """ + assert len(F) == len(V_pow), "Mismatch in lengths of F and V_pow." + acc = Polynomial([PyFelt(0, p)]) + for i, f in enumerate(F): + acc += f * V_pow[i] + assert acc == ( + Polynomial(F).pow(p**frob_power, irr) + ), "Mismatch in expected result." + return acc + + +def generate_frobenius_maps( + curve_id, extension_degree: int, frob_power: int +) -> tuple[list[str], list[list[tuple[int, int]]]]: + """ + Generates symbolic expressions for Frobenius map coefficients and a list of tuples with constants. + + Args: + curve_id (CurveID): Identifier for the curve. + extension_degree (int): Degree of the field extension. + frob_power (int): Power of the Frobenius automorphism. + + Returns: + tuple[list[str], list[list[tuple[int, int]]]]: Symbolic expressions for each coefficient and a list of tuples with constants. + """ + curve_id = curve_id if type(curve_id) == int else curve_id.value + V_pow = get_p_powers_of_V(curve_id, extension_degree, frob_power) + + k_expressions = ["" for _ in range(extension_degree)] + constants_list = [[] for _ in range(extension_degree)] + for i in range(extension_degree): + for f_index, poly in enumerate(V_pow): + if poly[i] != 0: + hex_value = f"0x{poly[i]:x}" + compact_hex = ( + f"{hex_value[:6]}...{hex_value[-4:]}" + if len(hex_value) > 10 + else hex_value + ) + k_expressions[i] += f" + {compact_hex} * f_{f_index}" + constants_list[i].append((f_index, poly[i])) + return k_expressions, constants_list + + if __name__ == "__main__": from tools.extension_trick import v_to_gnark, gnark_to_v, w_to_gnark, gnark_to_w from random import randint + field = get_base_field(BN254_ID) x12i = [randint(0, CURVES[BN254_ID].p) for _ in range(12)] x12f = [PyFelt(x, CURVES[BN254_ID].p) for x in x12i] @@ -189,7 +371,52 @@ def DT12(X: list[PyFelt], curve_id: int) -> list[PyFelt]: TD1 = tower_to_direct(x12f[:6], BN254_ID, 6) TD2 = gnark_to_v(x12i) - print(f"TD1: {TD1}") - print(f"TD2: {TD2}") print(TD1 == TD2) print([x == y for x, y in zip(TD1, TD2)]) + + XD = [11, 22, 33, 44, 55, 66, 77, 88, 99, 100, 111, 122] + XD = [field(x) for x in XD] + XT = direct_to_tower(XD, BN254_ID, 12) + XT0, XT1 = XT[0:6], XT[6:] + XD0 = tower_to_direct(XT0, BN254_ID, 6) + XD1 = tower_to_direct(XT1, BN254_ID, 6) + + print(f"XD = {[x.value for x in XD]}") + print(f"XT = {[x.value for x in XT]}") + print(f"XT0 = {[x.value for x in XT0]}") + print(f"XT1 = {[x.value for x in XT1]}") + print(f"XD0 = {[x.value for x in XD0]}") + print(f"XD1 = {[x.value for x in XD1]}") + + # Frobenius maps + for extension_degree in [6, 12]: + for curve_id in [CurveID.BN254, CurveID.BLS12_381]: + p = CURVES[curve_id.value].p + for frob_power in [1, 2, 3]: + print( + f"\nFrobenius^{frob_power} for {curve_id.name} Fp{extension_degree}" + ) + irr = get_irreducible_poly(curve_id.value, extension_degree) + + V_pow = get_p_powers_of_V(curve_id.value, extension_degree, frob_power) + print( + f"Torus Inv: {get_V_torus_powers(curve_id.value, extension_degree, frob_power).get_value_coeffs()}" + ) + F = [PyFelt(randint(0, p - 1), p) for _ in range(extension_degree)] + acc = frobenius(F, V_pow, p, frob_power, irr) + + k_expressions, constants_list = generate_frobenius_maps( + curve_id, extension_degree, frob_power + ) + print( + f"f = f0 + f1v + ... + f{extension_degree-1}v^{extension_degree-1}" + ) + print( + f"Frob(f) = f^p = f0 + f1v^(p^{frob_power}) + f2v^(2p^{frob_power}) + ... + f{extension_degree-1}*v^(({extension_degree-1})p^{frob_power})" + ) + print( + f"Frob(f) = k0 + k1v + ... + k{extension_degree-1}v^{extension_degree-1}" + ) + for i, expr in enumerate(k_expressions): + print(f"k_{i} = {expr}") + print(f"Constants: {constants_list[i]}") diff --git a/src/extension_field_modulo_circuit.py b/src/extension_field_modulo_circuit.py index bedbc11a..119de4c0 100644 --- a/src/extension_field_modulo_circuit.py +++ b/src/extension_field_modulo_circuit.py @@ -38,9 +38,9 @@ def __init__( def _init_accumulator(self): return EuclideanPolyAccumulator( - xy=self.constants["ZERO"], + xy=None, nondeterministic_Q=Polynomial([self.field.zero()]), - R=[self.constants["ZERO"]] * self.extension_degree, + R=[None] * self.extension_degree, ) def write_commitment( @@ -63,31 +63,50 @@ def write_commitments( return vals, self.transcript.continuable_hash, self.transcript.s1 def create_powers_of_Z( - self, Z: PyFelt, mock: bool = False + self, Z: PyFelt, mock: bool = False, max_degree: int = None ) -> list[ModuloCircuitElement]: + if max_degree is None: + max_degree = self.extension_degree powers = [self.write_cairo_native_felt(Z)] if not mock: - for _ in range(1, self.extension_degree + 1): + for _ in range(2, max_degree + 1): powers.append(self.mul(powers[-1], powers[0])) else: powers = powers + [ self.write_cairo_native_felt(self.field(Z.value**i)) - for i in range(1, self.extension_degree + 1) + for i in range(2, max_degree + 1) ] self.z_powers = powers return powers def eval_poly_in_precomputed_Z( - self, X: list[ModuloCircuitElement] + self, X: list[ModuloCircuitElement], sparse: bool = False ) -> ModuloCircuitElement: """ Evaluates a polynomial with coefficients `X` at precomputed powers of z. X(z) = X_0 + X_1 * z + X_2 * z^2 + ... + X_n * z^n """ - assert len(X) <= len(self.z_powers), f"{len(X)} > {len(self.z_powers)}" - X_of_z = X[0] - for i in range(1, len(X)): - X_of_z = self.add(X_of_z, self.mul(X[i], self.z_powers[i - 1])) + assert len(X) <= len( + self.z_powers + ), f"{len(X)} > Zpowlen = {len(self.z_powers)}" + sparsity = [1 if elmt != self.field.zero() else 0 for elmt in X] + + if not sparse: + X_of_z = X[0] + for i in range(1, len(X)): + X_of_z = self.add(X_of_z, self.mul(X[i], self.z_powers[i - 1])) + else: + first_non_zero_idx = sparsity.index(1) + if first_non_zero_idx == 0: + X_of_z = X[0] + else: + X_of_z = self.mul( + X[first_non_zero_idx], self.z_powers[first_non_zero_idx - 1] + ) + for i in range(first_non_zero_idx + 1, len(X)): + if sparsity[i] == 1: + X_of_z = self.add(X_of_z, self.mul(X[i], self.z_powers[i - 1])) + return X_of_z def extf_add( @@ -131,32 +150,32 @@ def extf_mul( X: list[ModuloCircuitElement], Y: list[ModuloCircuitElement], extension_degree: int, + x_is_sparse: bool = False, + y_is_sparse: bool = False, ) -> list[ModuloCircuitElement]: """ Multiply in the extension field X * Y mod irreducible_poly Commit to R and accumulates Q. """ - assert len(X) == len(Y) == extension_degree Q, R = nondeterministic_extension_field_mul_divmod( - X, Y, self.curve_id, self.extension_degree + X, Y, self.curve_id, extension_degree ) - R, s0, s1 = self.write_commitments(R) - + R, _, _ = self.write_commitments(R) + s1 = self.transcript.RLC_coeff Q = Polynomial(Q) s1 = self.field(s1) Q_acc: Polynomial = self.acc.nondeterministic_Q + s1 * Q s1 = self.write_cairo_native_felt(s1) # Evaluate polynomials X(z), Y(z) inside circuit. - X_of_z = self.eval_poly_in_precomputed_Z(X) - Y_of_z = self.eval_poly_in_precomputed_Z(Y) + X_of_z = self.eval_poly_in_precomputed_Z(X, x_is_sparse) + Y_of_z = self.eval_poly_in_precomputed_Z(Y, y_is_sparse) XY_of_z = self.mul(X_of_z, Y_of_z) ci_XY_of_z = self.mul(s1, XY_of_z) - XY_acc = self.add(self.acc.xy, ci_XY_of_z) + XY_acc = self.add(self.acc.xy, ci_XY_of_z) # Computes R_acc = R_acc + s1 * R as a Polynomial inside circuit R_acc = [self.add(r_acc, self.mul(s1, r)) for r_acc, r in zip(self.acc.R, R)] - self.acc = EuclideanPolyAccumulator( xy=XY_acc, nondeterministic_Q=Q_acc, R=R_acc ) @@ -175,10 +194,11 @@ def extf_square( Q, R = nondeterministic_extension_field_mul_divmod( X, X, self.curve_id, self.extension_degree ) - R, s0, s1 = self.write_commitments(R) + R, _, _ = self.write_commitments(R) + s1 = self.transcript.RLC_coeff + s1 = self.field(s1) Q = Polynomial(Q) - s1 = self.field(s1) Q_acc: Polynomial = self.acc.nondeterministic_Q + s1 * Q s1 = self.write_cairo_native_felt(s1) @@ -246,7 +266,7 @@ def finalize_circuit(self): lhs = self.acc.xy rhs = self.add(self.mul(Q_of_Z, P_of_z), R_of_Z) assert lhs.value == rhs.value, f"{lhs.value} != {rhs.value}" - return rhs + return True def summarize(self): add_count, mul_count = self.values_segment.summarize() @@ -256,7 +276,112 @@ def summarize(self): "ADDMOD": add_count, "POSEIDON": self.transcript.permutations_count, } - # TODO : add Number of poseidon. - # pprint(summary, ) return summary + + def compile_circuit( + self, + returns: dict[str] = { + "ptr": [ + "constants_ptr", + "add_offsets_ptr", + "mul_offsets_ptr", + "left_assert_eq_offsets_ptr", + "right_assert_eq_offsets_ptr", + "poseidon_indexes_ptr", + ], + "len": [ + "constants_ptr_len", + "add_mod_n", + "mul_mod_n", + "commitments_len", + "assert_eq_len", + "N_Euclidean_equations", + ], + }, + ) -> str: + values_segment_non_interactive = self.values_segment.non_interactive_transform() + dw_arrays = values_segment_non_interactive.get_dw_lookups() + dw_arrays["poseidon_indexes_ptr"] = self.transcript.poseidon_ptr_indexes + name = values_segment_non_interactive.name + len_returns = ":felt, ".join(returns["len"]) + ":felt," + ptr_returns = ":felt*, ".join(returns["ptr"]) + ":felt*," + code = f"func get_{name}_circuit()->({ptr_returns} {len_returns})" + "{" + "\n" + + code += "alloc_locals;\n" + + code += f"let constants_ptr_len = {len(dw_arrays['constants_ptr'])};\n" + code += f"let add_mod_n = {len(dw_arrays['add_offsets_ptr'])};\n" + code += f"let mul_mod_n = {len(dw_arrays['mul_offsets_ptr'])};\n" + code += f"let commitments_len = {len(self.commitments)};\n" + code += f"let assert_eq_len = {len(dw_arrays['left_assert_eq_offsets_ptr'])};\n" + code += ( + f"let N_Euclidean_equations = {len(dw_arrays['poseidon_indexes_ptr'])};\n" + ) + + assert len(dw_arrays["left_assert_eq_offsets_ptr"]) == len( + dw_arrays["right_assert_eq_offsets_ptr"] + ) + + for dw_array_name in returns["ptr"]: + code += f"let ({dw_array_name}:felt*) = get_label_location({dw_array_name}_loc);\n" + + return_vals = 0 + code += f"return ({', '.join(returns['ptr'])}, {', '.join(returns['len'])});\n" + + for dw_array_name in returns["ptr"]: + dw_values = dw_arrays[dw_array_name] + code += f"\t {dw_array_name}_loc:\n" + if dw_array_name == "constants_ptr": + for bigint in dw_values: + for limb in bigint: + code += f"\t dw {limb};\n" + code += "\n" + + elif dw_array_name in ["add_offsets_ptr", "mul_offsets_ptr"]: + for left, right, result in dw_values: + code += ( + f"\t dw {left};\n" + f"\t dw {right};\n" + f"\t dw {result};\n" + ) + code += "\n" + elif dw_array_name in [ + "left_assert_eq_offsets_ptr", + "right_assert_eq_offsets_ptr", + "poseidon_indexes_ptr", + ]: + for val in dw_values: + code += f"\t dw {val};\n" + + code += "\n" + code += "}\n" + return code + + +if __name__ == "__main__": + from src.definitions import CURVES, CurveID + + def init_z_circuit(z: int = 2): + c = ExtensionFieldModuloCircuit("test", CurveID.BN254.value, 6) + c.create_powers_of_Z(c.field(z), mock=True) + return c + + def test_eval(): + c = init_z_circuit() + X = c.write_elements([PyFelt(1, c.field.p) for _ in range(6)]) + print("X(z)", [x.value for x in X]) + X = c.eval_poly_in_precomputed_Z(X) + print("X(z)", X.value) + c.print_value_segment() + print([hex(x.value) for x in c.z_powers], len(c.z_powers)) + + test_eval() + + def test_eval_sparse(): + c = init_z_circuit() + X = c.write_elements([c.field.one(), c.field.zero(), c.field.one()]) + X = c.eval_poly_in_precomputed_Z(X, sparse=True) + print("X(z)", X.value) + c.print_value_segment() + print([hex(x.value) for x in c.z_powers], len(c.z_powers)) + + test_eval_sparse() diff --git a/src/hints/io.py b/src/hints/io.py index 0ff39436..29e7f862 100644 --- a/src/hints/io.py +++ b/src/hints/io.py @@ -39,6 +39,13 @@ def bigint_pack(x: object, n_limbs: int, base: int) -> int: return val +def bigint_pack_ptr(memory: object, ptr: object, n_limbs: int, base: int): + val = 0 + for i in range(n_limbs): + val += as_int(memory[ptr + i], PRIME) * base**i + return val + + def bigint_limbs(x: object, n_limbs: int): limbs = [] for i in range(n_limbs): @@ -93,6 +100,19 @@ def pack_bigint_array( return val +def pack_bigint_ptr( + memory: object, + ptr: object, + n_limbs: int, + base: int, + n_elements: int, +): + val = [] + for i in range(n_elements): + val.append(bigint_pack_ptr(memory, ptr + i * n_limbs, n_limbs, base)) + return val + + #### WRITE HINTS @@ -141,6 +161,16 @@ def fill_uint256(x: int, ids: object): ### OTHERS +def flatten(t): + result = [] + for item in t: + if isinstance(item, (tuple, list)): + result.extend(flatten(item)) + else: + result.append(item) + return result + + def split_128(a): """Takes in value, returns uint256-ish tuple.""" return (a & ((1 << 128) - 1), a >> 128) diff --git a/src/hints/tower_backup.py b/src/hints/tower_backup.py index 697e8f88..2c3ff82c 100644 --- a/src/hints/tower_backup.py +++ b/src/hints/tower_backup.py @@ -1,7 +1,6 @@ import numpy as np from dataclasses import dataclass from src.definitions import CURVES -from typing import Union from src.algebra import PyFelt @@ -14,6 +13,9 @@ class E2: def __str__(self) -> str: return f"({(self.a0)} + {self.a1}*u)" + def __eq__(self, other): + return self.a0 == other.a0 and self.a1 == other.a1 + @staticmethod def zero(p: int): return E2(0, 0, p) @@ -83,8 +85,10 @@ def __pow__(self, p: int): return self # Start the computation. - result = (1, 0) # Initialize result as the multiplicative identity in F_p^2. - temp = self.copy() # Initialize temp as self. + result = self.one( + self.p + ) # Initialize result as the multiplicative identity in F_p^2. + temp = self # Initialize temp as self. # Loop through each bit of the exponent p. for bit in reversed(bin(p)[2:]): # [2:] to strip the "0b" prefix. @@ -256,6 +260,7 @@ def felt_coeffs(self) -> list[PyFelt]: return [PyFelt(c, self.c0.b0.p) for c in self.coeffs] def __init__(self, x: list[PyFelt | E6], curve_id: int): + self.curve_id = curve_id if type(x[0]) == PyFelt and len(x) == 12: self.c0 = E6(x=x[0:6], curve_id=curve_id) self.c1 = E6(x=x[6:12], curve_id=curve_id) @@ -265,15 +270,6 @@ def __init__(self, x: list[PyFelt | E6], curve_id: int): else: raise ValueError - def __inv__(self): - t0, t1 = self.c0 * self.c0, self.c1 * self.c1 - tmp = t1.mul_by_non_residue() - t0 = t0 - tmp - t1 = t0.__inv__() - c0 = self.c0 * t1 - c1 = -self.c1 * t1 - return E12([c0, c1], curve_id=self.curve_id) - def __mul__(self, other): if isinstance(other, E12): a = self.c0 + self.c1 @@ -286,8 +282,17 @@ def __mul__(self, other): z0 = c + b return E12([z0, z1], self.curve_id) + def __inv__(self): + t0, t1 = self.c0 * self.c0, self.c1 * self.c1 + tmp = t1.mul_by_non_residue() + t0 = t0 - tmp + t1 = t0.__inv__() + c0 = self.c0 * t1 + c1 = -self.c1 * t1 + return E12([c0, c1], curve_id=self.curve_id) + def div(self, other): - if isinstance(other, E6): + if isinstance(other, E12): return self * other.__inv__() raise NotImplementedError diff --git a/src/modulo_circuit.cairo b/src/modulo_circuit.cairo new file mode 100644 index 00000000..65533414 --- /dev/null +++ b/src/modulo_circuit.cairo @@ -0,0 +1,88 @@ +from starkware.cairo.common.cairo_builtins import PoseidonBuiltin, ModBuiltin, UInt384 +from starkware.cairo.common.registers import get_fp_and_pc +from starkware.cairo.common.memcpy import memcpy +from starkware.cairo.common.modulo import run_mod_p_circuit + +from src.definitions import get_P, BASE, N_LIMBS +from src.precompiled_circuits.sample import get_sample_circuit +from src.utils import ( + get_Z_and_RLC_from_transcript, + write_felts_to_value_segment, + assert_limbs_at_index_are_equal, +) +func run_extension_field_modulo_circuit{ + range_check_ptr, + poseidon_ptr: PoseidonBuiltin*, + range_check96_ptr: felt*, + add_mod_ptr: ModBuiltin*, + mul_mod_ptr: ModBuiltin*, +}(input: felt*, input_len: felt, curve_id: felt, circuit_id: felt) -> felt { + alloc_locals; + let (__fp__, _) = get_fp_and_pc(); + let p: UInt384 = get_P(curve_id); + let ( + constants_ptr: felt*, + add_offsets_ptr: felt*, + mul_offsets_ptr: felt*, + left_assert_eq_offsets_ptr: felt*, + right_assert_eq_offsets_ptr: felt*, + poseidon_indexes_ptr: felt*, + constants_ptr_len: felt, + add_mod_n: felt, + mul_mod_n: felt, + commitments_len: felt, + assert_eq_len: felt, + N_Euclidean_equations: felt, + ) = get_sample_circuit(circuit_id); + + local values_ptr: UInt384* = cast(range_check96_ptr, UInt384*); + memcpy(dst=range_check96_ptr, src=constants_ptr, len=constants_ptr_len * N_LIMBS); // write(Constants) + memcpy(dst=range_check96_ptr + constants_ptr_len * N_LIMBS, src=input, len=input_len); // write(Input) + + local commitments: felt*; + %{ + from src.precompiled_circuits.sample import get_sample_circuit + from src.hints.io import pack_bigint_ptr, flatten + from src.definitions import CURVES, PyFelt + p = CURVES[ids.curve_id].p + circuit_input = pack_bigint_ptr(memory, ids.input, ids.N_LIMBS, ids.BASE, ids.input_len//ids.N_LIMBS) + circuit_input = [PyFelt(x, p) for x in circuit_input] + EXTF_MOD_CIRCUIT = get_sample_circuit(ids.circuit_id, circuit_input) + commitments = flatten([bigint_split(x.value, ids.N_LIMBS, ids.BASE) for x in EXTF_MOD_CIRCUIT.commitments]) + ids.commitments = segments.gen_arg(commitments) + print(len(commitments), len(commitments)//4) + %} + + memcpy( + dst=range_check96_ptr + constants_ptr_len * N_LIMBS + input_len, + src=commitments, + len=commitments_len * N_LIMBS, + ); // write(Commitments) + + let (local Z: felt, local RLC_coeffs: felt*) = get_Z_and_RLC_from_transcript( + transcript_start=cast(values_ptr, felt*) + constants_ptr_len, + poseidon_indexes_ptr=poseidon_indexes_ptr, + n_elements_in_transcript=commitments_len, + n_equations=N_Euclidean_equations, + ); + + tempvar range_check96_ptr = range_check96_ptr + constants_ptr_len * N_LIMBS + input_len + + commitments_len * N_LIMBS; + write_felts_to_value_segment(values=&Z, n=1); + write_felts_to_value_segment(values=RLC_coeffs, n=N_Euclidean_equations); + + run_mod_p_circuit( + p=p, + values_ptr=values_ptr, + add_mod_offsets_ptr=add_offsets_ptr, + add_mod_n=add_mod_n, + mul_mod_offsets_ptr=mul_offsets_ptr, + mul_mod_n=mul_mod_n, + ); + + // assert_limbs_at_index_are_equal( + // values_ptr, left_assert_eq_offsets_ptr, right_assert_eq_offsets_ptr, assert_eq_len + // ); + + return 0; +} diff --git a/src/modulo_circuit.py b/src/modulo_circuit.py index 7a939785..fb292ec8 100644 --- a/src/modulo_circuit.py +++ b/src/modulo_circuit.py @@ -122,7 +122,11 @@ def non_interactive_transform(self) -> "ValueSegment": BUILTIN Order matters! """ - res = ValueSegment(self.name + "_non_interactive") + res = ValueSegment( + self.name + if self.name.endswith("_non_interactive") + else self.name + "_non_interactive" + ) offset_map = {} for stacks_key in [ WriteOps.CONSTANT, @@ -161,17 +165,19 @@ def non_interactive_transform(self) -> "ValueSegment": def get_dw_lookups(self) -> dict: dw_arrays = { - "constants": [], - "add_offsets": [], - "mul_offsets": [], - "left_assert_eq_offsets": [], - "right_assert_eq_offsets": [], - "poseidon_ptr_indexes": [], + "constants_ptr": [], + "add_offsets_ptr": [], + "mul_offsets_ptr": [], + "left_assert_eq_offsets_ptr": [], + "right_assert_eq_offsets_ptr": [], + "poseidon_indexes_ptr": [], } for _, item in self.segment_stacks[WriteOps.CONSTANT].items(): - dw_arrays["constants"].append(bigint_split(item.value, self.n_limbs, BASE)) + dw_arrays["constants_ptr"].append( + bigint_split(item.value, self.n_limbs, BASE) + ) for result_offset, item in self.segment_stacks[WriteOps.BUILTIN].items(): - dw_arrays[item.instruction.operation.name.lower() + "_offsets"].append( + dw_arrays[item.instruction.operation.name.lower() + "_offsets_ptr"].append( ( item.instruction.left_offset, item.instruction.right_offset, @@ -179,10 +185,10 @@ def get_dw_lookups(self) -> dict: ) ) for assert_eq_instruction in self.assert_eq_instructions: - dw_arrays["left_assert_eq_offsets"].append( + dw_arrays["left_assert_eq_offsets_ptr"].append( assert_eq_instruction.segment_left_offset ) - dw_arrays["right_assert_eq_offsets"].append( + dw_arrays["right_assert_eq_offsets_ptr"].append( assert_eq_instruction.segment_right_offset ) @@ -248,10 +254,11 @@ def __init__(self, name: str, curve_id: int) -> None: self.field = BaseField(CURVES[curve_id].p) self.N_LIMBS = 4 self.values_segment: ValueSegment = ValueSegment(name) - self.constants: dict[str, ModuloCircuitElement] = dict() - self.add_constant("ZERO", self.field.zero()) - self.add_constant("ONE", self.field.one()) - self.add_constant("MINUS_ONE", self.field(-1)) + self.constants: dict[int, ModuloCircuitElement] = dict() + + self.add_constant(self.field.zero()) + self.add_constant(self.field.one()) + self.add_constant(self.field(-1)) @property def values_offset(self): @@ -292,27 +299,31 @@ def write_cairo_native_felt(self, native_felt: PyFelt): def write_sparse_elements( self, elmts: list[PyFelt], operation: WriteOps - ) -> list[ModuloCircuitElement]: + ) -> (list[ModuloCircuitElement], list[int]): sparsity = [1 if elmt != self.field.zero() else 0 for elmt in elmts] - return [ - self.write_element(elmt, operation) - for elmt, not_sparse in zip(elmts, sparsity) - if not_sparse - ], sparsity - - def add_constant(self, name: str, value: PyFelt) -> None: - if name in self.constants: - print((f"/!\ Constant '{name}' already exists.")) - return self.constants[name] - self.constants[name] = self.write_element(value, WriteOps.CONSTANT) - return self.constants[name] - - def get_constant(self, name: str) -> ModuloCircuitElement: - if name not in self.constants: + elements = [] + for elmt, not_sparse in zip(elmts, sparsity): + if not_sparse: + if elmt.value not in self.constants: + elements.append(self.write_element(elmt, operation)) + else: + elements.append(self.get_constant(elmt.value)) + return elements, sparsity + + def add_constant(self, val: PyFelt) -> None: + if val.value in self.constants: + # print((f"/!\ Constant '{hex(val.value)}' already exists.")) + return self.constants[val.value] + self.constants[val.value] = self.write_element(val, WriteOps.CONSTANT) + return self.constants[val.value] + + def get_constant(self, val: int) -> ModuloCircuitElement: + val = val % self.field.p + if (val) not in self.constants: raise ValueError( - f"Constant '{name}' does not exist. Available constants : {list(self.constants.keys())}" + f"Constant '{val}' does not exist. Available constants : {list(self.constants.keys())}" ) - return self.constants[name] + return self.constants[val] def assert_eq(self, a: ModuloCircuitElement, b: ModuloCircuitElement): self.values_segment.assert_eq(a.offset, b.offset) @@ -331,16 +342,21 @@ def assert_eq_one(self, a: ModuloCircuitElement): def add( self, a: ModuloCircuitElement, b: ModuloCircuitElement ) -> ModuloCircuitElement: - assert ( - type(a) == type(b) == ModuloCircuitElement - ), f"Expected ModuloElement, got {type(a)}, {a} and {type(b)}, {b}" - - instruction = ModuloCircuitInstruction( - ModBuiltinOps.ADD, a.offset, b.offset, self.values_offset - ) - return self.write_element( - a.emulated_felt + b.emulated_felt, WriteOps.BUILTIN, instruction - ) + if a is None: + return b + elif b is None: + return a + else: + assert ( + type(a) == type(b) == ModuloCircuitElement + ), f"Expected ModuloElement, got {type(a)}, {a} and {type(b)}, {b}" + + instruction = ModuloCircuitInstruction( + ModBuiltinOps.ADD, a.offset, b.offset, self.values_offset + ) + return self.write_element( + a.emulated_felt + b.emulated_felt, WriteOps.BUILTIN, instruction + ) def mul( self, @@ -356,7 +372,11 @@ def mul( ) def neg(self, a: ModuloCircuitElement) -> ModuloCircuitElement: - return self.mul(a, self.constants["MINUS_ONE"]) + res = self.mul(a, self.get_constant(-1)) + assert ( + res.value == (-a.felt).value + ), f"Expected {res.value} to be equal to {(-a.felt).value}" + return res def sub(self, a: ModuloCircuitElement, b: ModuloCircuitElement): return self.add(a, self.neg(b)) @@ -383,54 +403,6 @@ def _check_sanity(self): def print_value_segment(self): self.values_segment.print() - def compile_circuit( - self, - returns: list[str] = [ - "constants", - "add_offsets", - "mul_offsets", - "left_assert_eq_offsets", - "right_assert_eq_offsets", - "poseidon_ptr_indexes", - ], - ) -> str: - values_segment_non_interactive = self.values_segment.non_interactive_transform() - dw_arrays = values_segment_non_interactive.get_dw_lookups() - name = values_segment_non_interactive.name - code = f"func get_{name}_circuit()->({':felt*, '.join(returns)})" + "{" + "\n" - - for dw_array_name in returns: - code += f"let ({dw_array_name}_ptr:felt*) = get_label_location({dw_array_name}_loc);\n" - - code += f"return ({'_ptr, '.join(returns)});\n" - - for dw_array_name in returns: - dw_values = dw_arrays[dw_array_name] - code += f"\t {dw_array_name}_loc:\n" - if dw_array_name == "constants": - for bigint in dw_values: - for limb in bigint: - code += f"\t dw {limb};\n" - code += "\n" - - elif dw_array_name in ["add_offsets", "mul_offsets"]: - for left, right, result in dw_values: - code += ( - f"\t dw {left};\n" + f"\t dw {right};\n" + f"\t dw {result};\n" - ) - code += "\n" - elif dw_array_name in [ - "left_assert_eq_offsets", - "right_assert_eq_offsets", - "poseidon_ptr_indexes", - ]: - for val in dw_values: - code += f"\t dw {val};\n" - - code += "\n" - code += "}\n" - return code - if __name__ == "__main__": from src.algebra import BaseField diff --git a/src/pairing.cairo b/src/pairing.cairo index 3ba02a4a..473c8a94 100644 --- a/src/pairing.cairo +++ b/src/pairing.cairo @@ -13,63 +13,3 @@ func multi_miller_loop{ range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin* }() -> E12D { } - -func final_exponentiation{ - poseidon_ptr: PoseidonBuiltin*, - range_check96_ptr: felt*, - add_mod_ptr: ModBuiltin*, - mul_mod_ptr: ModBuiltin*, -}(input: E12D, curve_id: felt) -> felt { - alloc_locals; - let (__fp__, _) = get_fp_and_pc(); - let p: UInt384 = get_P(curve_id); - let ( - constants_ptr: felt*, - constants_ptr_len: felt, - add_offsets: felt*, - mul_offsets: felt*, - commitments_len: felt, - transcript_indexes: felt*, - N_Euclidean_equations: felt, - ) = get_final_exp_circuit(curve_id); - - let values_ptr = cast(range_check96_ptr, UInt384*); - memcpy(dst=range_check96_ptr, src=constants_ptr, len=constants_ptr_len); // write(Constants) - memcpy(dst=range_check96_ptr + constants_ptr_len, src=&input, len=E12D.SIZE); // write(Input) - - local commitments: felt*; - %{ - from src.precompiled_circuits.final_exp import get_final_exp_circuit - FinalExpCircuit = get_final_exp_circuit(ids.curve_id) - FinalExpCircuit.run( - range_check96_ptr=range_check96_ptr, - add_mod_ptr=add_mod_ptr, - mul_mod_ptr=mul_mod_ptr - ) - ids.commitments = segments.gen_arg(commitments) - %} - - memcpy( - dst=range_check96_ptr + constants_ptr_len + E12D.SIZE, src=commitments, len=commitments_len - ); // write(Commitments) - - let (local Z: felt, local RLC_coeffs: felt*) = get_Z_and_RLC_from_transcript( - transcript_start=cast(values_ptr, felt*) + constants_ptr_len, - poseidon_ptr_indexes=transcript_indexes, - n_elements_in_transcript=E12D.SIZE + commitments_len, - n_equations=N_Euclidean_equations, - ); - - tempvar range_check96_ptr = range_check96_ptr + constants_ptr_len + E12D.SIZE + commitments_len; - write_felts_to_value_segment(felts=&Z, n=1); - write_felts_to_value_segment(felts=RLC_coeffs, n=N_Euclidean_equations); - - run_mod_p_circuit( - p=p, - values_ptr=values_ptr, - add_mod_offsets_ptr=add_offsets, - add_mod_n=2, - mul_mod_offsets_ptr=mul_offsets, - mul_mod_n=2, - ); -} diff --git a/src/poseidon_transcript.py b/src/poseidon_transcript.py index 3da9fa97..49645bb6 100644 --- a/src/poseidon_transcript.py +++ b/src/poseidon_transcript.py @@ -15,6 +15,16 @@ def __init__(self, init_hash: int) -> None: self.continuable_hash = init_hash self.s1 = None self.permutations_count = 0 + self.poseidon_ptr_indexes = [] + + @property + def RLC_coeff(self): + """ + A function to retrieve the random linear combination coefficient after a permutation. + Stores the index of the last permutation in the poseidon_ptr_indexes list, to be used to retrieve RLC coefficients later. + """ + self.poseidon_ptr_indexes.append(self.permutations_count - 1) + return self.s1 def hash_value(self, x: int): s0, s1, _ = hades_permutation([x, self.continuable_hash, 2], self.params) @@ -41,23 +51,23 @@ def hash_limbs_multi( self.hash_value(combined_limbs) return self.continuable_hash, self.s1 - def generate_poseidon_assertions( - self, - continuable_hash_name: str, - num_pairs: int, - ) -> str: - cairo_code = "" - for i in range(num_pairs): - s0_index = i * 2 - s1_index = s0_index + 1 - if i == 0: - s1_previous_output = continuable_hash_name - else: - s1_previous_output = f"poseidon_ptr[{i-1}].output.s0" - cairo_code += ( - f" assert poseidon_ptr[{i}].input = PoseidonBuiltinState(\n" - f" s0=range_check96_ptr[{s0_index}] * range_check96_ptr[{s1_index}], " - f"s1={s1_previous_output}, s2=two\n" - " );\n" - ) - return cairo_code + # def generate_poseidon_assertions( + # self, + # continuable_hash_name: str, + # num_pairs: int, + # ) -> str: + # cairo_code = "" + # for i in range(num_pairs): + # s0_index = i * 2 + # s1_index = s0_index + 1 + # if i == 0: + # s1_previous_output = continuable_hash_name + # else: + # s1_previous_output = f"poseidon_ptr[{i-1}].output.s0" + # cairo_code += ( + # f" assert poseidon_ptr[{i}].input = PoseidonBuiltinState(\n" + # f" s0=range_check96_ptr[{s0_index}] * range_check96_ptr[{s1_index}], " + # f"s1={s1_previous_output}, s2=two\n" + # " );\n" + # ) + # return cairo_code diff --git a/src/precompiled_circuits/final_exp.cairo b/src/precompiled_circuits/final_exp.cairo index d522cee5..4b62e747 100644 --- a/src/precompiled_circuits/final_exp.cairo +++ b/src/precompiled_circuits/final_exp.cairo @@ -1,3 +1,5 @@ +from starkware.cairo.common.registers import get_label_location + func BN254_final_exp() -> (add_offsets_ptr: felt*, mul_offsets_ptr: felt*) { return (0, 0); } @@ -5,245 +7,3 @@ func BN254_final_exp() -> (add_offsets_ptr: felt*, mul_offsets_ptr: felt*) { func BLS12_381_final_exp() -> (add_offsets_ptr: felt*, mul_offsets_ptr: felt*) { return (0, 0); } - -func get_GaragaBN254FinalExp_non_interactive_circuit() -> ( - constants: felt*, - add_offsets: felt*, - mul_offsets: felt*, - left_assert_eq_offsets: felt*, - right_assert_eq_offsets: felt*, - poseidon_ptr_indexes, -) { - let (constants_ptr: felt*) = get_label_location(constants_loc); - let (add_offsets_ptr: felt*) = get_label_location(add_offsets_loc); - let (mul_offsets_ptr: felt*) = get_label_location(mul_offsets_loc); - let (left_assert_eq_offsets_ptr: felt*) = get_label_location(left_assert_eq_offsets_loc); - let (right_assert_eq_offsets_ptr: felt*) = get_label_location(right_assert_eq_offsets_loc); - let (poseidon_ptr_indexes_ptr: felt*) = get_label_location(poseidon_ptr_indexes_loc); - return ( - constants_ptr, - add_offsets_ptr, - mul_offsets_ptr, - left_assert_eq_offsets_ptr, - right_assert_eq_offsets_ptr, - poseidon_ptr_indexes, - ); - - constants_loc: - dw 0; - dw 0; - dw 0; - dw 0; - dw 1; - dw 0; - dw 0; - dw 0; - dw 32324006162389411176778628422; - dw 57042285082623239461879769745; - dw 3486998266802970665; - dw 0; - - add_offsets_loc: - dw 40; - dw 40; - dw 94; - - dw 44; - dw 44; - dw 98; - - dw 48; - dw 48; - dw 102; - - dw 52; - dw 52; - dw 106; - - dw 56; - dw 56; - dw 110; - - dw 60; - dw 60; - dw 114; - - dw 94; - dw 118; - dw 142; - - dw 98; - dw 122; - dw 146; - - dw 102; - dw 126; - dw 150; - - dw 106; - dw 130; - dw 154; - - dw 110; - dw 134; - dw 158; - - dw 114; - dw 138; - dw 162; - - dw 12; - dw 166; - dw 170; - - dw 170; - dw 174; - dw 178; - - dw 178; - dw 182; - dw 186; - - dw 186; - dw 190; - dw 194; - - dw 194; - dw 198; - dw 202; - - dw 142; - dw 206; - dw 210; - - dw 210; - dw 214; - dw 218; - - dw 218; - dw 222; - dw 226; - - dw 226; - dw 230; - dw 234; - - dw 234; - dw 238; - dw 242; - - dw 0; - dw 250; - dw 254; - - dw 0; - dw 69; - dw 258; - - mul_offsets_loc: - dw 64; - dw 64; - dw 74; - - dw 74; - dw 64; - dw 78; - - dw 78; - dw 64; - dw 82; - - dw 82; - dw 64; - dw 86; - - dw 86; - dw 64; - dw 90; - - dw 12; - dw 8; - dw 118; - - dw 16; - dw 8; - dw 122; - - dw 20; - dw 8; - dw 126; - - dw 24; - dw 8; - dw 130; - - dw 28; - dw 8; - dw 134; - - dw 32; - dw 8; - dw 138; - - dw 16; - dw 74; - dw 166; - - dw 20; - dw 78; - dw 174; - - dw 24; - dw 82; - dw 182; - - dw 28; - dw 86; - dw 190; - - dw 32; - dw 90; - dw 198; - - dw 146; - dw 74; - dw 206; - - dw 150; - dw 78; - dw 214; - - dw 154; - dw 82; - dw 222; - - dw 158; - dw 86; - dw 230; - - dw 162; - dw 90; - dw 238; - - dw 202; - dw 242; - dw 246; - - dw 69; - dw 246; - dw 250; - - left_assert_eq_offsets_loc: - dw 36; - dw 37; - dw 38; - dw 39; - - right_assert_eq_offsets_loc: - dw 4; - dw 5; - dw 6; - dw 7; - - poseidon_ptr_indexes_loc: -} diff --git a/src/precompiled_circuits/final_exp.py b/src/precompiled_circuits/final_exp.py index d1f9a926..fb1e6f80 100644 --- a/src/precompiled_circuits/final_exp.py +++ b/src/precompiled_circuits/final_exp.py @@ -7,18 +7,49 @@ PyFelt, Polynomial, ) -from src.definitions import CURVES, STARK, CurveID, BN254_ID, Curve +from src.poseidon_transcript import CairoPoseidonTranscript +from src.definitions import ( + CURVES, + STARK, + CurveID, + BN254_ID, + BLS12_381_ID, + Curve, + generate_frobenius_maps, + get_V_torus_powers, +) from src.hints.extf_mul import ( nondeterministic_square_torus, nondeterministic_extension_field_mul_divmod, ) +import random +from enum import Enum -class FinalExpCircuit(ExtensionFieldModuloCircuit): +class FinalExpTorusCircuit(ExtensionFieldModuloCircuit): def __init__(self, name: str, curve_id: int, extension_degree: int): super().__init__( name=name, curve_id=curve_id, extension_degree=extension_degree ) + self.frobenius_maps = {} + self.v_torus_powers_inv = {} + for i in [1, 2, 3]: + _, self.frobenius_maps[i] = generate_frobenius_maps( + curve_id=curve_id, extension_degree=extension_degree, frob_power=i + ) + self.v_torus_powers_inv[i] = get_V_torus_powers( + curve_id, extension_degree, i + ).get_coeffs() + + # Write to circuit. Note : add_constant will return existing circuit constant if it already exists. + self.v_torus_powers_inv[i] = [ + self.add_constant(v) for v in self.v_torus_powers_inv[i] + ] + + def final_exp_part1( + self, X: list[PyFelt], unsafe: bool + ) -> list[ModuloCircuitElement]: + return NotImplementedError def square_torus( self: ExtensionFieldModuloCircuit, X: list[ModuloCircuitElement] @@ -32,7 +63,8 @@ def square_torus( SQ: list[PyFelt] = nondeterministic_square_torus( X, self.curve_id, biject_from_direct=True ) - SQ, continuable_hash, s1 = self.write_commitments(SQ) + SQ, _, _ = self.write_commitments(SQ) + s1 = self.transcript.RLC_coeff s1 = self.write_cairo_native_felt(self.field(s1)) two_SQ = self.extf_add(SQ, SQ) two_SQ_min_X = self.extf_sub(two_SQ, X) @@ -65,6 +97,12 @@ def square_torus( return SQ + def n_square_torus(self, X: list[PyFelt], n: int) -> list[PyFelt]: + result = self.square_torus(X) + for _ in range(n - 1): + result = self.square_torus(result) + return result + def mul_torus( self, X: list[ModuloCircuitElement], Y: list[ModuloCircuitElement] ) -> list[ModuloCircuitElement]: @@ -74,7 +112,7 @@ def mul_torus( xy = self.extf_mul(X, Y, self.extension_degree) num = copy.deepcopy(xy) - num[1] = self.add(xy[1], self.constants["ONE"]) + num[1] = self.add(xy[1], self.constants[1]) den = self.extf_add(X, Y) return self.extf_div(num, den, self.extension_degree) @@ -88,10 +126,10 @@ def decompress_torus( """ Returns (X + w) / (X - w). Size is doubled. """ - zero = self.get_constant("ZERO") + zero = self.get_constant(0) num = [ X[0], - self.get_constant("ONE"), + self.get_constant(1), X[1], zero, X[2], @@ -104,44 +142,62 @@ def decompress_torus( zero, ] den = num.copy() - den[1] = self.get_constant("MINUS_ONE") + den[1] = self.get_constant(-1) return self.extf_div(num, den, 2 * self.extension_degree) - -class GaragaBN254FinalExp(FinalExpCircuit): - def __init__(self): - super().__init__( - name="GaragaBN254FinalExp", curve_id=BN254_ID, extension_degree=6 - ) - - def final_exp_part1( - self, X: list[PyFelt], unsafe: bool + def frobenius_torus( + self, X: list[ModuloCircuitElement], frob_power: int + ) -> list[ModuloCircuitElement]: + frob = [None] * self.extension_degree + for i, list_op in enumerate(self.frobenius_maps[frob_power]): + list_op_result = [] + for index, constant in list_op: + if constant == 1: + list_op_result.append(X[index]) + else: + # print(constant) + list_op_result.append( + self.mul(X[index], self.add_constant(self.field(constant))) + ) + frob[i] = list_op_result[0] + for op_res in list_op_result[1:]: + frob[i] = self.add(frob[i], op_res) + + if len(self.v_torus_powers_inv[frob_power]) == 1: + return self.extf_scalar_mul(frob, self.v_torus_powers_inv[frob_power][0]) + else: + return self.extf_mul( + X=frob, + Y=[self.add_constant(v) for v in self.v_torus_powers_inv[frob_power]], + extension_degree=self.extension_degree, + y_is_sparse=True, + ) + + def easy_part( + self, X: list[ModuloCircuitElement], unsafe: bool ) -> list[ModuloCircuitElement]: + """ + Computes the easy part of the final exponentiation. + """ self.write_elements(X, operation=WriteOps.INPUT) # Hash input. self.transcript.hash_limbs_multi(X) - MIN_ONE = self.get_constant("MINUS_ONE") - MIN_9 = circuit.add_constant("MIN_9", self.field(-9)) + num_indexes = [0, 2, 4, 6, 8, 10] + den_indexes = [1, 3, 5, 7, 9, 11] - num = circuit.write_elements(X[0:6]) - - num_full = [ - circuit.mul(MIN_ONE, circuit.add(num[0], circuit.mul(MIN_9, num[1]))), - circuit.mul(MIN_ONE, circuit.add(num[2], circuit.mul(MIN_9, num[3]))), - circuit.mul(MIN_ONE, circuit.add(num[4], circuit.mul(MIN_9, num[5]))), - circuit.mul(MIN_ONE, num[1]), - circuit.mul(MIN_ONE, num[3]), - circuit.mul(MIN_ONE, num[5]), - ] + num = self.write_elements([X[i] for i in num_indexes], operation=WriteOps.INPUT) + num = self.extf_scalar_mul(num, self.get_constant(-1)) if unsafe: - den = circuit.write_elements(X[6:12], operation=WriteOps.INPUT) + den = self.write_elements( + [X[i] for i in den_indexes], operation=WriteOps.INPUT + ) else: - if [x.value for x in X[6:12]] == [0, 0, 0, 0, 0, 0]: + if [x.value for x in [X[i] for i in den_indexes]] == [0, 0, 0, 0, 0, 0]: selector1 = 1 - den = circuit.write_elements( + den = self.write_elements( [ self.field.one(), self.field.zero(), @@ -153,20 +209,143 @@ def final_exp_part1( ) else: selector1 = 0 - den = circuit.write_elements(self.raw_input[6:12]) - - den_full = [ - circuit.add(den[0], circuit.mul(MIN_9, den[1])), - circuit.add(den[2], circuit.mul(MIN_9, den[3])), - circuit.add(den[4], circuit.mul(MIN_9, den[5])), - den[1], - den[3], - den[5], - ] + den = self.write_elements( + [X[i] for i in den_indexes], operation=WriteOps.INPUT + ) - c = self.extf_div(num_full, den_full, self.extension_degree) - t0 = self.frobenius_square_torus(c) + c = self.extf_div(num, den, self.extension_degree) + t0 = self.frobenius_torus(c, 2) + # c = self.extf_neg(c) c = self.mul_torus(t0, c) + return c + + def final_exp_finalize(self, t0: list[PyFelt], t2: list[PyFelt]): + # Computes Decompress Torus(MulTorus(t0, t2)). + # Only valid if (t0 + t2) != 0. + t0 = self.write_elements(t0, WriteOps.INPUT) + t2 = self.write_elements(t2, WriteOps.INPUT) + + mul = self.mul_torus(t0, t2) + return self.decompress_torus(mul) + + +class GaragaBLS12_381FinalExp(FinalExpTorusCircuit): + def __init__(self, init_hash: int = None): + super().__init__( + name="GaragaBLS12_381FinalExp", curve_id=BLS12_381_ID, extension_degree=6 + ) + if init_hash is not None: + self.transcript = CairoPoseidonTranscript(init_hash) + + def expt_half_torus(self, X: list[ModuloCircuitElement]): + z = self.square_torus(X) + z = self.mul_torus(X, z) + z = self.square_torus(z) + z = self.square_torus(z) + z = self.mul_torus(X, z) + z = self.n_square_torus(z, 3) + z = self.mul_torus(X, z) + z = self.n_square_torus(z, 9) + z = self.mul_torus(X, z) + z = self.n_square_torus(z, 32) + z = self.mul_torus(X, z) + z = self.n_square_torus(z, 15) + z = self.inverse_torus(z) + return z + + def expt_torus(self, X): + z = self.expt_half_torus(X) + return self.square_torus(z) + + def final_exp_part1(self, X: list[PyFelt], unsafe: bool) -> list[PyFelt]: + """ + X is a list of 12 elements in the tower of extension fields. + """ + c = self.easy_part(X, unsafe) + + # 2. Hard part (up to permutation) + # 3(p⁴-p²+1)/r + # Daiki Hayashida, Kenichiro Hayasaka and Tadanori Teruya + # https://eprint.iacr.org/2020/875.pdf + # performed in torus compressed form + t0 = self.square_torus(c) + t1 = self.expt_half_torus(t0) + t2 = self.inverse_torus(c) + t1 = self.mul_torus(t1, t2) + t2 = self.expt_torus(t1) + t1 = self.inverse_torus(t1) + t1 = self.mul_torus(t1, t2) + t2 = self.expt_torus(t1) + t1 = self.frobenius_torus(t1, 1) + t1 = self.mul_torus(t1, t2) + c = self.mul_torus(c, t0) + t0 = self.expt_torus(t1) + t2 = self.expt_torus(t0) + t0 = self.frobenius_torus(t1, 2) + t1 = self.inverse_torus(t1) + t1 = self.mul_torus(t1, t2) + t1 = self.mul_torus(t1, t0) + + # The final exp result is DecompressTorus(MulTorus(c, t1) + # MulTorus(c, t1) = (c*t1 + v)/(c + t1). + # (c+t1 = 0) ==> MulTorus(c, t1) is one in the Torus. + _sum = self.extf_add(c, t1) + # From this case we can conclude the result is 1 or !=1 without decompression. + # In case we want to decompress to get the result in GT, + # we might need to decompress with another circuit, if the result is not 1 (_sum!=0). + return _sum, c, t1 + + +class GaragaBN254FinalExp(FinalExpTorusCircuit): + def __init__(self, init_hash: int = None): + super().__init__( + name="GaragaBN254FinalExp", curve_id=BN254_ID, extension_degree=6 + ) + if init_hash is not None: + self.transcript = CairoPoseidonTranscript(init_hash) + + def expt_torus(self, X: list[PyFelt]): + t3 = self.square_torus(X) + t5 = self.square_torus(t3) + result = self.square_torus(t5) + t0 = self.square_torus(result) + t2 = self.mul_torus(X, t0) + t0 = self.mul_torus(t3, t2) + t1 = self.mul_torus(X, t0) + t4 = self.mul_torus(result, t2) + t6 = self.square_torus(t2) + t1 = self.mul_torus(t0, t1) + t0 = self.mul_torus(t3, t1) + t6 = self.n_square_torus(t6, 6) + t5 = self.mul_torus(t5, t6) + t5 = self.mul_torus(t4, t5) + t5 = self.n_square_torus(t5, 7) + t4 = self.mul_torus(t4, t5) + t4 = self.n_square_torus(t4, 8) + t4 = self.mul_torus(t0, t4) + t3 = self.mul_torus(t3, t4) + t3 = self.n_square_torus(t3, 6) + t2 = self.mul_torus(t2, t3) + t2 = self.n_square_torus(t2, 8) + t2 = self.mul_torus(t0, t2) + t2 = self.n_square_torus(t2, 6) + t2 = self.mul_torus(t0, t2) + t2 = self.n_square_torus(t2, 10) + t1 = self.mul_torus(t1, t2) + t1 = self.n_square_torus(t1, 6) + t0 = self.mul_torus(t0, t1) + z = self.mul_torus(result, t0) + return z + + def final_exp_part1( + self, X: list[PyFelt], unsafe: bool + ) -> list[ModuloCircuitElement]: + """ + single pairing -> unsafe = False + double pairing -> unsafe = True + """ + c = self.easy_part(X, unsafe) + # 2. Hard part (up to permutation) # 2x₀(6x₀²+3x₀+1)(p⁴-p²+1)/r # Duquesne and Ghammam @@ -174,7 +353,7 @@ def final_exp_part1( # Fuentes et al. (alg. 6) # performed in torus compressed form t0 = self.expt_torus(c) - t0 = self.inverse_torus(c) + t0 = self.inverse_torus(t0) t0 = self.square_torus(t0) t1 = self.square_torus(t0) t1 = self.mul_torus(t0, t1) @@ -182,18 +361,20 @@ def final_exp_part1( t2 = self.inverse_torus(t2) t3 = self.inverse_torus(t1) t1 = self.mul_torus(t2, t3) - t3 = self.square_torus(t3) + t3 = self.square_torus(t2) t4 = self.expt_torus(t3) t4 = self.mul_torus(t1, t4) t3 = self.mul_torus(t0, t4) + t0 = self.mul_torus(t2, t4) t0 = self.mul_torus(c, t0) - t2 = self.frobenius_torus(t3) + t2 = self.frobenius_torus(t3, 1) t0 = self.mul_torus(t2, t0) - t2 = self.frobenius_square_torus(t4) + t2 = self.frobenius_torus(t4, 2) t0 = self.mul_torus(t0, t2) t2 = self.inverse_torus(c) t2 = self.mul_torus(t2, t3) - t2 = self.frobenius_cube_torus(t2) + t2 = self.frobenius_torus(t2, 3) + # The final exp result is DecompressTorus(MulTorus(t0, t2)). # MulTorus(t0, t2) = (t0*t2 + v)/(t0 + t2). # (T0+T2 = 0) ==> MulTorus(t0, t2) is one in the Torus. _sum = self.extf_add(t0, t2) @@ -201,38 +382,83 @@ def final_exp_part1( # In case we want to decompress to get the result in GT, # we might need to decompress with another circuit, if the result is not 1 (_sum!=0). - pass + return _sum, t0, t2 - def final_exp_finalize(self, t0: list[PyFelt], t2: list[PyFelt]): - # Computes Decompress Torus(MulTorus(t0, t2)). - # Only valid if (t0 + t2) != 0. - t0 = self.write_elements(t0, WriteOps.INPUT) - t2 = self.write_elements(t2, WriteOps.INPUT) - self.transcript.hash_limbs_multi(t0) - self.transcript.hash_limbs_multi(t2) - mul = self.mul_torus(t0, t2) - return self.decompress_torus(mul) +class GaragaFinalExp(Enum): + BN254 = GaragaBN254FinalExp + BLS12_381 = GaragaBLS12_381FinalExp -if __name__ == "__main__": - from random import randint, seed +def test_final_exp(curve_id: CurveID): + from tools.gnark import GnarkCLI + from src.definitions import tower_to_direct + + cli = GnarkCLI(curve_id) + n = CURVES[curve_id.value].n + a, b = cli.nG1nG2_operation(random.randint(0, n - 1), random.randint(0, n - 1)) + a, b = cli.nG1nG2_operation(1, 1) + + base_class = GaragaFinalExp[curve_id.name].value + part1 = base_class() - seed(0) - curve: Curve = CURVES[BN254_ID] - p = curve.p - X = [PyFelt(randint(0, p - 1), p) for _ in range(12)] + XT = cli.miller([a], [b]) + ET = cli.pair([a], [b]) - circuit = GaragaBN254FinalExp() - circuit.create_powers_of_Z(PyFelt(11, STARK)) + XT = [part1.field(x) for x in XT] + ET = [part1.field(x) for x in ET] - M = circuit.write_elements(X[0:6]) + XD = tower_to_direct(XT, curve_id.value, 12) + ED = tower_to_direct(ET, curve_id.value, 12) - sqt = circuit.square_torus(M) - mtt = circuit.mul_torus(M, M) - circuit.print_value_segment() - print(circuit.compile_circuit()) - fiat = circuit.values_segment.non_interactive_transform() - fiat.print() + part1.create_powers_of_Z(part1.field(2)) + _sum, t0, t2 = part1.final_exp_part1(XD, unsafe=False) + _sum = [x.value for x in _sum] + t0 = [x.felt for x in t0] + t2 = [x.felt for x in t2] - print([x.value for x in sqt] == [x.value for x in mtt]) + part2 = base_class(init_hash=part1.transcript.s1) + part2.create_powers_of_Z(part2.field(2), max_degree=12) + if _sum == [0, 0, 0, 0, 0, 0]: + f = [part1.field.one()] + else: + f = part2.final_exp_finalize(t0, t2) + f = [f.value for f in f] + + assert f == [ + e.value for e in ED + ], f"Final exp in circuit and in Gnark do not match f={f}\ne={[e.value for e in ED]}" + print(f"{curve_id} Final Exp random test pass") + return part1, part2 + + +if __name__ == "__main__": + from src.definitions import ( + CurveID, + get_base_field, + Polynomial, + get_irreducible_poly, + ) + import random + + def test_frobenius_torus(): + from archive_tmp.bn254.pairing_final_exp import frobenius_torus + + field = get_base_field(CurveID.BN254.value) + X = [field(random.randint(0, field.p - 1)) for _ in range(6)] + t = FinalExpTorusCircuit("test", CurveID.BN254.value, 6) + t.create_powers_of_Z(field(2)) + X = t.write_elements(X) + XF = t.frobenius_torus(X, 1) + # Xpoly = Polynomial([x.felt for x in X]) + # XFpoly = Xpoly.pow(field.p, get_irreducible_poly(CurveID.BN254.value, )) + # assert t.finalize_circuit() + # t.values_segment = t.values_segment.non_interactive_transform() + + TT = frobenius_torus([x.value for x in X]) + assert all(x.value == y for x, y in zip(XF, TT)) + + t.print_value_segment() + + test_final_exp(CurveID.BN254) + test_final_exp(CurveID.BLS12_381) diff --git a/src/precompiled_circuits/sample.cairo b/src/precompiled_circuits/sample.cairo new file mode 100644 index 00000000..4ad32d29 --- /dev/null +++ b/src/precompiled_circuits/sample.cairo @@ -0,0 +1,342 @@ +from starkware.cairo.common.registers import get_label_location + + + func get_sample_circuit(id: felt) -> ( + constants_ptr: felt*, + add_offsets_ptr: felt*, + mul_offsets_ptr: felt*, + left_assert_eq_offsets_ptr: felt*, + right_assert_eq_offsets_ptr: felt*, + poseidon_indexes_ptr: felt*, + constants_ptr_len: felt, + add_mod_n: felt, + mul_mod_n: felt, + commitments_len: felt, + assert_eq_len: felt, + N_Euclidean_equations: felt, +) { + if (id == 1) { + return get_sample_circuit_1_non_interactive_circuit(); + } else { + return ( + cast(0, felt*), + cast(0, felt*), + cast(0, felt*), + cast(0, felt*), + cast(0, felt*), + cast(0, felt*), + 0, + 0, + 0, + 0, + 0, + 0, + ); + } +} +func get_sample_circuit_1_non_interactive_circuit()->(constants_ptr:felt*, add_offsets_ptr:felt*, mul_offsets_ptr:felt*, left_assert_eq_offsets_ptr:felt*, right_assert_eq_offsets_ptr:felt*, poseidon_indexes_ptr:felt*, constants_ptr_len:felt, add_mod_n:felt, mul_mod_n:felt, commitments_len:felt, assert_eq_len:felt, N_Euclidean_equations:felt,){ +alloc_locals; +let constants_ptr_len = 5; +let add_mod_n = 29; +let mul_mod_n = 36; +let commitments_len = 17; +let assert_eq_len = 0; +let N_Euclidean_equations = 2; +let (constants_ptr:felt*) = get_label_location(constants_ptr_loc); +let (add_offsets_ptr:felt*) = get_label_location(add_offsets_ptr_loc); +let (mul_offsets_ptr:felt*) = get_label_location(mul_offsets_ptr_loc); +let (left_assert_eq_offsets_ptr:felt*) = get_label_location(left_assert_eq_offsets_ptr_loc); +let (right_assert_eq_offsets_ptr:felt*) = get_label_location(right_assert_eq_offsets_ptr_loc); +let (poseidon_indexes_ptr:felt*) = get_label_location(poseidon_indexes_ptr_loc); +return (constants_ptr, add_offsets_ptr, mul_offsets_ptr, left_assert_eq_offsets_ptr, right_assert_eq_offsets_ptr, poseidon_indexes_ptr, constants_ptr_len, add_mod_n, mul_mod_n, commitments_len, assert_eq_len, N_Euclidean_equations); + constants_ptr_loc: + dw 0; + dw 0; + dw 0; + dw 0; + dw 1; + dw 0; + dw 0; + dw 0; + dw 32324006162389411176778628422; + dw 57042285082623239461879769745; + dw 3486998266802970665; + dw 0; + dw 82; + dw 0; + dw 0; + dw 0; + dw 32324006162389411176778628405; + dw 57042285082623239461879769745; + dw 3486998266802970665; + dw 0; + + add_offsets_ptr_loc: + dw 20; + dw 122; + dw 126; + + dw 126; + dw 130; + dw 134; + + dw 134; + dw 138; + dw 142; + + dw 142; + dw 146; + dw 150; + + dw 150; + dw 154; + dw 158; + + dw 20; + dw 162; + dw 166; + + dw 166; + dw 170; + dw 174; + + dw 174; + dw 178; + dw 182; + + dw 182; + dw 186; + dw 190; + + dw 190; + dw 194; + dw 198; + + dw 0; + dw 206; + dw 210; + + dw 0; + dw 214; + dw 218; + + dw 0; + dw 222; + dw 226; + + dw 0; + dw 230; + dw 234; + + dw 0; + dw 238; + dw 242; + + dw 0; + dw 246; + dw 250; + + dw 0; + dw 254; + dw 258; + + dw 68; + dw 262; + dw 266; + + dw 266; + dw 270; + dw 274; + + dw 274; + dw 278; + dw 282; + + dw 282; + dw 286; + dw 290; + + dw 12; + dw 294; + dw 298; + + dw 298; + dw 302; + dw 306; + + dw 218; + dw 310; + dw 314; + + dw 314; + dw 318; + dw 322; + + dw 322; + dw 326; + dw 330; + + dw 330; + dw 334; + dw 338; + + dw 338; + dw 342; + dw 346; + + dw 350; + dw 346; + dw 354; + + mul_offsets_ptr_loc: + dw 88; + dw 88; + dw 98; + + dw 98; + dw 88; + dw 102; + + dw 102; + dw 88; + dw 106; + + dw 106; + dw 88; + dw 110; + + dw 110; + dw 88; + dw 114; + + dw 114; + dw 88; + dw 118; + + dw 24; + dw 88; + dw 122; + + dw 28; + dw 98; + dw 130; + + dw 32; + dw 102; + dw 138; + + dw 36; + dw 106; + dw 146; + + dw 40; + dw 110; + dw 154; + + dw 24; + dw 88; + dw 162; + + dw 28; + dw 98; + dw 170; + + dw 32; + dw 102; + dw 178; + + dw 36; + dw 106; + dw 186; + + dw 40; + dw 110; + dw 194; + + dw 158; + dw 198; + dw 202; + + dw 93; + dw 202; + dw 206; + + dw 93; + dw 44; + dw 214; + + dw 93; + dw 48; + dw 222; + + dw 93; + dw 52; + dw 230; + + dw 93; + dw 56; + dw 238; + + dw 93; + dw 60; + dw 246; + + dw 93; + dw 64; + dw 254; + + dw 72; + dw 88; + dw 262; + + dw 76; + dw 98; + dw 270; + + dw 80; + dw 102; + dw 278; + + dw 84; + dw 106; + dw 286; + + dw 16; + dw 102; + dw 294; + + dw 4; + dw 114; + dw 302; + + dw 226; + dw 88; + dw 310; + + dw 234; + dw 98; + dw 318; + + dw 242; + dw 102; + dw 326; + + dw 250; + dw 106; + dw 334; + + dw 258; + dw 110; + dw 342; + + dw 290; + dw 306; + dw 350; + + left_assert_eq_offsets_ptr_loc: + right_assert_eq_offsets_ptr_loc: + poseidon_indexes_ptr_loc: + dw 11; + dw 23; + +} diff --git a/src/precompiled_circuits/sample.py b/src/precompiled_circuits/sample.py new file mode 100644 index 00000000..5c66ea15 --- /dev/null +++ b/src/precompiled_circuits/sample.py @@ -0,0 +1,82 @@ +from src.definitions import Curve, CURVES, BN254_ID, STARK, PyFelt +from src.precompiled_circuits.final_exp import ( + FinalExpTorusCircuit, + ExtensionFieldModuloCircuit, +) + + +def get_sample_circuit(id: int, input: list[PyFelt]) -> ExtensionFieldModuloCircuit: + if id == 1: + return sample_circuit_1(input) + else: + raise ValueError(f"Unknown circuit id: {id}") + + +def sample_circuit_1(input: list[PyFelt]) -> ExtensionFieldModuloCircuit: + assert len(input) == 6, f"Expected 6 elements in input, got {len(input)}" + circuit = FinalExpTorusCircuit("sample_circuit_1", BN254_ID, extension_degree=6) + circuit.create_powers_of_Z(PyFelt(11, STARK)) + X, _, _ = circuit.write_commitments(input) + s1 = circuit.transcript.RLC_coeff + sqt = circuit.extf_mul(X, X, 6) + # mtt = circuit.mul_torus(X, X) + circuit.finalize_circuit() + circuit.values_segment = circuit.values_segment.non_interactive_transform() + + circuit.print_value_segment() + + return circuit + + +if __name__ == "__main__": + from random import randint, seed + + seed(0) + curve: Curve = CURVES[BN254_ID] + p = curve.p + input = [PyFelt(randint(0, p - 1), p) for _ in range(6)] + + circuit = sample_circuit_1(input) + circuit.print_value_segment() + + cairo_code = circuit.compile_circuit() + with open("src/precompiled_circuits/sample.cairo", "w") as f: + f.write("from starkware.cairo.common.registers import get_label_location\n\n") + f.write( + """ + func get_sample_circuit(id: felt) -> ( + constants_ptr: felt*, + add_offsets_ptr: felt*, + mul_offsets_ptr: felt*, + left_assert_eq_offsets_ptr: felt*, + right_assert_eq_offsets_ptr: felt*, + poseidon_indexes_ptr: felt*, + constants_ptr_len: felt, + add_mod_n: felt, + mul_mod_n: felt, + commitments_len: felt, + assert_eq_len: felt, + N_Euclidean_equations: felt, +) { + if (id == 1) { + return get_sample_circuit_1_non_interactive_circuit(); + } else { + return ( + cast(0, felt*), + cast(0, felt*), + cast(0, felt*), + cast(0, felt*), + cast(0, felt*), + cast(0, felt*), + 0, + 0, + 0, + 0, + 0, + 0, + ); + } +} +""" + ) + f.write(cairo_code) diff --git a/src/utils.cairo b/src/utils.cairo index 13acf692..be5c2218 100644 --- a/src/utils.cairo +++ b/src/utils.cairo @@ -1,20 +1,21 @@ from starkware.cairo.common.alloc import alloc from starkware.cairo.common.cairo_builtins import PoseidonBuiltin from starkware.cairo.common.poseidon_state import PoseidonBuiltinState -from src.definitions import STARK_MIN_ONE_D2, N_LIMBS +from src.definitions import STARK_MIN_ONE_D2, N_LIMBS, BASE func get_Z_and_RLC_from_transcript{poseidon_ptr: PoseidonBuiltin*, range_check96_ptr: felt*}( transcript_start: felt*, - poseidon_ptr_indexes: felt*, + poseidon_indexes_ptr: felt*, n_elements_in_transcript: felt, n_equations: felt, ) -> (Z: felt, random_linear_combination_coefficients: felt*) { + alloc_locals; tempvar poseidon_start = poseidon_ptr; let (Z: felt) = hash_full_transcript_and_get_Z( limbs_ptr=transcript_start, n=n_elements_in_transcript ); let (RLC_coeffs: felt*) = retrieve_random_coefficients( - poseidon_start, poseidon_ptr_indexes=poseidon_ptr_indexes, n=n_equations + poseidon_start, poseidon_indexes_ptr=poseidon_indexes_ptr, n=n_equations ); return (Z=Z, random_linear_combination_coefficients=RLC_coeffs); } @@ -22,14 +23,28 @@ func hash_full_transcript_and_get_Z{poseidon_ptr: PoseidonBuiltin*}(limbs_ptr: f Z: felt ) { alloc_locals; + %{ print(f"N elemts in transcript : {ids.n} ") %} + local two = 2; + let input_hash = 14; - tempvar limbs_ptr = limb_ptr; - tempvar i = 0; + // Initialisation: + assert poseidon_ptr[0].input = PoseidonBuiltinState( + limbs_ptr[0] * limbs_ptr[1], input_hash, two + ); + assert poseidon_ptr[1].input = PoseidonBuiltinState( + limbs_ptr[2] * limbs_ptr[3], poseidon_ptr[0].output.s0, two + ); + + tempvar limbs_ptr: felt* = limbs_ptr + 4; + tempvar i = 2; hash_limbs_2_by_2: - let limb_ptr = [ap - 2]; + let limbs_ptr: felt* = cast([ap - 2], felt*); let i = [ap - 1]; - %{ memory[ap] = 1 if ids.i == ids.n else 0 %} + %{ + print(ids.i/2, "/", ids.n) + memory[ap] = 1 if ids.i == 2*ids.n else 0 + %} jmp end_loop if [ap] != 0, ap++; assert poseidon_ptr[i].input = PoseidonBuiltinState( @@ -39,18 +54,22 @@ func hash_full_transcript_and_get_Z{poseidon_ptr: PoseidonBuiltin*}(limbs_ptr: f limbs_ptr[2] * limbs_ptr[3], poseidon_ptr[i].output.s0, two ); - [ap] = limbs_ptr + 4; - [ap] = i + 1, ap++; + [ap] = limbs_ptr + 4, ap++; + [ap] = i + 2, ap++; + + jmp hash_limbs_2_by_2; end_loop: - assert i = n; - tempvar poseidon_ptr = poseidon_ptr + PoseidonBuiltin.SZIE * n * 2; - tempvar res = poseidon_ptr[0].output.s0; + // let i = [ap - 1]; + assert i = 2 * n; + %{ print(f"i: {ids.i}, n:{ids.n}") %} + tempvar poseidon_ptr = poseidon_ptr + PoseidonBuiltin.SIZE * n * 2; + tempvar res = [poseidon_ptr - PoseidonBuiltin.SIZE].output.s0; return (Z=res); } func retrieve_random_coefficients( - poseidon_ptr: PoseidonBuiltin*, poseidon_ptr_indexes: felt*, n: felt + poseidon_ptr: PoseidonBuiltin*, poseidon_indexes_ptr: felt*, n: felt ) -> (coefficients: felt*) { alloc_locals; let (local coefficients: felt*) = alloc(); @@ -61,7 +80,7 @@ func retrieve_random_coefficients( let i = [ap - 1]; %{ memory[ap] = 1 if ids.i == ids.n else 0 %} jmp end if [ap] != 0, ap++; - assert coefficients[i] = poseidon_ptr[poseidon_ptr_indexes[i]].output.s1; + assert coefficients[i] = poseidon_ptr[poseidon_indexes_ptr[i]].output.s1; [ap] = i + 1, ap++; jmp get_s1_loop; @@ -89,7 +108,9 @@ func write_felts_to_value_segment{range_check96_ptr: felt*}(values: felt*, n: fe let d2 = [range_check96_ptr + offset + 2]; %{ from src.hints.io import bigint_split - limbs = bigint_split(ids.values[ids.i], ids.N_LIMBS, ids.BASE) + felt_val = memory[ids.values+ids.i] + print(f"felt val : {felt_val}") + limbs = bigint_split(felt_val, ids.N_LIMBS, ids.BASE) assert limbs[3] == 0 ids.d0, ids.d1, ids.d2 = limbs[0], limbs[1], limbs[2] %} @@ -99,13 +120,17 @@ func write_felts_to_value_segment{range_check96_ptr: felt*}(values: felt*, n: fe if (d2 == stark_min_1_d2) { assert d1 = 0; assert d2 = 0; + [ap] = i + 1, ap++; + } else { + [ap] = i + 1, ap++; } - [ap] = i + 1, ap++; jmp loop; end: assert i = n; + %{ print(f"RangeCheckptr:{ids.range_check96_ptr}", ids.n, ids.n_rc_per_felt) %} tempvar range_check96_ptr = range_check96_ptr + n * n_rc_per_felt; + return (); } diff --git a/tests/cairo_programs/extf_circuit.cairo b/tests/cairo_programs/extf_circuit.cairo new file mode 100644 index 00000000..49afdd1c --- /dev/null +++ b/tests/cairo_programs/extf_circuit.cairo @@ -0,0 +1,38 @@ +%builtins range_check poseidon range_check96 add_mod mul_mod + +from starkware.cairo.common.cairo_builtins import PoseidonBuiltin, ModBuiltin +from starkware.cairo.common.registers import get_fp_and_pc +from starkware.cairo.common.alloc import alloc + +from src.modulo_circuit import run_extension_field_modulo_circuit +from src.definitions import bn, bls, UInt384, one_E12D, N_LIMBS, BASE +from src.precompiled_circuits.sample import get_sample_circuit +func main{ + range_check_ptr, + poseidon_ptr: PoseidonBuiltin*, + range_check96_ptr: felt*, + add_mod_ptr: ModBuiltin*, + mul_mod_ptr: ModBuiltin*, +}() { + alloc_locals; + let (__fp__, _) = get_fp_and_pc(); + let (local input: felt*) = alloc(); + local input_len: felt; + local circuit_id = 1; + %{ + from random import randint + import random + from src.definitions import CURVES, PyFelt + from src.hints.io import bigint_split, flatten + random.seed(0) + p = CURVES[ids.bn.CURVE_ID].p + X=[PyFelt(randint(0, p - 1), p) for _ in range(6)] + X=flatten([bigint_split(x.value, ids.N_LIMBS, ids.BASE) for x in X]) + print(X, len(X)) + segments.write_arg(ids.input, X) + ids.input_len = len(X) + %} + + let x = run_extension_field_modulo_circuit(input, input_len, bn.CURVE_ID, circuit_id); + return (); +} diff --git a/tests/cairo_programs/poseidon_chain.cairo b/tests/cairo_programs/poseidon_chain.cairo index 8c3b47d1..4b796738 100644 --- a/tests/cairo_programs/poseidon_chain.cairo +++ b/tests/cairo_programs/poseidon_chain.cairo @@ -1,6 +1,6 @@ %builtins poseidon -from starkware.cairo.common.cairo_builtins import PoseidonBuiltin, BitwiseBuiltin +from starkware.cairo.common.cairo_builtins import PoseidonBuiltin from starkware.cairo.common.poseidon_state import PoseidonBuiltinState from starkware.cairo.common.cairo_secp.bigint import BigInt3 diff --git a/tools/gnark.py b/tools/gnark.py new file mode 100644 index 00000000..03466135 --- /dev/null +++ b/tools/gnark.py @@ -0,0 +1,96 @@ +import re +import subprocess +from src.definitions import G1Point, G2Point, CurveID, CURVES + + +exec_path = { + CurveID.BN254: "tools/gnark/main", + CurveID.BLS12_381: "tools/gnark/bls12_381/cairo_test/main", +} + + +class GnarkCLI: + def __init__(self, curve_id: CurveID): + self.curve = CURVES[curve_id.value] + self.curve_id = curve_id + self.executable_path = exec_path[curve_id] + + def run_command(self, args): + process = subprocess.Popen( + [self.executable_path] + args, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + stdout, stderr = process.communicate() + if process.returncode != 0: + raise Exception(f"Error executing gnark-cli: {stderr.decode('utf-8')}") + return stdout.decode("utf-8") + + def parse_fp_elements(self, input_string: str): + pattern = re.compile(r"\[([^\[\]]+)\]") + substrings = pattern.findall(input_string) + sublists = [substring.split(" ") for substring in substrings] + sublists = [[int(x) for x in sublist] for sublist in sublists] + fp_elements = [] + for sublist in sublists: + element_value = sum(x * (2 ** (64 * i)) for i, x in enumerate(sublist)) + fp_elements.append(element_value) + return fp_elements + + def pair(self, P: list[G1Point], Q: list[G2Point]): + assert len(P) == len(Q) + args = ["n_pair", "pair", str(len(P))] + for p, q in zip(P, Q): + args += [ + str(p.x), + str(p.y), + str(q.x[0]), + str(q.x[1]), + str(q.y[0]), + str(q.y[1]), + ] + output = self.run_command(args) + res = self.parse_fp_elements(output) + assert len(res) == 12, f"Got {output}" + return res + + def miller(self, P: list[G1Point], Q: list[G2Point]): + assert len(P) == len(Q) + args = ["n_pair", "miller_loop", str(len(P))] + for p, q in zip(P, Q): + args += [ + str(p.x), + str(p.y), + str(q.x[0]), + str(q.x[1]), + str(q.y[0]), + str(q.y[1]), + ] + output = self.run_command(args) + res = self.parse_fp_elements(output) + assert len(res) == 12 + return res + + def nG1nG2_operation(self, n1: int, n2: int) -> tuple[G1Point, G2Point]: + args = ["nG1nG2", str(n1), str(n2)] + output = self.run_command(args) + fp_elements = self.parse_fp_elements(output) + assert len(fp_elements) == 6 + # print(fp_elements) + return G1Point(*fp_elements[:2], self.curve_id), G2Point( + tuple(fp_elements[2:4]), tuple(fp_elements[4:6]), self.curve_id + ) + + +if __name__ == "__main__": + for curve_id in [CurveID.BN254, CurveID.BLS12_381]: + print("\n\n", curve_id) + cli = GnarkCLI(curve_id) + curve = CURVES[curve_id.value] + + a, b = cli.nG1nG2_operation(1, 1) + + e = cli.pair([a], [b]) + m = cli.miller([a], [b]) + print(f"m={m}") + print(f"e={e}") diff --git a/tools/gnark/bls12_381/cairo_test/main.go b/tools/gnark/bls12_381/cairo_test/main.go index b3470010..f000b2ec 100644 --- a/tools/gnark/bls12_381/cairo_test/main.go +++ b/tools/gnark/bls12_381/cairo_test/main.go @@ -5,7 +5,7 @@ import ( "log" "math/big" "os" - "tools/gnark/bls12_381" + bls12381 "tools/gnark/bls12_381" "tools/gnark/bls12_381/fp" "tools/gnark/bls12_381/fptower" @@ -337,6 +337,69 @@ func main() { z.C1.B2.A1.FromMont() fmt.Println(z) + case "n_pair": + var z fptower.E12 + + n := new(big.Int) + n, _ = n.SetString(c.Args().Get(2), 10) + + g1_arr := make([]bls12381.G1Affine, int(n.Int64())) + g2_arr := make([]bls12381.G2Affine, int(n.Int64())) + + for i := 0; i < int(n.Int64()); i++ { + felt := new(big.Int) + + var X0, X1 fp.Element + var Y0, Y1, Y2, Y3 fp.Element + + felt, _ = felt.SetString(c.Args().Get(3+i*6), 10) + X0.SetBigInt(felt) + felt, _ = felt.SetString(c.Args().Get(4+i*6), 10) + X1.SetBigInt(felt) + felt, _ = felt.SetString(c.Args().Get(5+i*6), 10) + Y0.SetBigInt(felt) + felt, _ = felt.SetString(c.Args().Get(6+i*6), 10) + Y1.SetBigInt(felt) + felt, _ = felt.SetString(c.Args().Get(7+i*6), 10) + Y2.SetBigInt(felt) + felt, _ = felt.SetString(c.Args().Get(8+i*6), 10) + Y3.SetBigInt(felt) + + X := &g1_arr[i] + X.X = X0 + X.Y = X1 + + Y := &g2_arr[i] + Y.X.A0 = Y0 + Y.X.A1 = Y1 + Y.Y.A0 = Y2 + Y.Y.A1 = Y3 + } + + switch c.Args().Get(1) { + case "pair": + Z, _ := bls12381.Pair(g1_arr, g2_arr) + + z.Set(&Z) + + case "miller_loop": + Z, _ := bls12381.MillerLoop(g1_arr, g2_arr) + z.Set(&Z) + } + z.C0.B0.A0.FromMont() + z.C0.B0.A1.FromMont() + z.C0.B1.A0.FromMont() + z.C0.B1.A1.FromMont() + z.C0.B2.A0.FromMont() + z.C0.B2.A1.FromMont() + z.C1.B0.A0.FromMont() + z.C1.B0.A1.FromMont() + z.C1.B1.A0.FromMont() + z.C1.B1.A1.FromMont() + z.C1.B2.A0.FromMont() + z.C1.B2.A1.FromMont() + + fmt.Println(z) } return nil