Skip to content

Commit

Permalink
Pass custom http client through for the KeyStoreManager and jwks disc…
Browse files Browse the repository at this point in the history
…overy
  • Loading branch information
NotNorom committed Jun 10, 2024
1 parent d87b7ef commit b01e7ec
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 29 deletions.
33 changes: 17 additions & 16 deletions jwt-authorizer/src/authorizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ where
refresh: Option<Refresh>,
validation: crate::validation::Validation,
jwt_source: JwtSource,
http_client: Option<Client>,
http_client: Client,
) -> Result<Authorizer<C>, InitError> {
Ok(match key_source_type {
KeySourceType::RSA(path) => {
Expand Down Expand Up @@ -201,7 +201,7 @@ where
}
KeySourceType::Jwks(url) => {
let jwks_url = Url::parse(url.as_str()).map_err(|e| InitError::JwksUrlError(e.to_string()))?;
let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default());
let key_store_manager = KeyStoreManager::new(http_client, jwks_url, refresh.unwrap_or_default());
Authorizer {
key_source: KeySource::KeyStoreSource(key_store_manager),
claims_checker,
Expand All @@ -210,10 +210,10 @@ where
}
}
KeySourceType::Discovery(issuer_url) => {
let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url.as_str(), http_client).await?)
let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url.as_str(), &http_client).await?)
.map_err(|e| InitError::JwksUrlError(e.to_string()))?;

