Skip to content

Commit

Permalink
feat(esp-mbedtls): Add initial support for using esp-mbedtls in the c…
Browse files Browse the repository at this point in the history
…lient instead of embedded_tls

`esp-mbedtls` requires a specific arch to be passed.
Enable the feature using the following currently supported arch:
 - esp32
 - esp32c3
 - esp32s3
 - esp32s2
  • Loading branch information
AnthonyGrondin committed Feb 26, 2024
1 parent d93adfd commit 5ea4bbb
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 11 deletions.
8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ defmt = { version = "0.3", optional = true }
embedded-tls = { version = "0.17", default-features = false, optional = true }
rand_chacha = { version = "0.3", default-features = false }
nourl = "0.1.1"
esp-mbedtls = { git = "https://github.com/AnthonyGrondin/esp-mbedtls.git", features = ["async"], optional = true }

[dev-dependencies]
hyper = { version = "0.14.23", features = ["full"] }
Expand All @@ -50,3 +51,10 @@ defmt = [
"embedded-tls?/defmt",
"nourl/defmt",
]

# For esp32, re-export those features as required by Cargo.
# This will also automatically enable `esp-mbedtls`.
esp32 = ["esp-mbedtls/esp32"]
esp32c3 = ["esp-mbedtls/esp32c3"]
esp32s2 = ["esp-mbedtls/esp32s2"]
esp32s3 = ["esp-mbedtls/esp32s3"]
70 changes: 59 additions & 11 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,23 @@ where
{
client: &'a T,
dns: &'a D,
#[cfg(feature = "embedded-tls")]
#[cfg(any(feature = "embedded-tls", feature = "esp-mbedtls"))]
tls: Option<TlsConfig<'a>>,
}

/// Type for TLS configuration of HTTP client.
#[cfg(feature = "esp-mbedtls")]
pub struct TlsConfig<'a> {
/// Minimum TLS version for the connection
version: crate::TlsVersion,

/// Client certificates. See [esp_mbedtls::Certificates]
certificates: crate::Certificates<'a>,

/// Will use hardware acceleration on the ESP32 if it contains the RSA peripheral.
rsa: Option<&'a mut esp_mbedtls::Rsa<'a>>,
}

/// Type for TLS configuration of HTTP client.
#[cfg(feature = "embedded-tls")]
pub struct TlsConfig<'a> {
Expand Down Expand Up @@ -54,6 +67,21 @@ impl<'a> TlsConfig<'a> {
}
}

#[cfg(feature = "esp-mbedtls")]
impl<'a> TlsConfig<'a> {
pub fn new(
version: crate::TlsVersion,
certificates: crate::Certificates<'a>,
rsa: Option<&'a mut esp_mbedtls::Rsa<'a>>,
) -> Self {
Self {
version,
certificates,
rsa,
}
}
}

impl<'a, T, D> HttpClient<'a, T, D>
where
T: TcpConnect + 'a,
Expand All @@ -64,13 +92,13 @@ where
Self {
client,
dns,
#[cfg(feature = "embedded-tls")]
#[cfg(any(feature = "embedded-tls", feature = "esp-mbedtls"))]
tls: None,
}
}

