Skip to content

Commit

Permalink
Upgrades new Rust ECIP code to PyO3 0.22
Browse files Browse the repository at this point in the history
  • Loading branch information
raugfer committed Aug 19, 2024
1 parent b1857a8 commit 3fe17ed
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 46 deletions.
65 changes: 20 additions & 45 deletions tools/garaga_rs/src/ecip/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,16 @@ use super::curve::CurveParamsProvider;
#[pyfunction]
pub fn zk_ecip_hint(
py: Python,
py_list_1: &PyList,
py_list_2: &PyList,
py_list_1: &Bound<'_, PyList>,
py_list_2: &Bound<'_, PyList>,
curve_id: usize,
) -> PyResult<PyObject> {
match curve_id {
0 => {
let list_bytes: Vec<&[u8]> = py_list_1
.into_iter()
.map(|x| x.extract())
.collect::<Result<Vec<&[u8]>, _>>()?;

let list_felts: Vec<FieldElement<BN254PrimeField>> = list_bytes
let list_felts: Vec<FieldElement<BN254PrimeField>> = py_list_1
.into_iter()
.map(|x| {
FieldElement::<BN254PrimeField>::from_bytes_be(x).map_err(|e| {
FieldElement::<BN254PrimeField>::from_bytes_be(x.extract()?).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Byte conversion error: {:?}",
e
Expand All @@ -51,15 +46,10 @@ pub fn zk_ecip_hint(
run_ecip::<BN254PrimeField>(py, points, scalars)
}
1 => {
let list_bytes: Vec<&[u8]> = py_list_1
.into_iter()
.map(|x| x.extract())
.collect::<Result<Vec<&[u8]>, _>>()?;

let list_felts: Vec<FieldElement<BLS12381PrimeField>> = list_bytes
let list_felts: Vec<FieldElement<BLS12381PrimeField>> = py_list_1
.into_iter()
.map(|x| {
FieldElement::<BLS12381PrimeField>::from_bytes_be(x).map_err(|e| {
FieldElement::<BLS12381PrimeField>::from_bytes_be(x.extract()?).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Byte conversion error: {:?}",
e
Expand All @@ -77,15 +67,10 @@ pub fn zk_ecip_hint(
run_ecip::<BLS12381PrimeField>(py, points, scalars)
}
2 => {
let list_bytes: Vec<&[u8]> = py_list_1
.into_iter()
.map(|x| x.extract())
.collect::<Result<Vec<&[u8]>, _>>()?;

let list_felts: Vec<FieldElement<SECP256K1PrimeField>> = list_bytes
let list_felts: Vec<FieldElement<SECP256K1PrimeField>> = py_list_1
.into_iter()
.map(|x| {
FieldElement::<SECP256K1PrimeField>::from_bytes_be(x).map_err(|e| {
FieldElement::<SECP256K1PrimeField>::from_bytes_be(x.extract()?).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Byte conversion error: {:?}",
e
Expand All @@ -103,15 +88,10 @@ pub fn zk_ecip_hint(
run_ecip::<SECP256K1PrimeField>(py, points, scalars)
}
3 => {
let list_bytes: Vec<&[u8]> = py_list_1
.into_iter()
.map(|x| x.extract())
.collect::<Result<Vec<&[u8]>, _>>()?;

let list_felts: Vec<FieldElement<SECP256R1PrimeField>> = list_bytes
let list_felts: Vec<FieldElement<SECP256R1PrimeField>> = py_list_1
.into_iter()
.map(|x| {
FieldElement::<SECP256R1PrimeField>::from_bytes_be(x).map_err(|e| {
FieldElement::<SECP256R1PrimeField>::from_bytes_be(x.extract()?).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Byte conversion error: {:?}",
e
Expand All @@ -129,15 +109,10 @@ pub fn zk_ecip_hint(
run_ecip::<SECP256R1PrimeField>(py, points, scalars)
}
4 => {
let list_bytes: Vec<&[u8]> = py_list_1
.into_iter()
.map(|x| x.extract())
.collect::<Result<Vec<&[u8]>, _>>()?;

let list_felts: Vec<FieldElement<X25519PrimeField>> = list_bytes
let list_felts: Vec<FieldElement<X25519PrimeField>> = py_list_1
.into_iter()
.map(|x| {
FieldElement::<X25519PrimeField>::from_bytes_be(x).map_err(|e| {
FieldElement::<X25519PrimeField>::from_bytes_be(x.extract()?).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Byte conversion error: {:?}",
e
Expand All @@ -161,12 +136,12 @@ pub fn zk_ecip_hint(
}

fn extract_scalars<F: IsPrimeField + CurveParamsProvider<F>>(
py_list: &PyList,
py_list: &Bound<'_, PyList>,
) -> Result<Vec<Vec<i8>>, PyErr> {
let mut dss_ = Vec::new();

for i in 0..py_list.len() {
let scalar_biguint: BigUint = py_list[i].extract()?;
let scalar_biguint: BigUint = py_list.get_item(i)?.extract()?;
let neg_3_digits = neg_3_base_le(scalar_biguint);
dss_.push(neg_3_digits);
}
Expand Down Expand Up @@ -246,15 +221,15 @@ where

// let sum_dlog = sum_dlog.simplify();

let q_tuple = PyList::new(
let q_tuple = PyList::new_bound(
py,
[
q.x.representative().to_string(),
q.y.representative().to_string(),
],
);

let a_num_list = PyList::new(
let a_num_list = PyList::new_bound(
py,
sum_dlog
.a
Expand All @@ -263,7 +238,7 @@ where
.iter()
.map(|c| c.representative().to_string()),
);
let a_den_list = PyList::new(
let a_den_list = PyList::new_bound(
py,
sum_dlog
.a
Expand All @@ -272,7 +247,7 @@ where
.iter()
.map(|c| c.representative().to_string()),
);
let b_num_list = PyList::new(
let b_num_list = PyList::new_bound(
py,
sum_dlog
.b
Expand All @@ -281,7 +256,7 @@ where
.iter()
.map(|c| c.representative().to_string()),
);
let b_den_list = PyList::new(
let b_den_list = PyList::new_bound(
py,
sum_dlog
.b
Expand All @@ -291,7 +266,7 @@ where
.map(|c| c.representative().to_string()),
);

let result_tuple = PyList::new(
let result_tuple = PyList::new_bound(
py,
[q_tuple, a_num_list, a_den_list, b_num_list, b_den_list],
);
Expand Down
2 changes: 1 addition & 1 deletion tools/garaga_rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use pyo3::{
{prelude::*, wrap_pyfunction},
};

use crate::ecip::core::__pyo3_get_function_zk_ecip_hint;
use crate::ecip::core::zk_ecip_hint;

#[pymodule]
fn garaga_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
Expand Down

0 comments on commit 3fe17ed

Please sign in to comment.