Skip to content

Commit

Permalink
Refactors Rust ECIP code to better isolate PyO3 interface from genera…
Browse files Browse the repository at this point in the history
…l implementation
  • Loading branch information
raugfer committed Aug 19, 2024
1 parent b6d731e commit 80c579e
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 83 deletions.
6 changes: 1 addition & 5 deletions hydra/garaga/hints/ecip.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,8 @@ def zk_ecip_hint(
if ec_group_class == G1Point and use_rust:
pts = []
c_id = Bs[0].curve_id
if c_id == CurveID.BLS12_381:
nb = 48
else:
nb = 32
for pt in Bs:
pts.extend([pt.x.to_bytes(nb, "big"), pt.y.to_bytes(nb, "big")])
pts.extend([pt.x, pt.y])
field_type = get_field_type_from_ec_point(Bs[0])
field = get_base_field(c_id.value, field_type)

Expand Down
134 changes: 58 additions & 76 deletions tools/garaga_rs/src/ecip/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,45 @@ use crate::ecip::rational_function::FunctionFelt;
use crate::ecip::rational_function::RationalFunction;

use num_bigint::{BigInt, BigUint, ToBigInt};
use pyo3::{prelude::*, types::PyList};

use super::curve::CurveParamsProvider;

#[pyfunction]
pub fn zk_ecip_hint(
py: Python,
py_list_1: &Bound<'_, PyList>,
py_list_2: &Bound<'_, PyList>,
list_values: Vec<BigUint>,
list_scalars: Vec<BigUint>,
curve_id: usize,
) -> PyResult<PyObject> {
) -> Result<[Vec<String>; 5], String> {
match curve_id {
0 => {
let list_felts: Vec<FieldElement<BN254PrimeField>> = py_list_1
let list_felts: Vec<FieldElement<BN254PrimeField>> = list_values
.into_iter()
.map(|x| {
FieldElement::<BN254PrimeField>::from_bytes_be(x.extract()?).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
FieldElement::<BN254PrimeField>::from_bytes_be(&x.to_bytes_be()).map_err(|e| {
format!(
"Byte conversion error: {:?}",
e
))
)
})
})
.collect::<Result<Vec<FieldElement<BN254PrimeField>>, _>>()?;
.collect::<Result<Vec<FieldElement<BN254PrimeField>>, String>>()?;

let points: Vec<G1Point<BN254PrimeField>> = list_felts
.chunks(2)
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
.collect();

let scalars: Vec<Vec<i8>> = extract_scalars::<BN254PrimeField>(py_list_2)?;
run_ecip::<BN254PrimeField>(py, points, scalars)
let scalars = extract_scalars::<BN254PrimeField>(list_scalars);
Ok(run_ecip::<BN254PrimeField>(points, scalars))
}
1 => {
let list_felts: Vec<FieldElement<BLS12381PrimeField>> = py_list_1
let list_felts: Vec<FieldElement<BLS12381PrimeField>> = list_values
.into_iter()
.map(|x| {
FieldElement::<BLS12381PrimeField>::from_bytes_be(x.extract()?).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
FieldElement::<BLS12381PrimeField>::from_bytes_be(&x.to_bytes_be()).map_err(|e| {
format!(
"Byte conversion error: {:?}",
e
))
)
})
})
.collect::<Result<Vec<FieldElement<BLS12381PrimeField>>, _>>()?;
Expand All @@ -63,18 +60,18 @@ pub fn zk_ecip_hint(
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
.collect();

let scalars: Vec<Vec<i8>> = extract_scalars::<BLS12381PrimeField>(py_list_2)?;
run_ecip::<BLS12381PrimeField>(py, points, scalars)
let scalars = extract_scalars::<BLS12381PrimeField>(list_scalars);
Ok(run_ecip::<BLS12381PrimeField>(points, scalars))
}
2 => {
let list_felts: Vec<FieldElement<SECP256K1PrimeField>> = py_list_1
let list_felts: Vec<FieldElement<SECP256K1PrimeField>> = list_values
.into_iter()
.map(|x| {
FieldElement::<SECP256K1PrimeField>::from_bytes_be(x.extract()?).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
FieldElement::<SECP256K1PrimeField>::from_bytes_be(&x.to_bytes_be()).map_err(|e| {
format!(
"Byte conversion error: {:?}",
e
))
)
})
})
.collect::<Result<Vec<FieldElement<SECP256K1PrimeField>>, _>>()?;
Expand All @@ -84,18 +81,18 @@ pub fn zk_ecip_hint(
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
.collect();

let scalars: Vec<Vec<i8>> = extract_scalars::<SECP256K1PrimeField>(py_list_2)?;
run_ecip::<SECP256K1PrimeField>(py, points, scalars)
let scalars = extract_scalars::<SECP256K1PrimeField>(list_scalars);
Ok(run_ecip::<SECP256K1PrimeField>(points, scalars))
}
3 => {
let list_felts: Vec<FieldElement<SECP256R1PrimeField>> = py_list_1
let list_felts: Vec<FieldElement<SECP256R1PrimeField>> = list_values
.into_iter()
.map(|x| {
FieldElement::<SECP256R1PrimeField>::from_bytes_be(x.extract()?).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
FieldElement::<SECP256R1PrimeField>::from_bytes_be(&x.to_bytes_be()).map_err(|e| {
format!(
"Byte conversion error: {:?}",
e
))
)
})
})
.collect::<Result<Vec<FieldElement<SECP256R1PrimeField>>, _>>()?;
Expand All @@ -105,18 +102,18 @@ pub fn zk_ecip_hint(
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
.collect();

let scalars: Vec<Vec<i8>> = extract_scalars::<SECP256R1PrimeField>(py_list_2)?;
run_ecip::<SECP256R1PrimeField>(py, points, scalars)
let scalars = extract_scalars::<SECP256R1PrimeField>(list_scalars);
Ok(run_ecip::<SECP256R1PrimeField>(points, scalars))
}
4 => {
let list_felts: Vec<FieldElement<X25519PrimeField>> = py_list_1
let list_felts: Vec<FieldElement<X25519PrimeField>> = list_values
.into_iter()
.map(|x| {
FieldElement::<X25519PrimeField>::from_bytes_be(x.extract()?).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
FieldElement::<X25519PrimeField>::from_bytes_be(&x.to_bytes_be()).map_err(|e| {
format!(
"Byte conversion error: {:?}",
e
))
)
})
})
.collect::<Result<Vec<FieldElement<X25519PrimeField>>, _>>()?;
Expand All @@ -126,22 +123,20 @@ pub fn zk_ecip_hint(
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
.collect();

let scalars: Vec<Vec<i8>> = extract_scalars::<X25519PrimeField>(py_list_2)?;
run_ecip::<X25519PrimeField>(py, points, scalars)
let scalars = extract_scalars::<X25519PrimeField>(list_scalars);
Ok(run_ecip::<X25519PrimeField>(points, scalars))
}
_ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Invalid curve ID",
)),
_ => Err(String::from("Invalid curve ID")),
}
}

