diff --git a/src/elligator.rs b/src/elligator.rs index 8ee4771..c4f1f58 100644 --- a/src/elligator.rs +++ b/src/elligator.rs @@ -5,7 +5,8 @@ use crate::element::{Decaf377EdwardsConfig, EdwardsProjective}; use crate::{ constants::{ONE, TWO, ZETA}, - Element, Fq, OnCurve, Sign, SqrtRatioZeta, + sign::Sign, + Element, Fq, OnCurve, }; impl Element { @@ -68,20 +69,10 @@ impl Element { &R_1 + &R_2 } - #[deprecated(note = "please use `hash_to_curve` instead")] - pub fn map_to_group_uniform(r_1: &Fq, r_2: &Fq) -> Element { - Element::hash_to_curve(r_1, r_2) - } - /// Maps a field element to a decaf377 `Element` suitable for CDH challenges. pub fn encode_to_curve(r: &Fq) -> Element { Element::elligator_map(r) } - - #[deprecated(note = "please use `encode_to_curve` instead")] - pub fn map_to_group_cdh(r: &Fq) -> Element { - Element::encode_to_curve(r) - } } #[cfg(test)] diff --git a/src/encoding.rs b/src/encoding.rs index dcef08a..925934f 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -6,8 +6,8 @@ use ark_ec::twisted_edwards::TECurveConfig; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Read, Write}; use crate::{ - constants::TWO, element::Decaf377EdwardsConfig, EdwardsProjective, Element, EncodingError, Fq, - OnCurve, Sign, SqrtRatioZeta, + constants::TWO, element::Decaf377EdwardsConfig, sign::Sign, EdwardsProjective, Element, + EncodingError, Fq, OnCurve, }; #[derive(Copy, Clone, Default, Eq, Ord, PartialOrd, PartialEq)] @@ -87,11 +87,6 @@ impl Element { Element { inner: -self.inner } } - #[deprecated(note = "please use `vartime_compress_to_field` instead")] - pub fn compress_to_field(&self) -> Fq { - self.vartime_compress_to_field() - } - pub fn vartime_compress_to_field(&self) -> Fq { // This isn't a constant, only because traits don't have const methods // yet and subtraction is only implemented as part of the Sub trait. @@ -118,11 +113,6 @@ impl Element { s } - #[deprecated(note = "please use `vartime_compress` instead")] - pub fn compress(&self) -> Encoding { - self.vartime_compress() - } - pub fn vartime_compress(&self) -> Encoding { let s = self.vartime_compress_to_field(); diff --git a/src/fields/fq/u32/wrapper.rs b/src/fields/fq/u32/wrapper.rs index 9534264..40fe9f5 100644 --- a/src/fields/fq/u32/wrapper.rs +++ b/src/fields/fq/u32/wrapper.rs @@ -1,4 +1,4 @@ -use subtle::ConditionallySelectable; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; use super::fiat; @@ -210,7 +210,7 @@ impl Fq { } impl ConditionallySelectable for Fq { - fn conditional_select(a: &Self, b: &Self, choice: subtle::Choice) -> Self { + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { let mut out = [0u32; 8]; for i in 0..8 { out[i] = u32::conditional_select(&a.0 .0[i], &b.0 .0[i], choice); @@ -218,3 +218,9 @@ impl ConditionallySelectable for Fq { Self(fiat::FqMontgomeryDomainFieldElement(out)) } } + +impl ConstantTimeEq for Fq { + fn ct_eq(&self, other: &Fq) -> Choice { + self.0 .0.ct_eq(&other.0 .0) + } +} diff --git a/src/fields/fq/u64/arkworks_constants.rs b/src/fields/fq/u64/arkworks_constants.rs index b5784bd..6a139be 100644 --- a/src/fields/fq/u64/arkworks_constants.rs +++ b/src/fields/fq/u64/arkworks_constants.rs @@ -30,6 +30,7 @@ pub const TRACE_MINUS_ONE_DIV_TWO_LIMBS: [u64; 4] = [ 4779, ]; +// c1 pub const TWO_ADICITY: u32 = 0x2f; pub const QUADRATIC_NON_RESIDUE_TO_TRACE: Fq = Fq::from_montgomery_limbs([ diff --git a/src/fields/fq/u64/wrapper.rs b/src/fields/fq/u64/wrapper.rs index c46cf68..a439720 100644 --- a/src/fields/fq/u64/wrapper.rs +++ b/src/fields/fq/u64/wrapper.rs @@ -1,4 +1,4 @@ -use subtle::ConditionallySelectable; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; use super::fiat; @@ -198,3 +198,9 @@ impl ConditionallySelectable for Fq { Self(fiat::FqMontgomeryDomainFieldElement(out)) } } + +impl ConstantTimeEq for Fq { + fn ct_eq(&self, other: &Fq) -> Choice { + self.0 .0.ct_eq(&other.0 .0) + } +} diff --git a/src/invsqrt.rs b/src/invsqrt.rs index 3ce70fb..dc1af38 100644 --- a/src/invsqrt.rs +++ b/src/invsqrt.rs @@ -9,16 +9,6 @@ use once_cell::sync::Lazy; use crate::constants::{G, M_MINUS_ONE_DIV_TWO, N, ONE, SQRT_W, ZETA_TO_ONE_MINUS_M_DIV_TWO}; -pub trait SqrtRatioZeta: Sized { - /// Computes the square root of a ratio of field elements, returning: - /// - /// - `(true, sqrt(num/den))` if `num` and `den` are both nonzero and `num/den` is square; - /// - `(true, 0)` if `num` is zero; - /// - `(false, 0)` if `den` is zero; - /// - `(false, sqrt(zeta*num/den))` if `num` and `den` are both nonzero and `num/den` is nonsquare; - fn sqrt_ratio_zeta(num: &Self, den: &Self) -> (bool, Self); -} - struct SquareRootTables { pub s_lookup: HashMap, pub nonsquare_lookup: [Fq; 2], @@ -73,8 +63,14 @@ impl SquareRootTables { static SQRT_LOOKUP_TABLES: Lazy = Lazy::new(|| SquareRootTables::new()); -impl SqrtRatioZeta for Fq { - fn sqrt_ratio_zeta(num: &Self, den: &Self) -> (bool, Self) { +impl Fq { + /// Computes the square root of a ratio of field elements, returning: + /// + /// - `(true, sqrt(num/den))` if `num` and `den` are both nonzero and `num/den` is square; + /// - `(true, 0)` if `num` is zero; + /// - `(false, 0)` if `den` is zero; + /// - `(false, sqrt(zeta*num/den))` if `num` and `den` are both nonzero and `num/den` is nonsquare; + pub fn sqrt_ratio_zeta(num: &Self, den: &Self) -> (bool, Self) { // This square root method is based on: // * [Sarkar2020](https://eprint.iacr.org/2020/1407) // * [Zcash Pasta](https://github.com/zcash/pasta_curves) diff --git a/src/lib.rs b/src/lib.rs index 5c03b3f..70db8b5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,11 @@ pub mod fields; pub mod smol_curve; pub use fields::{fp::Fp, fq::Fq, fr::Fr}; +mod sign; + +mod on_curve; +use on_curve::OnCurve; + cfg_if! { if #[cfg(feature = "arkworks")] { pub mod bls12_377; @@ -18,11 +23,9 @@ cfg_if! { mod error; mod field_ext; mod invsqrt; - mod on_curve; mod ops; pub mod rand; pub mod serialize; - mod sign; pub use constants::ZETA; pub use element::{AffineElement, Element}; @@ -36,9 +39,6 @@ cfg_if! { pub use bls12_377::Bls12_377; - pub use invsqrt::SqrtRatioZeta; - use on_curve::OnCurve; - use sign::Sign; /// Return the conventional generator for `decaf377`. pub fn basepoint() -> Element { diff --git a/src/on_curve.rs b/src/on_curve.rs index c37f76d..560489f 100644 --- a/src/on_curve.rs +++ b/src/on_curve.rs @@ -1,16 +1,22 @@ -use ark_ec::{ - models::{twisted_edwards::Projective, twisted_edwards::TECurveConfig}, - Group, -}; -use ark_ff::{BigInteger, Field, PrimeField, Zero}; -use ark_serialize::CanonicalSerialize; +use cfg_if::cfg_if; -use crate::constants; +cfg_if! { + if #[cfg(feature = "arkworks")] { + use ark_ec::{ + models::{twisted_edwards::Projective, twisted_edwards::TECurveConfig}, + Group, + }; + use ark_ff::{BigInteger, Field, PrimeField, Zero}; + use ark_serialize::CanonicalSerialize; + use crate::constants; + } +} pub trait OnCurve { fn is_on_curve(&self) -> bool; } +#[cfg(feature = "arkworks")] impl OnCurve for Projective

{ #[allow(non_snake_case)] fn is_on_curve(&self) -> bool { diff --git a/src/r1cs/fqvar_ext.rs b/src/r1cs/fqvar_ext.rs index 21cfc1a..f2bf009 100644 --- a/src/r1cs/fqvar_ext.rs +++ b/src/r1cs/fqvar_ext.rs @@ -4,7 +4,7 @@ use ark_r1cs_std::select::CondSelectGadget; use ark_r1cs_std::{R1CSVar, ToBitsGadget}; use ark_relations::r1cs::SynthesisError; -use crate::{constants::ZETA, r1cs::FqVar, Fq, SqrtRatioZeta}; +use crate::{constants::ZETA, r1cs::FqVar, Fq}; pub trait FqVarExtension: Sized { fn isqrt(&self) -> Result<(Boolean, FqVar), SynthesisError>; diff --git a/src/sign.rs b/src/sign.rs index ff801e2..b120586 100644 --- a/src/sign.rs +++ b/src/sign.rs @@ -18,10 +18,6 @@ pub trait Sign: core::ops::Neg + Sized { impl Sign for Fq { fn is_nonnegative(&self) -> bool { - use ark_serialize::CanonicalSerialize; - let mut bytes = [0u8; 32]; - self.serialize_compressed(&mut bytes[..]) - .expect("serialization into array should be infallible"); - bytes[0] & 1 == 0 + (self.to_le_limbs()[0] & 1) == 0 } } diff --git a/src/smol_curve/constants.rs b/src/smol_curve/constants.rs new file mode 100644 index 0000000..2193b74 --- /dev/null +++ b/src/smol_curve/constants.rs @@ -0,0 +1,39 @@ +use crate::Fq; + +pub const ZETA: Fq = Fq::from_montgomery_limbs_64([ + 5947794125541564500, + 11292571455564096885, + 11814268415718120036, + 155746270000486182, +]); + +pub const ZETA_TO_TRACE: Fq = Fq::from_montgomery_limbs_64([ + 6282505393754313363, + 14378628227555923904, + 9804873068900332207, + 302335131180501866, +]); + +/// COEFF_A = -1 +pub const COEFF_A: Fq = Fq::from_montgomery_limbs_64([ + 10157024534604021774, + 16668528035959406606, + 5322190058819395602, + 387181115924875961, +]); + +/// COEFF_D = 3021 +pub const COEFF_D: Fq = Fq::from_montgomery_limbs_64([ + 15008245758212136496, + 17341409599856531410, + 648869460136961410, + 719771289660577536, +]); + +/// -2 COEFF_D / COEFF_A = 6042 +pub const COEFF_K: Fq = Fq::from_montgomery_limbs_64([ + 10844245690243005535, + 9774967673803681700, + 12776203677742963460, + 94262208632981673, +]); diff --git a/src/smol_curve/element.rs b/src/smol_curve/element.rs index e195ce4..e5fd18d 100644 --- a/src/smol_curve/element.rs +++ b/src/smol_curve/element.rs @@ -1,31 +1,12 @@ use core::ops::{Add, Neg}; use subtle::{Choice, ConditionallySelectable}; -use crate::Fq; - -/// COEFF_A = -1 -const COEFF_A: Fq = Fq::from_montgomery_limbs_64([ - 10157024534604021774, - 16668528035959406606, - 5322190058819395602, - 387181115924875961, -]); - -/// COEFF_D = 3021 -const COEFF_D: Fq = Fq::from_montgomery_limbs_64([ - 15008245758212136496, - 17341409599856531410, - 648869460136961410, - 719771289660577536, -]); - -/// -2 COEFF_D / COEFF_A = 6042 -const COEFF_K: Fq = Fq::from_montgomery_limbs_64([ - 10844245690243005535, - 9774967673803681700, - 12776203677742963460, - 94262208632981673, -]); +use crate::{sign::Sign, smol_curve::constants::*, smol_curve::encoding::Encoding, Fq}; + +/// Error type for decompression +pub enum EncodingError { + InvalidEncoding, +} /// A point on an Edwards curve. /// @@ -103,6 +84,35 @@ impl Element { ]), }; + /// Construct a new element from the projective coordinates, checking on curve. + fn new_checked(x: Fq, y: Fq, z: Fq, t: Fq) -> Option { + let XX = x.square(); + let YY = y.square(); + let ZZ = z.square(); + let TT = t.square(); + + let on_curve = (YY + COEFF_A * XX) == (ZZ + COEFF_D * TT); + if on_curve { + Some(Self { x, y, z, t }) + } else { + None + } + } + + fn from_affine(x: Fq, y: Fq) -> Self { + let z = Fq::one(); + let t = x * y; + Self::new(x, y, z, t) + } + + fn new(x: Fq, y: Fq, z: Fq, t: Fq) -> Self { + if cfg!(debug_assertions) { + Element::new_checked(x, y, z, t).expect("decompression should be on curve") + } else { + Element { x, y, z, t } + } + } + pub fn double(self) -> Self { // https://eprint.iacr.org/2008/522 Section 3.3 let a = self.x.square(); @@ -119,12 +129,7 @@ impl Element { let y3 = g * h; let t3 = e * h; let z3 = f * g; - Self { - x: x3, - y: y3, - z: z3, - t: t3, - } + Self::new(x3, y3, z3, t3) } fn scalar_mul_both(self, le_bits: &[u64]) -> Self { @@ -151,6 +156,134 @@ impl Element { pub fn scalar_mul(self, le_bits: &[u64]) -> Self { Self::scalar_mul_both::(self, le_bits) } + + pub fn vartime_compress_to_field(&self) -> Fq { + let A_MINUS_D = COEFF_A - COEFF_D; + + // 1. + let u_1 = (self.x + self.t) * (self.x - self.t); + + // 2. + let (_always_square, v) = + Fq::non_arkworks_sqrt_ratio_zeta(&Fq::one(), &(u_1 * A_MINUS_D * self.x.square())); + + // 3. + let u_2 = (v * u_1).abs(); + + // 4. + let u_3 = u_2 * self.z - self.t; + + // 5. + (A_MINUS_D * v * u_3 * self.x).abs() + } + + pub fn vartime_compress(&self) -> Encoding { + let s = self.vartime_compress_to_field(); + let bytes = s.to_bytes_le(); + Encoding(bytes) + } + + /// Elligator 2 map to decaf377 point + fn elligator_map(r_0: &Fq) -> Self { + // Ref: `Decaf_1_1_Point.elligator` (optimized) in `ristretto.sage` + const A: Fq = COEFF_A; + const D: Fq = COEFF_D; + + let r = ZETA * r_0.square(); + + let den = (D * r - (D - A)) * ((D - A) * r - D); + let num = (r + Fq::one()) * (A - (Fq::one() + Fq::one()) * D); + + let x = num * den; + let (iss, mut isri) = Fq::non_arkworks_sqrt_ratio_zeta(&Fq::one(), &x); + + let sgn; + let twiddle; + if iss { + sgn = Fq::one(); + twiddle = Fq::one(); + } else { + sgn = -(Fq::one()); + twiddle = *r_0; + } + + isri *= twiddle; + + let mut s = isri * num; + let t = -(sgn) * isri * s * (r - Fq::one()) * (A - (Fq::one() + Fq::one()) * D).square() + - Fq::one(); + + if s.is_negative() == iss { + s = -s + } + + // Convert point to extended projective (X : Y : Z : T) + let E = (Fq::one() + Fq::one()) * s; + let F = Fq::one() + A * s.square(); + let G = Fq::one() - A * s.square(); + let H = t; + + Self::new(E * H, F * G, F * H, E * G) + } + + /// Maps two field elements to a uniformly distributed decaf377 `Element`. + /// + /// The two field elements provided as inputs should be independently chosen. + pub fn hash_to_curve(r_1: &Fq, r_2: &Fq) -> Element { + let R_1 = Element::elligator_map(r_1); + let R_2 = Element::elligator_map(r_2); + &R_1 + &R_2 + } + + /// Maps a field element to a decaf377 `Element` suitable for CDH challenges. + pub fn encode_to_curve(r: &Fq) -> Element { + Element::elligator_map(r) + } +} + +impl Encoding { + pub fn vartime_decompress(&self) -> Result { + // Top three bits of last byte must be zero + if self.0[31] >> 5 != 0u8 { + return Err(EncodingError::InvalidEncoding); + } + + // 1/2. Reject unless s is canonically encoded and nonnegative. + // Check bytes correspond to valid field element (i.e. less than field modulus) + let s = Fq::from_bytes_checked(&self.0).ok_or(EncodingError::InvalidEncoding)?; + if s.is_negative() { + return Err(EncodingError::InvalidEncoding); + } + + // 3. u_1 <- 1 - s^2 + let ss = s.square(); + let u_1 = Fq::one() - ss; + + // 4. u_2 <- u_1^2 - 4d s^2 + let u_2 = u_1.square() - (Fq::from(4u32) * COEFF_D) * ss; + + // 5. sqrt + let (was_square, mut v) = + Fq::non_arkworks_sqrt_ratio_zeta(&Fq::one(), &(u_2 * u_1.square())); + if !was_square { + return Err(EncodingError::InvalidEncoding); + } + + // 6. sign check + let two_s_u_1 = (Fq::one() + Fq::one()) * s * u_1; + let check = two_s_u_1 * v; + if check.is_negative() { + v = -v; + } + + // 7. coordinates + let x = two_s_u_1 * v.square() * u_2; + let y = (Fq::one() + ss) * v * u_1; + let z = Fq::one(); + let t = x * y; + + Ok(Element::new(x, y, z, t)) + } } impl Add for Element { @@ -182,12 +315,7 @@ impl Add for Element { let y3 = g * h; let t3 = e * h; let z3 = f * g; - Self { - x: x3, - y: y3, - z: z3, - t: t3, - } + Self::new(x3, y3, z3, t3) } } @@ -296,4 +424,138 @@ mod proptests { assert_eq!(G * (a * b), (G * a) * b); } } + + proptest! { + #[test] + fn group_encoding_round_trip_if_successful(bytes: [u8; 32]) { + let bytes = Encoding(bytes); + + if let Ok(element) = bytes.vartime_decompress() { + let bytes2 = element.vartime_compress(); + assert_eq!(bytes, bytes2); + } + } + } + + #[test] + fn test_elligator() { + // These are the test cases from testElligatorDeterministic in ristretto.sage + let inputs = [ + [ + 221, 101, 215, 58, 170, 229, 36, 124, 172, 234, 94, 214, 186, 163, 242, 30, 65, + 123, 76, 74, 56, 60, 24, 213, 240, 137, 49, 189, 138, 39, 90, 6, + ], + [ + 23, 203, 214, 51, 26, 149, 7, 160, 228, 239, 208, 147, 124, 109, 75, 72, 64, 16, + 64, 215, 53, 185, 249, 168, 188, 49, 22, 194, 118, 7, 242, 16, + ], + [ + 177, 123, 90, 180, 115, 7, 108, 183, 161, 167, 24, 15, 248, 218, 206, 227, 76, 137, + 162, 187, 148, 174, 66, 44, 205, 1, 211, 91, 140, 50, 144, 1, + ], + [ + 204, 225, 121, 228, 145, 30, 86, 208, 132, 242, 203, 9, 153, 90, 195, 150, 215, 49, + 166, 70, 78, 68, 47, 98, 30, 130, 115, 139, 168, 242, 238, 8, + ], + [ + 59, 150, 40, 159, 229, 96, 201, 47, 170, 163, 9, 208, 205, 201, 112, 241, 179, 82, + 198, 79, 207, 160, 184, 245, 63, 189, 101, 115, 217, 228, 74, 13, + ], + [ + 74, 159, 227, 190, 73, 213, 131, 200, 50, 102, 249, 230, 48, 103, 85, 168, 239, + 149, 7, 164, 12, 42, 217, 177, 189, 97, 214, 98, 102, 73, 10, 16, + ], + [ + 183, 227, 227, 192, 119, 10, 155, 143, 64, 60, 249, 165, 240, 39, 31, 197, 159, + 121, 64, 82, 10, 1, 34, 35, 121, 34, 146, 69, 226, 196, 156, 14, + ], + [ + 61, 21, 56, 224, 11, 181, 71, 186, 238, 126, 234, 240, 14, 168, 75, 73, 251, 111, + 175, 85, 108, 9, 77, 2, 88, 249, 24, 235, 53, 96, 51, 15, + ], + ]; + + let expected_xy_coordinates = [ + [ + ark_ff::MontFp!( + "1267955849280145133999011095767946180059440909377398529682813961428156596086" + ), + ark_ff::MontFp!( + "5356565093348124788258444273601808083900527100008973995409157974880178412098" + ), + ], + [ + ark_ff::MontFp!( + "1502379126429822955521756759528876454108853047288874182661923263559139887582" + ), + ark_ff::MontFp!( + "7074060208122316523843780248565740332109149189893811936352820920606931717751" + ), + ], + [ + ark_ff::MontFp!( + "2943006201157313879823661217587757631000260143892726691725524748591717287835" + ), + ark_ff::MontFp!( + "4988568968545687084099497807398918406354768651099165603393269329811556860241" + ), + ], + [ + ark_ff::MontFp!( + "2893226299356126359042735859950249532894422276065676168505232431940642875576" + ), + ark_ff::MontFp!( + "5540423804567408742733533031617546054084724133604190833318816134173899774745" + ), + ], + [ + ark_ff::MontFp!( + "2950911977149336430054248283274523588551527495862004038190631992225597951816" + ), + ark_ff::MontFp!( + "4487595759841081228081250163499667279979722963517149877172642608282938805393" + ), + ], + [ + ark_ff::MontFp!( + "3318574188155535806336376903248065799756521242795466350457330678746659358665" + ), + ark_ff::MontFp!( + "7706453242502782485686954136003233626318476373744684895503194201695334921001" + ), + ], + [ + ark_ff::MontFp!( + "3753408652523927772367064460787503971543824818235418436841486337042861871179" + ), + ark_ff::MontFp!( + "2820605049615187268236268737743168629279853653807906481532750947771625104256" + ), + ], + [ + ark_ff::MontFp!( + "7803875556376973796629423752730968724982795310878526731231718944925551226171" + ), + ark_ff::MontFp!( + "7033839813997913565841973681083930410776455889380940679209912201081069572111" + ), + ], + ]; + + use ark_serialize::CanonicalDeserialize; + + for (ind, input) in inputs.iter().enumerate() { + let input_element = + Fq::deserialize_compressed(&input[..]).expect("encoding of test vector is valid"); + + let expected: Element = crate::smol_curve::element::Element::from_affine( + crate::constants::from_ark_fq(expected_xy_coordinates[ind][0]), + crate::constants::from_ark_fq(expected_xy_coordinates[ind][1]), + ); + + let actual = Element::elligator_map(&input_element); + + assert_eq!(actual, expected); + } + } } diff --git a/src/smol_curve/encoding.rs b/src/smol_curve/encoding.rs new file mode 100644 index 0000000..8a10cc2 --- /dev/null +++ b/src/smol_curve/encoding.rs @@ -0,0 +1,2 @@ +#[derive(Copy, Clone, Default, Eq, Ord, PartialOrd, PartialEq, Debug)] +pub struct Encoding(pub [u8; 32]); diff --git a/src/smol_curve/invsqrt.rs b/src/smol_curve/invsqrt.rs new file mode 100644 index 0000000..4887348 --- /dev/null +++ b/src/smol_curve/invsqrt.rs @@ -0,0 +1,158 @@ +use subtle::{ConditionallySelectable, ConstantTimeEq}; + +use crate::{fields::fq::arkworks_constants::*, Fq}; + +use crate::smol_curve::constants::ZETA; + +impl Fq { + /// For square elements, calculate their square root, otherwise return an undefined element. + /// + /// Based on https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-10.html#name-constant-time-tonelli-shanks + fn our_sqrt(&self) -> Self { + // Constants c1,...,c5 used for square root computation as defined in the above Appendix: + // c1 = TWO_ADICITY + // c2 is not directly used in the computation, it's used to compute c3 + // c3 = TRACE_MINUS_ONE_DIV_TWO_LIMBS; + // c4 is not directly used in the computation, but should match ZETA- + // c5 = c4 ^ c2 + // c5 = QUADRATIC_NON_RESIDUE_TO_TRACE + + // Step 1: z = x^c3 + let mut z = self.pow_le_limbs(&TRACE_MINUS_ONE_DIV_TWO_LIMBS); + + // Step 2: t = z * z * x + let mut t = z * z * self; + + // Step 3: z = z * x; + z = z * self; + + // Step 4: b = t + let mut b = t; + + // Step 5: c = c5 + let mut c = QUADRATIC_NON_RESIDUE_TO_TRACE; + + // Step 6: for i in (c1, c1 - 1, ..., 2): + for i in (2..=TWO_ADICITY).rev() { + // Step 7: for j in (1, 2, ..., i - 2): + for _j in 1..=i - 2 { + // Step 8: b = b * b + b = b * b; + } + + // Step 9: z = CMOV(z, z * c, b != 1) + z = Fq::conditional_select(&z, &(z * c), !b.ct_eq(&Self::one())); + + // Step 10: c = c * c + c = c * c; + + // Step 11: t = CMOV(t, t * c, b != 1) + t = Fq::conditional_select(&t, &(t * c), !b.ct_eq(&Self::one())); + + // Step 12: b = t + b = t; + } + + // Step 13: return z + z + } + + fn pow_le_limbs(&self, limbs: &[u64]) -> Self { + let mut acc = Self::one(); + let mut insert = *self; + for limb in limbs { + for i in 0..64 { + if (limb >> i) & 1 == 1 { + acc *= insert; + } + insert *= insert; + } + } + acc + } + + /// Computes the square root of a ratio of field elements, returning: + /// + /// - `(true, sqrt(num/den))` if `num` and `den` are both nonzero and `num/den` is square; + /// - `(true, 0)` if `num` is zero; + /// - `(false, 0)` if `den` is zero; + /// - `(false, sqrt(zeta*num/den))` if `num` and `den` are both nonzero and `num/den` is nonsquare; + pub fn non_arkworks_sqrt_ratio_zeta(num: &Self, den: &Self) -> (bool, Self) { + if num == &Fq::zero() { + return (true, *num); + } + if den == &Fq::zero() { + return (false, *den); + } + let x = *num / *den; + // Because num was not zero, this will only be 1 or -1 + let symbol = x.pow_le_limbs(&MODULUS_MINUS_ONE_DIV_TWO_LIMBS); + if symbol == Self::one() { + (true, x.our_sqrt()) + } else { + (false, (ZETA * x).our_sqrt()) + } + } +} + +#[cfg(all(test, feature = "arkworks"))] +mod tests { + use super::*; + use ark_ff::Field; + use ark_ff::PrimeField; + use proptest::prelude::*; + + fn fq_strategy() -> impl Strategy { + any::<[u8; 32]>() + .prop_map(|bytes| Fq::from_le_bytes_mod_order(&bytes[..])) + .boxed() + } + + proptest! { + #![proptest_config(ProptestConfig::with_cases(10000))] + #[test] + fn sqrt_ratio_zeta(u in fq_strategy(), v in fq_strategy()) { + if u == Fq::zero() { + assert_eq!(Fq::non_arkworks_sqrt_ratio_zeta(&u, &v), (true, u)); + } else if v == Fq::zero() { + assert_eq!(Fq::non_arkworks_sqrt_ratio_zeta(&u, &v), (false, v)); + } else { + let (was_square, sqrt_zeta_uv) = Fq::non_arkworks_sqrt_ratio_zeta(&u, &v); + let zeta_uv = sqrt_zeta_uv * sqrt_zeta_uv; + if was_square { + // check zeta_uv = u/v + assert_eq!(u, v * zeta_uv); + } else { + // check zeta_uv = zeta * u / v + assert_eq!(ZETA * u, v * zeta_uv); + } + } + } + } + + #[test] + fn sqrt_ratio_edge_cases() { + // u = 0 + assert_eq!( + Fq::non_arkworks_sqrt_ratio_zeta(&Fq::zero(), &Fq::one()), + (true, Fq::zero()) + ); + + // v = 0 + assert_eq!( + Fq::non_arkworks_sqrt_ratio_zeta(&Fq::one(), &Fq::zero()), + (false, Fq::zero()) + ); + } + + proptest! { + #[test] + fn sqrt_matches_arkworks(x in fq_strategy()) { + let arkworks_sqrt = x.sqrt(); + let our_sqrt = x.our_sqrt(); + if arkworks_sqrt.is_some() { + assert_eq!(arkworks_sqrt.unwrap(), our_sqrt); + } + } + } +} diff --git a/src/smol_curve/mod.rs b/src/smol_curve/mod.rs index f7ef0f4..a59fc74 100644 --- a/src/smol_curve/mod.rs +++ b/src/smol_curve/mod.rs @@ -1,2 +1,5 @@ +mod constants; pub mod element; +pub mod encoding; +mod invsqrt; mod ops;