Skip to content

Commit

Permalink
add derive_point_from_X circuit
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime committed May 7, 2024
1 parent 3da0586 commit 36c9edd
Show file tree
Hide file tree
Showing 11 changed files with 491 additions and 37 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ make run

| circuit | MULMOD | ADDMOD | ASSERT_EQ | POSEIDON | RLC | ~steps |
|-------------------------------------------|----------|----------|-------------|------------|-------|----------|
| Derive Point From X | 5 | 1 | 0 | 0 | 0 | 44 |
| Double Step BLS12_381 | 22 | 9 | 2 | 0 | 0 | 216 |
| Double Step BN254 | 24 | 11 | 2 | 0 | 0 | 240 |
| Fp6 SQUARE_TORUS | 12 | 16 | 0 | 7 | 1 | 300 |
Expand Down
35 changes: 35 additions & 0 deletions src/definitions.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ namespace bls {
// Non residue constants:
const NON_RESIDUE_E2_a0 = 1;
const NON_RESIDUE_E2_a1 = 1;
// Curve equation parameters:
const a = 0;
const b = 4;
// Fp generator :
const g = 3;
}

namespace bn {
Expand All @@ -36,6 +43,11 @@ namespace bn {
// Non residue constants:
const NON_RESIDUE_E2_a0 = 9;
const NON_RESIDUE_E2_a1 = 1;
// Curve equation parameters:
const a = 0;
const b = 3;
// Fp Generator :
const g = 3;
}

func get_P(curve_id: felt) -> (prime: UInt384) {
Expand All @@ -50,6 +62,29 @@ func get_P(curve_id: felt) -> (prime: UInt384) {
}
}

func get_b(curve_id: felt) -> (res: UInt384) {
if (curve_id == bls.CURVE_ID) {
return (res=UInt384(bls.b, 0, 0, 0));
} else {
if (curve_id == bn.CURVE_ID) {
return (res=UInt384(bn.b, 0, 0, 0));
} else {
return (res=UInt384(-1, 0, 0, 0));
}
}
}

func get_fp_gen(curve_id: felt) -> (res: UInt384) {
if (curve_id == bls.CURVE_ID) {
return (res=UInt384(bls.g, 0, 0, 0));
} else {
if (curve_id == bn.CURVE_ID) {
return (res=UInt384(bn.g, 0, 0, 0));
} else {
return (res=UInt384(-1, 0, 0, 0));
}
}
}
const SUPPORTED_CURVE_ID = 0;
const UNSUPPORTED_CURVE_ID = 1;

Expand Down
89 changes: 86 additions & 3 deletions src/ec_ops.cairo
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from src.definitions import is_zero_mod_P, get_P
from src.precompiled_circuits.ec import get_IS_ON_CURVE_G1_G2_circuit
from src.definitions import is_zero_mod_P, get_P, G1Point, get_b, get_fp_gen, verify_zero4, BASE
from src.precompiled_circuits.ec import (
get_IS_ON_CURVE_G1_G2_circuit,
get_DERIVE_POINT_FROM_X_circuit,
)
from src.modulo_circuit import run_modulo_circuit, ModuloCircuit
from starkware.cairo.common.cairo_builtins import ModBuiltin, UInt384
from starkware.cairo.common.cairo_builtins import ModBuiltin, UInt384, PoseidonBuiltin
from starkware.cairo.common.alloc import alloc
from src.utils import felt_to_UInt384

func is_on_curve_g1_g2{
range_check_ptr, range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin*
Expand All @@ -26,3 +31,81 @@ func is_on_curve_g1_g2{
}
return (res=1);
}

struct DerivePointFromXOutput {
rhs: UInt384, // x^3 + ax + b
grhs: UInt384, // g * (x^3+ax+b)
should_be_rhs: UInt384,
should_be_grhs: UInt384,
y_try: UInt384,
}

struct DerivePointFromXInput {
entropy: UInt384,
y: UInt384,
rhs_from_x_is_a_square_residue: felt,
}
func derive_EC_point_from_entropy{
range_check_ptr,
range_check96_ptr: felt*,
add_mod_ptr: ModBuiltin*,
mul_mod_ptr: ModBuiltin*,
poseidon_ptr: PoseidonBuiltin*,
}(curve_id: felt, entropy: felt, attempt: felt) -> (res: G1Point) {
%{ print(f"Attempt : {ids.attempt}") %}
alloc_locals;
local rhs_from_x_is_a_square_residue: felt;
%{
from starkware.python.math_utils import is_quad_residue
from src.definitions import CURVES
a = CURVES[ids.curve_id].a
b = CURVES[ids.curve_id].b
p = CURVES[ids.curve_id].p
rhs = (ids.entropy**3 + a*ids.entropy + b) % p
ids.rhs_from_x_is_a_square_residue = is_quad_residue(rhs, p)
%}
let (x_384: UInt384) = felt_to_UInt384(entropy);
let (b_Weirstrass: UInt384) = get_b(curve_id);
let (fp_generator: UInt384) = get_fp_gen(curve_id);
let (P: UInt384) = get_P(curve_id);
let (circuit) = get_DERIVE_POINT_FROM_X_circuit(curve_id);

let (input: UInt384*) = alloc();
assert input[0] = x_384;
assert input[1] = b_Weirstrass;
assert input[2] = fp_generator;

let (output_array: felt*) = run_modulo_circuit(circuit, cast(input, felt*));
let output: DerivePointFromXOutput* = cast(output_array, DerivePointFromXOutput*);

if (rhs_from_x_is_a_square_residue != 0) {
// Assert should_be_rhs == rhs
verify_zero4(
UInt384(
output.rhs.d0 - output.should_be_rhs.d0,
output.rhs.d1 - output.should_be_rhs.d1,
output.rhs.d2 - output.should_be_rhs.d2,
output.rhs.d3 - output.should_be_rhs.d3,
),
P,
);
return (res=G1Point(x_384, output.y_try));
} else {
// Assert should_be_grhs == grhs & Retry.
verify_zero4(
UInt384(
output.grhs.d0 - output.should_be_grhs.d0,
output.grhs.d1 - output.should_be_grhs.d1,
output.grhs.d2 - output.should_be_grhs.d2,
output.grhs.d3 - output.should_be_grhs.d3,
),
P,
);
assert poseidon_ptr[0].input.s0 = entropy;
assert poseidon_ptr[0].input.s1 = attempt;
assert poseidon_ptr[0].input.s2 = 2;
let new_entropy = poseidon_ptr[0].output.s0;
let poseidon_ptr = poseidon_ptr + PoseidonBuiltin.SIZE;
return derive_EC_point_from_entropy(curve_id, new_entropy, attempt + 1);
}
}
2 changes: 1 addition & 1 deletion src/modulo_circuit.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ struct ModuloCircuit {
mul_offsets_ptr: felt*,
output_offsets_ptr: felt*,
constants_ptr_len: felt,
witnesses_len: felt,
input_len: felt,
witnesses_len: felt,
output_len: felt,
continuous_output: felt,
add_mod_n: felt,
Expand Down
29 changes: 26 additions & 3 deletions src/modulo_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ class ModuloCircuit:
constants (dict[str, ModuloElement]): A dictionary mapping constant names to their ModuloElement representations.
"""

def __init__(self, name: str, curve_id: int) -> None:
def __init__(self, name: str, curve_id: int, generic_circuit: bool = False) -> None:
assert len(name) < 30, f"Name '{name}' is too long."
self.name = name
self.class_name = "ModuloCircuit"
Expand All @@ -259,6 +259,7 @@ def __init__(self, name: str, curve_id: int) -> None:
self.N_LIMBS = 4
self.values_segment: ValueSegment = ValueSegment(name)
self.constants: dict[int, ModuloCircuitElement] = dict()
self.generic_circuit = generic_circuit

self.set_or_get_constant(self.field.zero())
self.set_or_get_constant(self.field.one())
Expand Down Expand Up @@ -537,7 +538,14 @@ def compile_circuit(
dw_arrays = self.values_segment.get_dw_lookups()
name = function_name or self.values_segment.name
function_name = f"get_{name}_circuit"
code = f"func {function_name}()->(circuit:{self.class_name}*)" + "{" + "\n"
if self.generic_circuit:
code = (
f"func {function_name}(curve_id:felt)->(circuit:{self.class_name}*)"
+ "{"
+ "\n"
)
else:
code = f"func {function_name}()->(circuit:{self.class_name}*)" + "{" + "\n"

code += "alloc_locals;\n"
code += "let (__fp__, _) = get_fp_and_pc();\n"
Expand All @@ -557,7 +565,9 @@ def compile_circuit(
f"let n_assert_eq = {len(self.values_segment.assert_eq_instructions)};\n"
)
code += f"let name = '{self.name}';\n"
code += f"let curve_id = {self.curve_id};\n"
code += (
f"let curve_id = {'curve_id' if self.generic_circuit else self.curve_id};\n"
)

code += f"local circuit:{self.class_name} = {self.class_name}({', '.join(returns['felt*'])}, {', '.join(returns['felt'])});\n"
code += f"return (&circuit,);\n"
Expand Down Expand Up @@ -603,6 +613,19 @@ def compile_circuit(
"code": code,
}

def summarize(self):
add_count, mul_count, assert_eq_count = self.values_segment.summarize()
summary = {
"circuit": self.name,
"MULMOD": mul_count,
"ADDMOD": add_count,
"ASSERT_EQ": assert_eq_count,
"POSEIDON": 0,
"RLC": 0,
}

return summary


if __name__ == "__main__":
from src.algebra import BaseField
Expand Down
89 changes: 71 additions & 18 deletions src/precompiled_circuits/all_circuits.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src.precompiled_circuits import multi_miller_loop, final_exp
from src.precompiled_circuits.ec import IsOnCurveCircuit
from src.precompiled_circuits.ec import IsOnCurveCircuit, DerivePointFromX
from src.extension_field_modulo_circuit import (
ExtensionFieldModuloCircuit,
ModuloCircuit,
Expand All @@ -14,6 +14,7 @@
BLS12_381_ID,
get_base_field,
CurveID,
STARK,
)
from random import seed, randint
from enum import Enum
Expand All @@ -38,12 +39,22 @@ class CircuitID(Enum):
MILLER_LOOP_N2 = int.from_bytes(b"miller_loop_n2", "big")
MILLER_LOOP_N3 = int.from_bytes(b"miller_loop_n3", "big")
IS_ON_CURVE_G1_G2 = int.from_bytes(b"is_on_curve_g1_g2", "big")
DERIVE_POINT_FROM_X = int.from_bytes(b"derive_point_from_x", "big")


from abc import ABC, abstractmethod


class BaseModuloCircuit(ABC):
"""
Base class for all modulo circuits.
Parameters:
- name: str, the name of the circuit
- input_len: int, the number of input elements (/!\ of total felt252 values)
- curve_id: int, the id of the curve
- auto_run: bool, whether to run the circuit automatically at initialization.
"""

def __init__(
self,
name: str,
Expand All @@ -57,6 +68,7 @@ def __init__(
self.field = get_base_field(curve_id)
self.input_len = input_len
self.init_hash = None
self.generic_circuit = False
if auto_run:
self.circuit: ModuloCircuit = self._run_circuit_inner(self.build_input())

Expand Down Expand Up @@ -117,6 +129,34 @@ def _run_circuit_inner(self, input: list[PyFelt]) -> ModuloCircuit:
return circuit


class DerivePointFromXCircuit(BaseModuloCircuit):
def __init__(self, curve_id: int, auto_run: bool = True) -> None:
super().__init__(
name="derive_point_from_x",
input_len=4 * 3, # X + b + G + SQ(X^3+ + SQG
curve_id=curve_id,
auto_run=auto_run,
)
self.generic_circuit = True

def build_input(self) -> list[PyFelt]:
input = []
input.append(self.field(randint(0, STARK - 1)))
input.append(self.field(CURVES[self.curve_id].b)) # y^2 = x^3 + b
input.append(self.field(CURVES[self.curve_id].fp_generator))
return input

def _run_circuit_inner(self, input: list[PyFelt]) -> ModuloCircuit:
circuit = DerivePointFromX(self.name, self.curve_id)
x, b, g = circuit.write_elements(input[0:3], WriteOps.INPUT)
rhs, grhs, should_be_rhs, should_be_grhs, y_try = circuit._derive_point_from_x(
x, b, g
)
circuit.extend_output([rhs, grhs, should_be_rhs, should_be_grhs, y_try])
circuit.values_segment = circuit.values_segment.non_interactive_transform()
return circuit


class FP12MulCircuit(BaseEXTFCircuit):
def __init__(self, curve_id: int, auto_run: bool = True, init_hash: int = None):
super().__init__("fp12_mul", 24, curve_id, auto_run, init_hash)
Expand Down Expand Up @@ -412,6 +452,7 @@ def _run_circuit_inner(self, input: list[PyFelt]):

ALL_EXTF_CIRCUITS = {
CircuitID.IS_ON_CURVE_G1_G2: IsOnCurveG1G2Circuit,
CircuitID.DERIVE_POINT_FROM_X: DerivePointFromXCircuit,
CircuitID.FP12_MUL: FP12MulCircuit,
CircuitID.FINAL_EXP_PART_1: FinalExpPart1Circuit,
CircuitID.FINAL_EXP_PART_2: FinalExpPart2Circuit,
Expand All @@ -433,6 +474,7 @@ def to_snake_case(s: str) -> str:
filenames = ["final_exp", "multi_miller_loop", "extf_mul", "ec"]
circuit_name_to_filename = {
CircuitID.IS_ON_CURVE_G1_G2: "ec",
CircuitID.DERIVE_POINT_FROM_X: "ec",
CircuitID.FP12_MUL: "extf_mul",
CircuitID.FINAL_EXP_PART_1: "final_exp",
CircuitID.FINAL_EXP_PART_2: "final_exp",
Expand All @@ -458,32 +500,43 @@ def to_snake_case(s: str) -> str:
file.write(header)

# Instantiate and compile circuits for each curve

for curve_id in [CurveID.BN254, CurveID.BLS12_381]:
for i, curve_id in enumerate([CurveID.BN254, CurveID.BLS12_381]):
for circuit_id, circuit_class in ALL_EXTF_CIRCUITS.items():
circuit_instance: BaseModuloCircuit = circuit_class(curve_id.value)
print(f"Compiling {curve_id.name}:{circuit_instance.name} ...")

compiled_circuit = circuit_instance.circuit.compile_circuit(
function_name=f"{curve_id.name}_{circuit_id.name}"
)
if circuit_instance.generic_circuit == True and i == 0:
compiled_circuit = circuit_instance.circuit.compile_circuit(
function_name=f"{circuit_id.name}"
)
elif circuit_instance.generic_circuit == False:
compiled_circuit = circuit_instance.circuit.compile_circuit(
function_name=f"{curve_id.name}_{circuit_id.name}"
)
else:
compiled_circuit = {"function_name": "", "code": ""}

filename_key = circuit_name_to_filename[circuit_id]
codes[filename_key].append(compiled_circuit)
struct_name = circuit_instance.circuit.class_name
# Add the selector function for this circuit if it doesn't already exist in the list
if circuit_id not in selector_functions[filename_key]:
selector_function = f"""
func get_{circuit_id.name}_circuit(curve_id:felt) -> (circuit:{struct_name}*){{
if (curve_id == bn.CURVE_ID) {{
return get_BN254_{circuit_id.name}_circuit();
}}
if (curve_id == bls.CURVE_ID) {{
return get_BLS12_381_{circuit_id.name}_circuit();
}}
return get_void_{to_snake_case(struct_name)}();
}}
"""
selector_functions[filename_key].add(selector_function)
if circuit_instance.generic_circuit == True:
selector_function = ""
else:
selector_function = f"""
func get_{circuit_id.name}_circuit(curve_id:felt) -> (circuit:{struct_name}*){{
if (curve_id == bn.CURVE_ID) {{
return get_BN254_{circuit_id.name}_circuit();
}}
if (curve_id == bls.CURVE_ID) {{
return get_BLS12_381_{circuit_id.name}_circuit();
}}
return get_void_{to_snake_case(struct_name)}();
}}
"""

selector_functions[filename_key].add(selector_function)

# Write selector functions and compiled circuit codes to their respective files
print(f"Writing circuits and selectors to .cairo files...")
Expand Down
Loading

0 comments on commit 36c9edd

Please sign in to comment.