Skip to content

Commit

Permalink
Check Prover Points are on curve & include scalars in Fiat-Shamir
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime committed Jun 3, 2024
1 parent aad04d1 commit e1ae840
Showing 1 changed file with 47 additions and 12 deletions.
59 changes: 47 additions & 12 deletions src/ec_ops.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ from src.definitions import (
)
from src.precompiled_circuits.ec import (
get_IS_ON_CURVE_G1_G2_circuit,
get_IS_ON_CURVE_G1_circuit,
get_DERIVE_POINT_FROM_X_circuit,
get_SLOPE_INTERCEPT_SAME_POINT_circuit,
get_ACCUMULATE_EVAL_POINT_CHALLENGE_SIGNED_circuit,
Expand All @@ -27,6 +28,7 @@ from starkware.cairo.common.cairo_builtins import ModBuiltin, UInt384, PoseidonB
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.memcpy import memcpy
from starkware.cairo.common.registers import get_fp_and_pc
from starkware.cairo.common.builtin_poseidon.poseidon import poseidon_hash, poseidon_hash_many

from src.utils import (
felt_to_UInt384,
Expand All @@ -36,6 +38,23 @@ from src.utils import (
hash_full_transcript_and_get_Z,
)

func is_on_curve_g1{
range_check_ptr, range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin*
}(curve_id: felt, point: G1Point) -> (res: felt) {
alloc_locals;
let (P) = get_P(curve_id);
let (A) = get_a(curve_id);
let (B) = get_b(curve_id);
let (circuit) = get_IS_ON_CURVE_G1_circuit(curve_id);
let (input: UInt384*) = alloc();
assert input[0] = point.x;
assert input[1] = point.y;
assert input[2] = A;
assert input[3] = B;
let (output: felt*) = run_modulo_circuit(circuit, input);
let (check_g1: felt) = is_zero_mod_P([cast(output, UInt384*)], P);
return (res=check_g1);
}
func is_on_curve_g1_g2{
range_check_ptr, range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin*
}(curve_id: felt, input: felt*) -> (res: felt) {
Expand Down Expand Up @@ -80,14 +99,21 @@ func add_ec_points{
return (res=G1Point(UInt384(0, 0, 0, 0), UInt384(0, 0, 0, 0)));
} else {
let (circuit) = get_DOUBLE_EC_POINT_circuit(curve_id);
let (res) = run_modulo_circuit(circuit, cast(&P, felt*));
let (A) = get_a(curve_id);
let (input: UInt384*) = alloc();
assert input[0] = P.x;
assert input[1] = P.y;
assert input[2] = A;
let (res) = run_modulo_circuit(circuit, cast(input, felt*));
return (res=[cast(res, G1Point*)]);
}
} else {
let (circuit) = get_ADD_EC_POINT_circuit(curve_id);
let (input: G1Point*) = alloc();
assert input[0] = P;
assert input[1] = Q;
let (input: UInt384*) = alloc();
assert input[0] = P.x;
assert input[1] = P.y;
assert input[2] = Q.x;
assert input[3] = Q.y;
let (res) = run_modulo_circuit(circuit, cast(input, felt*));
return (res=[cast(res, G1Point*)]);
}
Expand Down Expand Up @@ -203,6 +229,7 @@ func compute_slope_intercept_same_point{
return (res=output);
}

// Compute RHS of eq 3 in https://eprint.iacr.org/2022/596.pdf , without accounting for result point
func compute_RHS_basis_sum{
range_check_ptr, range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin*
}(
Expand Down Expand Up @@ -267,6 +294,7 @@ func compute_RHS_basis_sum{
);
}

// Finalize RHS computation of eq 3 in https://eprint.iacr.org/2022/596.pdf , accounting for the result point (-Q)
func finalize_RHS{
range_check_ptr, range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin*
}(acc_circuit: ModuloCircuit*, Q: G1Point, sum: UInt384, constants: SlopeInterceptOutput*) -> (
Expand Down Expand Up @@ -380,6 +408,13 @@ func msm{
print(f"Hashing Z = Poseidon(Input, Commitments) = Hash(Points, scalars, Q_low, Q_high, Q_high_shifted, SumDlogDivLow, SumDlogDivHigh, SumDlogDivShifted)...")
%}
let (is_on_curve_low: felt) = is_on_curve_g1(curve_id, Q_low);
let (is_on_curve_high: felt) = is_on_curve_g1(curve_id, Q_high);
let (is_on_curve_shifted: felt) = is_on_curve_g1(curve_id, Q_high_shifted);
assert is_on_curve_low = 1;
assert is_on_curve_high = 1;
assert is_on_curve_shifted = 1;

let (Z: felt) = hash_full_transcript_and_get_Z(
limbs_ptr=cast(points, felt*), n=n * 2, init_hash='MSM', curve_id=curve_id
);
Expand All @@ -392,6 +427,8 @@ func msm{
let (Z: felt) = hash_full_transcript_and_get_Z(
limbs_ptr=cast(&Q_high_shifted, felt*), n=2, init_hash=Z, curve_id=curve_id
);
let (Zt: felt) = poseidon_hash_many(2 * n, cast(scalars, felt*));
let (Z: felt) = poseidon_hash(Zt, Z);
let (Z: felt) = hash_sum_dlog_div(f=SumDlogDivLow, msm_size=n, init_hash=Z, curve_id=curve_id);
let (Z: felt) = hash_sum_dlog_div(f=SumDlogDivHigh, msm_size=n, init_hash=Z, curve_id=curve_id);
let (Z: felt) = hash_sum_dlog_div(
Expand All @@ -402,14 +439,6 @@ func msm{
// Sample random EC point for challenge.
let (random_point: G1Point) = derive_EC_point_from_entropy(curve_id, Z, 0);

// let random_point: G1Point = G1Point(
// UInt384(
// 17787641035291477414016626793, 37797302034529340534579361905, 2935882634526754611, 0
// ),
// UInt384(
// 39243316047810539479482107377, 68275252041657531689158113746, 1897820647238684479, 0
// ),
// );
// Get slope, intercept and other constants from random EC point
let (local mb: SlopeInterceptOutput*) = compute_slope_intercept_same_point(
curve_id, random_point
Expand All @@ -430,6 +459,8 @@ func msm{
) = get_EVAL_FUNCTION_CHALLENGE_DUPL_circuit(curve_id, 1);

%{ print(f"Verifying ZK-ECIP equations evaluated at the random point...") %}

// Verify Q_low = sum(scalar_low * P for scalar_low,P in zip(scalars_low, points))
let (Q_low_is_zero) = zk_ecip_check(
curve_id=curve_id,
points=points,
Expand All @@ -443,6 +474,7 @@ func msm{
sum_dlog_div=SumDlogDivLow,
);

// Verify Q_high = sum(scalar_high * P for scalar_high,P in zip(scalars_high, points))
let (_) = zk_ecip_check(
curve_id=curve_id,
points=points,
Expand All @@ -456,6 +488,7 @@ func msm{
sum_dlog_div=SumDlogDivHigh,
);

// Verify Q_high_shifted = 2^128 * Q_high
let (Q_high_shifted_is_zero) = zk_ecip_check(
curve_id=curve_id,
points=&Q_high,
Expand Down Expand Up @@ -489,6 +522,8 @@ func msm{
}
}

// Compute LHS of eq 3 in https://eprint.iacr.org/2022/596.pdf
// Uses sum_dlog_div = sum((-3)^j * Dlog(Dj) for j in [0, 81]).
func compute_LHS{
range_check_ptr, range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin*
}(
Expand Down

0 comments on commit e1ae840

Please sign in to comment.