diff --git a/hydra/hints/multi_miller_witness.py b/hydra/hints/multi_miller_witness.py index aba4cd94..7fb9c379 100644 --- a/hydra/hints/multi_miller_witness.py +++ b/hydra/hints/multi_miller_witness.py @@ -1,15 +1,32 @@ import math +from hydra.algebra import PyFelt from hydra.definitions import CURVES, CurveID, G1Point, G2Point from hydra.hints.bls import get_root_and_scaling_factor_bls from hydra.hints.tower_backup import E12 from tools.gnark_cli import GnarkCLI - +import garaga_rs def get_final_exp_witness(curve_id: int, f: E12) -> tuple[E12, E12]: """ Returns the witness for the final exponentiation step. """ + if curve_id != CurveID.BN254.value and curve_id != CurveID.BLS12_381.value: + raise ValueError(f"Curve ID {curve_id} not supported") + curve = CURVES[curve_id] + byte_size = (curve.p.bit_length() + 7) // 8 + input_data = [v.to_bytes(byte_size, "big") for v in f.value_coeffs] + output_data = garaga_rs.get_final_exp_witness( + curve_id, + input_data[0], input_data[1], input_data[2], + input_data[3], input_data[4], input_data[5], + input_data[6], input_data[7], input_data[8], + input_data[9], input_data[10], input_data[11], + ) + result = [int.from_bytes(v, "big") for v in output_data] + c = E12([PyFelt(v, curve.p) for v in result[:12]], curve_id) + wi = E12([PyFelt(v, curve.p) for v in result[12:]], curve_id) + if curve_id == CurveID.BN254.value: c, wi = find_c_e12(f, get_27th_bn254_root()) return c, wi diff --git a/tools/garaga_rs/src/lib.rs b/tools/garaga_rs/src/lib.rs index e4cf8b0c..2a1ffa64 100644 --- a/tools/garaga_rs/src/lib.rs +++ b/tools/garaga_rs/src/lib.rs @@ -12,10 +12,65 @@ use pyo3::{ #[pymodule] fn garaga_rs(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(get_final_exp_witness, m)?)?; m.add_function(wrap_pyfunction!(hades_permutation, m)?)?; Ok(()) } +#[pyfunction] +fn get_final_exp_witness( + py: Python, + curve_id: usize, + py_value_1: &PyBytes, + py_value_2: &PyBytes, + py_value_3: &PyBytes, + py_value_4: &PyBytes, + py_value_5: &PyBytes, + py_value_6: &PyBytes, + py_value_7: &PyBytes, + py_value_8: &PyBytes, + py_value_9: &PyBytes, + py_value_10: &PyBytes, + py_value_11: &PyBytes, + py_value_12: &PyBytes, +) -> PyResult { + let byte_slice_1: &[u8] = py_value_1.as_bytes(); + let byte_slice_2: &[u8] = py_value_2.as_bytes(); + let byte_slice_3: &[u8] = py_value_3.as_bytes(); + let byte_slice_4: &[u8] = py_value_4.as_bytes(); + let byte_slice_5: &[u8] = py_value_5.as_bytes(); + let byte_slice_6: &[u8] = py_value_6.as_bytes(); + let byte_slice_7: &[u8] = py_value_7.as_bytes(); + let byte_slice_8: &[u8] = py_value_8.as_bytes(); + let byte_slice_9: &[u8] = py_value_9.as_bytes(); + let byte_slice_10: &[u8] = py_value_10.as_bytes(); + let byte_slice_11: &[u8] = py_value_11.as_bytes(); + let byte_slice_12: &[u8] = py_value_12.as_bytes(); + + if curve_id == 0 { + // c, wi = find_c_e12(f, get_27th_bn254_root()) + } + else + if curve_id == 1 { + // c, wi = get_root_and_scaling_factor_bls(f) + } + else { + panic!("Curve ID {} not supported", curve_id); + } + + let py_tuple = PyTuple::new( + py, + &[ + byte_slice_1, byte_slice_2, byte_slice_3, byte_slice_4, byte_slice_5, byte_slice_6, + byte_slice_7, byte_slice_8, byte_slice_9, byte_slice_10, byte_slice_11, byte_slice_12, + byte_slice_1, byte_slice_2, byte_slice_3, byte_slice_4, byte_slice_5, byte_slice_6, + byte_slice_7, byte_slice_8, byte_slice_9, byte_slice_10, byte_slice_11, byte_slice_12, + ] + ); + + Ok(py_tuple.into()) +} + #[pyfunction] fn hades_permutation( py: Python,