/// Create a new HTTP client for a given connection handle and a target host.
#[cfg(feature = "embedded-tls")]
#[cfg(any(feature = "embedded-tls", feature = "esp-mbedtls"))]
pub fn new_with_tls(client: &'a T, dns: &'a D, tls: TlsConfig<'a>) -> Self {
Self {
client,
Expand Down Expand Up @@ -99,6 +127,24 @@ where
.map_err(|e| e.kind())?;

if url.scheme() == UrlScheme::HTTPS {
#[cfg(feature = "esp-mbedtls")]
if let Some(tls) = self.tls.as_mut() {
let session = esp_mbedtls::asynch::Session::new(
conn,
host,
esp_mbedtls::Mode::Client,
tls.version,
tls.certificates,
// Create a inner Some(&mut Rsa) because Rsa doesn't implement Copy
tls.rsa.as_mut().map(|inner| inner as &mut esp_mbedtls::Rsa),
)?
.connect()
.await?;
Ok(HttpConnection::Tls(session))
} else {
Ok(HttpConnection::Plain(conn))
}

#[cfg(feature = "embedded-tls")]
if let Some(tls) = self.tls.as_mut() {
use embedded_tls::{TlsConfig, TlsContext};
Expand All @@ -118,7 +164,7 @@ where
} else {
Ok(HttpConnection::Plain(conn))
}
#[cfg(not(feature = "embedded-tls"))]
#[cfg(all(not(feature = "embedded-tls"), not(feature = "esp-mbedtls")))]
Err(Error::InvalidUrl(nourl::Error::UnsupportedScheme))
} else {
#[cfg(feature = "embedded-tls")]
Expand Down Expand Up @@ -172,9 +218,11 @@ where
{
Plain(C),
PlainBuffered(BufferedWrite<'conn, C>),
#[cfg(feature = "esp-mbedtls")]
Tls(esp_mbedtls::asynch::AsyncConnectedSession<C, 4096>),
#[cfg(feature = "embedded-tls")]
Tls(embedded_tls::TlsConnection<'conn, C, embedded_tls::Aes128GcmSha256>),
#[cfg(not(feature = "embedded-tls"))]
#[cfg(all(not(feature = "embedded-tls"), not(feature = "esp-mbedtls")))]
Tls((&'conn mut (), core::convert::Infallible)), // Variant is impossible to create, but we need it to avoid "unused lifetime" warning
}

Expand Down Expand Up @@ -255,9 +303,9 @@ where
match self {
Self::Plain(conn) => conn.read(buf).await.map_err(|e| e.kind()),
Self::PlainBuffered(conn) => conn.read(buf).await.map_err(|e| e.kind()),
#[cfg(feature = "embedded-tls")]
#[cfg(any(feature = "embedded-tls", feature = "esp-mbedtls"))]
Self::Tls(conn) => conn.read(buf).await.map_err(|e| e.kind()),
#[cfg(not(feature = "embedded-tls"))]
#[cfg(not(any(feature = "embedded-tls", feature = "esp-mbedtls")))]
_ => unreachable!(),
}
}
Expand All @@ -271,9 +319,9 @@ where
match self {
Self::Plain(conn) => conn.write(buf).await.map_err(|e| e.kind()),
Self::PlainBuffered(conn) => conn.write(buf).await.map_err(|e| e.kind()),
#[cfg(feature = "embedded-tls")]
#[cfg(any(feature = "embedded-tls", feature = "esp-mbedtls"))]
Self::Tls(conn) => conn.write(buf).await.map_err(|e| e.kind()),
#[cfg(not(feature = "embedded-tls"))]
#[cfg(not(any(feature = "embedded-tls", feature = "esp-mbedtls")))]
_ => unreachable!(),
}
}
Expand All @@ -282,9 +330,9 @@ where
match self {
Self::Plain(conn) => conn.flush().await.map_err(|e| e.kind()),
Self::PlainBuffered(conn) => conn.flush().await.map_err(|e| e.kind()),
#[cfg(feature = "embedded-tls")]
#[cfg(any(feature = "embedded-tls", feature = "esp-mbedtls"))]
Self::Tls(conn) => conn.flush().await.map_err(|e| e.kind()),
#[cfg(not(feature = "embedded-tls"))]
#[cfg(not(any(feature = "embedded-tls", feature = "esp-mbedtls")))]
_ => unreachable!(),
}
}
Expand Down
14 changes: 14 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ pub enum Error {
/// Tls Error
#[cfg(feature = "embedded-tls")]
Tls(embedded_tls::TlsError),
/// Tls Error
#[cfg(feature = "esp-mbedtls")]
Tls(esp_mbedtls::TlsError),
/// The provided buffer is too small
BufferTooSmall,
/// The request is already sent
Expand Down Expand Up @@ -70,6 +73,17 @@ impl From<embedded_tls::TlsError> for Error {
}
}

/// Re-export those members since they're used for [client::TlsConfig].
#[cfg(feature = "esp-mbedtls")]
pub use esp_mbedtls::{Certificates, Rsa, TlsVersion, X509};

#[cfg(feature = "esp-mbedtls")]
impl From<esp_mbedtls::TlsError> for Error {
fn from(e: esp_mbedtls::TlsError) -> Error {
Error::Tls(e)
}
}

impl From<ParseIntError> for Error {
fn from(_: ParseIntError) -> Error {
Error::Codec
Expand Down

0 comments on commit 5ea4bbb

Please sign in to comment.