diff --git a/Cargo.lock b/Cargo.lock index 4e0d14119..9a4d4ef3f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -468,6 +468,7 @@ dependencies = [ "generic-array", "hkdf", "hmac", + "memsec", "num-bigint", "num-traits", "pbkdf2", @@ -2307,6 +2308,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "memsec" +version = "0.7.0" +source = "git+https://github.com/dani-garcia/memsec?rev=3d2e272d284442637bac0a7d94f76883960db7e2#3d2e272d284442637bac0a7d94f76883960db7e2" +dependencies = [ + "getrandom", + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "mime" version = "0.3.17" diff --git a/crates/bitwarden-crypto/Cargo.toml b/crates/bitwarden-crypto/Cargo.toml index 3254640fe..5fd9d6144 100644 --- a/crates/bitwarden-crypto/Cargo.toml +++ b/crates/bitwarden-crypto/Cargo.toml @@ -46,6 +46,11 @@ uniffi = { workspace = true, optional = true } uuid = { workspace = true } zeroize = { version = ">=1.7.0, <2.0", features = ["derive", "aarch64"] } +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +memsec = { version = "0.7.0", features = [ + "alloc_ext", +], git = "https://github.com/dani-garcia/memsec", rev = "3d2e272d284442637bac0a7d94f76883960db7e2" } + [dev-dependencies] criterion = "0.5.1" rand_chacha = "0.3.1" diff --git a/crates/bitwarden-crypto/src/error.rs b/crates/bitwarden-crypto/src/error.rs index 2f9a58b8b..1fcd967b8 100644 --- a/crates/bitwarden-crypto/src/error.rs +++ b/crates/bitwarden-crypto/src/error.rs @@ -23,6 +23,10 @@ pub enum CryptoError { MissingKey(Uuid), #[error("The item was missing a required field: {0}")] MissingField(&'static str), + #[error("Missing Key for Ref. {0}")] + MissingKey2(String), + #[error("Crypto store is read-only")] + ReadOnlyCryptoStore, #[error("Insufficient KDF parameters")] InsufficientKdfParameters, diff --git a/crates/bitwarden-crypto/src/keys/encryptable.rs b/crates/bitwarden-crypto/src/keys/encryptable.rs new file mode 100644 index 000000000..66762a641 --- /dev/null +++ b/crates/bitwarden-crypto/src/keys/encryptable.rs @@ -0,0 +1,238 @@ +use super::key_ref::{AsymmetricKeyRef, KeyRef, SymmetricKeyRef}; +use crate::{service::CryptoServiceContext, AsymmetricEncString, CryptoError, EncString}; + +/// This trait should be implemented by any struct capable of knowing which key it needs +/// to encrypt or decrypt itself. +pub trait UsesKey { + fn uses_key(&self) -> Key; +} + +pub trait Encryptable< + SymmKeyRef: SymmetricKeyRef, + AsymmKeyRef: AsymmetricKeyRef, + Key: KeyRef, + Output, +> +{ + fn encrypt( + &self, + ctx: &mut CryptoServiceContext, + key: Key, + ) -> Result; +} + +pub trait Decryptable< + SymmKeyRef: SymmetricKeyRef, + AsymmKeyRef: AsymmetricKeyRef, + Key: KeyRef, + Output, +> +{ + fn decrypt( + &self, + ctx: &mut CryptoServiceContext, + key: Key, + ) -> Result; +} + +// Basic Encryptable/Decryptable implementations to and from bytes + +impl + Decryptable> for EncString +{ + fn decrypt( + &self, + ctx: &mut CryptoServiceContext, + key: SymmKeyRef, + ) -> Result, crate::CryptoError> { + ctx.decrypt_data_with_symmetric_key(key, self) + } +} + +impl + Decryptable> for AsymmetricEncString +{ + fn decrypt( + &self, + ctx: &mut CryptoServiceContext, + key: AsymmKeyRef, + ) -> Result, crate::CryptoError> { + ctx.decrypt_data_with_asymmetric_key(key, self) + } +} + +impl + Encryptable for &[u8] +{ + fn encrypt( + &self, + ctx: &mut CryptoServiceContext, + key: SymmKeyRef, + ) -> Result { + ctx.encrypt_data_with_symmetric_key(key, self) + } +} + +impl + Encryptable for &[u8] +{ + fn encrypt( + &self, + ctx: &mut CryptoServiceContext, + key: AsymmKeyRef, + ) -> Result { + ctx.encrypt_data_with_asymmetric_key(key, self) + } +} + +// Encryptable/Decryptable implementations to and from strings + +impl + Decryptable for EncString +{ + fn decrypt( + &self, + ctx: &mut CryptoServiceContext, + key: SymmKeyRef, + ) -> Result { + let bytes: Vec = self.decrypt(ctx, key)?; + String::from_utf8(bytes).map_err(|_| CryptoError::InvalidUtf8String) + } +} + +impl + Decryptable for AsymmetricEncString +{ + fn decrypt( + &self, + ctx: &mut CryptoServiceContext, + key: AsymmKeyRef, + ) -> Result { + let bytes: Vec = self.decrypt(ctx, key)?; + String::from_utf8(bytes).map_err(|_| CryptoError::InvalidUtf8String) + } +} + +impl + Encryptable for &str +{ + fn encrypt( + &self, + ctx: &mut CryptoServiceContext, + key: SymmKeyRef, + ) -> Result { + self.as_bytes().encrypt(ctx, key) + } +} + +impl + Encryptable for &str +{ + fn encrypt( + &self, + ctx: &mut CryptoServiceContext, + key: AsymmKeyRef, + ) -> Result { + self.as_bytes().encrypt(ctx, key) + } +} + +impl + Encryptable for String +{ + fn encrypt( + &self, + ctx: &mut CryptoServiceContext, + key: SymmKeyRef, + ) -> Result { + self.as_bytes().encrypt(ctx, key) + } +} + +impl + Encryptable for String +{ + fn encrypt( + &self, + ctx: &mut CryptoServiceContext, + key: AsymmKeyRef, + ) -> Result { + self.as_bytes().encrypt(ctx, key) + } +} + +// Generic implementations for Optional values + +impl< + SymmKeyRef: SymmetricKeyRef, + AsymmKeyRef: AsymmetricKeyRef, + Key: KeyRef, + T: Encryptable, + Output, + > Encryptable> for Option +{ + fn encrypt( + &self, + ctx: &mut CryptoServiceContext, + key: Key, + ) -> Result, crate::CryptoError> { + self.as_ref() + .map(|value| value.encrypt(ctx, key)) + .transpose() + } +} + +impl< + SymmKeyRef: SymmetricKeyRef, + AsymmKeyRef: AsymmetricKeyRef, + Key: KeyRef, + T: Decryptable, + Output, + > Decryptable> for Option +{ + fn decrypt( + &self, + ctx: &mut CryptoServiceContext, + key: Key, + ) -> Result, crate::CryptoError> { + self.as_ref() + .map(|value| value.decrypt(ctx, key)) + .transpose() + } +} + +// Generic implementations for Vec values + +impl< + SymmKeyRef: SymmetricKeyRef, + AsymmKeyRef: AsymmetricKeyRef, + Key: KeyRef, + T: Encryptable, + Output, + > Encryptable> for Vec +{ + fn encrypt( + &self, + ctx: &mut CryptoServiceContext, + key: Key, + ) -> Result, crate::CryptoError> { + self.iter().map(|value| value.encrypt(ctx, key)).collect() + } +} + +impl< + SymmKeyRef: SymmetricKeyRef, + AsymmKeyRef: AsymmetricKeyRef, + Key: KeyRef, + T: Decryptable, + Output, + > Decryptable> for Vec +{ + fn decrypt( + &self, + ctx: &mut CryptoServiceContext, + key: Key, + ) -> Result, crate::CryptoError> { + self.iter().map(|value| value.decrypt(ctx, key)).collect() + } +} diff --git a/crates/bitwarden-crypto/src/keys/key_ref.rs b/crates/bitwarden-crypto/src/keys/key_ref.rs new file mode 100644 index 000000000..39ac2cddc --- /dev/null +++ b/crates/bitwarden-crypto/src/keys/key_ref.rs @@ -0,0 +1,89 @@ +use crate::{AsymmetricCryptoKey, SymmetricCryptoKey}; + +// Hide the `KeyRef` trait from the public API, to avoid confusion +// the trait itself needs to be public to reference it in the macro, so wrap it in a hidden module +#[doc(hidden)] +pub mod __internal { + use std::{fmt::Debug, hash::Hash}; + + use zeroize::ZeroizeOnDrop; + + use crate::CryptoKey; + + /// This trait represents a key reference that can be used to identify cryptographic keys in the + /// key store. It is used to abstract over the different types of keys that can be used in + /// the system, an end user would not implement this trait directly, and would instead use + /// the `SymmetricKeyRef` and `AsymmetricKeyRef` traits. + pub trait KeyRef: + Debug + Clone + Copy + Hash + Eq + PartialEq + Ord + PartialOrd + Send + Sync + 'static + { + type KeyValue: CryptoKey + Send + Sync + ZeroizeOnDrop; + + /// Returns whether the key is local to the current context or shared globally by the + /// service. + fn is_local(&self) -> bool; + } +} +pub(crate) use __internal::KeyRef; + +// These traits below are just basic aliases of the `KeyRef` trait, but they allow us to have two +// separate trait bounds + +pub trait SymmetricKeyRef: KeyRef {} +pub trait AsymmetricKeyRef: KeyRef {} + +// Just a small derive_like macro that can be used to generate the key reference enums. +// Example usage: +// ```rust +// key_refs! { +// #[symmetric] +// pub enum KeyRef { +// User, +// Org(Uuid), +// #[local] +// Local(String), +// } +// } +#[macro_export] +macro_rules! key_refs { + ( $( + #[$meta_type:tt] + $(pub)? enum $name:ident { + $( + $( #[$variant_tag:tt] )? + $variant:ident $( ( $inner:ty ) )? + ,)+ + } + )+ ) => { $( + #[derive(std::fmt::Debug, Clone, Copy, std::hash::Hash, Eq, PartialEq, Ord, PartialOrd)] + pub enum $name { $( + $variant $( ($inner) )? + ,)+ } + + impl $crate::key_ref::__internal::KeyRef for $name { + type KeyValue = key_refs!(@key_type $meta_type); + + fn is_local(&self) -> bool { + use $name::*; + match self { $( + key_refs!(@variant_match $variant $( ( $inner ) )?) => + key_refs!(@variant_tag $( $variant_tag )? ), + )+ } + } + } + + key_refs!(@key_trait $meta_type $name); + )+ }; + + ( @key_type symmetric ) => { $crate::SymmetricCryptoKey }; + ( @key_type asymmetric ) => { $crate::AsymmetricCryptoKey }; + + ( @key_trait symmetric $name:ident ) => { impl $crate::key_ref::SymmetricKeyRef for $name {} }; + ( @key_trait asymmetric $name:ident ) => { impl $crate::key_ref::AsymmetricKeyRef for $name {} }; + + ( @variant_match $variant:ident ( $inner:ty ) ) => { $variant (_) }; + ( @variant_match $variant:ident ) => { $variant }; + + ( @variant_tag local ) => { true }; + ( @variant_tag ) => { false }; +} diff --git a/crates/bitwarden-crypto/src/keys/mod.rs b/crates/bitwarden-crypto/src/keys/mod.rs index ac1732966..a75d9dba0 100644 --- a/crates/bitwarden-crypto/src/keys/mod.rs +++ b/crates/bitwarden-crypto/src/keys/mod.rs @@ -1,5 +1,10 @@ mod key_encryptable; pub use key_encryptable::{CryptoKey, KeyContainer, KeyDecryptable, KeyEncryptable, LocateKey}; +mod encryptable; +pub use encryptable::{Decryptable, Encryptable, UsesKey}; +pub mod key_ref; +pub(crate) use key_ref::KeyRef; +pub use key_ref::{AsymmetricKeyRef, SymmetricKeyRef}; mod master_key; pub use master_key::{ default_argon2_iterations, default_argon2_memory, default_argon2_parallelism, diff --git a/crates/bitwarden-crypto/src/lib.rs b/crates/bitwarden-crypto/src/lib.rs index 44efaac30..6eae550f3 100644 --- a/crates/bitwarden-crypto/src/lib.rs +++ b/crates/bitwarden-crypto/src/lib.rs @@ -82,6 +82,7 @@ mod wordlist; pub use wordlist::EFF_LONG_WORD_LIST; mod allocator; pub use allocator::ZeroizingAllocator; +pub mod service; #[cfg(feature = "uniffi")] uniffi::setup_scaffolding!(); diff --git a/crates/bitwarden-crypto/src/service/context.rs b/crates/bitwarden-crypto/src/service/context.rs new file mode 100644 index 000000000..3e966ccbd --- /dev/null +++ b/crates/bitwarden-crypto/src/service/context.rs @@ -0,0 +1,357 @@ +use std::sync::{RwLockReadGuard, RwLockWriteGuard}; + +use rsa::Oaep; +use zeroize::Zeroizing; + +use super::Keys; +use crate::{ + derive_shareable_key, + service::{key_store::KeyStore, AsymmetricKeyRef, SymmetricKeyRef}, + AsymmetricCryptoKey, AsymmetricEncString, CryptoError, EncString, Result, SymmetricCryptoKey, +}; + +// This is to abstract over the read-only and read-write access to the global keys +// inside the CryptoServiceContext. The read-write access should only be used internally +// in this crate to avoid users leaving the crypto store in an inconsistent state, +// but for the moment we have some operations that require access to it. +pub trait GlobalAccessMode<'a, SymmKeyRef: SymmetricKeyRef, AsymmKeyRef: AsymmetricKeyRef> { + fn get(&self) -> &Keys; + fn get_mut(&mut self) -> Result<&mut Keys>; +} + +pub struct ReadOnlyGlobal<'a, SymmKeyRef: SymmetricKeyRef, AsymmKeyRef: AsymmetricKeyRef>( + pub(super) RwLockReadGuard<'a, Keys>, +); + +impl<'a, SymmKeyRef: SymmetricKeyRef, AsymmKeyRef: AsymmetricKeyRef> + GlobalAccessMode<'a, SymmKeyRef, AsymmKeyRef> for ReadOnlyGlobal<'a, SymmKeyRef, AsymmKeyRef> +{ + fn get(&self) -> &Keys { + &self.0 + } + + fn get_mut(&mut self) -> Result<&mut Keys> { + Err(crate::CryptoError::ReadOnlyCryptoStore) + } +} + +pub struct ReadWriteGlobal<'a, SymmKeyRef: SymmetricKeyRef, AsymmKeyRef: AsymmetricKeyRef>( + pub(super) RwLockWriteGuard<'a, Keys>, +); + +impl<'a, SymmKeyRef: SymmetricKeyRef, AsymmKeyRef: AsymmetricKeyRef> + GlobalAccessMode<'a, SymmKeyRef, AsymmKeyRef> for ReadWriteGlobal<'a, SymmKeyRef, AsymmKeyRef> +{ + fn get(&self) -> &Keys { + &self.0 + } + + fn get_mut(&mut self) -> Result<&mut Keys> { + Ok(&mut self.0) + } +} + +pub struct CryptoServiceContext< + 'a, + SymmKeyRef: SymmetricKeyRef, + AsymmKeyRef: AsymmetricKeyRef, + AccessMode: GlobalAccessMode<'a, SymmKeyRef, AsymmKeyRef> = ReadOnlyGlobal< + 'a, + SymmKeyRef, + AsymmKeyRef, + >, +> { + pub(super) global: AccessMode, + + pub(super) local_symmetric_keys: Box>, + pub(super) local_asymmetric_keys: Box>, + + pub(super) _phantom: std::marker::PhantomData<&'a ()>, +} + +impl< + 'a, + SymmKeyRef: SymmetricKeyRef, + AsymmKeyRef: AsymmetricKeyRef, + AccessMode: GlobalAccessMode<'a, SymmKeyRef, AsymmKeyRef>, + > CryptoServiceContext<'a, SymmKeyRef, AsymmKeyRef, AccessMode> +{ + pub fn clear(&mut self) { + // Clear global keys if we have write access + if let Ok(keys) = self.global.get_mut() { + keys.symmetric_keys.clear(); + keys.asymmetric_keys.clear(); + } + + self.local_symmetric_keys.clear(); + self.local_asymmetric_keys.clear(); + } + + pub fn retain_symmetric_keys(&mut self, f: fn(SymmKeyRef) -> bool) { + if let Ok(keys) = self.global.get_mut() { + keys.symmetric_keys.retain(f); + } + self.local_symmetric_keys.retain(f); + } + + pub fn retain_asymmetric_keys(&mut self, f: fn(AsymmKeyRef) -> bool) { + if let Ok(keys) = self.global.get_mut() { + keys.asymmetric_keys.retain(f); + } + self.local_asymmetric_keys.retain(f); + } + + /// TODO: All these encrypt x key with x key look like they need to be made generic, + /// but I haven't found the best way to do that yet. + + pub fn decrypt_symmetric_key_with_symmetric_key( + &mut self, + encryption_key: SymmKeyRef, + new_key_ref: SymmKeyRef, + encrypted_key: &EncString, + ) -> Result { + let mut new_key_material = + self.decrypt_data_with_symmetric_key(encryption_key, encrypted_key)?; + + #[allow(deprecated)] + self.set_symmetric_key( + new_key_ref, + SymmetricCryptoKey::try_from(new_key_material.as_mut_slice())?, + )?; + + // Returning the new key reference for convenience + Ok(new_key_ref) + } + + pub fn encrypt_symmetric_key_with_symmetric_key( + &self, + encryption_key: SymmKeyRef, + key_to_encrypt: SymmKeyRef, + ) -> Result { + let key_to_encrypt = self.get_symmetric_key(key_to_encrypt)?; + self.encrypt_data_with_symmetric_key(encryption_key, &key_to_encrypt.to_vec()) + } + + pub fn decrypt_symmetric_key_with_asymmetric_key( + &mut self, + encryption_key: AsymmKeyRef, + new_key_ref: SymmKeyRef, + encrypted_key: &AsymmetricEncString, + ) -> Result { + let mut new_key_material = + self.decrypt_data_with_asymmetric_key(encryption_key, encrypted_key)?; + + #[allow(deprecated)] + self.set_symmetric_key( + new_key_ref, + SymmetricCryptoKey::try_from(new_key_material.as_mut_slice())?, + )?; + + // Returning the new key reference for convenience + Ok(new_key_ref) + } + + pub fn encrypt_symmetric_key_with_asymmetric_key( + &self, + encryption_key: AsymmKeyRef, + key_to_encrypt: SymmKeyRef, + ) -> Result { + let key_to_encrypt = self.get_symmetric_key(key_to_encrypt)?; + self.encrypt_data_with_asymmetric_key(encryption_key, &key_to_encrypt.to_vec()) + } + + pub fn decrypt_asymmetric_key( + &mut self, + encryption_key: AsymmKeyRef, + new_key_ref: AsymmKeyRef, + encrypted_key: &AsymmetricEncString, + ) -> Result { + let new_key_material = + self.decrypt_data_with_asymmetric_key(encryption_key, encrypted_key)?; + + #[allow(deprecated)] + self.set_asymmetric_key( + new_key_ref, + AsymmetricCryptoKey::from_der(&new_key_material)?, + )?; + + // Returning the new key reference for convenience + Ok(new_key_ref) + } + + pub fn encrypt_asymmetric_key( + &self, + encryption_key: AsymmKeyRef, + key_to_encrypt: AsymmKeyRef, + ) -> Result { + let encryption_key = self.get_asymmetric_key(encryption_key)?; + let key_to_encrypt = self.get_asymmetric_key(key_to_encrypt)?; + + AsymmetricEncString::encrypt_rsa2048_oaep_sha1( + key_to_encrypt.to_der()?.as_slice(), + encryption_key, + ) + } + + pub fn has_symmetric_key(&self, key_ref: SymmKeyRef) -> bool { + self.get_symmetric_key(key_ref).is_ok() + } + + pub fn has_asymmetric_key(&self, key_ref: AsymmKeyRef) -> bool { + self.get_asymmetric_key(key_ref).is_ok() + } + + pub fn generate_symmetric_key(&mut self, key_ref: SymmKeyRef) -> Result { + let key = SymmetricCryptoKey::generate(rand::thread_rng()); + #[allow(deprecated)] + self.set_symmetric_key(key_ref, key)?; + Ok(key_ref) + } + + pub fn derive_shareable_key( + &mut self, + key_ref: SymmKeyRef, + secret: Zeroizing<[u8; 16]>, + name: &str, + info: Option<&str>, + ) -> Result { + #[allow(deprecated)] + self.set_symmetric_key(key_ref, derive_shareable_key(secret, name, info))?; + Ok(key_ref) + } + + #[deprecated(note = "This function should ideally never be used outside this crate")] + pub fn dangerous_get_symmetric_key(&self, key_ref: SymmKeyRef) -> Result<&SymmetricCryptoKey> { + self.get_symmetric_key(key_ref) + } + + #[deprecated(note = "This function should ideally never be used outside this crate")] + pub fn dangerous_get_asymmetric_key( + &self, + key_ref: AsymmKeyRef, + ) -> Result<&AsymmetricCryptoKey> { + self.get_asymmetric_key(key_ref) + } + + fn get_symmetric_key(&self, key_ref: SymmKeyRef) -> Result<&SymmetricCryptoKey> { + if key_ref.is_local() { + self.local_symmetric_keys.get(key_ref) + } else { + self.global.get().symmetric_keys.get(key_ref) + } + .ok_or_else(|| crate::CryptoError::MissingKey2(format!("{key_ref:?}"))) + } + + fn get_asymmetric_key(&self, key_ref: AsymmKeyRef) -> Result<&AsymmetricCryptoKey> { + if key_ref.is_local() { + self.local_asymmetric_keys.get(key_ref) + } else { + self.global.get().asymmetric_keys.get(key_ref) + } + .ok_or_else(|| crate::CryptoError::MissingKey2(format!("{key_ref:?}"))) + } + + #[deprecated(note = "This function should ideally never be used outside this crate")] + pub fn set_symmetric_key( + &mut self, + key_ref: SymmKeyRef, + key: SymmetricCryptoKey, + ) -> Result<()> { + if key_ref.is_local() { + self.local_symmetric_keys.insert(key_ref, key); + } else { + self.global.get_mut()?.symmetric_keys.insert(key_ref, key); + } + Ok(()) + } + + #[deprecated(note = "This function should ideally never be used outside this crate")] + pub fn set_asymmetric_key( + &mut self, + key_ref: AsymmKeyRef, + key: AsymmetricCryptoKey, + ) -> Result<()> { + if key_ref.is_local() { + self.local_asymmetric_keys.insert(key_ref, key); + } else { + self.global.get_mut()?.asymmetric_keys.insert(key_ref, key); + } + Ok(()) + } + + pub(crate) fn decrypt_data_with_symmetric_key( + &self, + key: SymmKeyRef, + data: &EncString, + ) -> Result> { + let key = self.get_symmetric_key(key)?; + + match data { + EncString::AesCbc256_B64 { iv, data } => { + let dec = crate::aes::decrypt_aes256(iv, data.clone(), &key.key)?; + Ok(dec) + } + EncString::AesCbc128_HmacSha256_B64 { iv, mac, data } => { + // TODO: SymmetricCryptoKey is designed to handle 32 byte keys only, but this + // variant uses a 16 byte key This means the key+mac are going to be + // parsed as a single 32 byte key, at the moment we split it manually + // When refactoring the key handling, this should be fixed. + let enc_key = (&key.key[0..16]).into(); + let mac_key = (&key.key[16..32]).into(); + let dec = crate::aes::decrypt_aes128_hmac(iv, mac, data.clone(), mac_key, enc_key)?; + Ok(dec) + } + EncString::AesCbc256_HmacSha256_B64 { iv, mac, data } => { + let mac_key = key.mac_key.as_ref().ok_or(CryptoError::InvalidMac)?; + let dec = + crate::aes::decrypt_aes256_hmac(iv, mac, data.clone(), mac_key, &key.key)?; + Ok(dec) + } + } + } + + pub(crate) fn encrypt_data_with_symmetric_key( + &self, + key: SymmKeyRef, + data: &[u8], + ) -> Result { + let key = self.get_symmetric_key(key)?; + EncString::encrypt_aes256_hmac( + data, + key.mac_key.as_ref().ok_or(CryptoError::InvalidMac)?, + &key.key, + ) + } + + pub(crate) fn decrypt_data_with_asymmetric_key( + &self, + key: AsymmKeyRef, + data: &AsymmetricEncString, + ) -> Result> { + let key = self.get_asymmetric_key(key)?; + + use AsymmetricEncString::*; + match data { + Rsa2048_OaepSha256_B64 { data } => key.key.decrypt(Oaep::new::(), data), + Rsa2048_OaepSha1_B64 { data } => key.key.decrypt(Oaep::new::(), data), + #[allow(deprecated)] + Rsa2048_OaepSha256_HmacSha256_B64 { data, .. } => { + key.key.decrypt(Oaep::new::(), data) + } + #[allow(deprecated)] + Rsa2048_OaepSha1_HmacSha256_B64 { data, .. } => { + key.key.decrypt(Oaep::new::(), data) + } + } + .map_err(|_| CryptoError::KeyDecrypt) + } + + pub(crate) fn encrypt_data_with_asymmetric_key( + &self, + key: AsymmKeyRef, + data: &[u8], + ) -> Result { + let key = self.get_asymmetric_key(key)?; + AsymmetricEncString::encrypt_rsa2048_oaep_sha1(data, key) + } +} diff --git a/crates/bitwarden-crypto/src/service/key_store/implementation/linux_memfd_secret.rs b/crates/bitwarden-crypto/src/service/key_store/implementation/linux_memfd_secret.rs new file mode 100644 index 000000000..49e08c0ac --- /dev/null +++ b/crates/bitwarden-crypto/src/service/key_store/implementation/linux_memfd_secret.rs @@ -0,0 +1,111 @@ +use std::{mem::MaybeUninit, ptr::NonNull, sync::OnceLock}; + +use super::{ + slice::{KeyData, SliceKeyStore}, + KeyRef, +}; + +// This is an in-memory key store that is protected by memfd_secret on Linux 5.14+. +// This should be secure against memory dumps from anything except a malicious kernel driver. +// Note that not all 5.14+ systems have support for memfd_secret enabled, so +// LinuxMemfdSecretKeyStore::new returns an Option. +pub(crate) type LinuxMemfdSecretKeyStore = SliceKeyStore; + +pub(crate) struct MemfdSecretImplKeyData { + ptr: std::ptr::NonNull<[u8]>, + capacity: usize, +} + +// For Send+Sync to be safe, we need to ensure that the memory is only accessed mutably from one +// thread. To do this, we have to make sure that any funcion in `MemfdSecretImplKeyData` that +// accesses the pointer mutably is defined as &mut self, and that the pointer is never copied or +// moved outside the struct. +unsafe impl Send for MemfdSecretImplKeyData {} +unsafe impl Sync for MemfdSecretImplKeyData {} + +impl Drop for MemfdSecretImplKeyData { + fn drop(&mut self) { + unsafe { + memsec::free_memfd_secret(self.ptr); + } + } +} + +impl KeyData for MemfdSecretImplKeyData { + fn is_available() -> bool { + static IS_SUPPORTED: OnceLock = OnceLock::new(); + + *IS_SUPPORTED.get_or_init(|| unsafe { + let Some(ptr) = memsec::memfd_secret_sized(1) else { + return false; + }; + memsec::free_memfd_secret(ptr); + true + }) + } + + fn with_capacity(capacity: usize) -> Self { + let entry_size = std::mem::size_of::>(); + + unsafe { + let ptr: NonNull<[u8]> = memsec::memfd_secret_sized(capacity * entry_size) + .expect("memfd_secret_sized failed"); + + // Initialize the array with Nones using MaybeUninit + let uninit_slice: &mut [MaybeUninit<_>] = std::slice::from_raw_parts_mut( + ptr.as_ptr() as *mut MaybeUninit>, + capacity, + ); + for elem in uninit_slice { + elem.write(None); + } + + MemfdSecretImplKeyData { ptr, capacity } + } + } + + fn get_key_data(&self) -> &[Option<(Key, Key::KeyValue)>] { + let ptr = self.ptr.as_ptr() as *const Option<(Key, Key::KeyValue)>; + // SAFETY: The pointer is valid and points to a valid slice of the correct size. + // This function is &self so it only takes a immutable *const pointer. + unsafe { std::slice::from_raw_parts(ptr, self.capacity) } + } + + fn get_key_data_mut(&mut self) -> &mut [Option<(Key, Key::KeyValue)>] { + let ptr = self.ptr.as_ptr() as *mut Option<(Key, Key::KeyValue)>; + // SAFETY: The pointer is valid and points to a valid slice of the correct size. + // This function is &mut self so it can take a mutable *mut pointer. + unsafe { std::slice::from_raw_parts_mut(ptr, self.capacity) } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::service::key_store::{util::tests::*, KeyStore as _}; + + #[test] + fn test_resize() { + let mut store = LinuxMemfdSecretKeyStore::::with_capacity(1).unwrap(); + + for (idx, key) in [ + TestKey::A, + TestKey::B(10), + TestKey::C, + TestKey::B(7), + TestKey::A, + TestKey::C, + ] + .into_iter() + .enumerate() + { + store.insert(key, TestKeyValue::new(idx)); + } + + assert_eq!(store.get(TestKey::A), Some(&TestKeyValue::new(4))); + assert_eq!(store.get(TestKey::B(10)), Some(&TestKeyValue::new(1))); + assert_eq!(store.get(TestKey::C), Some(&TestKeyValue::new(5))); + assert_eq!(store.get(TestKey::B(7)), Some(&TestKeyValue::new(3))); + assert_eq!(store.get(TestKey::B(20)), None); + } +} diff --git a/crates/bitwarden-crypto/src/service/key_store/implementation/mod.rs b/crates/bitwarden-crypto/src/service/key_store/implementation/mod.rs new file mode 100644 index 000000000..6ecff0cdd --- /dev/null +++ b/crates/bitwarden-crypto/src/service/key_store/implementation/mod.rs @@ -0,0 +1,15 @@ +use super::{slice, KeyStore}; +use crate::service::KeyRef; + +#[cfg(all(target_os = "linux", not(feature = "no-memory-hardening")))] +pub(crate) mod linux_memfd_secret; +pub(crate) mod rust_slice; + +pub fn create_key_store() -> Box> { + #[cfg(all(target_os = "linux", not(feature = "no-memory-hardening")))] + if let Some(key_store) = linux_memfd_secret::LinuxMemfdSecretKeyStore::::new() { + return Box::new(key_store); + } + + Box::new(rust_slice::RustKeyStore::new().expect("RustKeyStore should always be available")) +} diff --git a/crates/bitwarden-crypto/src/service/key_store/implementation/rust_slice.rs b/crates/bitwarden-crypto/src/service/key_store/implementation/rust_slice.rs new file mode 100644 index 000000000..e349a7703 --- /dev/null +++ b/crates/bitwarden-crypto/src/service/key_store/implementation/rust_slice.rs @@ -0,0 +1,101 @@ +use super::{ + slice::{KeyData, SliceKeyStore}, + KeyRef, +}; + +// This is a basic in-memory key store for the cases where we don't have a secure key store +// available. We still make use mlock to protect the memory from being swapped to disk, and we +// zeroize the values when dropped. +pub(crate) type RustKeyStore = SliceKeyStore>; + +pub(crate) struct RustImplKeyData { + #[allow(clippy::type_complexity)] + data: Box<[Option<(Key, Key::KeyValue)>]>, +} + +impl Drop for RustImplKeyData { + fn drop(&mut self) { + #[cfg(all(not(target_arch = "wasm32"), not(feature = "no-memory-hardening")))] + { + use std::mem::MaybeUninit; + + let entry_size = std::mem::size_of::>(); + unsafe { + memsec::munlock( + self.data.as_mut_ptr() as *mut u8, + self.data.len() * entry_size, + ); + + // Note: munlock is zeroing the memory, which leaves the data in an inconsistent + // state. So we need to set it to None again, in case any Drop impl + // expects the data to be correct. + let uninit_slice: &mut [MaybeUninit<_>] = std::slice::from_raw_parts_mut( + self.data.as_mut_ptr() as *mut MaybeUninit>, + self.data.len(), + ); + for elem in uninit_slice { + elem.write(None); + } + } + } + } +} + +impl KeyData for RustImplKeyData { + fn is_available() -> bool { + true + } + + fn with_capacity(capacity: usize) -> Self { + #[allow(unused_mut)] + let mut data: Box<_> = std::iter::repeat_with(|| None).take(capacity).collect(); + + #[cfg(all(not(target_arch = "wasm32"), not(feature = "no-memory-hardening")))] + { + let entry_size = std::mem::size_of::>(); + unsafe { + memsec::mlock(data.as_mut_ptr() as *mut u8, capacity * entry_size); + } + } + RustImplKeyData { data } + } + + fn get_key_data(&self) -> &[Option<(Key, Key::KeyValue)>] { + self.data.as_ref() + } + + fn get_key_data_mut(&mut self) -> &mut [Option<(Key, Key::KeyValue)>] { + self.data.as_mut() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::service::key_store::{slice::tests::*, KeyStore as _}; + + #[test] + fn test_resize() { + let mut store = RustKeyStore::::with_capacity(1).unwrap(); + + for (idx, key) in [ + TestKey::A, + TestKey::B(10), + TestKey::C, + TestKey::B(7), + TestKey::A, + TestKey::C, + ] + .into_iter() + .enumerate() + { + store.insert(key, TestKeyValue::new(idx)); + } + + assert_eq!(store.get(TestKey::A), Some(&TestKeyValue::new(4))); + assert_eq!(store.get(TestKey::B(10)), Some(&TestKeyValue::new(1))); + assert_eq!(store.get(TestKey::C), Some(&TestKeyValue::new(5))); + assert_eq!(store.get(TestKey::B(7)), Some(&TestKeyValue::new(3))); + assert_eq!(store.get(TestKey::B(20)), None); + } +} diff --git a/crates/bitwarden-crypto/src/service/key_store/mod.rs b/crates/bitwarden-crypto/src/service/key_store/mod.rs new file mode 100644 index 000000000..e0d762926 --- /dev/null +++ b/crates/bitwarden-crypto/src/service/key_store/mod.rs @@ -0,0 +1,20 @@ +use zeroize::ZeroizeOnDrop; + +use crate::service::KeyRef; + +mod implementation; +mod slice; + +pub use implementation::create_key_store; + +/// This trait represents a platform that can securely store and return keys. The `SliceKeyStore` +/// implementation is a simple in-memory store with some platform-specific security features. Other +/// implementations could use secure enclaves or HSMs, or OS provided keychains. +pub trait KeyStore: ZeroizeOnDrop + Send + Sync { + fn insert(&mut self, key_ref: Key, key: Key::KeyValue); + fn get(&self, key_ref: Key) -> Option<&Key::KeyValue>; + fn remove(&mut self, key_ref: Key); + fn clear(&mut self); + + fn retain(&mut self, f: fn(Key) -> bool); +} diff --git a/crates/bitwarden-crypto/src/service/key_store/slice.rs b/crates/bitwarden-crypto/src/service/key_store/slice.rs new file mode 100644 index 000000000..4ad83c977 --- /dev/null +++ b/crates/bitwarden-crypto/src/service/key_store/slice.rs @@ -0,0 +1,742 @@ +use std::marker::PhantomData; + +use zeroize::ZeroizeOnDrop; + +use super::KeyStore; +use crate::KeyRef; + +/// This trait represents some data stored sequentially in memory, with a fixed size. +/// We use this to abstract the implementation over Vec/Box<[u8]/NonNull<[u8]>, which +/// helps contain any unsafe code to the implementations of this trait. +/// Implementations of this trait should ensure that the initialized data is protected +/// as much as possible. The data is already Zeroized on Drop, so implementations should +/// only need to worry about removing any protections they've added, or releasing any resources. +#[allow(drop_bounds)] +pub(crate) trait KeyData: Send + Sync + Sized + Drop { + /// Check if the data store is available on this platform. + fn is_available() -> bool; + + /// Initialize a new data store with the given capacity. + /// The data MUST be initialized to all None values, and + /// it's capacity must be equal or greater than the provided value. + fn with_capacity(capacity: usize) -> Self; + + /// Return an immutable slice of the data. It must return the full allocated capacity, with no + /// uninitialized values. + fn get_key_data(&self) -> &[Option<(Key, Key::KeyValue)>]; + + /// Return an mutable slice of the data. It must return the full allocated capacity, with no + /// uninitialized values. + fn get_key_data_mut(&mut self) -> &mut [Option<(Key, Key::KeyValue)>]; +} + +/// This represents a key store over an arbitrary fixed size slice. +/// This is meant to abstract over the different ways to store keys in memory, +/// whether we're using a Vec, a Box<[u8]> or a NonNull. +pub(crate) struct SliceKeyStore> { + // This represents the number of elements in the container, it's always less than or equal to + // the length of `data`. + length: usize, + + // This represents the maximum number of elements that can be stored in the container. + // This is always equal to the length of `data`, but we store it to avoid recomputing it. + capacity: usize, + + // This is the actual data that stores the keys, optional as we can have it not yet + // uninitialized + data: Option, + + _key: PhantomData, +} + +impl> ZeroizeOnDrop for SliceKeyStore {} + +impl> Drop for SliceKeyStore { + fn drop(&mut self) { + self.clear(); + } +} + +impl> KeyStore for SliceKeyStore { + fn insert(&mut self, key_ref: Key, key: Key::KeyValue) { + match self.find_by_key_ref(&key_ref) { + Ok(idx) => { + // Key already exists, we just need to replace the value + let slice = self.get_key_data_mut(); + slice[idx] = Some((key_ref, key)); + } + Err(idx) => { + // Make sure that we have enough capacity, and resize if needed + self.ensure_capacity(1); + + let len = self.length; + let slice = self.get_key_data_mut(); + if idx < len { + // If we're not right at the end, we have to shift all the following elements + // one position to the right + slice[idx..=len].rotate_right(1); + } + slice[idx] = Some((key_ref, key)); + self.length += 1; + } + } + } + + fn get(&self, key_ref: Key) -> Option<&Key::KeyValue> { + self.find_by_key_ref(&key_ref) + .ok() + .and_then(|idx| self.get_key_data().get(idx)) + .and_then(|f| f.as_ref().map(|f| &f.1)) + } + + fn remove(&mut self, key_ref: Key) { + if let Ok(idx) = self.find_by_key_ref(&key_ref) { + let len = self.length; + let slice = self.get_key_data_mut(); + slice[idx] = None; + slice[idx..len].rotate_left(1); + self.length -= 1; + } + } + + fn clear(&mut self) { + let len = self.length; + self.get_key_data_mut()[0..len].fill_with(|| None); + self.length = 0; + } + + fn retain(&mut self, f: fn(Key) -> bool) { + let len = self.length; + let slice = self.get_key_data_mut(); + + let mut removed_elements = 0; + + for value in slice.iter_mut().take(len) { + let key = value + .as_ref() + .map(|e| e.0) + .expect("Values in a slice are always Some"); + + if !f(key) { + *value = None; + removed_elements += 1; + } + } + + // If we haven't removed any elements, we don't need to compact the slice + if removed_elements == 0 { + return; + } + + // Remove all the None values from the middle of the slice + + for idx in 0..len { + if slice[idx].is_none() { + slice[idx..len].rotate_left(1); + } + } + + self.length -= removed_elements; + } +} + +impl> SliceKeyStore { + pub(crate) fn new() -> Option { + Self::with_capacity(0) + } + + pub(crate) fn with_capacity(capacity: usize) -> Option { + if !Data::is_available() { + return None; + } + + // If the capacity is 0, we don't need to allocate any memory. + // This allows us to initialize the container lazily. + if capacity == 0 { + return Some(Self { + length: 0, + capacity: 0, + data: None, + _key: PhantomData, + }); + } + + Some(Self { + length: 0, + capacity, + data: Some(Data::with_capacity(capacity)), + _key: PhantomData, + }) + } + + /// Check if the container has enough capacity to store `new_elements` more elements. + /// If the result is Ok, the container has enough capacity. + /// If it's Err, the container needs to be resized. + /// The error value returns a suggested new capacity. + fn check_capacity(&self, new_elements: usize) -> Result<(), usize> { + let new_size = self.length + new_elements; + + // We still have enough capacity + if new_size <= self.capacity { + Ok(()) + + // This is the first allocation + } else if self.capacity == 0 { + const PAGE_SIZE: usize = 4096; + let entry_size = std::mem::size_of::>(); + + // We're using mlock APIs to protect the memory, which lock at the page level. + // To avoid wasting memory, we want to allocate at least a page. + let entries_per_page = PAGE_SIZE / entry_size; + Err(entries_per_page) + + // We need to resize the container + } else { + // We want to increase the capacity by a multiple to be mostly aligned with page size, + // we also need to make sure that we have enough space for the new elements, so we round + // up + let increase_factor = usize::div_ceil(new_size, self.capacity); + Err(self.capacity * increase_factor) + } + } + + fn ensure_capacity(&mut self, new_elements: usize) { + if let Err(new_capacity) = self.check_capacity(new_elements) { + // Create a new store with the correct capacity and replace self with it + let mut new_self = + Self::with_capacity(new_capacity).expect("Could not allocate new store"); + new_self.copy_from(self); + *self = new_self; + } + } + + // These two are just helper functions to avoid having to deal with the optional Data + // When Data is None we just return empty slices, which don't allow any operations + fn get_key_data(&self) -> &[Option<(Key, Key::KeyValue)>] { + self.data.as_ref().map(|d| d.get_key_data()).unwrap_or(&[]) + } + fn get_key_data_mut(&mut self) -> &mut [Option<(Key, Key::KeyValue)>] { + self.data + .as_mut() + .map(|d| d.get_key_data_mut()) + .unwrap_or(&mut []) + } + + fn find_by_key_ref(&self, key_ref: &Key) -> Result { + // Because we know all the None's are at the end and all the Some values are at the + // beginning, we only need to search for the key in the first `size` elements. + let slice = &self.get_key_data()[..self.length]; + + // This structure is almost always used for reads instead of writes, so we can use a binary + // search to optimize for the read case. + slice.binary_search_by(|k| { + debug_assert!( + k.is_some(), + "We should never have a None value in the middle of the slice" + ); + + match k { + Some((k, _)) => k.cmp(key_ref), + None => std::cmp::Ordering::Greater, + } + }) + } + + pub(crate) fn copy_from(&mut self, other: &mut Self) -> bool { + if other.capacity > self.capacity { + return false; + } + + // Empty the current container + self.clear(); + + let new_length = other.length; + + // Move the data from the other container + let this = self.get_key_data_mut(); + let that = other.get_key_data_mut(); + for idx in 0..new_length { + std::mem::swap(&mut this[idx], &mut that[idx]); + } + + // Update the length + self.length = new_length; + + true + } +} + +#[cfg(test)] +pub(crate) mod tests { + use zeroize::Zeroize; + + use super::*; + use crate::{service::key_store::implementation::rust_slice::RustKeyStore, CryptoKey, KeyRef}; + + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] + pub enum TestKey { + A, + B(u8), + C, + } + #[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] + pub struct TestKeyValue([u8; 16]); + impl zeroize::ZeroizeOnDrop for TestKeyValue {} + impl CryptoKey for TestKeyValue {} + impl TestKeyValue { + pub fn new(value: usize) -> Self { + // Just fill the array with some values + let mut key = [0; 16]; + key[0..8].copy_from_slice(&value.to_le_bytes()); + key[8..16].copy_from_slice(&value.to_be_bytes()); + Self(key) + } + } + + impl Drop for TestKeyValue { + fn drop(&mut self) { + self.0.as_mut().zeroize(); + } + } + + impl KeyRef for TestKey { + type KeyValue = TestKeyValue; + + fn is_local(&self) -> bool { + false + } + } + + #[test] + fn test_slice_container_insertion() { + let mut container = RustKeyStore::::with_capacity(5).unwrap(); + + assert_eq!(container.get_key_data(), [None, None, None, None, None]); + + // Insert one key, which should be at the beginning + container.insert(TestKey::B(10), TestKeyValue::new(110)); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::B(10), TestKeyValue::new(110))), + None, + None, + None, + None + ] + ); + + // Insert a key that should be right after the first one + container.insert(TestKey::C, TestKeyValue::new(1000)); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::B(10), TestKeyValue::new(110))), + Some((TestKey::C, TestKeyValue::new(1000))), + None, + None, + None + ] + ); + + // Insert a key in the middle + container.insert(TestKey::B(20), TestKeyValue::new(210)); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::B(10), TestKeyValue::new(110))), + Some((TestKey::B(20), TestKeyValue::new(210))), + Some((TestKey::C, TestKeyValue::new(1000))), + None, + None + ] + ); + + // Insert a key right at the start + container.insert(TestKey::A, TestKeyValue::new(0)); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::A, TestKeyValue::new(0))), + Some((TestKey::B(10), TestKeyValue::new(110))), + Some((TestKey::B(20), TestKeyValue::new(210))), + Some((TestKey::C, TestKeyValue::new(1000))), + None + ] + ); + + // Insert a key in the middle, which fills the container + container.insert(TestKey::B(30), TestKeyValue::new(310)); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::A, TestKeyValue::new(0))), + Some((TestKey::B(10), TestKeyValue::new(110))), + Some((TestKey::B(20), TestKeyValue::new(210))), + Some((TestKey::B(30), TestKeyValue::new(310))), + Some((TestKey::C, TestKeyValue::new(1000))), + ] + ); + + // Replacing an existing value at the start + container.insert(TestKey::A, TestKeyValue::new(1)); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::A, TestKeyValue::new(1))), + Some((TestKey::B(10), TestKeyValue::new(110))), + Some((TestKey::B(20), TestKeyValue::new(210))), + Some((TestKey::B(30), TestKeyValue::new(310))), + Some((TestKey::C, TestKeyValue::new(1000))), + ] + ); + + // Replacing an existing value at the middle + container.insert(TestKey::B(20), TestKeyValue::new(211)); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::A, TestKeyValue::new(1))), + Some((TestKey::B(10), TestKeyValue::new(110))), + Some((TestKey::B(20), TestKeyValue::new(211))), + Some((TestKey::B(30), TestKeyValue::new(310))), + Some((TestKey::C, TestKeyValue::new(1000))), + ] + ); + + // Replacing an existing value at the end + container.insert(TestKey::C, TestKeyValue::new(1001)); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::A, TestKeyValue::new(1))), + Some((TestKey::B(10), TestKeyValue::new(110))), + Some((TestKey::B(20), TestKeyValue::new(211))), + Some((TestKey::B(30), TestKeyValue::new(310))), + Some((TestKey::C, TestKeyValue::new(1001))), + ] + ); + } + + #[test] + fn test_slice_container_get() { + let mut container = RustKeyStore::::with_capacity(5).unwrap(); + + for (key, value) in [ + (TestKey::A, TestKeyValue::new(1)), + (TestKey::B(10), TestKeyValue::new(110)), + (TestKey::C, TestKeyValue::new(1000)), + ] { + container.insert(key, value); + } + + assert_eq!(container.get(TestKey::A), Some(&TestKeyValue::new(1))); + assert_eq!(container.get(TestKey::B(10)), Some(&TestKeyValue::new(110))); + assert_eq!(container.get(TestKey::B(20)), None); + assert_eq!(container.get(TestKey::C), Some(&TestKeyValue::new(1000))); + } + + #[test] + fn test_slice_container_clear() { + let mut container = RustKeyStore::::with_capacity(5).unwrap(); + + for (key, value) in [ + (TestKey::A, TestKeyValue::new(1)), + (TestKey::B(10), TestKeyValue::new(110)), + (TestKey::B(20), TestKeyValue::new(210)), + (TestKey::B(30), TestKeyValue::new(310)), + (TestKey::C, TestKeyValue::new(1000)), + ] { + container.insert(key, value); + } + + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::A, TestKeyValue::new(1))), + Some((TestKey::B(10), TestKeyValue::new(110))), + Some((TestKey::B(20), TestKeyValue::new(210))), + Some((TestKey::B(30), TestKeyValue::new(310))), + Some((TestKey::C, TestKeyValue::new(1000))), + ] + ); + + container.clear(); + + assert_eq!(container.get_key_data(), [None, None, None, None, None]); + } + + #[test] + fn test_slice_container_ensure_capacity() { + let mut container = RustKeyStore::::with_capacity(5).unwrap(); + + assert_eq!(container.capacity, 5); + assert_eq!(container.length, 0); + + assert_eq!(container.check_capacity(0), Ok(())); + assert_eq!(container.check_capacity(6), Err(10)); + assert_eq!(container.check_capacity(10), Err(10)); + assert_eq!(container.check_capacity(11), Err(15)); + assert_eq!(container.check_capacity(51), Err(55)); + + for (key, value) in [ + (TestKey::A, TestKeyValue::new(1)), + (TestKey::B(10), TestKeyValue::new(110)), + (TestKey::B(20), TestKeyValue::new(210)), + (TestKey::B(30), TestKeyValue::new(310)), + (TestKey::C, TestKeyValue::new(1000)), + ] { + container.insert(key, value); + } + + assert_eq!(container.check_capacity(0), Ok(())); + assert_eq!(container.check_capacity(6), Err(15)); + assert_eq!(container.check_capacity(10), Err(15)); + assert_eq!(container.check_capacity(11), Err(20)); + assert_eq!(container.check_capacity(51), Err(60)); + } + + #[test] + fn test_slice_container_removal() { + let mut container = RustKeyStore::::with_capacity(5).unwrap(); + + for (key, value) in [ + (TestKey::A, TestKeyValue::new(1)), + (TestKey::B(10), TestKeyValue::new(110)), + (TestKey::B(20), TestKeyValue::new(210)), + (TestKey::B(30), TestKeyValue::new(310)), + (TestKey::C, TestKeyValue::new(1000)), + ] { + container.insert(key, value); + } + + // Remove the last element + container.remove(TestKey::C); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::A, TestKeyValue::new(1))), + Some((TestKey::B(10), TestKeyValue::new(110))), + Some((TestKey::B(20), TestKeyValue::new(210))), + Some((TestKey::B(30), TestKeyValue::new(310))), + None, + ] + ); + + // Remove the first element + container.remove(TestKey::A); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::B(10), TestKeyValue::new(110))), + Some((TestKey::B(20), TestKeyValue::new(210))), + Some((TestKey::B(30), TestKeyValue::new(310))), + None, + None + ] + ); + + // Remove a non-existing element + container.remove(TestKey::A); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::B(10), TestKeyValue::new(110))), + Some((TestKey::B(20), TestKeyValue::new(210))), + Some((TestKey::B(30), TestKeyValue::new(310))), + None, + None + ] + ); + + // Remove an element in the middle + container.remove(TestKey::B(20)); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::B(10), TestKeyValue::new(110))), + Some((TestKey::B(30), TestKeyValue::new(310))), + None, + None, + None + ] + ); + + // Remove all the remaining elements + container.remove(TestKey::B(30)); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::B(10), TestKeyValue::new(110))), + None, + None, + None, + None + ] + ); + container.remove(TestKey::B(10)); + assert_eq!(container.get_key_data(), [None, None, None, None, None]); + + // Remove from an empty container + container.remove(TestKey::B(10)); + assert_eq!(container.get_key_data(), [None, None, None, None, None]); + } + + #[test] + fn test_slice_container_retain_removes_one() { + let mut container = RustKeyStore::::with_capacity(5).unwrap(); + + for (key, value) in [ + (TestKey::A, TestKeyValue::new(1)), + (TestKey::B(10), TestKeyValue::new(110)), + (TestKey::B(20), TestKeyValue::new(210)), + (TestKey::B(30), TestKeyValue::new(310)), + (TestKey::C, TestKeyValue::new(1000)), + ] { + container.insert(key, value); + } + + // Remove the last element + container.retain(|k| k != TestKey::C); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::A, TestKeyValue::new(1))), + Some((TestKey::B(10), TestKeyValue::new(110))), + Some((TestKey::B(20), TestKeyValue::new(210))), + Some((TestKey::B(30), TestKeyValue::new(310))), + None, + ] + ); + + // Remove the first element + container.retain(|k| k != TestKey::A); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::B(10), TestKeyValue::new(110))), + Some((TestKey::B(20), TestKeyValue::new(210))), + Some((TestKey::B(30), TestKeyValue::new(310))), + None, + None + ] + ); + + // Remove a non-existing element + container.retain(|k| k != TestKey::A); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::B(10), TestKeyValue::new(110))), + Some((TestKey::B(20), TestKeyValue::new(210))), + Some((TestKey::B(30), TestKeyValue::new(310))), + None, + None + ] + ); + + // Remove an element in the middle + container.retain(|k| k != TestKey::B(20)); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::B(10), TestKeyValue::new(110))), + Some((TestKey::B(30), TestKeyValue::new(310))), + None, + None, + None + ] + ); + + // Remove all the remaining elements + container.retain(|k| k != TestKey::B(30)); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::B(10), TestKeyValue::new(110))), + None, + None, + None, + None + ] + ); + container.retain(|k| k != TestKey::B(10)); + assert_eq!(container.get_key_data(), [None, None, None, None, None]); + + // Remove from an empty container + container.retain(|k| k != TestKey::B(10)); + assert_eq!(container.get_key_data(), [None, None, None, None, None]); + } + + #[test] + fn test_slice_container_retain_removes_none() { + let mut container = RustKeyStore::::with_capacity(5).unwrap(); + + for (key, value) in [ + (TestKey::A, TestKeyValue::new(1)), + (TestKey::B(10), TestKeyValue::new(110)), + (TestKey::B(20), TestKeyValue::new(210)), + (TestKey::B(30), TestKeyValue::new(310)), + (TestKey::C, TestKeyValue::new(1000)), + ] { + container.insert(key, value); + } + + container.retain(|_k| true); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::A, TestKeyValue::new(1))), + Some((TestKey::B(10), TestKeyValue::new(110))), + Some((TestKey::B(20), TestKeyValue::new(210))), + Some((TestKey::B(30), TestKeyValue::new(310))), + Some((TestKey::C, TestKeyValue::new(1000))), + ] + ); + } + + #[test] + fn test_slice_container_retain_removes_some() { + let mut container = RustKeyStore::::with_capacity(5).unwrap(); + + for (key, value) in [ + (TestKey::A, TestKeyValue::new(1)), + (TestKey::B(10), TestKeyValue::new(110)), + (TestKey::B(20), TestKeyValue::new(210)), + (TestKey::B(30), TestKeyValue::new(310)), + (TestKey::C, TestKeyValue::new(1000)), + ] { + container.insert(key, value); + } + + container.retain(|k| matches!(k, TestKey::A | TestKey::B(20) | TestKey::C)); + assert_eq!( + container.get_key_data(), + [ + Some((TestKey::A, TestKeyValue::new(1))), + Some((TestKey::B(20), TestKeyValue::new(210))), + Some((TestKey::C, TestKeyValue::new(1000))), + None, + None, + ] + ); + } + + #[test] + fn test_slice_container_retain_removes_all() { + let mut container = RustKeyStore::::with_capacity(5).unwrap(); + + for (key, value) in [ + (TestKey::A, TestKeyValue::new(1)), + (TestKey::B(10), TestKeyValue::new(110)), + (TestKey::B(20), TestKeyValue::new(210)), + (TestKey::B(30), TestKeyValue::new(310)), + (TestKey::C, TestKeyValue::new(1000)), + ] { + container.insert(key, value); + } + + container.retain(|_k| false); + assert_eq!(container.get_key_data(), [None, None, None, None, None]); + } +} diff --git a/crates/bitwarden-crypto/src/service/mod.rs b/crates/bitwarden-crypto/src/service/mod.rs new file mode 100644 index 000000000..6d94e29bc --- /dev/null +++ b/crates/bitwarden-crypto/src/service/mod.rs @@ -0,0 +1,197 @@ +use std::sync::{Arc, RwLock}; + +use crate::{AsymmetricKeyRef, Decryptable, Encryptable, KeyRef, SymmetricKeyRef, UsesKey}; + +mod context; + +mod key_store; + +use context::ReadWriteGlobal; +pub use context::{CryptoServiceContext, ReadOnlyGlobal}; +pub use key_store::create_key_store; +use key_store::KeyStore; + +#[derive(Clone)] +pub struct CryptoService { + // We use an Arc<> to make it easier to pass this service around, as we can + // clone it instead of passing references + key_stores: Arc>>, +} + +impl std::fmt::Debug + for CryptoService +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CryptoService").finish() + } +} + +pub struct Keys { + symmetric_keys: Box>, + asymmetric_keys: Box>, +} + +impl + CryptoService +{ + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + Self { + key_stores: Arc::new(RwLock::new(Keys { + symmetric_keys: create_key_store(), + asymmetric_keys: create_key_store(), + })), + } + } + + pub fn clear(&self) { + let mut keys = self.key_stores.write().expect("RwLock is poisoned"); + keys.symmetric_keys.clear(); + keys.asymmetric_keys.clear(); + } + + /// Initiate an encryption/decryption context. This context will have read only access to the + /// global keys, and will have its own local key stores with read/write access. This + /// context-local store will be cleared up when the context is dropped. + /// + /// This is an advanced API, use with care. Prefer to instead use + /// `encrypt`/`decrypt`/`encrypt_list`/`decrypt_list` methods. + /// + /// One of the pitfalls of the current implementations is that keys stored in the context-local + /// store only get cleared automatically when dropped, and not between operations. This + /// means that if you are using the same context for multiple operations, you may want to + /// clear it manually between them. + pub fn context(&'_ self) -> CryptoServiceContext<'_, SymmKeyRef, AsymmKeyRef> { + CryptoServiceContext { + global: ReadOnlyGlobal(self.key_stores.read().expect("RwLock is poisoned")), + local_symmetric_keys: create_key_store(), + local_asymmetric_keys: create_key_store(), + _phantom: std::marker::PhantomData, + } + } + + /// Initiate an encryption/decryption context. This context will have MUTABLE access to the + /// global keys, and will have its own local key stores with read/write access. This + /// context-local store will be cleared up when the context is dropped. + /// + /// This is an advanced API, use with care and ONLY when needing to modify the global keys. + /// + /// The same pitfalls as `context` apply here, but with the added risk of accidentally + /// modifying the global keys and leaving the service in an inconsistent state. + /// + /// TODO: We should work towards making this pub(crate) + pub fn context_mut( + &'_ self, + ) -> CryptoServiceContext< + '_, + SymmKeyRef, + AsymmKeyRef, + ReadWriteGlobal<'_, SymmKeyRef, AsymmKeyRef>, + > { + CryptoServiceContext { + global: ReadWriteGlobal(self.key_stores.write().expect("RwLock is poisoned")), + local_symmetric_keys: create_key_store(), + local_asymmetric_keys: create_key_store(), + _phantom: std::marker::PhantomData, + } + } + + // These are just convenience methods to avoid having to call `context` every time + pub fn decrypt< + Key: KeyRef, + Data: Decryptable + UsesKey, + Output, + >( + &self, + data: &Data, + ) -> Result { + let key = data.uses_key(); + data.decrypt(&mut self.context(), key) + } + + pub fn encrypt< + Key: KeyRef, + Data: Encryptable + UsesKey, + Output, + >( + &self, + data: Data, + ) -> Result { + let key = data.uses_key(); + data.encrypt(&mut self.context(), key) + } + + pub fn decrypt_list< + Key: KeyRef, + Data: Decryptable + UsesKey + Send + Sync, + Output: Send + Sync, + >( + &self, + data: &[Data], + ) -> Result, crate::CryptoError> { + use rayon::prelude::*; + + // We want to split all the data between available threads, but at the + // same time we don't want to split it too much if the amount of data is small. + + // In this case, the minimum chunk size is 50. + let chunk_size = usize::max(1 + data.len() / rayon::current_num_threads(), 50); + + let res: Result, _> = data + .par_chunks(chunk_size) + .map(|chunk| { + let mut ctx = self.context(); + + let mut result = Vec::with_capacity(chunk.len()); + + for item in chunk { + let key = item.uses_key(); + result.push(item.decrypt(&mut ctx, key)); + ctx.clear(); + } + + result + }) + .flatten() + .collect(); + + res + } + + pub fn encrypt_list< + Key: KeyRef, + Data: Encryptable + UsesKey + Send + Sync, + Output: Send + Sync, + >( + &self, + data: &[Data], + ) -> Result, crate::CryptoError> { + use rayon::prelude::*; + + // We want to split all the data between available threads, but at the + // same time we don't want to split it too much if the amount of data is small. + + // In this case, the minimum chunk size is 50. + let chunk_size = usize::max(1 + data.len() / rayon::current_num_threads(), 50); + + let res: Result, _> = data + .par_chunks(chunk_size) + .map(|chunk| { + let mut ctx = self.context(); + + let mut result = Vec::with_capacity(chunk.len()); + + for item in chunk { + let key = item.uses_key(); + result.push(item.encrypt(&mut ctx, key)); + ctx.clear(); + } + + result + }) + .flatten() + .collect(); + + res + } +}