From 68d65e0f33a247595528c49462503e834369a70e Mon Sep 17 00:00:00 2001 From: Constance Beguier Date: Sun, 7 Jul 2024 16:56:20 +0200 Subject: [PATCH] Refactor --- halo2_gadgets/src/ecc/chip/mul_fixed/short.rs | 3 +- halo2_gadgets/src/sinsemilla.rs | 2 +- .../src/sinsemilla/chip/hash_to_point.rs | 156 +++++++++--------- halo2_gadgets/src/sinsemilla/merkle.rs | 9 +- .../src/utilities/lookup_range_check.rs | 57 +++++-- 5 files changed, 124 insertions(+), 103 deletions(-) diff --git a/halo2_gadgets/src/ecc/chip/mul_fixed/short.rs b/halo2_gadgets/src/ecc/chip/mul_fixed/short.rs index 5f801fdf3..d5ea2f07e 100644 --- a/halo2_gadgets/src/ecc/chip/mul_fixed/short.rs +++ b/halo2_gadgets/src/ecc/chip/mul_fixed/short.rs @@ -588,7 +588,7 @@ pub mod tests { fn test_invalid_magnitude_sign() { // Magnitude larger than 64 bits should fail { - let circuits: Vec> = vec![ + let circuits = [ // 2^64 MyMagnitudeSignCircuit:: { magnitude: Value::known(pallas::Base::from_u128(1 << 64)), @@ -742,7 +742,6 @@ pub mod tests { } } - #[cfg(feature = "test-dev-graph")] #[test] fn invalid_magnitude_sign() { MyMagnitudeSignCircuit::::test_invalid_magnitude_sign(); diff --git a/halo2_gadgets/src/sinsemilla.rs b/halo2_gadgets/src/sinsemilla.rs index d6dd82e4a..6f56e35da 100644 --- a/halo2_gadgets/src/sinsemilla.rs +++ b/halo2_gadgets/src/sinsemilla.rs @@ -835,7 +835,7 @@ pub(crate) mod tests { } #[test] - fn test_against_stored_sinsemilla_chip() { + fn test_sinsemilla_chip_against_stored_circuit() { let circuit: MyCircuit = MyCircuit { _lookup_marker: PhantomData, }; diff --git a/halo2_gadgets/src/sinsemilla/chip/hash_to_point.rs b/halo2_gadgets/src/sinsemilla/chip/hash_to_point.rs index e6ce0e9d4..ac81a5e48 100644 --- a/halo2_gadgets/src/sinsemilla/chip/hash_to_point.rs +++ b/halo2_gadgets/src/sinsemilla/chip/hash_to_point.rs @@ -50,7 +50,7 @@ where ), Error, > { - let (offset, x_a, y_a) = self.public_initialization(region, Q)?; + let (offset, x_a, y_a) = self.public_q_initialization(region, Q)?; let (x_a, y_a, zs_sum) = self.hash_all_pieces(region, offset, message, x_a, y_a)?; @@ -80,87 +80,13 @@ where return Err(Error::HashFromPrivatePoint); } - let (offset, x_a, y_a) = self.private_initialization(region, Q)?; + let (offset, x_a, y_a) = self.private_q_initialization(region, Q)?; let (x_a, y_a, zs_sum) = self.hash_all_pieces(region, offset, message, x_a, y_a)?; self.check_hash_result(EccPointQ::PrivatePoint(Q), message, x_a, y_a, zs_sum) } - #[allow(unused_variables)] - #[allow(non_snake_case)] - #[allow(clippy::type_complexity)] - fn check_hash_result( - &self, - Q: EccPointQ, - message: &>::Message, - x_a: X, - y_a: AssignedCell, pallas::Base>, - zs_sum: Vec>>, - ) -> Result< - ( - NonIdentityEccPoint, - Vec>>, - ), - Error, - > { - #[cfg(test)] - // Check equivalence to result from primitives::sinsemilla::hash_to_point - { - use crate::sinsemilla::primitives::{K, S_PERSONALIZATION}; - - use group::{prime::PrimeCurveAffine, Curve}; - use pasta_curves::arithmetic::CurveExt; - - let field_elems: Value> = message - .iter() - .map(|piece| piece.field_elem().map(|elem| (elem, piece.num_words()))) - .collect(); - - let value_Q = match Q { - EccPointQ::PublicPoint(p) => Value::known(p), - EccPointQ::PrivatePoint(p) => p.point(), - }; - - field_elems - .zip(x_a.value().zip(y_a.value())) - .zip(value_Q) - .assert_if_known(|((field_elems, (x_a, y_a)), value_Q)| { - // Get message as a bitstring. - let bitstring: Vec = field_elems - .iter() - .flat_map(|(elem, num_words)| { - elem.to_le_bits().into_iter().take(K * num_words) - }) - .collect(); - - let hasher_S = pallas::Point::hash_to_curve(S_PERSONALIZATION); - let S = |chunk: &[bool]| hasher_S(&lebs2ip_k(chunk).to_le_bytes()); - - // We can use complete addition here because it differs from - // incomplete addition with negligible probability. - let expected_point = bitstring - .chunks(K) - .fold(value_Q.to_curve(), |acc, chunk| (acc + S(chunk)) + acc); - let actual_point = - pallas::Affine::from_xy(x_a.evaluate(), y_a.evaluate()).unwrap(); - expected_point.to_affine() == actual_point - }); - } - - x_a.value() - .zip(y_a.value()) - .error_if_known_and(|(x_a, y_a)| x_a.is_zero_vartime() || y_a.is_zero_vartime())?; - Ok(( - NonIdentityEccPoint::from_coordinates_unchecked(x_a.0, y_a), - zs_sum, - )) - } - #[allow(non_snake_case)] /// Assign the coordinates of the initial public point `Q`. /// @@ -174,7 +100,7 @@ where /// -------------------------------------- /// | 0 | | y_Q | | /// | 1 | x_Q | | 1 | - fn public_initialization( + fn public_q_initialization( &self, region: &mut Region<'_, pallas::Base>, Q: pallas::Affine, @@ -235,7 +161,7 @@ where /// -------------------------------------- /// | 0 | | y_Q | | /// | 1 | x_Q | | 1 | - fn private_initialization( + fn private_q_initialization( &self, region: &mut Region<'_, pallas::Base>, Q: &NonIdentityEccPoint, @@ -545,6 +471,80 @@ where Ok((x_a, y_a, zs)) } + + #[allow(unused_variables)] + #[allow(non_snake_case)] + #[allow(clippy::type_complexity)] + fn check_hash_result( + &self, + Q: EccPointQ, + message: &>::Message, + x_a: X, + y_a: AssignedCell, pallas::Base>, + zs_sum: Vec>>, + ) -> Result< + ( + NonIdentityEccPoint, + Vec>>, + ), + Error, + > { + #[cfg(test)] + // Check equivalence to result from primitives::sinsemilla::hash_to_point + { + use crate::sinsemilla::primitives::{K, S_PERSONALIZATION}; + + use group::{prime::PrimeCurveAffine, Curve}; + use pasta_curves::arithmetic::CurveExt; + + let field_elems: Value> = message + .iter() + .map(|piece| piece.field_elem().map(|elem| (elem, piece.num_words()))) + .collect(); + + let value_Q = match Q { + EccPointQ::PublicPoint(p) => Value::known(p), + EccPointQ::PrivatePoint(p) => p.point(), + }; + + field_elems + .zip(x_a.value().zip(y_a.value())) + .zip(value_Q) + .assert_if_known(|((field_elems, (x_a, y_a)), value_Q)| { + // Get message as a bitstring. + let bitstring: Vec = field_elems + .iter() + .flat_map(|(elem, num_words)| { + elem.to_le_bits().into_iter().take(K * num_words) + }) + .collect(); + + let hasher_S = pallas::Point::hash_to_curve(S_PERSONALIZATION); + let S = |chunk: &[bool]| hasher_S(&lebs2ip_k(chunk).to_le_bytes()); + + // We can use complete addition here because it differs from + // incomplete addition with negligible probability. + let expected_point = bitstring + .chunks(K) + .fold(value_Q.to_curve(), |acc, chunk| (acc + S(chunk)) + acc); + let actual_point = + pallas::Affine::from_xy(x_a.evaluate(), y_a.evaluate()).unwrap(); + expected_point.to_affine() == actual_point + }); + } + + x_a.value() + .zip(y_a.value()) + .error_if_known_and(|(x_a, y_a)| x_a.is_zero_vartime() || y_a.is_zero_vartime())?; + Ok(( + NonIdentityEccPoint::from_coordinates_unchecked(x_a.0, y_a), + zs_sum, + )) + } } /// The x-coordinate of the accumulator in a Sinsemilla hash instance. diff --git a/halo2_gadgets/src/sinsemilla/merkle.rs b/halo2_gadgets/src/sinsemilla/merkle.rs index 27b00cb97..adf1f02b3 100644 --- a/halo2_gadgets/src/sinsemilla/merkle.rs +++ b/halo2_gadgets/src/sinsemilla/merkle.rs @@ -187,7 +187,10 @@ pub mod tests { tests::test_utils::test_against_stored_circuit, utilities::{ i2lebsp, - lookup_range_check::{PallasLookupRangeCheck45BConfig, PallasLookupRangeCheckConfig}, + lookup_range_check::{ + PallasLookupRangeCheck, PallasLookupRangeCheck45BConfig, + PallasLookupRangeCheckConfig, + }, UtilitiesInstructions, }, }; @@ -200,10 +203,8 @@ pub mod tests { plonk::{Circuit, ConstraintSystem, Error}, }; - use crate::utilities::lookup_range_check::PallasLookupRangeCheck; use rand::{rngs::OsRng, RngCore}; - use std::marker::PhantomData; - use std::{convert::TryInto, iter}; + use std::{convert::TryInto, iter, marker::PhantomData}; const MERKLE_DEPTH: usize = 32; diff --git a/halo2_gadgets/src/utilities/lookup_range_check.rs b/halo2_gadgets/src/utilities/lookup_range_check.rs index a7dde6f2e..83ed67ee3 100644 --- a/halo2_gadgets/src/utilities/lookup_range_check.rs +++ b/halo2_gadgets/src/utilities/lookup_range_check.rs @@ -34,8 +34,8 @@ impl RangeConstrained> { /// # Panics /// /// Panics if `bitrange.len() >= K`. - pub fn witness_short>( - lookup_config: &L, + pub fn witness_short>( + lookup_config: &Lookup, layouter: impl Layouter, value: Value<&F>, bitrange: Range, @@ -817,6 +817,16 @@ mod tests { let prover = MockProver::::run(11, &circuit, vec![]).unwrap(); assert_eq!(prover.verify(), Ok(())); + } + + #[test] + fn test_lookup_range_check_against_stored_circuit() { + let circuit: MyLookupCircuit = + MyLookupCircuit { + num_words: 6, + _field_marker: PhantomData, + _lookup_marker: PhantomData, + }; test_against_stored_circuit(circuit, "lookup_range_check", 1888); } @@ -832,7 +842,16 @@ mod tests { let prover = MockProver::::run(11, &circuit, vec![]).unwrap(); assert_eq!(prover.verify(), Ok(())); + } + #[test] + fn test_lookup_range_check_against_stored_circuit_4_5_b() { + let circuit: MyLookupCircuit = + MyLookupCircuit { + num_words: 6, + _field_marker: PhantomData, + _lookup_marker: PhantomData, + }; test_against_stored_circuit(circuit, "lookup_range_check_4_5_b", 2048); } @@ -907,6 +926,9 @@ mod tests { #[test] fn short_range_check() { + let proof_size_10_bits = 1888; + let proof_size_4_5_10_bits = 2048; + // Edge case: zero bits (case 0) let element = pallas::Base::ZERO; let num_bits = 0; @@ -915,14 +937,14 @@ mod tests { num_bits, &Ok(()), "short_range_check_case0", - 1888, + proof_size_10_bits, ); test_short_range_check::( element, num_bits, &Ok(()), "short_range_check_4_5_b_case0", - 2048, + proof_size_4_5_10_bits, ); // Edge case: K bits (case 1) @@ -933,14 +955,14 @@ mod tests { num_bits, &Ok(()), "short_range_check_case1", - 1888, + proof_size_10_bits, ); test_short_range_check::( element, num_bits, &Ok(()), "short_range_check_4_5_b_case1", - 2048, + proof_size_4_5_10_bits, ); // Element within `num_bits` (case 2) @@ -951,14 +973,14 @@ mod tests { num_bits, &Ok(()), "short_range_check_case2", - 1888, + proof_size_10_bits, ); test_short_range_check::( element, num_bits, &Ok(()), "short_range_check_4_5_b_case2", - 2048, + proof_size_4_5_10_bits, ); // Element larger than `num_bits` but within K bits @@ -976,14 +998,14 @@ mod tests { num_bits, &error, "not_saved", - 0, + proof_size_10_bits, ); test_short_range_check::( element, num_bits, &error, "not_saved", - 0, + proof_size_4_5_10_bits, ); // Element larger than K bits @@ -1010,18 +1032,17 @@ mod tests { num_bits, &error, "not_saved", - 0, + proof_size_10_bits, ); test_short_range_check::( element, num_bits, &error, "not_saved", - 0, + proof_size_4_5_10_bits, ); - // Element which is not within `num_bits`, but which has a shifted value within - // num_bits + // Element which is not within `num_bits`, but which has a shifted value within num_bits let num_bits = 6; let shifted = pallas::Base::from((1 << num_bits) - 1); // Recall that shifted = element * 2^{K-s} @@ -1042,14 +1063,14 @@ mod tests { num_bits as usize, &error, "not_saved", - 0, + proof_size_10_bits, ); test_short_range_check::( element, num_bits as usize, &error, "not_saved", - 0, + proof_size_4_5_10_bits, ); // Element within 4 bits @@ -1058,7 +1079,7 @@ mod tests { 4, &Ok(()), "short_range_check_4_5_b_case3", - 2048, + proof_size_4_5_10_bits, ); // Element larger than 5 bits @@ -1073,7 +1094,7 @@ mod tests { }, }]), "not_saved", - 0, + proof_size_4_5_10_bits, ); } }