Skip to content

Commit

Permalink
Pass custom http client through for the KeyStoreManager and jwks disc…
Browse files Browse the repository at this point in the history
…overy
  • Loading branch information
NotNorom committed Jun 10, 2024
1 parent d87b7ef commit 6bd1ef7
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 16 deletions.
8 changes: 4 additions & 4 deletions jwt-authorizer/src/authorizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ where
refresh: Option<Refresh>,
validation: crate::validation::Validation,
jwt_source: JwtSource,
http_client: Option<Client>,
http_client: Client,
) -> Result<Authorizer<C>, InitError> {
Ok(match key_source_type {
KeySourceType::RSA(path) => {
Expand Down Expand Up @@ -201,7 +201,7 @@ where
}
KeySourceType::Jwks(url) => {
let jwks_url = Url::parse(url.as_str()).map_err(|e| InitError::JwksUrlError(e.to_string()))?;
let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default());
let key_store_manager = KeyStoreManager::new(http_client, jwks_url, refresh.unwrap_or_default());
Authorizer {
key_source: KeySource::KeyStoreSource(key_store_manager),
claims_checker,
Expand All @@ -210,10 +210,10 @@ where
}
}
KeySourceType::Discovery(issuer_url) => {
let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url.as_str(), http_client).await?)
let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url.as_str(), &http_client).await?)
.map_err(|e| InitError::JwksUrlError(e.to_string()))?;

let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default());
let key_store_manager = KeyStoreManager::new(http_client, jwks_url, refresh.unwrap_or_default());
Authorizer {
key_source: KeySource::KeyStoreSource(key_store_manager),
claims_checker,
Expand Down
4 changes: 2 additions & 2 deletions jwt-authorizer/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ where
self.refresh,
val,
self.jwt_source,
None,
self.http_client.unwrap_or_default(),
)
.await?,
);
Expand All @@ -249,7 +249,7 @@ where
self.refresh,
val,
self.jwt_source,
self.http_client,
self.http_client.unwrap_or_default(),
)
.await
}
Expand Down
17 changes: 10 additions & 7 deletions jwt-authorizer/src/jwks/key_store_manager.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use jsonwebtoken::{jwk::JwkSet, Algorithm};
use reqwest::Url;
use reqwest::{Client, Url};
use std::{
sync::Arc,
time::{Duration, Instant},
Expand Down Expand Up @@ -51,6 +51,7 @@ impl Default for Refresh {

#[derive(Clone)]
pub struct KeyStoreManager {
http_client: Client,
key_url: Url,
/// in case of fail loading (error or key not found), minimal interval
refresh: Refresh,
Expand All @@ -67,8 +68,9 @@ pub struct KeyStore {
}

impl KeyStoreManager {
pub(crate) fn new(key_url: Url, refresh: Refresh) -> KeyStoreManager {
pub(crate) fn new(http_client: Client, key_url: Url, refresh: Refresh) -> KeyStoreManager {
KeyStoreManager {
http_client,
key_url,
refresh,
keystore: Arc::new(Mutex::new(KeyStore {
Expand All @@ -85,7 +87,7 @@ impl KeyStoreManager {
let key = match self.refresh.strategy {
RefreshStrategy::Interval => {
if ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval) {
ks_gard.refresh(&self.key_url, &[]).await?;
ks_gard.refresh(&self.http_client, &self.key_url, &[]).await?;
}
ks_gard.get_key(header)?
}
Expand All @@ -95,7 +97,7 @@ impl KeyStoreManager {
if let Some(jwk) = jwk_opt {
jwk
} else if ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval) {
ks_gard.refresh(&self.key_url, &[("kid", kid)]).await?;
ks_gard.refresh(&self.http_client, &self.key_url, &[("kid", kid)]).await?;
ks_gard.find_kid(kid).ok_or_else(|| AuthError::InvalidKid(kid.to_owned()))?
} else {
return Err(AuthError::InvalidKid(kid.to_owned()));
Expand All @@ -107,6 +109,7 @@ impl KeyStoreManager {
} else if ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval) {
ks_gard
.refresh(
&self.http_client,
&self.key_url,
&[(
"alg",
Expand All @@ -127,7 +130,7 @@ impl KeyStoreManager {
// if jwks endpoint is down for the loading, respect retry_interval
&& ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval)
{
ks_gard.refresh(&self.key_url, &[]).await?;
ks_gard.refresh(&self.http_client, &self.key_url, &[]).await?;
}
ks_gard.get_key(header)?
}
Expand All @@ -151,8 +154,8 @@ impl KeyStore {
}
}

async fn refresh(&mut self, key_url: &Url, qparam: &[(&str, &str)]) -> Result<(), AuthError> {
reqwest::Client::new()
async fn refresh(&mut self, http_client: &Client, key_url: &Url, qparam: &[(&str, &str)]) -> Result<(), AuthError> {
http_client
.get(key_url.as_ref())
.query(qparam)
.send()
Expand Down
4 changes: 1 addition & 3 deletions jwt-authorizer/src/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ fn discovery_url(issuer: &str) -> Result<Url, InitError> {
Ok(url)
}

pub async fn discover_jwks(issuer: &str, client: Option<Client>) -> Result<String, InitError> {
let client = client.unwrap_or_default();

pub async fn discover_jwks(issuer: &str, client: &Client) -> Result<String, InitError> {
client
.get(discovery_url(issuer)?)
.send()
Expand Down

0 comments on commit 6bd1ef7

Please sign in to comment.