let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default());
let key_store_manager = KeyStoreManager::new(http_client, jwks_url, refresh.unwrap_or_default());
Authorizer {
key_source: KeySource::KeyStoreSource(key_store_manager),
claims_checker,
Expand Down Expand Up @@ -318,6 +318,7 @@ where
mod tests {

use jsonwebtoken::{Algorithm, Header};
use reqwest::Client;
use serde_json::Value;

use crate::{layer::JwtSource, validation::Validation};
Expand All @@ -333,7 +334,7 @@ mod tests {
None,
Validation::new(),
JwtSource::AuthorizationHeader,
None,
Client::default(),
)
.await
.unwrap();
Expand All @@ -359,7 +360,7 @@ mod tests {
None,
Validation::new(),
JwtSource::AuthorizationHeader,
None,
Client::default(),
)
.await
.unwrap();
Expand All @@ -375,7 +376,7 @@ mod tests {
None,
Validation::new(),
JwtSource::AuthorizationHeader,
None,
Client::default(),
)
.await
.unwrap();
Expand All @@ -388,7 +389,7 @@ mod tests {
None,
Validation::new(),
JwtSource::AuthorizationHeader,
None,
Client::default(),
)
.await
.unwrap();
Expand All @@ -401,7 +402,7 @@ mod tests {
None,
Validation::new(),
JwtSource::AuthorizationHeader,
None,
Client::default(),
)
.await
.unwrap();
Expand All @@ -414,7 +415,7 @@ mod tests {
None,
Validation::new(),
JwtSource::AuthorizationHeader,
None,
Client::default(),
)
.await
.unwrap();
Expand All @@ -440,7 +441,7 @@ mod tests {
None,
Validation::new(),
JwtSource::AuthorizationHeader,
None,
Client::default(),
)
.await
.unwrap();
Expand All @@ -453,7 +454,7 @@ mod tests {
None,
Validation::new(),
JwtSource::AuthorizationHeader,
None,
Client::default(),
)
.await
.unwrap();
Expand All @@ -466,7 +467,7 @@ mod tests {
None,
Validation::new(),
JwtSource::AuthorizationHeader,
None,
Client::default(),
)
.await
.unwrap();
Expand All @@ -482,7 +483,7 @@ mod tests {
None,
Validation::new(),
JwtSource::AuthorizationHeader,
None,
Client::default(),
)
.await;
println!("{:?}", a.as_ref().err());
Expand All @@ -497,7 +498,7 @@ mod tests {
None,
Validation::default(),
JwtSource::AuthorizationHeader,
None,
Client::default(),
)
.await;
println!("{:?}", a.as_ref().err());
Expand All @@ -512,7 +513,7 @@ mod tests {
None,
Validation::default(),
JwtSource::AuthorizationHeader,
None,
Client::default(),
)
.await;
println!("{:?}", a.as_ref().err());
Expand Down
4 changes: 2 additions & 2 deletions jwt-authorizer/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ where
self.refresh,
val,
self.jwt_source,
None,
self.http_client.unwrap_or_default(),
)
.await?,
);
Expand All @@ -249,7 +249,7 @@ where
self.refresh,
val,
self.jwt_source,
self.http_client,
self.http_client.unwrap_or_default(),
)
.await
}
Expand Down
22 changes: 14 additions & 8 deletions jwt-authorizer/src/jwks/key_store_manager.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use jsonwebtoken::{jwk::JwkSet, Algorithm};
use reqwest::Url;
use reqwest::{Client, Url};
use std::{
sync::Arc,
time::{Duration, Instant},
Expand Down Expand Up @@ -51,6 +51,7 @@ impl Default for Refresh {

#[derive(Clone)]
pub struct KeyStoreManager {
http_client: Client,
key_url: Url,
/// in case of fail loading (error or key not found), minimal interval
refresh: Refresh,
Expand All @@ -67,8 +68,9 @@ pub struct KeyStore {
}

impl KeyStoreManager {
pub(crate) fn new(key_url: Url, refresh: Refresh) -> KeyStoreManager {
pub(crate) fn new(http_client: Client, key_url: Url, refresh: Refresh) -> KeyStoreManager {
KeyStoreManager {
http_client,
key_url,
refresh,
keystore: Arc::new(Mutex::new(KeyStore {
Expand All @@ -85,7 +87,7 @@ impl KeyStoreManager {
let key = match self.refresh.strategy {
RefreshStrategy::Interval => {
if ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval) {
ks_gard.refresh(&self.key_url, &[]).await?;
ks_gard.refresh(&self.http_client, &self.key_url, &[]).await?;
}
ks_gard.get_key(header)?
}
Expand All @@ -95,7 +97,7 @@ impl KeyStoreManager {
if let Some(jwk) = jwk_opt {
jwk
} else if ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval) {
ks_gard.refresh(&self.key_url, &[("kid", kid)]).await?;
ks_gard.refresh(&self.http_client, &self.key_url, &[("kid", kid)]).await?;
ks_gard.find_kid(kid).ok_or_else(|| AuthError::InvalidKid(kid.to_owned()))?
} else {
return Err(AuthError::InvalidKid(kid.to_owned()));
Expand All @@ -107,6 +109,7 @@ impl KeyStoreManager {
} else if ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval) {
ks_gard
.refresh(
&self.http_client,
&self.key_url,
&[(
"alg",
Expand All @@ -127,7 +130,7 @@ impl KeyStoreManager {
// if jwks endpoint is down for the loading, respect retry_interval
&& ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval)
{
ks_gard.refresh(&self.key_url, &[]).await?;
ks_gard.refresh(&self.http_client, &self.key_url, &[]).await?;
}
ks_gard.get_key(header)?
}
Expand All @@ -151,8 +154,8 @@ impl KeyStore {
}
}

async fn refresh(&mut self, key_url: &Url, qparam: &[(&str, &str)]) -> Result<(), AuthError> {
reqwest::Client::new()
async fn refresh(&mut self, http_client: &Client, key_url: &Url, qparam: &[(&str, &str)]) -> Result<(), AuthError> {
http_client
.get(key_url.as_ref())
.query(qparam)
.send()
Expand Down Expand Up @@ -216,7 +219,7 @@ mod tests {

use jsonwebtoken::Algorithm;
use jsonwebtoken::{jwk::Jwk, Header};
use reqwest::Url;
use reqwest::{Client, Url};
use wiremock::{
matchers::{method, path},
Mock, MockServer, ResponseTemplate,
Expand Down Expand Up @@ -366,6 +369,7 @@ mod tests {
mock_jwks_response_once(&mock_server, JWK_ED01).await;

let ksm = KeyStoreManager::new(
Client::default(),
Url::parse(&mock_server.uri()).unwrap(),
Refresh {
strategy: RefreshStrategy::Interval,
Expand Down Expand Up @@ -413,6 +417,7 @@ mod tests {
mock_jwks_response_once(&mock_server, JWK_ED01).await;

let ksm = KeyStoreManager::new(
Client::default(),
Url::parse(&mock_server.uri()).unwrap(),
Refresh {
strategy: RefreshStrategy::KeyNotFound,
Expand Down Expand Up @@ -472,6 +477,7 @@ mod tests {
mock_jwks_response_once(&mock_server, JWK_ED01).await;

let ksm = KeyStoreManager::new(
Client::default(),
Url::parse(&mock_server.uri()).unwrap(),
Refresh {
strategy: RefreshStrategy::NoRefresh,
Expand Down
4 changes: 1 addition & 3 deletions jwt-authorizer/src/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ fn discovery_url(issuer: &str) -> Result<Url, InitError> {
Ok(url)
}

pub async fn discover_jwks(issuer: &str, client: Option<Client>) -> Result<String, InitError> {
let client = client.unwrap_or_default();

pub async fn discover_jwks(issuer: &str, client: &Client) -> Result<String, InitError> {
client
.get(discovery_url(issuer)?)
.send()
Expand Down

0 comments on commit b01e7ec

Please sign in to comment.