diff --git a/CHANGELOG.md b/CHANGELOG.md index 38ba4ac1..7a4605b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +- Add missing implementation to support Client Certificate Authorization (#135) + ## 0.17.0 - 2024-01-06 - Update to stable rust diff --git a/Cargo.toml b/Cargo.toml index 344fd729..9869aff2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,13 +13,19 @@ exclude = [".github"] [dependencies] atomic-polyfill = "1" -p256 = { version = "0.13.2", default-features = false, features = [ "ecdh", "arithmetic" ] } +p256 = { version = "0.13.2", default-features = false, features = [ + "ecdh", + "ecdsa", + "sha256", +] } rand_core = { version = "0.6.3", default-features = false } hkdf = "0.12.3" hmac = "0.12.1" sha2 = { version = "0.10.2", default-features = false } aes-gcm = { version = "0.10.1", default-features = false, features = ["aes"] } -digest = { version = "0.10.3", default-features = false, features = ["core-api"] } +digest = { version = "0.10.3", default-features = false, features = [ + "core-api", +] } typenum = { version = "1.15.0", default-features = false } heapless = { version = "0.8", default-features = false } heapless_typenum = { package = "heapless", version = "0.6", default-features = false } @@ -28,13 +34,15 @@ embedded-io-async = "0.6" embedded-io-adapters = { version = "0.6", optional = true } generic-array = { version = "0.14", default-features = false } webpki = { package = "rustls-webpki", version = "0.101.7", default-features = false, optional = true } +signature = { version = "2.2", default-features = false } +ecdsa = { version = "0.16.9", default-features = false } # Logging alternatives log = { version = "0.4", optional = true } defmt = { version = "0.3", optional = true } [dev-dependencies] -env_logger = "0.10" +env_logger = "0.11" tokio = { version = "1", features = ["full"] } mio = { version = "0.8.3", features = ["os-poll", "net"] } rustls = "0.21.6" diff --git a/examples/blocking/src/main.rs b/examples/blocking/src/main.rs index a429fa45..1afcaf9d 100644 --- a/examples/blocking/src/main.rs +++ b/examples/blocking/src/main.rs @@ -6,6 +6,27 @@ use rand::rngs::OsRng; use std::net::TcpStream; use std::time::SystemTime; +struct Provider { + rng: OsRng, + verifier: CertVerifier, +} + +impl CryptoProvider for Provider { + type CipherSuite = Aes128GcmSha256; + + type Signature = &'static [u8]; + + fn rng(&mut self) -> impl embedded_tls::CryptoRngCore { + &mut self.rng + } + + fn verifier( + &mut self, + ) -> Result<&mut impl TlsVerifier, embedded_tls::TlsError> { + Ok(&mut self.verifier) + } +} + fn main() { env_logger::init(); let stream = TcpStream::connect("127.0.0.1:12345").expect("error connecting to server"); @@ -14,15 +35,18 @@ fn main() { let mut read_record_buffer = [0; 16384]; let mut write_record_buffer = [0; 16384]; let config = TlsConfig::new().with_server_name("localhost"); - let mut tls: TlsConnection, Aes128GcmSha256> = TlsConnection::new( + let mut tls = TlsConnection::new( FromStd::new(stream), &mut read_record_buffer, &mut write_record_buffer, ); - let mut rng = OsRng; - tls.open::>(TlsContext::new( - &config, &mut rng, + tls.open(TlsContext::new( + &config, + Provider { + rng: OsRng, + verifier: CertVerifier::new(), + }, )) .expect("error establishing TLS connection"); diff --git a/examples/embassy/src/main.rs b/examples/embassy/src/main.rs index 535efd77..8daee608 100644 --- a/examples/embassy/src/main.rs +++ b/examples/embassy/src/main.rs @@ -5,7 +5,7 @@ use embassy_net::{Config, Ipv4Address, Ipv4Cidr, Stack, StackResources}; use embassy_net_tuntap::TunTapDevice; use embassy_time::Duration; use embedded_io_async::Write; -use embedded_tls::{Aes128GcmSha256, NoVerify, TlsConfig, TlsConnection, TlsContext}; +use embedded_tls::{Aes128GcmSha256, TlsConfig, TlsConnection, TlsContext, UnsecureProvider}; use heapless::Vec; use log::*; use rand::{rngs::OsRng, RngCore}; @@ -57,7 +57,7 @@ async fn main_task(spawner: Spawner) { device, config, RESOURCES.init(StackResources::<3>::new()), - seed + seed, )); // Launch network task @@ -81,14 +81,15 @@ async fn main_task(spawner: Spawner) { let mut read_record_buffer = [0; 16384]; let mut write_record_buffer = [0; 16384]; - let mut rng = OsRng; let config = TlsConfig::new().with_server_name("example.com"); - let mut tls: TlsConnection = - TlsConnection::new(socket, &mut read_record_buffer, &mut write_record_buffer); - - tls.open::(TlsContext::new(&config, &mut rng)) - .await - .expect("error establishing TLS connection"); + let mut tls = TlsConnection::new(socket, &mut read_record_buffer, &mut write_record_buffer); + + tls.open(TlsContext::new( + &config, + UnsecureProvider::new::(OsRng), + )) + .await + .expect("error establishing TLS connection"); tls.write_all(b"ping").await.expect("error writing data"); tls.flush().await.expect("error flushing data"); diff --git a/examples/nrf52/src/main.rs b/examples/nrf52/src/main.rs index b46de457..742c06b4 100644 --- a/examples/nrf52/src/main.rs +++ b/examples/nrf52/src/main.rs @@ -17,16 +17,18 @@ use hal::rng::Rng; #[entry] fn main() -> ! { let p = hal::pac::Peripherals::take().unwrap(); - let mut rng = Rng::new(p.RNG); + let rng = Rng::new(p.RNG); defmt::info!("Connected"); let mut read_record_buffer = [0; 16384]; let mut write_record_buffer = [0; 16384]; let config = TlsConfig::new().with_server_name("example.com"); - let mut tls: TlsConnection = - TlsConnection::new(Dummy {}, &mut read_record_buffer, &mut write_record_buffer); + let mut tls = TlsConnection::new(Dummy {}, &mut read_record_buffer, &mut write_record_buffer); - tls.open::(TlsContext::new(&config, &mut rng)) - .expect("error establishing TLS connection"); + tls.open(TlsContext::new( + &config, + UnsecureProvider::new::(rng), + )) + .expect("error establishing TLS connection"); tls.write_all(b"ping").expect("error writing data"); tls.flush().expect("error flushing data"); diff --git a/examples/tokio-psk/src/main.rs b/examples/tokio-psk/src/main.rs index e188cf95..20d97222 100644 --- a/examples/tokio-psk/src/main.rs +++ b/examples/tokio-psk/src/main.rs @@ -19,16 +19,18 @@ async fn main() -> Result<(), Box> { let config = TlsConfig::new() .with_server_name("localhost") .with_psk(&[0xaa, 0xbb, 0xcc, 0xdd], &[b"vader"]); - let mut rng = OsRng; - let mut tls: TlsConnection, Aes128GcmSha256> = TlsConnection::new( + let mut tls = TlsConnection::new( FromTokio::new(stream), &mut read_record_buffer, &mut write_record_buffer, ); - tls.open::(TlsContext::new(&config, &mut rng)) - .await - .expect("error establishing TLS connection"); + tls.open(TlsContext::new( + &config, + UnsecureProvider::new::(OsRng), + )) + .await + .expect("error establishing TLS connection"); tls.write_all(b"ping").await.expect("error writing data"); tls.flush().await.expect("error flushing data"); diff --git a/examples/tokio/src/main.rs b/examples/tokio/src/main.rs index 46bb68fc..1f101dfe 100644 --- a/examples/tokio/src/main.rs +++ b/examples/tokio/src/main.rs @@ -17,16 +17,18 @@ async fn main() -> Result<(), Box> { let mut read_record_buffer = [0; 16384]; let mut write_record_buffer = [0; 16384]; let config = TlsConfig::new().with_server_name("localhost"); - let mut rng = OsRng; - let mut tls: TlsConnection, Aes128GcmSha256> = TlsConnection::new( + let mut tls = TlsConnection::new( FromTokio::new(stream), &mut read_record_buffer, &mut write_record_buffer, ); - tls.open::(TlsContext::new(&config, &mut rng)) - .await - .expect("error establishing TLS connection"); + tls.open(TlsContext::new( + &config, + UnsecureProvider::new::(OsRng), + )) + .await + .expect("error establishing TLS connection"); tls.write_all(b"ping").await.expect("error writing data"); tls.flush().await.expect("error flushing data"); diff --git a/src/asynch.rs b/src/asynch.rs index 92b344b6..5d05dc83 100644 --- a/src/asynch.rs +++ b/src/asynch.rs @@ -12,7 +12,6 @@ use crate::TlsError; use embedded_io::Error as _; use embedded_io::ErrorType; use embedded_io_async::{BufRead, Read as AsyncRead, Write as AsyncWrite}; -use rand_core::{CryptoRng, RngCore}; pub use crate::config::*; #[cfg(feature = "std")] @@ -71,16 +70,20 @@ where /// /// Returns an error if the handshake does not proceed. If an error occurs, the connection /// instance must be recreated. - pub async fn open<'v, RNG, Verifier>( + pub async fn open<'v, Provider>( &mut self, - context: TlsContext<'v, CipherSuite, RNG>, + mut context: TlsContext<'v, Provider>, ) -> Result<(), TlsError> where - RNG: CryptoRng + RngCore, - Verifier: TlsVerifier<'v, CipherSuite>, + Provider: CryptoProvider, { - let mut handshake: Handshake = - Handshake::new(Verifier::new(context.config.server_name)); + let mut handshake: Handshake = Handshake::new(); + if let (Ok(verifier), Some(server_name)) = ( + context.crypto_provider.verifier(), + context.config.server_name, + ) { + verifier.set_hostname_verification(server_name)?; + } let mut state = State::ClientHello; while state != State::ApplicationData { @@ -92,7 +95,7 @@ where &mut self.record_write_buf, &mut self.key_schedule, context.config, - context.rng, + &mut context.crypto_provider, ) .await?; trace!("State {:?} -> {:?}", state, next_state); diff --git a/src/blocking.rs b/src/blocking.rs index 42c7326d..2e6fc83d 100644 --- a/src/blocking.rs +++ b/src/blocking.rs @@ -10,7 +10,6 @@ use crate::split::{SplitState, SplitStateContainer}; use crate::write_buffer::WriteBuffer; use embedded_io::Error as _; use embedded_io::{BufRead, ErrorType, Read, Write}; -use rand_core::{CryptoRng, RngCore}; pub use crate::config::*; #[cfg(feature = "std")] @@ -70,16 +69,20 @@ where /// /// Returns an error if the handshake does not proceed. If an error occurs, the connection /// instance must be recreated. - pub fn open<'v, RNG, Verifier>( + pub fn open<'v, Provider>( &mut self, - context: TlsContext<'v, CipherSuite, RNG>, + mut context: TlsContext<'v, Provider>, ) -> Result<(), TlsError> where - RNG: CryptoRng + RngCore, - Verifier: TlsVerifier<'v, CipherSuite>, + Provider: CryptoProvider, { - let mut handshake: Handshake = - Handshake::new(Verifier::new(context.config.server_name)); + let mut handshake: Handshake = Handshake::new(); + if let (Ok(verifier), Some(server_name)) = ( + context.crypto_provider.verifier(), + context.config.server_name, + ) { + verifier.set_hostname_verification(server_name)?; + } let mut state = State::ClientHello; while state != State::ApplicationData { @@ -90,7 +93,7 @@ where &mut self.record_write_buf, &mut self.key_schedule, context.config, - context.rng, + &mut context.crypto_provider, )?; trace!("State {:?} -> {:?}", state, next_state); state = next_state; diff --git a/src/config.rs b/src/config.rs index 0ed4402e..e147c179 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,17 +1,21 @@ +use core::marker::PhantomData; + use crate::cipher_suites::CipherSuite; use crate::extensions::extension_data::signature_algorithms::SignatureScheme; use crate::extensions::extension_data::supported_groups::NamedGroup; use crate::handshake::certificate::CertificateRef; -pub use crate::handshake::certificate_verify::CertificateVerify; +pub use crate::handshake::certificate_verify::CertificateVerifyRef; use crate::TlsError; use aes_gcm::{AeadInPlace, Aes128Gcm, Aes256Gcm, KeyInit}; -use core::marker::PhantomData; use digest::core_api::BlockSizeUser; use digest::{Digest, FixedOutput, OutputSizeUser, Reset}; +use ecdsa::elliptic_curve::SecretKey; use generic_array::ArrayLength; use heapless::Vec; -use rand_core::{CryptoRng, RngCore}; +use p256::ecdsa::SigningKey; +use rand_core::CryptoRngCore; pub use sha2::Sha256; + pub use sha2::Sha384; use typenum::{Sum, U10, U12, U16, U32}; @@ -66,16 +70,12 @@ impl TlsCipherSuite for Aes256GcmSha384 { /// The verifier is responsible for verifying certificates and signatures. Since certificate verification is /// an expensive process, this trait allows clients to choose how much verification should take place, /// and also to skip the verification if the server is verified through other means (I.e. a pre-shared key). -pub trait TlsVerifier<'a, CipherSuite> +pub trait TlsVerifier where CipherSuite: TlsCipherSuite, { - /// Create a new verification instance. - /// - /// This method is called for every TLS handshake. - /// /// Host verification is enabled by passing a server hostname. - fn new(host: Option<&'a str>) -> Self; + fn set_hostname_verification(&mut self, hostname: &str) -> Result<(), crate::TlsError>; /// Verify a certificate. /// @@ -92,17 +92,17 @@ where /// Verify the certificate signature. /// /// The signature verification uses the transcript and certificate provided earlier to decode the provided signature. - fn verify_signature(&mut self, verify: CertificateVerify) -> Result<(), crate::TlsError>; + fn verify_signature(&mut self, verify: CertificateVerifyRef) -> Result<(), crate::TlsError>; } pub struct NoVerify; -impl<'a, CipherSuite> TlsVerifier<'a, CipherSuite> for NoVerify +impl TlsVerifier for NoVerify where CipherSuite: TlsCipherSuite, { - fn new(_host: Option<&str>) -> Self { - Self + fn set_hostname_verification(&mut self, _hostname: &str) -> Result<(), crate::TlsError> { + Ok(()) } fn verify_certificate( @@ -114,26 +114,22 @@ where Ok(()) } - fn verify_signature(&mut self, _verify: CertificateVerify) -> Result<(), crate::TlsError> { + fn verify_signature(&mut self, _verify: CertificateVerifyRef) -> Result<(), crate::TlsError> { Ok(()) } } #[derive(Debug)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct TlsConfig<'a, CipherSuite> -where - CipherSuite: TlsCipherSuite, -{ - //pub(crate) cipher_suites: Vec, +pub struct TlsConfig<'a> { pub(crate) server_name: Option<&'a str>, pub(crate) psk: Option<(&'a [u8], Vec<&'a [u8], 4>)>, - pub(crate) cipher_suite: PhantomData, pub(crate) signature_schemes: Vec, pub(crate) named_groups: Vec, pub(crate) max_fragment_length: Option, pub(crate) ca: Option>, pub(crate) cert: Option>, + pub(crate) priv_key: &'a [u8], } pub trait TlsClock { @@ -148,35 +144,121 @@ impl TlsClock for NoClock { } } +pub trait CryptoProvider { + type CipherSuite: TlsCipherSuite; + type Signature: AsRef<[u8]>; + + fn rng(&mut self) -> impl CryptoRngCore; + + fn verifier(&mut self) -> Result<&mut impl TlsVerifier, crate::TlsError> { + Err::<&mut NoVerify, _>(crate::TlsError::Unimplemented) + } + + /// Decode and validate a private signing key from `key_der`. + fn signer( + &mut self, + _key_der: &[u8], + ) -> Result<(impl signature::SignerMut, SignatureScheme), crate::TlsError> + { + Err::<(NoSign, _), crate::TlsError>(crate::TlsError::Unimplemented) + } +} + +impl CryptoProvider for &mut T { + type CipherSuite = T::CipherSuite; + + type Signature = T::Signature; + + fn rng(&mut self) -> impl CryptoRngCore { + T::rng(self) + } + + fn verifier(&mut self) -> Result<&mut impl TlsVerifier, crate::TlsError> { + T::verifier(self) + } + + fn signer( + &mut self, + key_der: &[u8], + ) -> Result<(impl signature::SignerMut, SignatureScheme), crate::TlsError> + { + T::signer(self, key_der) + } +} + +pub struct NoSign; + +impl signature::Signer for NoSign { + fn try_sign(&self, _msg: &[u8]) -> Result { + unimplemented!() + } +} + +pub struct UnsecureProvider { + rng: RNG, + _marker: PhantomData, +} + +impl UnsecureProvider<(), RNG> { + pub fn new(rng: RNG) -> UnsecureProvider { + UnsecureProvider { + rng, + _marker: PhantomData, + } + } +} + +impl CryptoProvider + for UnsecureProvider +{ + type CipherSuite = CipherSuite; + type Signature = p256::ecdsa::DerSignature; + + fn rng(&mut self) -> impl CryptoRngCore { + &mut self.rng + } + + fn signer( + &mut self, + key_der: &[u8], + ) -> Result<(impl signature::SignerMut, SignatureScheme), crate::TlsError> + { + let secret_key = + SecretKey::from_sec1_der(key_der).map_err(|_| TlsError::InvalidPrivateKey)?; + + Ok(( + SigningKey::from(&secret_key), + SignatureScheme::EcdsaSecp256r1Sha256, + )) + } +} + #[derive(Debug)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct TlsContext<'a, CipherSuite, RNG> +pub struct TlsContext<'a, Provider> where - CipherSuite: TlsCipherSuite, - RNG: CryptoRng + RngCore + 'a, + Provider: CryptoProvider, { - pub(crate) config: &'a TlsConfig<'a, CipherSuite>, - pub(crate) rng: &'a mut RNG, + pub(crate) config: &'a TlsConfig<'a>, + pub(crate) crypto_provider: Provider, } -impl<'a, CipherSuite, RNG> TlsContext<'a, CipherSuite, RNG> +impl<'a, Provider> TlsContext<'a, Provider> where - CipherSuite: TlsCipherSuite, - RNG: CryptoRng + RngCore + 'a, + Provider: CryptoProvider, { - /// Create a new context with a given config and random number generator reference. - pub fn new(config: &'a TlsConfig<'a, CipherSuite>, rng: &'a mut RNG) -> Self { - Self { config, rng } + /// Create a new context with a given config and a crypto provider. + pub fn new(config: &'a TlsConfig<'a>, crypto_provider: Provider) -> Self { + Self { + config, + crypto_provider, + } } } -impl<'a, CipherSuite> TlsConfig<'a, CipherSuite> -where - CipherSuite: TlsCipherSuite, -{ +impl<'a> TlsConfig<'a> { pub fn new() -> Self { let mut config = Self { - cipher_suite: PhantomData, signature_schemes: Vec::new(), named_groups: Vec::new(), max_fragment_length: None, @@ -184,10 +266,9 @@ where server_name: None, ca: None, cert: None, + priv_key: &[], }; - //config.cipher_suites.push(CipherSuite::TlsAes128GcmSha256); - // if cfg!(feature = "alloc") { config = config.enable_rsa_signatures(); } @@ -280,6 +361,11 @@ where self } + pub fn with_priv_key(mut self, priv_key: &'a [u8]) -> Self { + self.priv_key = priv_key; + self + } + pub fn with_psk(mut self, psk: &'a [u8], identities: &[&'a [u8]]) -> Self { // TODO: Remove potential panic self.psk = Some((psk, unwrap!(Vec::from_slice(identities).ok()))); @@ -287,10 +373,7 @@ where } } -impl<'a, CipherSuite> Default for TlsConfig<'a, CipherSuite> -where - CipherSuite: TlsCipherSuite, -{ +impl<'a> Default for TlsConfig<'a> { fn default() -> Self { TlsConfig::new() } diff --git a/src/connection.rs b/src/connection.rs index 419ffcd2..5a970cf6 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,38 +1,27 @@ -use crate::config::{TlsCipherSuite, TlsConfig, TlsVerifier}; +use crate::config::{TlsCipherSuite, TlsConfig}; use crate::handshake::{ClientHandshake, ServerHandshake}; use crate::key_schedule::{KeySchedule, ReadKeySchedule, WriteKeySchedule}; use crate::record::{ClientRecord, ServerRecord}; use crate::record_reader::RecordReader; use crate::write_buffer::WriteBuffer; -use crate::TlsError; use crate::{ alert::*, handshake::{certificate::CertificateRef, certificate_request::CertificateRequest}, }; +use crate::{CertificateVerify, CryptoProvider, TlsError, TlsVerifier}; use core::fmt::Debug; +use digest::Digest; use embedded_io::Error as _; use embedded_io::{Read as BlockingRead, Write as BlockingWrite}; use embedded_io_async::{Read as AsyncRead, Write as AsyncWrite}; -use rand_core::{CryptoRng, RngCore}; use crate::application_data::ApplicationData; -// use crate::handshake::certificate_request::CertificateRequest; -// use crate::handshake::certificate_verify::CertificateVerify; -// use crate::handshake::encrypted_extensions::EncryptedExtensions; -// use crate::handshake::finished::Finished; -// use crate::handshake::new_session_ticket::NewSessionTicket; -// use crate::handshake::server_hello::ServerHello; use crate::buffer::CryptoBuffer; use digest::generic_array::typenum::Unsigned; use p256::ecdh::EphemeralSecret; +use signature::SignerMut; use crate::content_types::ContentType; -// use crate::handshake::certificate_request::CertificateRequest; -// use crate::handshake::certificate_verify::CertificateVerify; -// use crate::handshake::encrypted_extensions::EncryptedExtensions; -// use crate::handshake::finished::Finished; -// use crate::handshake::new_session_ticket::NewSessionTicket; -// use crate::handshake::server_hello::ServerHello; use crate::parse_buffer::ParseBuffer; use aes_gcm::aead::{AeadCore, AeadInPlace, KeyInit}; @@ -139,27 +128,24 @@ where .map_err(|_| TlsError::InvalidApplicationData) } -pub struct Handshake +pub struct Handshake where CipherSuite: TlsCipherSuite, { traffic_hash: Option, secret: Option, certificate_request: Option, - verifier: Verifier, } -impl<'v, CipherSuite, Verifier> Handshake +impl<'v, CipherSuite> Handshake where CipherSuite: TlsCipherSuite, - Verifier: TlsVerifier<'v, CipherSuite>, { - pub fn new(verifier: Verifier) -> Handshake { + pub fn new() -> Handshake { Handshake { traffic_hash: None, secret: None, certificate_request: None, - verifier, } } } @@ -171,31 +157,31 @@ pub enum State { ServerHello, ServerVerify, ClientCert, + ClientCertVerify, ClientFinished, ApplicationData, } impl<'a> State { #[allow(clippy::too_many_arguments)] - pub async fn process<'v, Transport, CipherSuite, RNG, Verifier>( + pub async fn process<'v, Transport, Provider>( self, transport: &mut Transport, - handshake: &mut Handshake, - record_reader: &mut RecordReader<'_, CipherSuite>, + handshake: &mut Handshake, + record_reader: &mut RecordReader<'_, Provider::CipherSuite>, tx_buf: &mut WriteBuffer<'_>, - key_schedule: &mut KeySchedule, - config: &TlsConfig<'a, CipherSuite>, - rng: &mut RNG, + key_schedule: &mut KeySchedule, + config: &TlsConfig<'a>, + crypto_provider: &mut Provider, ) -> Result where Transport: AsyncRead + AsyncWrite + 'a, - RNG: CryptoRng + RngCore + 'a, - CipherSuite: TlsCipherSuite, - Verifier: TlsVerifier<'v, CipherSuite>, + Provider: CryptoProvider, { match self { State::ClientHello => { - let (state, tx) = client_hello(key_schedule, config, rng, tx_buf, handshake)?; + let (state, tx) = + client_hello(key_schedule, config, crypto_provider, tx_buf, handshake)?; respond(tx, transport, key_schedule).await?; @@ -215,7 +201,8 @@ impl<'a> State { .read(transport, key_schedule.read_state()) .await?; - let result = process_server_verify(handshake, key_schedule, config, record); + let result = + process_server_verify(handshake, key_schedule, config, crypto_provider, record); handle_processing_error(result, transport, key_schedule, tx_buf).await } @@ -226,6 +213,14 @@ impl<'a> State { Ok(state) } + State::ClientCertVerify => { + let (result, tx) = + client_cert_verify(key_schedule, config, crypto_provider, tx_buf)?; + + respond(tx, transport, key_schedule).await?; + + result + } State::ClientFinished => { let tx = client_finished(key_schedule, tx_buf)?; @@ -238,25 +233,24 @@ impl<'a> State { } #[allow(clippy::too_many_arguments)] - pub fn process_blocking<'v, Transport, CipherSuite, RNG, Verifier>( + pub fn process_blocking<'v, Transport, Provider>( self, transport: &mut Transport, - handshake: &mut Handshake, - record_reader: &mut RecordReader<'_, CipherSuite>, + handshake: &mut Handshake, + record_reader: &mut RecordReader<'_, Provider::CipherSuite>, tx_buf: &mut WriteBuffer, - key_schedule: &mut KeySchedule, - config: &TlsConfig<'a, CipherSuite>, - rng: &mut RNG, + key_schedule: &mut KeySchedule, + config: &TlsConfig<'a>, + crypto_provider: &mut Provider, ) -> Result where Transport: BlockingRead + BlockingWrite + 'a, - RNG: CryptoRng + RngCore, - CipherSuite: TlsCipherSuite + 'static, - Verifier: TlsVerifier<'v, CipherSuite>, + Provider: CryptoProvider, { match self { State::ClientHello => { - let (state, tx) = client_hello(key_schedule, config, rng, tx_buf, handshake)?; + let (state, tx) = + client_hello(key_schedule, config, crypto_provider, tx_buf, handshake)?; respond_blocking(tx, transport, key_schedule)?; @@ -272,7 +266,8 @@ impl<'a> State { State::ServerVerify => { let record = record_reader.read_blocking(transport, key_schedule.read_state())?; - let result = process_server_verify(handshake, key_schedule, config, record); + let result = + process_server_verify(handshake, key_schedule, config, crypto_provider, record); handle_processing_error_blocking(result, transport, key_schedule, tx_buf) } @@ -283,6 +278,14 @@ impl<'a> State { Ok(state) } + State::ClientCertVerify => { + let (result, tx) = + client_cert_verify(key_schedule, config, crypto_provider, tx_buf)?; + + respond_blocking(tx, transport, key_schedule)?; + + result + } State::ClientFinished => { let tx = client_finished(key_schedule, tx_buf)?; @@ -383,20 +386,19 @@ where Ok(()) } -fn client_hello<'r, CipherSuite, RNG, Verifier>( - key_schedule: &mut KeySchedule, - config: &TlsConfig, - rng: &mut RNG, +fn client_hello<'r, Provider>( + key_schedule: &mut KeySchedule, + config: &TlsConfig, + crypto_provider: &mut Provider, tx_buf: &'r mut WriteBuffer, - handshake: &mut Handshake, + handshake: &mut Handshake, ) -> Result<(State, &'r [u8]), TlsError> where - RNG: CryptoRng + RngCore, - CipherSuite: TlsCipherSuite, + Provider: CryptoProvider, { key_schedule.initialize_early_secret(config.psk.as_ref().map(|p| p.0))?; let (write_key_schedule, read_key_schedule) = key_schedule.as_split(); - let client_hello = ClientRecord::client_hello(config, rng); + let client_hello = ClientRecord::client_hello(config, crypto_provider); let slice = tx_buf.write_record(&client_hello, write_key_schedule, Some(read_key_schedule))?; if let ClientRecord::Handshake(ClientHandshake::ClientHello(client_hello), _) = client_hello { @@ -407,8 +409,8 @@ where } } -fn process_server_hello( - handshake: &mut Handshake, +fn process_server_hello( + handshake: &mut Handshake, key_schedule: &mut KeySchedule, record: ServerRecord<'_, CipherSuite>, ) -> Result @@ -435,15 +437,15 @@ where } } -fn process_server_verify<'a, 'v, CipherSuite, Verifier>( - handshake: &mut Handshake, - key_schedule: &mut KeySchedule, - config: &TlsConfig<'a, CipherSuite>, - record: ServerRecord<'_, CipherSuite>, +fn process_server_verify<'a, 'v, Provider>( + handshake: &mut Handshake, + key_schedule: &mut KeySchedule, + config: &TlsConfig<'a>, + crypto_provider: &mut Provider, + record: ServerRecord<'_, Provider::CipherSuite>, ) -> Result where - CipherSuite: TlsCipherSuite, - Verifier: TlsVerifier<'v, CipherSuite>, + Provider: CryptoProvider, { let mut state = State::ServerVerify; decrypt_record(key_schedule.read_state(), record, |key_schedule, record| { @@ -453,16 +455,20 @@ where ServerHandshake::EncryptedExtensions(_) => {} ServerHandshake::Certificate(certificate) => { let transcript = key_schedule.transcript_hash(); - handshake.verifier.verify_certificate( - transcript, - &config.ca, - certificate, - )?; - debug!("Certificate verified!"); + if let Ok(verifier) = crypto_provider.verifier() { + verifier.verify_certificate(transcript, &config.ca, certificate)?; + debug!("Certificate verified!"); + } else { + debug!("Certificate verification skipped due to no verifier!"); + } } ServerHandshake::CertificateVerify(verify) => { - handshake.verifier.verify_signature(verify)?; - debug!("Signature verified!"); + if let Ok(verifier) = crypto_provider.verifier() { + verifier.verify_signature(verify)?; + debug!("Signature verified!"); + } else { + debug!("Signature verification skipped due to no verifier!"); + } } ServerHandshake::CertificateRequest(request) => { handshake.certificate_request.replace(request.try_into()?); @@ -495,10 +501,10 @@ where Ok(state) } -fn client_cert<'r, CipherSuite, Verifier>( - handshake: &mut Handshake, +fn client_cert<'r, CipherSuite>( + handshake: &mut Handshake, key_schedule: &mut KeySchedule, - config: &TlsConfig, + config: &TlsConfig, buffer: &'r mut WriteBuffer, ) -> Result<(State, &'r [u8]), TlsError> where @@ -515,9 +521,12 @@ where .request_context; let mut certificate = CertificateRef::with_context(request_context); - if let Some(cert) = &config.cert { + let next_state = if let Some(cert) = &config.cert { certificate.add(cert.into())?; - } + State::ClientCertVerify + } else { + State::ClientFinished + }; let (write_key_schedule, read_key_schedule) = key_schedule.as_split(); buffer @@ -526,7 +535,60 @@ where write_key_schedule, Some(read_key_schedule), ) - .map(|slice| (State::ClientFinished, slice)) + .map(|slice| (next_state, slice)) +} + +fn client_cert_verify<'r, Provider>( + key_schedule: &mut KeySchedule, + config: &TlsConfig, + crypto_provider: &mut Provider, + buffer: &'r mut WriteBuffer, +) -> Result<(Result, &'r [u8]), TlsError> +where + Provider: CryptoProvider, +{ + let (result, record) = match crypto_provider.signer(config.priv_key) { + Ok((mut signing_key, signature_scheme)) => { + let ctx_str = b"TLS 1.3, client CertificateVerify\x00"; + let mut msg: heapless::Vec = heapless::Vec::new(); + msg.resize(64, 0x20).map_err(|_| TlsError::EncodeError)?; + msg.extend_from_slice(ctx_str) + .map_err(|_| TlsError::EncodeError)?; + msg.extend_from_slice(&key_schedule.transcript_hash().clone().finalize()) + .map_err(|_| TlsError::EncodeError)?; + + let signature = signing_key.sign(&msg); + + let certificate_verify = CertificateVerify { + signature_scheme, + signature: heapless::Vec::from_slice(signature.as_ref()).unwrap(), + }; + + ( + Ok(State::ClientFinished), + ClientRecord::Handshake( + ClientHandshake::ClientCertVerify(certificate_verify), + true, + ), + ) + } + Err(e) => { + error!("Failed to obtain signing key: {:?}", e); + ( + Err(e), + ClientRecord::Alert( + Alert::new(AlertLevel::Warning, AlertDescription::CloseNotify), + true, + ), + ) + } + }; + + let (write_key_schedule, read_key_schedule) = key_schedule.as_split(); + + buffer + .write_record(&record, write_key_schedule, Some(read_key_schedule)) + .map(|slice| (result, slice)) } fn client_finished<'r, CipherSuite>( @@ -549,9 +611,9 @@ where ) } -fn client_finished_finalize( +fn client_finished_finalize( key_schedule: &mut KeySchedule, - handshake: &mut Handshake, + handshake: &mut Handshake, ) -> Result where CipherSuite: TlsCipherSuite, diff --git a/src/extensions/messages.rs b/src/extensions/messages.rs index 8fa025d8..c5e4e98b 100644 --- a/src/extensions/messages.rs +++ b/src/extensions/messages.rs @@ -71,7 +71,7 @@ extension_group! { extension_group! { pub enum CertificateRequestExtension<'a> { StatusRequest(Unimplemented<'a>), - SignatureAlgorithms(SignatureAlgorithms<4>), + SignatureAlgorithms(SignatureAlgorithms<16>), SignedCertificateTimestamp(Unimplemented<'a>), CertificateAuthorities(Unimplemented<'a>), OidFilters(Unimplemented<'a>), diff --git a/src/handshake/certificate.rs b/src/handshake/certificate.rs index 862fab11..e3bd9e0c 100644 --- a/src/handshake/certificate.rs +++ b/src/handshake/certificate.rs @@ -49,15 +49,14 @@ impl<'a> CertificateRef<'a> { } pub(crate) fn encode(&self, buf: &mut CryptoBuffer<'_>) -> Result<(), TlsError> { - buf.push(self.request_context.len() as u8) - .map_err(|_| TlsError::EncodeError)?; - buf.extend_from_slice(self.request_context) - .map_err(|_| TlsError::EncodeError)?; - - buf.push_u24(self.entries.len() as u32)?; - for entry in self.entries.iter() { - entry.encode(buf)?; - } + buf.with_u8_length(|buf| buf.extend_from_slice(self.request_context))?; + buf.with_u24_length(|buf| { + for entry in self.entries.iter() { + entry.encode(buf)?; + } + Ok(()) + })?; + Ok(()) } } @@ -100,19 +99,20 @@ impl<'a> CertificateEntryRef<'a> { Ok(result) } - pub(crate) fn encode(&self, _buf: &mut CryptoBuffer<'_>) -> Result<(), TlsError> { - todo!("not implemented"); - /* + pub(crate) fn encode(&self, buf: &mut CryptoBuffer<'_>) -> Result<(), TlsError> { match self { - CertificateEntry::RawPublicKey(key) => { - let entry_len = (key.len() as u32).to_be_bytes(); + &CertificateEntryRef::RawPublicKey(_key) => { + todo!("ASN1_subjectPublicKeyInfo encoding?"); + // buf.with_u24_length(|buf| buf.extend_from_slice(key))?; } - CertificateEntry::X509(cert) => { - let entry_len = (cert.len() as u32).to_be_bytes(); + &CertificateEntryRef::X509(cert) => { + buf.with_u24_length(|buf| buf.extend_from_slice(cert))?; } } + + // Zero extensions for now + buf.push_u16(0)?; Ok(()) - */ } } diff --git a/src/handshake/certificate_request.rs b/src/handshake/certificate_request.rs index 1f9fef1b..1c78e485 100644 --- a/src/handshake/certificate_request.rs +++ b/src/handshake/certificate_request.rs @@ -1,3 +1,4 @@ +use crate::extensions::extension_data::signature_algorithms::SignatureAlgorithms; use crate::extensions::messages::CertificateRequestExtension; use crate::parse_buffer::ParseBuffer; use crate::TlsError; @@ -7,6 +8,7 @@ use heapless::Vec; #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct CertificateRequestRef<'a> { pub(crate) request_context: &'a [u8], + pub(crate) extensions: Vec, 6>, } impl<'a> CertificateRequestRef<'a> { @@ -19,10 +21,11 @@ impl<'a> CertificateRequestRef<'a> { .map_err(|_| TlsError::InvalidCertificateRequest)?; // Validate extensions - CertificateRequestExtension::parse_vector::<6>(buf)?; + let extensions = CertificateRequestExtension::parse_vector::<6>(buf)?; Ok(Self { request_context: request_context.as_slice(), + extensions, }) } } @@ -31,6 +34,7 @@ impl<'a> CertificateRequestRef<'a> { #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct CertificateRequest { pub(crate) request_context: Vec, + pub(crate) signature_algorithms: Option>, } impl<'a> TryFrom> for CertificateRequest { @@ -43,6 +47,18 @@ impl<'a> TryFrom> for CertificateRequest { error!("CertificateRequest: InsufficientSpace"); TlsError::InsufficientSpace })?; - Ok(Self { request_context }) + + let mut signature_algorithms = None; + + for ext in cert.extensions { + if let CertificateRequestExtension::SignatureAlgorithms(algos) = ext { + signature_algorithms = Some(algos) + } + } + + Ok(Self { + request_context, + signature_algorithms, + }) } } diff --git a/src/handshake/certificate_verify.rs b/src/handshake/certificate_verify.rs index 4d176d9e..0ffed4c4 100644 --- a/src/handshake/certificate_verify.rs +++ b/src/handshake/certificate_verify.rs @@ -2,15 +2,17 @@ use crate::extensions::extension_data::signature_algorithms::SignatureScheme; use crate::parse_buffer::ParseBuffer; use crate::TlsError; +use super::CryptoBuffer; + #[derive(Debug)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct CertificateVerify<'a> { +pub struct CertificateVerifyRef<'a> { pub(crate) signature_scheme: SignatureScheme, pub(crate) signature: &'a [u8], } -impl<'a> CertificateVerify<'a> { - pub fn parse(buf: &mut ParseBuffer<'a>) -> Result, TlsError> { +impl<'a> CertificateVerifyRef<'a> { + pub fn parse(buf: &mut ParseBuffer<'a>) -> Result, TlsError> { let signature_scheme = SignatureScheme::parse(buf).map_err(|_| TlsError::InvalidSignatureScheme)?; @@ -25,3 +27,18 @@ impl<'a> CertificateVerify<'a> { }) } } + +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct CertificateVerify { + pub(crate) signature_scheme: SignatureScheme, + pub(crate) signature: heapless::Vec, +} + +impl CertificateVerify { + pub(crate) fn encode(&self, buf: &mut CryptoBuffer<'_>) -> Result<(), TlsError> { + buf.push_u16(self.signature_scheme as _)?; + buf.with_u16_length(|buf| buf.extend_from_slice(self.signature.as_slice()))?; + Ok(()) + } +} diff --git a/src/handshake/client_hello.rs b/src/handshake/client_hello.rs index 7a0c96b5..0c943936 100644 --- a/src/handshake/client_hello.rs +++ b/src/handshake/client_hello.rs @@ -1,11 +1,12 @@ +use core::marker::PhantomData; + use digest::{Digest, OutputSizeUser}; use heapless::Vec; use p256::ecdh::EphemeralSecret; -use p256::elliptic_curve::rand_core::{CryptoRng, RngCore}; +use p256::elliptic_curve::rand_core::RngCore; use p256::EncodedPoint; use typenum::Unsigned; -use crate::buffer::*; use crate::config::{TlsCipherSuite, TlsConfig}; use crate::extensions::extension_data::key_share::{KeyShareClientHello, KeyShareEntry}; use crate::extensions::extension_data::pre_shared_key::PreSharedKeyClientHello; @@ -20,13 +21,15 @@ use crate::extensions::messages::ClientHelloExtension; use crate::handshake::{Random, LEGACY_VERSION}; use crate::key_schedule::{HashOutputSize, WriteKeySchedule}; use crate::TlsError; +use crate::{buffer::*, CryptoProvider}; pub struct ClientHello<'config, CipherSuite> where CipherSuite: TlsCipherSuite, { - pub(crate) config: &'config TlsConfig<'config, CipherSuite>, + pub(crate) config: &'config TlsConfig<'config>, random: Random, + cipher_suite: PhantomData, pub(crate) secret: EphemeralSecret, } @@ -34,17 +37,18 @@ impl<'config, CipherSuite> ClientHello<'config, CipherSuite> where CipherSuite: TlsCipherSuite, { - pub fn new(config: &'config TlsConfig<'config, CipherSuite>, rng: &mut RNG) -> Self + pub fn new(config: &'config TlsConfig<'config>, mut provider: Provider) -> Self where - RNG: CryptoRng + RngCore, + Provider: CryptoProvider, { let mut random = [0; 32]; - rng.fill_bytes(&mut random); + provider.rng().fill_bytes(&mut random); Self { config, random, - secret: EphemeralSecret::random(rng), + cipher_suite: PhantomData, + secret: EphemeralSecret::random(&mut provider.rng()), } } diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index 733356fb..9cf7f0a5 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -2,7 +2,7 @@ use crate::config::TlsCipherSuite; use crate::handshake::certificate::CertificateRef; use crate::handshake::certificate_request::CertificateRequestRef; -use crate::handshake::certificate_verify::CertificateVerify; +use crate::handshake::certificate_verify::{CertificateVerify, CertificateVerifyRef}; use crate::handshake::client_hello::ClientHello; use crate::handshake::encrypted_extensions::EncryptedExtensions; use crate::handshake::finished::Finished; @@ -70,6 +70,7 @@ where CipherSuite: TlsCipherSuite, { ClientCert(CertificateRef<'a>), + ClientCertVerify(CertificateVerify), ClientHello(ClientHello<'config, CipherSuite>), Finished(Finished>), } @@ -83,6 +84,7 @@ where ClientHandshake::ClientHello(_) => HandshakeType::ClientHello, ClientHandshake::Finished(_) => HandshakeType::Finished, ClientHandshake::ClientCert(_) => HandshakeType::Certificate, + ClientHandshake::ClientCertVerify(_) => HandshakeType::CertificateVerify, } } @@ -91,6 +93,7 @@ where ClientHandshake::ClientHello(inner) => inner.encode(buf), ClientHandshake::Finished(inner) => inner.encode(buf), ClientHandshake::ClientCert(inner) => inner.encode(buf), + ClientHandshake::ClientCertVerify(inner) => inner.encode(buf), } } @@ -123,8 +126,7 @@ where ) -> Result<(), TlsError> { let enc_buf = buf.as_slice(); let end = enc_buf.len(); - // Don't include the content type in the slice - transcript.update(&enc_buf[0..end - 1]); + transcript.update(&enc_buf[0..end]); Ok(()) } } @@ -135,7 +137,7 @@ pub enum ServerHandshake<'a, CipherSuite: TlsCipherSuite> { NewSessionTicket(NewSessionTicket<'a>), Certificate(CertificateRef<'a>), CertificateRequest(CertificateRequestRef<'a>), - CertificateVerify(CertificateVerify<'a>), + CertificateVerify(CertificateVerifyRef<'a>), Finished(Finished>), } @@ -224,7 +226,7 @@ impl<'a, CipherSuite: TlsCipherSuite> ServerHandshake<'a, CipherSuite> { } HandshakeType::CertificateVerify => { - ServerHandshake::CertificateVerify(CertificateVerify::parse(buf)?) + ServerHandshake::CertificateVerify(CertificateVerifyRef::parse(buf)?) } HandshakeType::Finished => { ServerHandshake::Finished(Finished::parse(buf, content_len)?) diff --git a/src/lib.rs b/src/lib.rs index c4eb405b..5874f4c0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,19 +13,28 @@ use tokio::net::TcpStream; #[tokio::main] async fn main() { - let stream = TcpStream::connect("http.sandbox.drogue.cloud:443").await.expect("error creating TCP connection"); + let stream = TcpStream::connect("http.sandbox.drogue.cloud:443") + .await + .expect("error creating TCP connection"); println!("TCP connection opened"); let mut read_record_buffer = [0; 16384]; let mut write_record_buffer = [0; 16384]; - let config = TlsConfig::new() - .with_server_name("http.sandbox.drogue.cloud"); - let mut tls: TlsConnection, Aes128GcmSha256> = - TlsConnection::new(FromTokio::new(stream), &mut read_record_buffer, &mut write_record_buffer); + let config = TlsConfig::new().with_server_name("http.sandbox.drogue.cloud"); + let mut tls = TlsConnection::new( + FromTokio::new(stream), + &mut read_record_buffer, + &mut write_record_buffer, + ); // Allows disabling cert verification, in case you are using PSK and don't need it, or are just testing. // otherwise, use embedded_tls::webpki::CertVerifier, which only works on std for now. - tls.open::(TlsContext::new(&config, &mut OsRng)).await.expect("error establishing TLS connection"); + tls.open(TlsContext::new( + &config, + UnsecureProvider::new::(OsRng), + )) + .await + .expect("error establishing TLS connection"); println!("TLS session opened"); } @@ -57,6 +66,11 @@ mod record_reader; mod split; mod write_buffer; +pub use config::UnsecureProvider; +pub use extensions::extension_data::signature_algorithms::SignatureScheme; +pub use handshake::certificate_verify::CertificateVerify; +pub use rand_core::{CryptoRng, CryptoRngCore}; + #[cfg(feature = "webpki")] pub mod webpki; @@ -91,6 +105,7 @@ pub enum TlsError { InvalidCertificate, InvalidCertificateEntry, InvalidCertificateRequest, + InvalidPrivateKey, UnableToInitializeCryptoEngine, ParseError(ParseError), OutOfMemory, diff --git a/src/record.rs b/src/record.rs index 24e6fee6..aa12a179 100644 --- a/src/record.rs +++ b/src/record.rs @@ -1,5 +1,4 @@ use crate::application_data::ApplicationData; -use crate::buffer::*; use crate::change_cipher_spec::ChangeCipherSpec; use crate::config::{TlsCipherSuite, TlsConfig}; use crate::content_types::ContentType; @@ -8,8 +7,8 @@ use crate::handshake::{ClientHandshake, ServerHandshake}; use crate::key_schedule::WriteKeySchedule; use crate::TlsError; use crate::{alert::*, parse_buffer::ParseBuffer}; +use crate::{buffer::*, CryptoProvider}; use core::fmt::Debug; -use rand_core::{CryptoRng, RngCore}; pub type Encrypted = bool; @@ -103,15 +102,15 @@ where } } - pub fn client_hello( - config: &'config TlsConfig<'config, CipherSuite>, - rng: &mut RNG, + pub fn client_hello( + config: &'config TlsConfig<'config>, + provider: &mut Provider, ) -> Self where - RNG: CryptoRng + RngCore, + Provider: CryptoProvider, { ClientRecord::Handshake( - ClientHandshake::ClientHello(ClientHello::new(config, rng)), + ClientHandshake::ClientHello(ClientHello::new(config, provider)), false, ) } diff --git a/src/webpki.rs b/src/webpki.rs index 9079cee7..2734d76b 100644 --- a/src/webpki.rs +++ b/src/webpki.rs @@ -4,7 +4,7 @@ use crate::handshake::{ certificate::{ Certificate as OwnedCertificate, CertificateEntryRef, CertificateRef as ServerCertificate, }, - certificate_verify::CertificateVerify, + certificate_verify::CertificateVerifyRef, }; use crate::TlsError; use core::marker::PhantomData; @@ -89,31 +89,44 @@ static ALL_SIGALGS: &[&webpki::SignatureAlgorithm] = &[ &webpki::ED25519, ]; -pub struct CertVerifier<'a, CipherSuite, Clock, const CERT_SIZE: usize> +pub struct CertVerifier where Clock: TlsClock, CipherSuite: TlsCipherSuite, { - host: Option<&'a str>, + host: Option>, certificate_transcript: Option, certificate: Option>, _clock: PhantomData, } -impl<'a, CipherSuite, Clock, const CERT_SIZE: usize> TlsVerifier<'a, CipherSuite> - for CertVerifier<'a, CipherSuite, Clock, CERT_SIZE> +impl CertVerifier where - CipherSuite: TlsCipherSuite, Clock: TlsClock, + CipherSuite: TlsCipherSuite, { - fn new(host: Option<&'a str>) -> Self { + pub fn new() -> Self { Self { - host, + host: None, certificate_transcript: None, certificate: None, _clock: PhantomData, } } +} + +impl TlsVerifier + for CertVerifier +where + CipherSuite: TlsCipherSuite, + Clock: TlsClock, +{ + fn set_hostname_verification(&mut self, hostname: &str) -> Result<(), TlsError> { + self.host.replace( + heapless::String::try_from(hostname).map_err(|_| TlsError::InsufficientSpace)?, + ); + Ok(()) + } fn verify_certificate( &mut self, @@ -121,13 +134,13 @@ where ca: &Option, cert: ServerCertificate, ) -> Result<(), TlsError> { - verify_certificate(self.host.clone(), ca, &cert, Clock::now())?; + verify_certificate(self.host.as_deref(), ca, &cert, Clock::now())?; self.certificate.replace(cert.try_into()?); self.certificate_transcript.replace(transcript.clone()); Ok(()) } - fn verify_signature(&mut self, verify: CertificateVerify) -> Result<(), TlsError> { + fn verify_signature(&mut self, verify: CertificateVerifyRef) -> Result<(), TlsError> { let handshake_hash = unwrap!(self.certificate_transcript.take()); let ctx_str = b"TLS 1.3, server CertificateVerify\x00"; let mut msg: Vec = Vec::new(); @@ -146,7 +159,7 @@ where fn verify_signature( message: &[u8], certificate: ServerCertificate, - verify: CertificateVerify, + verify: CertificateVerifyRef, ) -> Result<(), TlsError> { let mut verified = false; if !certificate.entries.is_empty() { diff --git a/tests/client_cert_test.rs b/tests/client_cert_test.rs new file mode 100644 index 00000000..d8b6d9a0 --- /dev/null +++ b/tests/client_cert_test.rs @@ -0,0 +1,170 @@ +use ecdsa::elliptic_curve::SecretKey; +use embedded_io_adapters::tokio_1::FromTokio; +use embedded_tls::{CryptoProvider, SignatureScheme}; +use p256::ecdsa::SigningKey; +use rand::rngs::OsRng; +use rand_core::CryptoRngCore; +use rustls::server::AllowAnyAuthenticatedClient; +use std::net::SocketAddr; +use std::sync::Once; + +mod tlsserver; + +static LOG_INIT: Once = Once::new(); +static INIT: Once = Once::new(); +static mut ADDR: Option = None; + +fn init_log() { + LOG_INIT.call_once(|| { + env_logger::init(); + }); +} + +fn setup() -> SocketAddr { + use mio::net::TcpListener; + init_log(); + INIT.call_once(|| { + let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap(); + + let listener = TcpListener::bind(addr).expect("cannot listen on port"); + let addr = listener + .local_addr() + .expect("error retrieving socket address"); + + std::thread::spawn(move || { + use tlsserver::*; + + let versions = &[&rustls::version::TLS13]; + + let test_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests"); + + let ca = load_certs(&test_dir.join("data").join("ca-cert.pem")); + let certs = load_certs(&test_dir.join("data").join("server-cert.pem")); + let privkey = load_private_key(&test_dir.join("data").join("server-key.pem")); + + let mut client_auth_roots = rustls::RootCertStore::empty(); + for root in ca.iter() { + client_auth_roots.add(root).unwrap() + } + + let client_cert_verifier = AllowAnyAuthenticatedClient::new(client_auth_roots); + + let config = rustls::ServerConfig::builder() + .with_cipher_suites(rustls::ALL_CIPHER_SUITES) + .with_kx_groups(&rustls::ALL_KX_GROUPS) + .with_protocol_versions(versions) + .unwrap() + .with_client_cert_verifier(client_cert_verifier.boxed()) + .with_single_cert(certs, privkey) + .unwrap(); + + run_with_config(listener, config); + }); + unsafe { ADDR.replace(addr) }; + }); + unsafe { ADDR.unwrap() } +} + +#[derive(Default)] +struct Provider { + rng: OsRng, +} + +impl CryptoProvider for Provider { + type CipherSuite = embedded_tls::Aes128GcmSha256; + type Signature = p256::ecdsa::DerSignature; + + fn rng(&mut self) -> impl CryptoRngCore { + &mut self.rng + } + + fn signer( + &mut self, + key_der: &[u8], + ) -> Result<(impl signature::SignerMut, SignatureScheme), embedded_tls::TlsError> + { + let secret_key = SecretKey::from_sec1_der(key_der) + .map_err(|_| embedded_tls::TlsError::InvalidPrivateKey)?; + + Ok(( + SigningKey::from(&secret_key), + SignatureScheme::EcdsaSecp256r1Sha256, + )) + } +} + +#[tokio::test] +async fn test_client_certificate_auth() { + use embedded_tls::*; + use tokio::net::TcpStream; + let addr = setup(); + + let ca_pem = include_str!("data/ca-cert.pem"); + let ca_der = pem_parser::pem_to_der(ca_pem); + + let client_cert_pem = include_str!("data/client-cert.pem"); + let client_cert_der = pem_parser::pem_to_der(client_cert_pem); + + let private_key_pem = include_str!("data/client-key.pem"); + let private_key_der = pem_parser::pem_to_der(private_key_pem); + + let stream = TcpStream::connect(addr) + .await + .expect("error connecting to server"); + + log::info!("Connected"); + let mut read_record_buffer = [0; 16384]; + let mut write_record_buffer = [0; 16384]; + let config = TlsConfig::new() + .with_ca(Certificate::X509(&ca_der)) + .with_cert(Certificate::X509(&client_cert_der)) + .with_priv_key(&private_key_der) + .with_server_name("factbird.com"); + + let mut tls = TlsConnection::new( + FromTokio::new(stream), + &mut read_record_buffer, + &mut write_record_buffer, + ); + + log::info!("SIZE of connection is {}", core::mem::size_of_val(&tls)); + + let mut provider = Provider::default(); + let open_fut = tls.open(TlsContext::new(&config, &mut provider)); + log::info!("SIZE of open fut is {}", core::mem::size_of_val(&open_fut)); + open_fut.await.expect("error establishing TLS connection"); + log::info!("Established"); + + let write_fut = tls.write(b"ping"); + log::info!( + "SIZE of write fut is {}", + core::mem::size_of_val(&write_fut) + ); + write_fut.await.expect("error writing data"); + tls.flush().await.expect("error flushing data"); + + // Make sure reading into a 0 length buffer doesn't loop + let mut rx_buf = [0; 0]; + let read_fut = tls.read(&mut rx_buf); + log::info!("SIZE of read fut is {}", core::mem::size_of_val(&read_fut)); + let sz = read_fut.await.expect("error reading data"); + assert_eq!(sz, 0); + + let mut rx_buf = [0; 4096]; + let read_fut = tls.read(&mut rx_buf); + log::info!("SIZE of read fut is {}", core::mem::size_of_val(&read_fut)); + let sz = read_fut.await.expect("error reading data"); + assert_eq!(4, sz); + assert_eq!(b"ping", &rx_buf[..sz]); + log::info!("Read {} bytes: {:?}", sz, &rx_buf[..sz]); + + // Test that embedded-tls doesn't block if the buffer is empty. + let mut rx_buf = [0; 0]; + let sz = tls.read(&mut rx_buf).await.expect("error reading data"); + assert_eq!(sz, 0); + + tls.close() + .await + .map_err(|(_, e)| e) + .expect("error closing session"); +} diff --git a/tests/client_test.rs b/tests/client_test.rs index 63483559..6c33b6b1 100644 --- a/tests/client_test.rs +++ b/tests/client_test.rs @@ -38,7 +38,6 @@ fn setup() -> SocketAddr { unsafe { ADDR.unwrap() } } -#[ignore] #[tokio::test] async fn test_google() { use embedded_tls::*; @@ -55,14 +54,16 @@ async fn test_google() { let mut write_record_buffer = [0; 16384]; let config = TlsConfig::new().with_server_name("google.com"); - let mut tls: TlsConnection, Aes128GcmSha256> = TlsConnection::new( + let mut tls = TlsConnection::new( FromTokio::new(stream), &mut read_record_buffer, &mut write_record_buffer, ); - let mut rng = OsRng; - let open_fut = tls.open::(TlsContext::new(&config, &mut rng)); + let open_fut = tls.open(TlsContext::new( + &config, + UnsecureProvider::new::(OsRng), + )); log::info!("SIZE of open fut is {}", core::mem::size_of_val(&open_fut)); open_fut.await.expect("error establishing TLS connection"); log::info!("Established"); @@ -101,7 +102,7 @@ async fn test_ping() { .with_ca(Certificate::X509(&der[..])) .with_server_name("localhost"); - let mut tls: TlsConnection, Aes128GcmSha256> = TlsConnection::new( + let mut tls = TlsConnection::new( FromTokio::new(stream), &mut read_record_buffer, &mut write_record_buffer, @@ -109,8 +110,10 @@ async fn test_ping() { log::info!("SIZE of connection is {}", core::mem::size_of_val(&tls)); - let mut rng = OsRng; - let open_fut = tls.open::(TlsContext::new(&config, &mut rng)); + let open_fut = tls.open(TlsContext::new( + &config, + UnsecureProvider::new::(OsRng), + )); log::info!("SIZE of open fut is {}", core::mem::size_of_val(&open_fut)); open_fut.await.expect("error establishing TLS connection"); log::info!("Established"); @@ -168,7 +171,7 @@ async fn test_ping_nocopy() { .with_ca(Certificate::X509(&der[..])) .with_server_name("localhost"); - let mut tls: TlsConnection, Aes128GcmSha256> = TlsConnection::new( + let mut tls = TlsConnection::new( FromTokio::new(stream), &mut read_record_buffer, &mut write_record_buffer, @@ -176,8 +179,10 @@ async fn test_ping_nocopy() { log::info!("SIZE of connection is {}", core::mem::size_of_val(&tls)); - let mut rng = OsRng; - let open_fut = tls.open::(TlsContext::new(&config, &mut rng)); + let open_fut = tls.open(TlsContext::new( + &config, + UnsecureProvider::new::(OsRng), + )); log::info!("SIZE of open fut is {}", core::mem::size_of_val(&open_fut)); open_fut.await.expect("error establishing TLS connection"); log::info!("Established"); @@ -238,15 +243,17 @@ async fn test_ping_nocopy_bufread() { .with_ca(Certificate::X509(&der[..])) .with_server_name("localhost"); - let mut tls: TlsConnection, Aes128GcmSha256> = TlsConnection::new( + let mut tls = TlsConnection::new( FromTokio::new(stream), &mut read_record_buffer, &mut write_record_buffer, ); - - tls.open::(TlsContext::new(&config, &mut OsRng)) - .await - .expect("error establishing TLS connection"); + tls.open(TlsContext::new( + &config, + UnsecureProvider::new::(OsRng), + )) + .await + .expect("error establishing TLS connection"); log::info!("Established"); tls.write(b"ping").await.expect("error writing data"); @@ -288,9 +295,11 @@ fn test_blocking_ping() { &mut read_record_buffer, &mut write_record_buffer, ); - - tls.open::(TlsContext::new(&config, &mut OsRng)) - .expect("error establishing TLS connection"); + tls.open(TlsContext::new( + &config, + UnsecureProvider::new::(OsRng), + )) + .expect("error establishing TLS connection"); log::info!("Established"); tls.write(b"ping").expect("error writing data"); @@ -339,9 +348,11 @@ fn test_blocking_ping_nocopy() { &mut read_record_buffer, &mut write_record_buffer, ); - - tls.open::(TlsContext::new(&config, &mut OsRng)) - .expect("error establishing TLS connection"); + tls.open(TlsContext::new( + &config, + UnsecureProvider::new::(OsRng), + )) + .expect("error establishing TLS connection"); log::info!("Established"); tls.write(b"ping").expect("error writing data"); @@ -384,9 +395,11 @@ fn test_blocking_ping_nocopy_bufread() { &mut read_record_buffer, &mut write_record_buffer, ); - - tls.open::(TlsContext::new(&config, &mut OsRng)) - .expect("error establishing TLS connection"); + tls.open(TlsContext::new( + &config, + UnsecureProvider::new::(OsRng), + )) + .expect("error establishing TLS connection"); log::info!("Established"); tls.write(b"ping").expect("error writing data"); diff --git a/tests/data/client-cert.pem b/tests/data/client-cert.pem new file mode 100644 index 00000000..cb04a1a3 --- /dev/null +++ b/tests/data/client-cert.pem @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIBzDCCAXGgAwIBAgIUVB+wKMT9vfrrgAOVt5qON8J8onMwCgYIKoZIzj0EAwIw +QjELMAkGA1UEBhMCWFgxFTATBgNVBAcMDERlZmF1bHQgQ2l0eTEcMBoGA1UECgwT +RGVmYXVsdCBDb21wYW55IEx0ZDAeFw0yNDAyMDkwOTI3NDlaFw0yNDAzMTAwOTI3 +NDlaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQK +DBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwWTATBgcqhkjOPQIBBggqhkjOPQMB +BwNCAAQzXKrX05qlw3NP1k6+kSiTnmI6Mo3ffT6VY71oPQIcqYiD1+hY7tIkk9kV +ke11ZNdGZR0r/o+4TzYJcxcgkNhLo0IwQDAdBgNVHQ4EFgQUBH7ViSdnDzmkYtsO +/f+BpHjeJHcwHwYDVR0jBBgwFoAU7HQ64pisg1MasN9wSLE/LC6PcjowCgYIKoZI +zj0EAwIDSQAwRgIhAONbHGkd+/wpgELOk/az5ELfrB7YO2o4a6Uix5KQOnARAiEA +tDGyTnCEmHjB/GGsLwLa8DRplNXFESDH2erfhutw8ME= +-----END CERTIFICATE----- diff --git a/tests/data/client-key.pem b/tests/data/client-key.pem new file mode 100644 index 00000000..3f0ba5af --- /dev/null +++ b/tests/data/client-key.pem @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIIMoxSnX9BbbgLSGk2rVi0o+NLwzisbbfce/pLGkHwvooAoGCCqGSM49 +AwEHoUQDQgAEM1yq19OapcNzT9ZOvpEok55iOjKN330+lWO9aD0CHKmIg9foWO7S +JJPZFZHtdWTXRmUdK/6PuE82CXMXIJDYSw== +-----END EC PRIVATE KEY----- diff --git a/tests/early_data_test.rs b/tests/early_data_test.rs index a7aa5b9b..647f5627 100644 --- a/tests/early_data_test.rs +++ b/tests/early_data_test.rs @@ -68,14 +68,17 @@ fn early_data_ignored() { .with_ca(Certificate::X509(&der[..])) .with_server_name("localhost"); - let mut tls: TlsConnection, Aes128GcmSha256> = TlsConnection::new( + let mut tls = TlsConnection::new( FromStd::new(stream), &mut read_record_buffer, &mut write_record_buffer, ); - tls.open::(TlsContext::new(&config, &mut OsRng)) - .expect("error establishing TLS connection"); + tls.open(TlsContext::new( + &config, + UnsecureProvider::new::(OsRng), + )) + .expect("error establishing TLS connection"); tls.write_all(b"ping").expect("Failed to write data"); tls.flush().expect("Failed to flush"); diff --git a/tests/psk_test.rs b/tests/psk_test.rs index 74510aec..d82bd462 100644 --- a/tests/psk_test.rs +++ b/tests/psk_test.rs @@ -75,15 +75,17 @@ async fn test_psk_open() { .with_psk(&[0xaa, 0xbb, 0xcc, 0xdd], &[b"vader"]) .with_server_name("localhost"); - let mut tls: TlsConnection, Aes128GcmSha256> = TlsConnection::new( + let mut tls = TlsConnection::new( FromTokio::new(stream), &mut read_record_buffer, &mut write_record_buffer, ); - let mut rng = OsRng; assert!(tls - .open::(TlsContext::new(&config, &mut rng)) + .open(TlsContext::new( + &config, + UnsecureProvider::new::(OsRng) + )) .await .is_ok()); println!("TLS session opened"); diff --git a/tests/split_test.rs b/tests/split_test.rs index 69d8d4a3..866442e5 100644 --- a/tests/split_test.rs +++ b/tests/split_test.rs @@ -78,14 +78,17 @@ fn test_blocking_borrowed() { .with_ca(Certificate::X509(&der[..])) .with_server_name("localhost"); - let mut tls: TlsConnection, Aes128GcmSha256> = TlsConnection::new( + let mut tls = TlsConnection::new( Clonable(Arc::new(stream)), &mut read_record_buffer, &mut write_record_buffer, ); - tls.open::(TlsContext::new(&config, &mut OsRng)) - .expect("error establishing TLS connection"); + tls.open(TlsContext::new( + &config, + UnsecureProvider::new::(OsRng), + )) + .expect("error establishing TLS connection"); let mut state = SplitConnectionState::default(); let (mut reader, mut writer) = tls.split_with(&mut state); @@ -126,14 +129,17 @@ fn test_blocking_managed() { .with_ca(Certificate::X509(&der[..])) .with_server_name("localhost"); - let mut tls: TlsConnection, Aes128GcmSha256> = TlsConnection::new( + let mut tls = TlsConnection::new( Clonable(Arc::new(stream)), &mut read_record_buffer, &mut write_record_buffer, ); - tls.open::(TlsContext::new(&config, &mut OsRng)) - .expect("error establishing TLS connection"); + tls.open(TlsContext::new( + &config, + UnsecureProvider::new::(OsRng), + )) + .expect("error establishing TLS connection"); let (mut reader, mut writer) = tls.split(); diff --git a/tests/tlsserver.rs b/tests/tlsserver.rs index ad6f4973..46510fb0 100644 --- a/tests/tlsserver.rs +++ b/tests/tlsserver.rs @@ -342,6 +342,7 @@ pub fn load_private_key(filename: &PathBuf) -> rustls::PrivateKey { match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") { Some(rustls_pemfile::Item::RSAKey(key)) => return rustls::PrivateKey(key), Some(rustls_pemfile::Item::PKCS8Key(key)) => return rustls::PrivateKey(key), + Some(rustls_pemfile::Item::ECKey(key)) => return rustls::PrivateKey(key), None => break, _ => {} }