fn extract_scalars<F: IsPrimeField + CurveParamsProvider<F>>(
py_list: &Bound<'_, PyList>,
) -> Result<Vec<Vec<i8>>, PyErr> {
list: Vec<BigUint>,
) -> Vec<Vec<i8>> {
let mut dss_ = Vec::new();

for i in 0..py_list.len() {
let scalar_biguint: BigUint = py_list.get_item(i)?.extract()?;
for i in 0..list.len() {
let scalar_biguint = list[i].clone();
let neg_3_digits = neg_3_base_le(scalar_biguint);
dss_.push(neg_3_digits);
}
Expand All @@ -160,7 +155,7 @@ fn extract_scalars<F: IsPrimeField + CurveParamsProvider<F>>(
dss.push(ds);
}

Ok(dss)
dss
}

fn neg_3_base_le(scalar: BigUint) -> Vec<i8> {
Expand Down Expand Up @@ -202,7 +197,7 @@ fn floor_division(a: BigInt, b: BigInt) -> BigInt {
}
}

fn run_ecip<F>(py: Python, points: Vec<G1Point<F>>, dss: Vec<Vec<i8>>) -> PyResult<PyObject>
fn run_ecip<F>(points: Vec<G1Point<F>>, dss: Vec<Vec<i8>>) -> [Vec<String>; 5]
where
F: IsPrimeField + CurveParamsProvider<F>,
{
Expand All @@ -221,57 +216,44 @@ where

// let sum_dlog = sum_dlog.simplify();

let q_tuple = PyList::new_bound(
py,
[
q.x.representative().to_string(),
q.y.representative().to_string(),
],
);

let a_num_list = PyList::new_bound(
py,
let q_tuple = vec![
q.x.representative().to_string(),
q.y.representative().to_string(),
];
let a_num_list =
sum_dlog
.a
.numerator
.coefficients
.iter()
.map(|c| c.representative().to_string()),
);
let a_den_list = PyList::new_bound(
py,
.map(|c| c.representative().to_string())
.collect();
let a_den_list =
sum_dlog
.a
.denominator
.coefficients
.iter()
.map(|c| c.representative().to_string()),
);
let b_num_list = PyList::new_bound(
py,
.map(|c| c.representative().to_string())
.collect();
let b_num_list =
sum_dlog
.b
.numerator
.coefficients
.iter()
.map(|c| c.representative().to_string()),
);
let b_den_list = PyList::new_bound(
py,
.map(|c| c.representative().to_string())
.collect();
let b_den_list =
sum_dlog
.b
.denominator
.coefficients
.iter()
.map(|c| c.representative().to_string()),
);

let result_tuple = PyList::new_bound(
py,
[q_tuple, a_num_list, a_den_list, b_num_list, b_den_list],
);
.map(|c| c.representative().to_string())
.collect();

Ok(result_tuple.into())
[q_tuple, a_num_list, a_den_list, b_num_list, b_den_list]
}

fn line<F: IsPrimeField + CurveParamsProvider<F>>(p: G1Point<F>, q: G1Point<F>) -> FF<F> {
Expand Down
30 changes: 28 additions & 2 deletions tools/garaga_rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ use pyo3::{
{prelude::*, wrap_pyfunction},
};

use crate::ecip::core::zk_ecip_hint;

#[pymodule]
fn garaga_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(g2_add, m)?)?;
Expand All @@ -31,6 +29,34 @@ fn garaga_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
Ok(())
}

#[pyfunction]
pub fn zk_ecip_hint(
py: Python,
py_list_1: &Bound<'_, PyList>,
py_list_2: &Bound<'_, PyList>,
curve_id: usize,
) -> PyResult<PyObject> {
let list_values = py_list_1
.into_iter()
.map(|x| x.extract())
.collect::<Result<Vec<BigUint>, _>>()?;

let list_scalars = py_list_2
.into_iter()
.map(|x| x.extract())
.collect::<Result<Vec<BigUint>, _>>()?;

let v = ecip::core::zk_ecip_hint(list_values, list_scalars, curve_id)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e))?;

let py_list = PyList::new_bound(
py,
v.into_iter().map(|x| PyList::new_bound(py, x)),
);

Ok(py_list.into())
}

const CURVE_BN254: usize = 0;
const CURVE_BLS12_381: usize = 1;

Expand Down

0 comments on commit 80c579e

Please sign in to comment.