Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rsa bigint: Don't store CPU features in modulus/key #1856

Merged
merged 2 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions src/arithmetic/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,8 +719,8 @@ mod tests {
|section, test_case| {
assert_eq!(section, "");

let m = consume_modulus::<M>(test_case, "M", cpu_features);
let m = m.modulus();
let m = consume_modulus::<M>(test_case, "M");
let m = m.modulus(cpu_features);
let expected_result = consume_elem(test_case, "ModExp", &m);
let base = consume_elem(test_case, "A", &m);
let e = {
Expand Down Expand Up @@ -749,8 +749,8 @@ mod tests {
|section, test_case| {
assert_eq!(section, "");

let m = consume_modulus::<M>(test_case, "M", cpu_features);
let m = m.modulus();
let m = consume_modulus::<M>(test_case, "M");
let m = m.modulus(cpu_features);
let expected_result = consume_elem(test_case, "ModMul", &m);
let a = consume_elem(test_case, "A", &m);
let b = consume_elem(test_case, "B", &m);
Expand All @@ -774,8 +774,8 @@ mod tests {
|section, test_case| {
assert_eq!(section, "");

let m = consume_modulus::<M>(test_case, "M", cpu_features);
let m = m.modulus();
let m = consume_modulus::<M>(test_case, "M");
let m = m.modulus(cpu_features);
let expected_result = consume_elem(test_case, "ModSquare", &m);
let a = consume_elem(test_case, "A", &m);

Expand All @@ -799,8 +799,8 @@ mod tests {

struct M {}

let m_ = consume_modulus::<M>(test_case, "M", cpu_features);
let m = m_.modulus();
let m_ = consume_modulus::<M>(test_case, "M");
let m = m_.modulus(cpu_features);
let expected_result = consume_elem(test_case, "R", &m);
let a =
consume_elem_unchecked::<M>(test_case, "A", expected_result.limbs.len() * 2);
Expand All @@ -826,12 +826,13 @@ mod tests {

struct M {}
struct O {}
let m = consume_modulus::<M>(test_case, "m", cpu_features);
let a = consume_elem_unchecked::<O>(test_case, "a", m.modulus().limbs().len());
let expected_result = consume_elem::<M>(test_case, "r", &m.modulus());
let other_modulus_len_bits = m.modulus().len_bits();
let m = consume_modulus::<M>(test_case, "m");
let m = m.modulus(cpu_features);
let a = consume_elem_unchecked::<O>(test_case, "a", m.limbs().len());
let expected_result = consume_elem::<M>(test_case, "r", &m);
let other_modulus_len_bits = m.len_bits();

let actual_result = elem_reduced_once(&a, &m.modulus(), other_modulus_len_bits);
let actual_result = elem_reduced_once(&a, &m, other_modulus_len_bits);
assert_elem_eq(&actual_result, &expected_result);

Ok(())
Expand Down Expand Up @@ -863,13 +864,9 @@ mod tests {
}
}

fn consume_modulus<M>(
test_case: &mut test::TestCase,
name: &str,
cpu_features: cpu::Features,
) -> OwnedModulus<M> {
fn consume_modulus<M>(test_case: &mut test::TestCase, name: &str) -> OwnedModulus<M> {
let value = test_case.consume_bytes(name);
OwnedModulus::from_be_bytes(untrusted::Input::from(&value), cpu_features).unwrap()
OwnedModulus::from_be_bytes(untrusted::Input::from(&value)).unwrap()
}

fn assert_elem_eq<M, E>(a: &Elem<M, E>, b: &Elem<M, E>) {
Expand Down
13 changes: 3 additions & 10 deletions src/arithmetic/bigint/modulus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ pub struct OwnedModulus<M> {
n0: N0,

len_bits: BitLength,

cpu_features: cpu::Features,
}

impl<M: PublicModulus> Clone for OwnedModulus<M> {
Expand All @@ -85,16 +83,12 @@ impl<M: PublicModulus> Clone for OwnedModulus<M> {
limbs: self.limbs.clone(),
n0: self.n0,
len_bits: self.len_bits,
cpu_features: self.cpu_features,
}
}
}

impl<M> OwnedModulus<M> {
pub(crate) fn from_be_bytes(
input: untrusted::Input,
cpu_features: cpu::Features,
) -> Result<Self, error::KeyRejected> {
pub(crate) fn from_be_bytes(input: untrusted::Input) -> Result<Self, error::KeyRejected> {
let n = BoxedLimbs::positive_minimal_width_from_be_bytes(input)?;
if n.len() > MODULUS_MAX_LIMBS {
return Err(error::KeyRejected::too_large());
Expand Down Expand Up @@ -135,7 +129,6 @@ impl<M> OwnedModulus<M> {
limbs: n,
n0,
len_bits,
cpu_features,
})
}

Expand All @@ -158,13 +151,13 @@ impl<M> OwnedModulus<M> {
encoding: PhantomData,
})
}
pub fn modulus(&self) -> Modulus<M> {
pub(crate) fn modulus(&self, cpu_features: cpu::Features) -> Modulus<M> {
Modulus {
limbs: &self.limbs,
n0: self.n0,
len_bits: self.len_bits,
m: PhantomData,
cpu_features: self.cpu_features,
cpu_features,
}
}

Expand Down
48 changes: 30 additions & 18 deletions src/rsa/keypair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ impl KeyPair {
)?;

let n_one = public_key.inner().n().oneRR();
let n = &public_key.inner().n().value();
let n = &public_key.inner().n().value(cpu_features);

// 6.4.1.4.3 says to skip 6.4.1.2.1 Step 2.

Expand Down Expand Up @@ -338,7 +338,7 @@ impl KeyPair {
// First, validate `2**half_n_bits < d`. Since 2**half_n_bits has a bit
// length of half_n_bits + 1, this check gives us 2**half_n_bits <= d,
// and knowing d is odd makes the inequality strict.
let d = bigint::OwnedModulus::<D>::from_be_bytes(d, cpu_features)
let d = bigint::OwnedModulus::<D>::from_be_bytes(d)
.map_err(|_| error::KeyRejected::invalid_component())?;
if !(n_bits.half_rounded_up() < d.len_bits()) {
return Err(KeyRejected::inconsistent_components());
Expand All @@ -350,7 +350,7 @@ impl KeyPair {

// Step 6.b is omitted as explained above.

let pm = &p.modulus.modulus();
let pm = &p.modulus.modulus(cpu_features);

// 6.4.1.4.3 - Step 7.

Expand All @@ -371,8 +371,8 @@ impl KeyPair {

// This should never fail since `n` and `e` were validated above.

let p = PrivateCrtPrime::new(p, dP)?;
let q = PrivateCrtPrime::new(q, dQ)?;
let p = PrivateCrtPrime::new(p, dP, cpu_features)?;
let q = PrivateCrtPrime::new(q, dQ, cpu_features)?;

Ok(Self {
p,
Expand Down Expand Up @@ -416,7 +416,7 @@ impl<M> PrivatePrime<M> {
n_bits: BitLength,
cpu_features: cpu::Features,
) -> Result<Self, KeyRejected> {
let p = bigint::OwnedModulus::from_be_bytes(p, cpu_features)?;
let p = bigint::OwnedModulus::from_be_bytes(p)?;

// 5.c / 5.g:
//
Expand All @@ -438,7 +438,7 @@ impl<M> PrivatePrime<M> {

// Steps 5.e and 5.f are omitted as explained above.

let oneRR = bigint::One::newRR(&p.modulus());
let oneRR = bigint::One::newRR(&p.modulus(cpu_features));

Ok(Self { modulus: p, oneRR })
}
Expand All @@ -453,8 +453,12 @@ struct PrivateCrtPrime<M> {
impl<M> PrivateCrtPrime<M> {
/// Constructs a `PrivateCrtPrime` from the private prime `p` and `dP` where
/// dP == d % (p - 1).
fn new(p: PrivatePrime<M>, dP: untrusted::Input) -> Result<Self, KeyRejected> {
let m = &p.modulus.modulus();
fn new(
p: PrivatePrime<M>,
dP: untrusted::Input,
cpu_features: cpu::Features,
) -> Result<Self, KeyRejected> {
let m = &p.modulus.modulus(cpu_features);
// [NIST SP-800-56B rev. 1] 6.4.1.4.3 - Steps 7.a & 7.b.
let dP = bigint::PrivateExponent::from_be_bytes_padded(dP, m)
.map_err(|error::Unspecified| KeyRejected::inconsistent_components())?;
Expand Down Expand Up @@ -482,8 +486,9 @@ fn elem_exp_consttime<M>(
c: &bigint::Elem<N>,
p: &PrivateCrtPrime<M>,
other_prime_len_bits: BitLength,
cpu_features: cpu::Features,
) -> Result<bigint::Elem<M>, error::Unspecified> {
let m = &p.modulus.modulus();
let m = &p.modulus.modulus(cpu_features);
let c_mod_m = bigint::elem_reduced(c, m, other_prime_len_bits);
let c_mod_m = bigint::elem_mul(p.oneRRR.as_ref(), c_mod_m, m);
bigint::elem_exp_consttime(c_mod_m, &p.exponent, m)
Expand Down Expand Up @@ -523,6 +528,8 @@ impl KeyPair {
msg: &[u8],
signature: &mut [u8],
) -> Result<(), error::Unspecified> {
let cpu_features = cpu::features();

if signature.len() != self.public().modulus_len() {
return Err(error::Unspecified);
}
Expand All @@ -537,7 +544,7 @@ impl KeyPair {
// with Garner's algorithm.

// Steps 1 and 2.
let m = self.private_exponentiate(signature)?;
let m = self.private_exponentiate(signature, cpu_features)?;

// Step 3.
m.fill_be_bytes(signature);
Expand All @@ -552,13 +559,17 @@ impl KeyPair {
/// leaked that would endanger the private key.
///
/// Panics if `in_out` is not `self.public().modulus_len()`.
fn private_exponentiate(&self, base: &[u8]) -> Result<bigint::Elem<N>, error::Unspecified> {
fn private_exponentiate(
&self,
base: &[u8],
cpu_features: cpu::Features,
) -> Result<bigint::Elem<N>, error::Unspecified> {
assert_eq!(base.len(), self.public().modulus_len());

// RFC 8017 Section 5.1.2: RSADP, using the Chinese Remainder Theorem
// with Garner's algorithm.

let n = &self.public.inner().n().value();
let n = &self.public.inner().n().value(cpu_features);
let n_one = self.public.inner().n().oneRR();

// Step 1. The value zero is also rejected.
Expand All @@ -569,14 +580,14 @@ impl KeyPair {

// Step 2.b.i.
let q_bits = self.q.modulus.len_bits();
let m_1 = elem_exp_consttime(&c, &self.p, q_bits)?;
let m_2 = elem_exp_consttime(&c, &self.q, self.p.modulus.len_bits())?;
let m_1 = elem_exp_consttime(&c, &self.p, q_bits, cpu_features)?;
let m_2 = elem_exp_consttime(&c, &self.q, self.p.modulus.len_bits(), cpu_features)?;

// Step 2.b.ii isn't needed since there are only two primes.

// Step 2.b.iii.
let h = {
let p = &self.p.modulus.modulus();
let p = &self.p.modulus.modulus(cpu_features);
let m_2 = bigint::elem_reduced_once(&m_2, p, q_bits);
let m_1_minus_m_2 = bigint::elem_sub(m_1, &m_2, p);
bigint::elem_mul(&self.qInv, m_1_minus_m_2, p)
Expand Down Expand Up @@ -605,7 +616,7 @@ impl KeyPair {
// minimum value, since the relationship of `e` to `d`, `p`, and `q` is
// not verified during `KeyPair` construction.
{
let verify = self.public.inner().exponentiate_elem(&m);
let verify = self.public.inner().exponentiate_elem(&m, cpu_features);
bigint::elem_verify_equal_consttime(&verify, &c)?;
}

Expand All @@ -623,6 +634,7 @@ mod tests {

#[test]
fn test_rsakeypair_private_exponentiate() {
let cpu = cpu::features();
test::run(
test_file!("keypair_private_exponentiate_tests.txt"),
|section, test_case| {
Expand All @@ -645,7 +657,7 @@ mod tests {
let mut padded = vec![0; key.public.modulus_len()];
let zeroes = padded.len() - test_case.len();
padded[zeroes..].copy_from_slice(test_case);
let _: bigint::Elem<_> = key.private_exponentiate(&padded).unwrap();
let _: bigint::Elem<_> = key.private_exponentiate(&padded, cpu).unwrap();
}
Ok(())
},
Expand Down
13 changes: 9 additions & 4 deletions src/rsa/public_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,9 @@ impl Inner {
&self,
base: untrusted::Input,
out_buffer: &'out mut [u8; PUBLIC_KEY_PUBLIC_MODULUS_MAX_LEN],
cpu_features: cpu::Features,
) -> Result<&'out [u8], error::Unspecified> {
let n = &self.n.value();
let n = &self.n.value(cpu_features);

// The encoded value of the base must be the same length as the modulus,
// in bytes.
Expand All @@ -162,7 +163,7 @@ impl Inner {
}

// Step 2.
let m = self.exponentiate_elem(&s);
let m = self.exponentiate_elem(&s, cpu_features);

// Step 3.
Ok(fill_be_bytes_n(m, self.n.len_bits(), out_buffer))
Expand All @@ -171,13 +172,17 @@ impl Inner {
/// Calculates base**e (mod n).
///
/// This is constant-time with respect to `base` only.
pub(super) fn exponentiate_elem(&self, base: &bigint::Elem<N>) -> bigint::Elem<N> {
pub(super) fn exponentiate_elem(
&self,
base: &bigint::Elem<N>,
cpu_features: cpu::Features,
) -> bigint::Elem<N> {
// The exponent was already checked to be at least 3.
let exponent_without_low_bit = NonZeroU64::try_from(self.e.value().get() & !1).unwrap();
// The exponent was already checked to be odd.
debug_assert_ne!(exponent_without_low_bit, self.e.value());

let n = &self.n.value();
let n = &self.n.value(cpu_features);

let base_r = bigint::elem_mul(self.n.oneRR(), base.clone(), n);

Expand Down
8 changes: 4 additions & 4 deletions src/rsa/public_modulus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl PublicModulus {
const MIN_BITS: bits::BitLength = bits::BitLength::from_usize_bits(1024);

// Step 3 / Step c for `n` (out of order).
let value = bigint::OwnedModulus::from_be_bytes(n, cpu_features)?;
let value = bigint::OwnedModulus::from_be_bytes(n)?;
let bits = value.len_bits();

// Step 1 / Step a. XXX: SP800-56Br1 and SP800-89 require the length of
Expand All @@ -52,7 +52,7 @@ impl PublicModulus {
if bits > max_bits {
return Err(error::KeyRejected::too_large());
}
let oneRR = bigint::One::newRR(&value.modulus());
let oneRR = bigint::One::newRR(&value.modulus(cpu_features));

Ok(Self { value, oneRR })
}
Expand All @@ -69,8 +69,8 @@ impl PublicModulus {
self.value.len_bits()
}

pub(super) fn value(&self) -> bigint::Modulus<N> {
self.value.modulus()
pub(super) fn value(&self, cpu_features: cpu::Features) -> bigint::Modulus<N> {
self.value.modulus(cpu_features)
}

pub(super) fn oneRR(&self) -> &bigint::Elem<N, RR> {
Expand Down
Loading