From 36c9edd9bd8e90fb3ff4d7528b37089b7e8375d7 Mon Sep 17 00:00:00 2001 From: feltroidprime Date: Tue, 7 May 2024 11:31:47 +0200 Subject: [PATCH] add derive_point_from_X circuit --- README.md | 1 + src/definitions.cairo | 35 +++++++ src/ec_ops.cairo | 89 ++++++++++++++++- src/modulo_circuit.cairo | 2 +- src/modulo_circuit.py | 29 +++++- src/precompiled_circuits/all_circuits.py | 89 +++++++++++++---- src/precompiled_circuits/ec.cairo | 109 +++++++++++++++++++++ src/precompiled_circuits/ec.py | 68 ++++++++++--- src/utils.cairo | 26 ++++- tests/benchmarks.py | 18 ++++ tests/cairo_programs/derive_ec_point.cairo | 62 ++++++++++++ 11 files changed, 491 insertions(+), 37 deletions(-) create mode 100644 tests/cairo_programs/derive_ec_point.cairo diff --git a/README.md b/README.md index 203bdc910..19b345649 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,7 @@ make run | circuit | MULMOD | ADDMOD | ASSERT_EQ | POSEIDON | RLC | ~steps | |-------------------------------------------|----------|----------|-------------|------------|-------|----------| +| Derive Point From X | 5 | 1 | 0 | 0 | 0 | 44 | | Double Step BLS12_381 | 22 | 9 | 2 | 0 | 0 | 216 | | Double Step BN254 | 24 | 11 | 2 | 0 | 0 | 240 | | Fp6 SQUARE_TORUS | 12 | 16 | 0 | 7 | 1 | 300 | diff --git a/src/definitions.cairo b/src/definitions.cairo index 058cc1ef9..d6c6ac61d 100644 --- a/src/definitions.cairo +++ b/src/definitions.cairo @@ -17,6 +17,13 @@ namespace bls { // Non residue constants: const NON_RESIDUE_E2_a0 = 1; const NON_RESIDUE_E2_a1 = 1; + + // Curve equation parameters: + const a = 0; + const b = 4; + + // Fp generator : + const g = 3; } namespace bn { @@ -36,6 +43,11 @@ namespace bn { // Non residue constants: const NON_RESIDUE_E2_a0 = 9; const NON_RESIDUE_E2_a1 = 1; + // Curve equation parameters: + const a = 0; + const b = 3; + // Fp Generator : + const g = 3; } func get_P(curve_id: felt) -> (prime: UInt384) { @@ -50,6 +62,29 @@ func get_P(curve_id: felt) -> (prime: UInt384) { } } +func get_b(curve_id: felt) -> (res: UInt384) { + if (curve_id == bls.CURVE_ID) { + return (res=UInt384(bls.b, 0, 0, 0)); + } else { + if (curve_id == bn.CURVE_ID) { + return (res=UInt384(bn.b, 0, 0, 0)); + } else { + return (res=UInt384(-1, 0, 0, 0)); + } + } +} + +func get_fp_gen(curve_id: felt) -> (res: UInt384) { + if (curve_id == bls.CURVE_ID) { + return (res=UInt384(bls.g, 0, 0, 0)); + } else { + if (curve_id == bn.CURVE_ID) { + return (res=UInt384(bn.g, 0, 0, 0)); + } else { + return (res=UInt384(-1, 0, 0, 0)); + } + } +} const SUPPORTED_CURVE_ID = 0; const UNSUPPORTED_CURVE_ID = 1; diff --git a/src/ec_ops.cairo b/src/ec_ops.cairo index e51c416a6..8b160009a 100644 --- a/src/ec_ops.cairo +++ b/src/ec_ops.cairo @@ -1,7 +1,12 @@ -from src.definitions import is_zero_mod_P, get_P -from src.precompiled_circuits.ec import get_IS_ON_CURVE_G1_G2_circuit +from src.definitions import is_zero_mod_P, get_P, G1Point, get_b, get_fp_gen, verify_zero4, BASE +from src.precompiled_circuits.ec import ( + get_IS_ON_CURVE_G1_G2_circuit, + get_DERIVE_POINT_FROM_X_circuit, +) from src.modulo_circuit import run_modulo_circuit, ModuloCircuit -from starkware.cairo.common.cairo_builtins import ModBuiltin, UInt384 +from starkware.cairo.common.cairo_builtins import ModBuiltin, UInt384, PoseidonBuiltin +from starkware.cairo.common.alloc import alloc +from src.utils import felt_to_UInt384 func is_on_curve_g1_g2{ range_check_ptr, range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin* @@ -26,3 +31,81 @@ func is_on_curve_g1_g2{ } return (res=1); } + +struct DerivePointFromXOutput { + rhs: UInt384, // x^3 + ax + b + grhs: UInt384, // g * (x^3+ax+b) + should_be_rhs: UInt384, + should_be_grhs: UInt384, + y_try: UInt384, +} + +struct DerivePointFromXInput { + entropy: UInt384, + y: UInt384, + rhs_from_x_is_a_square_residue: felt, +} +func derive_EC_point_from_entropy{ + range_check_ptr, + range_check96_ptr: felt*, + add_mod_ptr: ModBuiltin*, + mul_mod_ptr: ModBuiltin*, + poseidon_ptr: PoseidonBuiltin*, +}(curve_id: felt, entropy: felt, attempt: felt) -> (res: G1Point) { + %{ print(f"Attempt : {ids.attempt}") %} + alloc_locals; + local rhs_from_x_is_a_square_residue: felt; + %{ + from starkware.python.math_utils import is_quad_residue + from src.definitions import CURVES + a = CURVES[ids.curve_id].a + b = CURVES[ids.curve_id].b + p = CURVES[ids.curve_id].p + rhs = (ids.entropy**3 + a*ids.entropy + b) % p + ids.rhs_from_x_is_a_square_residue = is_quad_residue(rhs, p) + %} + let (x_384: UInt384) = felt_to_UInt384(entropy); + let (b_Weirstrass: UInt384) = get_b(curve_id); + let (fp_generator: UInt384) = get_fp_gen(curve_id); + let (P: UInt384) = get_P(curve_id); + let (circuit) = get_DERIVE_POINT_FROM_X_circuit(curve_id); + + let (input: UInt384*) = alloc(); + assert input[0] = x_384; + assert input[1] = b_Weirstrass; + assert input[2] = fp_generator; + + let (output_array: felt*) = run_modulo_circuit(circuit, cast(input, felt*)); + let output: DerivePointFromXOutput* = cast(output_array, DerivePointFromXOutput*); + + if (rhs_from_x_is_a_square_residue != 0) { + // Assert should_be_rhs == rhs + verify_zero4( + UInt384( + output.rhs.d0 - output.should_be_rhs.d0, + output.rhs.d1 - output.should_be_rhs.d1, + output.rhs.d2 - output.should_be_rhs.d2, + output.rhs.d3 - output.should_be_rhs.d3, + ), + P, + ); + return (res=G1Point(x_384, output.y_try)); + } else { + // Assert should_be_grhs == grhs & Retry. + verify_zero4( + UInt384( + output.grhs.d0 - output.should_be_grhs.d0, + output.grhs.d1 - output.should_be_grhs.d1, + output.grhs.d2 - output.should_be_grhs.d2, + output.grhs.d3 - output.should_be_grhs.d3, + ), + P, + ); + assert poseidon_ptr[0].input.s0 = entropy; + assert poseidon_ptr[0].input.s1 = attempt; + assert poseidon_ptr[0].input.s2 = 2; + let new_entropy = poseidon_ptr[0].output.s0; + let poseidon_ptr = poseidon_ptr + PoseidonBuiltin.SIZE; + return derive_EC_point_from_entropy(curve_id, new_entropy, attempt + 1); + } +} diff --git a/src/modulo_circuit.cairo b/src/modulo_circuit.cairo index d3b265966..6bf604233 100644 --- a/src/modulo_circuit.cairo +++ b/src/modulo_circuit.cairo @@ -32,8 +32,8 @@ struct ModuloCircuit { mul_offsets_ptr: felt*, output_offsets_ptr: felt*, constants_ptr_len: felt, - witnesses_len: felt, input_len: felt, + witnesses_len: felt, output_len: felt, continuous_output: felt, add_mod_n: felt, diff --git a/src/modulo_circuit.py b/src/modulo_circuit.py index 76969468e..6a2dcc8d3 100644 --- a/src/modulo_circuit.py +++ b/src/modulo_circuit.py @@ -250,7 +250,7 @@ class ModuloCircuit: constants (dict[str, ModuloElement]): A dictionary mapping constant names to their ModuloElement representations. """ - def __init__(self, name: str, curve_id: int) -> None: + def __init__(self, name: str, curve_id: int, generic_circuit: bool = False) -> None: assert len(name) < 30, f"Name '{name}' is too long." self.name = name self.class_name = "ModuloCircuit" @@ -259,6 +259,7 @@ def __init__(self, name: str, curve_id: int) -> None: self.N_LIMBS = 4 self.values_segment: ValueSegment = ValueSegment(name) self.constants: dict[int, ModuloCircuitElement] = dict() + self.generic_circuit = generic_circuit self.set_or_get_constant(self.field.zero()) self.set_or_get_constant(self.field.one()) @@ -537,7 +538,14 @@ def compile_circuit( dw_arrays = self.values_segment.get_dw_lookups() name = function_name or self.values_segment.name function_name = f"get_{name}_circuit" - code = f"func {function_name}()->(circuit:{self.class_name}*)" + "{" + "\n" + if self.generic_circuit: + code = ( + f"func {function_name}(curve_id:felt)->(circuit:{self.class_name}*)" + + "{" + + "\n" + ) + else: + code = f"func {function_name}()->(circuit:{self.class_name}*)" + "{" + "\n" code += "alloc_locals;\n" code += "let (__fp__, _) = get_fp_and_pc();\n" @@ -557,7 +565,9 @@ def compile_circuit( f"let n_assert_eq = {len(self.values_segment.assert_eq_instructions)};\n" ) code += f"let name = '{self.name}';\n" - code += f"let curve_id = {self.curve_id};\n" + code += ( + f"let curve_id = {'curve_id' if self.generic_circuit else self.curve_id};\n" + ) code += f"local circuit:{self.class_name} = {self.class_name}({', '.join(returns['felt*'])}, {', '.join(returns['felt'])});\n" code += f"return (&circuit,);\n" @@ -603,6 +613,19 @@ def compile_circuit( "code": code, } + def summarize(self): + add_count, mul_count, assert_eq_count = self.values_segment.summarize() + summary = { + "circuit": self.name, + "MULMOD": mul_count, + "ADDMOD": add_count, + "ASSERT_EQ": assert_eq_count, + "POSEIDON": 0, + "RLC": 0, + } + + return summary + if __name__ == "__main__": from src.algebra import BaseField diff --git a/src/precompiled_circuits/all_circuits.py b/src/precompiled_circuits/all_circuits.py index 49107f54b..c6deed15c 100644 --- a/src/precompiled_circuits/all_circuits.py +++ b/src/precompiled_circuits/all_circuits.py @@ -1,5 +1,5 @@ from src.precompiled_circuits import multi_miller_loop, final_exp -from src.precompiled_circuits.ec import IsOnCurveCircuit +from src.precompiled_circuits.ec import IsOnCurveCircuit, DerivePointFromX from src.extension_field_modulo_circuit import ( ExtensionFieldModuloCircuit, ModuloCircuit, @@ -14,6 +14,7 @@ BLS12_381_ID, get_base_field, CurveID, + STARK, ) from random import seed, randint from enum import Enum @@ -38,12 +39,22 @@ class CircuitID(Enum): MILLER_LOOP_N2 = int.from_bytes(b"miller_loop_n2", "big") MILLER_LOOP_N3 = int.from_bytes(b"miller_loop_n3", "big") IS_ON_CURVE_G1_G2 = int.from_bytes(b"is_on_curve_g1_g2", "big") + DERIVE_POINT_FROM_X = int.from_bytes(b"derive_point_from_x", "big") from abc import ABC, abstractmethod class BaseModuloCircuit(ABC): + """ + Base class for all modulo circuits. + Parameters: + - name: str, the name of the circuit + - input_len: int, the number of input elements (/!\ of total felt252 values) + - curve_id: int, the id of the curve + - auto_run: bool, whether to run the circuit automatically at initialization. + """ + def __init__( self, name: str, @@ -57,6 +68,7 @@ def __init__( self.field = get_base_field(curve_id) self.input_len = input_len self.init_hash = None + self.generic_circuit = False if auto_run: self.circuit: ModuloCircuit = self._run_circuit_inner(self.build_input()) @@ -117,6 +129,34 @@ def _run_circuit_inner(self, input: list[PyFelt]) -> ModuloCircuit: return circuit +class DerivePointFromXCircuit(BaseModuloCircuit): + def __init__(self, curve_id: int, auto_run: bool = True) -> None: + super().__init__( + name="derive_point_from_x", + input_len=4 * 3, # X + b + G + SQ(X^3+ + SQG + curve_id=curve_id, + auto_run=auto_run, + ) + self.generic_circuit = True + + def build_input(self) -> list[PyFelt]: + input = [] + input.append(self.field(randint(0, STARK - 1))) + input.append(self.field(CURVES[self.curve_id].b)) # y^2 = x^3 + b + input.append(self.field(CURVES[self.curve_id].fp_generator)) + return input + + def _run_circuit_inner(self, input: list[PyFelt]) -> ModuloCircuit: + circuit = DerivePointFromX(self.name, self.curve_id) + x, b, g = circuit.write_elements(input[0:3], WriteOps.INPUT) + rhs, grhs, should_be_rhs, should_be_grhs, y_try = circuit._derive_point_from_x( + x, b, g + ) + circuit.extend_output([rhs, grhs, should_be_rhs, should_be_grhs, y_try]) + circuit.values_segment = circuit.values_segment.non_interactive_transform() + return circuit + + class FP12MulCircuit(BaseEXTFCircuit): def __init__(self, curve_id: int, auto_run: bool = True, init_hash: int = None): super().__init__("fp12_mul", 24, curve_id, auto_run, init_hash) @@ -412,6 +452,7 @@ def _run_circuit_inner(self, input: list[PyFelt]): ALL_EXTF_CIRCUITS = { CircuitID.IS_ON_CURVE_G1_G2: IsOnCurveG1G2Circuit, + CircuitID.DERIVE_POINT_FROM_X: DerivePointFromXCircuit, CircuitID.FP12_MUL: FP12MulCircuit, CircuitID.FINAL_EXP_PART_1: FinalExpPart1Circuit, CircuitID.FINAL_EXP_PART_2: FinalExpPart2Circuit, @@ -433,6 +474,7 @@ def to_snake_case(s: str) -> str: filenames = ["final_exp", "multi_miller_loop", "extf_mul", "ec"] circuit_name_to_filename = { CircuitID.IS_ON_CURVE_G1_G2: "ec", + CircuitID.DERIVE_POINT_FROM_X: "ec", CircuitID.FP12_MUL: "extf_mul", CircuitID.FINAL_EXP_PART_1: "final_exp", CircuitID.FINAL_EXP_PART_2: "final_exp", @@ -458,32 +500,43 @@ def to_snake_case(s: str) -> str: file.write(header) # Instantiate and compile circuits for each curve - - for curve_id in [CurveID.BN254, CurveID.BLS12_381]: + for i, curve_id in enumerate([CurveID.BN254, CurveID.BLS12_381]): for circuit_id, circuit_class in ALL_EXTF_CIRCUITS.items(): circuit_instance: BaseModuloCircuit = circuit_class(curve_id.value) print(f"Compiling {curve_id.name}:{circuit_instance.name} ...") - compiled_circuit = circuit_instance.circuit.compile_circuit( - function_name=f"{curve_id.name}_{circuit_id.name}" - ) + if circuit_instance.generic_circuit == True and i == 0: + compiled_circuit = circuit_instance.circuit.compile_circuit( + function_name=f"{circuit_id.name}" + ) + elif circuit_instance.generic_circuit == False: + compiled_circuit = circuit_instance.circuit.compile_circuit( + function_name=f"{curve_id.name}_{circuit_id.name}" + ) + else: + compiled_circuit = {"function_name": "", "code": ""} + filename_key = circuit_name_to_filename[circuit_id] codes[filename_key].append(compiled_circuit) struct_name = circuit_instance.circuit.class_name # Add the selector function for this circuit if it doesn't already exist in the list if circuit_id not in selector_functions[filename_key]: - selector_function = f""" -func get_{circuit_id.name}_circuit(curve_id:felt) -> (circuit:{struct_name}*){{ - if (curve_id == bn.CURVE_ID) {{ - return get_BN254_{circuit_id.name}_circuit(); - }} - if (curve_id == bls.CURVE_ID) {{ - return get_BLS12_381_{circuit_id.name}_circuit(); - }} - return get_void_{to_snake_case(struct_name)}(); -}} -""" - selector_functions[filename_key].add(selector_function) + if circuit_instance.generic_circuit == True: + selector_function = "" + else: + selector_function = f""" + func get_{circuit_id.name}_circuit(curve_id:felt) -> (circuit:{struct_name}*){{ + if (curve_id == bn.CURVE_ID) {{ + return get_BN254_{circuit_id.name}_circuit(); + }} + if (curve_id == bls.CURVE_ID) {{ + return get_BLS12_381_{circuit_id.name}_circuit(); + }} + return get_void_{to_snake_case(struct_name)}(); + }} + """ + + selector_functions[filename_key].add(selector_function) # Write selector functions and compiled circuit codes to their respective files print(f"Writing circuits and selectors to .cairo files...") diff --git a/src/precompiled_circuits/ec.cairo b/src/precompiled_circuits/ec.cairo index 05ce5bc4e..8e1f1216d 100644 --- a/src/precompiled_circuits/ec.cairo +++ b/src/precompiled_circuits/ec.cairo @@ -157,6 +157,115 @@ func get_BN254_IS_ON_CURVE_G1_G2_circuit() -> (circuit: ModuloCircuit*) { dw 60; } +func get_DERIVE_POINT_FROM_X_circuit(curve_id: felt) -> (circuit: ModuloCircuit*) { + alloc_locals; + let (__fp__, _) = get_fp_and_pc(); + let (constants_ptr: felt*) = get_label_location(constants_ptr_loc); + let (add_offsets_ptr: felt*) = get_label_location(add_offsets_ptr_loc); + let (mul_offsets_ptr: felt*) = get_label_location(mul_offsets_ptr_loc); + let (output_offsets_ptr: felt*) = get_label_location(output_offsets_ptr_loc); + let constants_ptr_len = 3; + let input_len = 12; + let witnesses_len = 8; + let output_len = 20; + let continuous_output = 0; + let add_mod_n = 1; + let mul_mod_n = 5; + let n_assert_eq = 0; + let name = 'derive_point_from_x'; + let curve_id = curve_id; + local circuit: ModuloCircuit = ModuloCircuit( + constants_ptr, + add_offsets_ptr, + mul_offsets_ptr, + output_offsets_ptr, + constants_ptr_len, + input_len, + witnesses_len, + output_len, + continuous_output, + add_mod_n, + mul_mod_n, + n_assert_eq, + name, + curve_id, + ); + return (&circuit,); + + constants_ptr_loc: + dw 0; + dw 0; + dw 0; + dw 0; + dw 1; + dw 0; + dw 0; + dw 0; + dw 32324006162389411176778628422; + dw 57042285082623239461879769745; + dw 3486998266802970665; + dw 0; + + add_offsets_ptr_loc: + dw 36; + dw 16; + dw 40; + dw 36; + dw 16; + dw 40; + dw 36; + dw 16; + dw 40; + dw 36; + dw 16; + dw 40; + dw 36; + dw 16; + dw 40; + dw 36; + dw 16; + dw 40; + dw 36; + dw 16; + dw 40; + dw 36; + dw 16; + dw 40; + + mul_offsets_ptr_loc: + dw 12; + dw 12; + dw 32; + dw 12; + dw 32; + dw 36; + dw 20; + dw 40; + dw 44; + dw 24; + dw 24; + dw 48; + dw 28; + dw 28; + dw 52; + dw 12; + dw 12; + dw 32; + dw 12; + dw 12; + dw 32; + dw 12; + dw 12; + dw 32; + + output_offsets_ptr_loc: + dw 40; + dw 44; + dw 48; + dw 52; + dw 24; +} + func get_BLS12_381_IS_ON_CURVE_G1_G2_circuit() -> (circuit: ModuloCircuit*) { alloc_locals; let (__fp__, _) = get_fp_and_pc(); diff --git a/src/precompiled_circuits/ec.py b/src/precompiled_circuits/ec.py index 8f2fffc43..03caf0099 100644 --- a/src/precompiled_circuits/ec.py +++ b/src/precompiled_circuits/ec.py @@ -4,9 +4,7 @@ ModuloCircuitElement, PyFelt, Polynomial, - AccPolyInstructionType, ) -from src.poseidon_transcript import CairoPoseidonTranscript from src.definitions import ( CURVES, STARK, @@ -14,17 +12,13 @@ BN254_ID, BLS12_381_ID, Curve, - generate_frobenius_maps, - get_V_torus_powers, - get_sparsity, -) -from src.hints.extf_mul import ( - nondeterministic_square_torus, - nondeterministic_extension_field_mul_divmod, ) + import random from enum import Enum +from starkware.python.math_utils import is_quad_residue, sqrt as sqrt_mod_p + class IsOnCurveCircuit(ModuloCircuit): def __init__(self, name: str, curve_id: int): @@ -33,7 +27,7 @@ def __init__(self, name: str, curve_id: int): def _is_on_curve_G1( self, x: ModuloCircuitElement, y: ModuloCircuitElement - ) -> ModuloCircuitElement: + ) -> tuple[ModuloCircuitElement, ModuloCircuitElement]: # y^2 = x^3 + ax + b a = self.set_or_get_constant(self.field(self.curve.a)) b = self.set_or_get_constant(self.field(self.curve.b)) @@ -49,7 +43,13 @@ def _is_on_curve_G1( return y2, x3_ax_b - def _is_on_curve_G2(self, x0, x1, y0, y1): + def _is_on_curve_G2( + self, + x0: ModuloCircuitElement, + x1: ModuloCircuitElement, + y0: ModuloCircuitElement, + y1: ModuloCircuitElement, + ): # y^2 = x^3 + ax + b [Fp2] a = self.set_or_get_constant(self.field(self.curve.a)) b0 = self.set_or_get_constant(self.field(self.curve.b20)) @@ -68,3 +68,49 @@ def _is_on_curve_G2(self, x0, x1, y0, y1): x3_ax_b = [self.add(x3[0], ax_b[0]), self.add(x3[1], ax_b[1])] return y2, x3_ax_b + + +class DerivePointFromX(ModuloCircuit): + def __init__(self, name: str, curve_id: int): + super().__init__(name=name, curve_id=curve_id, generic_circuit=True) + self.curve = CURVES[curve_id] + + def _derive_point_from_x( + self, + x: ModuloCircuitElement, + b: ModuloCircuitElement, + g: ModuloCircuitElement, + ) -> list[ModuloCircuitElement]: + # y^2 = x^3 + ax + b + # Assumes a == 0. + x3 = self.mul(x, self.mul(x, x)) + rhs = self.add(x3, b) + + grhs = self.mul(g, rhs) + + # WRITE g*rhs and rhs "square roots" to circuit. + # If rhs is a square, write zero to gx and the square root of rhs to x3_ax_b_sqrt. + # Otherwise, write the square root of gx to gx_sqrt and zero to x3_ax_b_sqrt. + ## %{ + if is_quad_residue(rhs.value, self.field.p): + rhs_sqrt = self.write_element( + self.field(sqrt_mod_p(rhs.value, self.field.p)), + WriteOps.WITNESS, + ) + grhs_sqrt = self.write_element(self.field.zero(), WriteOps.WITNESS) + + else: + assert is_quad_residue(grhs.value, self.field.p) # Sanity check. + rhs_sqrt = self.write_element(self.field.zero(), WriteOps.WITNESS) + + grhs_sqrt = self.write_element( + self.field(sqrt_mod_p(grhs.value, self.field.p)), + WriteOps.WITNESS, + ) + + ## %} + should_be_rhs = self.mul(rhs_sqrt, rhs_sqrt) + + should_be_grhs = self.mul(grhs_sqrt, grhs_sqrt) + + return (rhs, grhs, should_be_rhs, should_be_grhs, rhs_sqrt) diff --git a/src/utils.cairo b/src/utils.cairo index e4035481d..efed72035 100644 --- a/src/utils.cairo +++ b/src/utils.cairo @@ -1,7 +1,7 @@ from starkware.cairo.common.alloc import alloc from starkware.cairo.common.cairo_builtins import PoseidonBuiltin from starkware.cairo.common.poseidon_state import PoseidonBuiltinState -from src.definitions import STARK_MIN_ONE_D2, N_LIMBS, BASE, bn +from src.definitions import STARK_MIN_ONE_D2, N_LIMBS, BASE, bn, UInt384 func get_Z_and_RLC_from_transcript{poseidon_ptr: PoseidonBuiltin*, range_check96_ptr: felt*}( transcript_start: felt*, @@ -291,6 +291,30 @@ func write_felts_to_value_segment{range_check96_ptr: felt*}(values_start: felt*, return (); } +func felt_to_UInt384{range_check96_ptr: felt*}(x: felt) -> (res: UInt384) { + let d0 = [range_check96_ptr]; + let d1 = [range_check96_ptr + 1]; + let d2 = [range_check96_ptr + 2]; + %{ + from src.hints.io import bigint_split + limbs = bigint_split(ids.x, ids.N_LIMBS, ids.BASE) + assert limbs[3] == 0 + ids.d0, ids.d1, ids.d2 = limbs[0], limbs[1], limbs[2] + %} + assert [range_check96_ptr + 3] = STARK_MIN_ONE_D2 - d2; + assert x = d0 + d1 * BASE + d2 * BASE ** 2; + if (d2 == STARK_MIN_ONE_D2) { + // Take advantage of Cairo prime structure. STARK_MIN_ONE = 0 + 0 * BASE + stark_min_1_d2 * (BASE)**2. + assert d1 = 0; + assert d2 = 0; + tempvar range_check96_ptr = range_check96_ptr + 4; + return (res=UInt384(d0, d1, d2, 0)); + } else { + tempvar range_check96_ptr = range_check96_ptr + 4; + return (res=UInt384(d0, d1, d2, 0)); + } +} + func retrieve_output{}( values_segment: felt*, output_offsets_ptr: felt*, n: felt, continuous_output: felt ) -> (output: felt*) { diff --git a/tests/benchmarks.py b/tests/benchmarks.py index fc40dcefd..97290b0f9 100644 --- a/tests/benchmarks.py +++ b/tests/benchmarks.py @@ -12,12 +12,16 @@ precompute_lineline_sparsity, ) from random import randint +import random from src.extension_field_modulo_circuit import ExtensionFieldModuloCircuit, WriteOps from src.precompiled_circuits.final_exp import FinalExpTorusCircuit, test_final_exp from src.precompiled_circuits.multi_miller_loop import MultiMillerLoopCircuit +from src.precompiled_circuits.ec import DerivePointFromX from tools.gnark_cli import GnarkCLI from src.hints.tower_backup import E12 +random.seed(0) + def test_extf_mul(curve_id: CurveID, extension_degree: int): curve: Curve = CURVES[curve_id.value] @@ -346,6 +350,19 @@ def test_miller_n(curve_id, n): return c.summarize(), c.ops_counter +def test_derive_point_from_x(curve_id: CurveID): + field = get_base_field(curve_id.value) + c = DerivePointFromX( + f"Derive Point From X", + curve_id.value, + ) + x = c.write_element(field(randint(0, STARK - 1))) + b = c.write_element(field(CURVES[curve_id.value].b)) + g = c.write_element(field(CURVES[curve_id.value].fp_generator)) + c._derive_point_from_x(x, b, g) + return c.summarize(), None + + if __name__ == "__main__": import pandas as pd from tabulate import tabulate @@ -387,6 +404,7 @@ def test_miller_n(curve_id, n): builtin_ops_data.append(builtin_ops) for test_func, curve_id in [ + (test_derive_point_from_x, CurveID.BN254), (test_double_step, CurveID.BLS12_381), (test_double_and_add_step, CurveID.BLS12_381), (test_double_step, CurveID.BN254), diff --git a/tests/cairo_programs/derive_ec_point.cairo b/tests/cairo_programs/derive_ec_point.cairo new file mode 100644 index 000000000..8f9027438 --- /dev/null +++ b/tests/cairo_programs/derive_ec_point.cairo @@ -0,0 +1,62 @@ +%builtins range_check poseidon range_check96 add_mod mul_mod + +from starkware.cairo.common.cairo_builtins import PoseidonBuiltin, ModBuiltin +from starkware.cairo.common.registers import get_fp_and_pc +from starkware.cairo.common.alloc import alloc + +from src.definitions import bn, bls, UInt384, one_E12D, N_LIMBS, BASE, G1Point + +from src.ec_ops import derive_EC_point_from_entropy +from src.modulo_circuit import ExtensionFieldModuloCircuit + +func main{ + range_check_ptr, + poseidon_ptr: PoseidonBuiltin*, + range_check96_ptr: felt*, + add_mod_ptr: ModBuiltin*, + mul_mod_ptr: ModBuiltin*, +}() { + alloc_locals; + let (__fp__, _) = get_fp_and_pc(); + local entropy0: felt; + local entropy1: felt; + local entropy2: felt; + local entropy3: felt; + local entropy4: felt; + local entropy5: felt; + local entropy6: felt; + local entropy7: felt; + local entropy8: felt; + local entropy9: felt; + %{ + from random import randint + from src.definitions import STARK + entropies = [randint(0, STARK-1) for _ in range(10)] + for i in range(10): + setattr(ids, f"entropy{i}", entropies[i]) + %} + + let (random_point: G1Point) = derive_EC_point_from_entropy(bn.CURVE_ID, entropy0, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bn.CURVE_ID, entropy1, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bn.CURVE_ID, entropy2, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bn.CURVE_ID, entropy3, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bn.CURVE_ID, entropy4, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bn.CURVE_ID, entropy5, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bn.CURVE_ID, entropy6, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bn.CURVE_ID, entropy7, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bn.CURVE_ID, entropy8, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bn.CURVE_ID, entropy9, 0); + + let (random_point: G1Point) = derive_EC_point_from_entropy(bls.CURVE_ID, entropy0, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bls.CURVE_ID, entropy1, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bls.CURVE_ID, entropy2, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bls.CURVE_ID, entropy3, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bls.CURVE_ID, entropy4, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bls.CURVE_ID, entropy5, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bls.CURVE_ID, entropy6, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bls.CURVE_ID, entropy7, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bls.CURVE_ID, entropy8, 0); + let (random_point: G1Point) = derive_EC_point_from_entropy(bls.CURVE_ID, entropy9, 0); + + return (); +}