Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade jwt_authorizer to axum 0.7 #44

Merged
merged 9 commits into from
Jan 21, 2024
227 changes: 181 additions & 46 deletions Cargo.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions demo-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ edition = "2021"

[dependencies]
anyhow = "1.0.75"
axum = { version = "0.6.20", features = ["headers"] }
headers = "0.3"
axum = { version = "0.7.1" }
headers = "0.4"
josekit = "0.8.3"
jsonwebtoken = "9.1.0"
once_cell = "1.18.0"
Expand All @@ -17,7 +17,7 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0.47"
tokio = { version = "1.32.0", features = ["full"] }
tower-http = { version = "0.4.3", features = ["trace"] }
tower-http = { version = "0.5.0", features = ["trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
jwt-authorizer = { path = "../jwt-authorizer" }
8 changes: 4 additions & 4 deletions demo-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use jwt_authorizer::{
error::InitError, AuthError, Authorizer, IntoLayer, JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy,
};
use serde::Deserialize;
use std::net::SocketAddr;
use tokio::net::TcpListener;
use tower_http::trace::TraceLayer;
use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
Expand Down Expand Up @@ -62,10 +62,10 @@ async fn main() -> Result<(), InitError> {
.nest("/api", api)
.layer(TraceLayer::new_for_http());

let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::info!("listening on {}", addr);
let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap();
tracing::info!("listening on {:?}", listener.local_addr());

axum::Server::bind(&addr).serve(app.into_make_service()).await.unwrap();
axum::serve(listener, app.into_make_service()).await.unwrap();

Ok(())
}
Expand Down
9 changes: 5 additions & 4 deletions demo-server/src/oidc_provider/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
use jwt_authorizer::{NumericDate, OneOrArray, RegisteredClaims};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::{net::SocketAddr, thread, time::Duration};
use std::{thread, time::Duration};
use tokio::net::TcpListener;

const ISSUER_URI: &str = "http://localhost:3001";

Expand Down Expand Up @@ -171,9 +172,9 @@ pub fn run_server() -> &'static str {
.route("/tokens", get(tokens));

tokio::spawn(async move {
let addr = SocketAddr::from(([127, 0, 0, 1], 3001));
tracing::info!("oidc provider starting on: {}", addr);
axum::Server::bind(&addr).serve(app.into_make_service()).await.unwrap();
let listener = TcpListener::bind("127.0.0.1:3001").await.unwrap();
tracing::info!("oidc provider starting on: {:?}", listener.local_addr());
axum::serve(listener, app.into_make_service()).await.unwrap();
});

thread::sleep(Duration::from_millis(200)); // waiting oidc to start
Expand Down
15 changes: 8 additions & 7 deletions jwt-authorizer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,33 @@ authors = ["cduvray <[email protected]>"]
license = "MIT"
readme = "docs/README.md"
repository = "https://github.com/cduvray/jwt-authorizer"
keywords = ["jwt","axum","authorisation","jwks"]
keywords = ["jwt", "axum", "authorisation", "jwks"]

[dependencies]
axum = { version = "0.6", features = ["headers"] }
axum = { version = "0.7.1" }
chrono = { version = "0.4", optional = true }
futures-util = "0.3"
futures-core = "0.3"
headers = "0.3"
headers = "0.4"
jsonwebtoken = "9.1.0"
http = "0.2"
http = "1.0"
pin-project = "1.0"
reqwest = { version = "0.11", default-features = false, features = ["json"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
tokio = { version = "1.25", features = ["full"] }
tower-http = { version = "0.4", features = ["trace", "auth"] }
tower-http = { version = "0.5.0", features = ["trace", "auth"] }
tower-layer = "0.3"
tower-service = "0.3"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tonic = { version = "0.10", optional = true }
time = { version = "0.3", optional = true }
http-body-util = "0.1.0"

[dev-dependencies]
hyper = { version = "0.14", features = ["full"] }
hyper = { version = "1.0.1", features = ["full"] }
lazy_static = "1.4.0"
prost = "0.12"
tower = { version = "0.4", features = ["util", "buffer"] }
Expand All @@ -53,4 +54,4 @@ chrono = ["dep:chrono"]

[[test]]
name = "tonic"
required-features = [ "tonic" ]
required-features = ["tonic"]
7 changes: 3 additions & 4 deletions jwt-authorizer/docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ JWT authoriser Layer for Axum and Tonic.
# use jwt_authorizer::{AuthError, Authorizer, JwtAuthorizer, JwtClaims, RegisteredClaims, IntoLayer};
# use axum::{routing::get, Router};
# use serde::Deserialize;

# use tokio::net::TcpListener;
# async {

// let's create an authorizer builder from a JWKS Endpoint
Expand All @@ -41,9 +41,8 @@ JWT authoriser Layer for Axum and Tonic.
// Send the protected data to the user
Ok(format!("Welcome: {:?}", user.sub))
}

axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
.serve(app.into_make_service()).await.expect("server failed");
let listener = TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app.into_make_service()).await.expect("server failed");
# };
```

Expand Down
2 changes: 1 addition & 1 deletion jwt-authorizer/src/authorizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pub enum KeySourceType {

impl<C> Authorizer<C>
where
C: DeserializeOwned + Clone + Send + Sync,
C: DeserializeOwned + Clone + Send,
{
pub(crate) async fn build(
key_source_type: KeySourceType,
Expand Down
10 changes: 5 additions & 5 deletions jwt-authorizer/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use axum::{
body::{self, BoxBody, Empty},
body::Body,
http::StatusCode,
response::{IntoResponse, Response},
};
Expand Down Expand Up @@ -64,8 +64,8 @@ pub enum AuthError {
NoAuthorizerLayer(),
}

fn response_wwwauth(status: StatusCode, bearer: &str) -> Response<BoxBody> {
let mut res = Response::new(body::boxed(Empty::new()));
fn response_wwwauth(status: StatusCode, bearer: &str) -> Response<Body> {
let mut res = Response::new(Body::empty());
*res.status_mut() = status;
let h = if bearer.is_empty() {
"Bearer".to_owned()
Expand All @@ -77,8 +77,8 @@ fn response_wwwauth(status: StatusCode, bearer: &str) -> Response<BoxBody> {
res
}

fn response_500() -> Response<BoxBody> {
let mut res = Response::new(body::boxed(Empty::new()));
fn response_500() -> Response<Body> {
let mut res = Response::new(Body::empty());
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;

res
Expand Down
43 changes: 19 additions & 24 deletions jwt-authorizer/src/layer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use axum::http::Request;
use axum::extract::Request;
use futures_core::ready;
use futures_util::future::{self, BoxFuture};
use jsonwebtoken::TokenData;
Expand All @@ -15,28 +15,25 @@ use crate::authorizer::Authorizer;
use crate::AuthError;

/// Trait for authorizing requests.
pub trait Authorize<B> {
type RequestBody;
type Future: Future<Output = Result<Request<Self::RequestBody>, AuthError>>;
pub trait Authorize {
type Future: Future<Output = Result<Request, AuthError>>;

/// Authorize the request.
///
/// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not.
fn authorize(&self, request: Request<B>) -> Self::Future;
fn authorize(&self, request: Request) -> Self::Future;
}

impl<B, S, C> Authorize<B> for AuthorizationService<S, C>
impl<S, C> Authorize for AuthorizationService<S, C>
where
B: Send + Sync + 'static,
C: Clone + DeserializeOwned + Send + Sync + 'static,
{
type RequestBody = B;
type Future = BoxFuture<'static, Result<Request<B>, AuthError>>;
type Future = BoxFuture<'static, Result<Request, AuthError>>;

/// The authorizers are sequentially applied (check_auth) until one of them validates the token.
/// If no authorizer validates the token the request is rejected.
///
fn authorize(&self, mut request: Request<B>) -> Self::Future {
fn authorize(&self, mut request: Request) -> Self::Future {
let tkns_auths: Vec<(String, Arc<Authorizer<C>>)> = self
.auths
.iter()
Expand All @@ -59,6 +56,7 @@ where
Ok(tdata) => {
// Set `token_data` as a request extension so it can be accessed by other
// services down the stack.

request.extensions_mut().insert(tdata);

Ok(request)
Expand Down Expand Up @@ -119,15 +117,15 @@ pub enum JwtSource {
#[derive(Clone)]
pub struct AuthorizationService<S, C>
where
C: Clone + DeserializeOwned + Send + Sync,
C: Clone + DeserializeOwned + Send,
{
pub inner: S,
pub auths: Vec<Arc<Authorizer<C>>>,
}

impl<S, C> AuthorizationService<S, C>
where
C: Clone + DeserializeOwned + Send + Sync,
C: Clone + DeserializeOwned + Send,
{
pub fn get_ref(&self) -> &S {
&self.inner
Expand Down Expand Up @@ -156,22 +154,21 @@ where
}
}

impl<ReqBody, S, C> Service<Request<ReqBody>> for AuthorizationService<S, C>
impl<S, C> Service<Request> for AuthorizationService<S, C>
where
ReqBody: Send + Sync + 'static,
S: Service<Request<ReqBody>> + Clone,
S: Service<Request> + Clone,
S::Response: From<AuthError>,
C: Clone + DeserializeOwned + Send + Sync + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = ResponseFuture<S, ReqBody, C>;
type Future = ResponseFuture<S, C>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
fn call(&mut self, req: Request) -> Self::Future {
let inner = self.inner.clone();
// take the service that was ready
let inner = std::mem::replace(&mut self.inner, inner);
Expand All @@ -187,14 +184,13 @@ where

#[pin_project]
/// Response future for [`AuthorizationService`].
pub struct ResponseFuture<S, ReqBody, C>
pub struct ResponseFuture<S, C>
where
S: Service<Request<ReqBody>>,
ReqBody: Send + Sync + 'static,
S: Service<Request>,
C: Clone + DeserializeOwned + Send + Sync + 'static,
{
#[pin]
state: State<<AuthorizationService<S, C> as Authorize<ReqBody>>::Future, S::Future>,
state: State<<AuthorizationService<S, C> as Authorize>::Future, S::Future>,
service: S,
}

Expand All @@ -210,11 +206,10 @@ enum State<A, SFut> {
},
}

impl<S, ReqBody, C> Future for ResponseFuture<S, ReqBody, C>
impl<S, C> Future for ResponseFuture<S, C>
where
S: Service<Request<ReqBody>>,
S: Service<Request>,
S::Response: From<AuthError>,
ReqBody: Send + Sync + 'static,
C: Clone + DeserializeOwned + Send + Sync,
{
type Output = Result<S::Response, S::Error>;
Expand Down
16 changes: 7 additions & 9 deletions jwt-authorizer/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ use std::{
time::Duration,
};

use axum::{response::Response, routing::get, Json, Router};
use axum::{body::Body, response::Response, routing::get, Json, Router};
use http::{header::AUTHORIZATION, Request, StatusCode};
use hyper::Body;
dsgallups marked this conversation as resolved.
Show resolved Hide resolved
use jwt_authorizer::{IntoLayer, JwtAuthorizer, JwtClaims, Refresh, RefreshStrategy, Validation};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -77,11 +76,8 @@ fn run_jwks_server() -> String {
.route("/jwks", get(jwks));

tokio::spawn(async move {
axum::Server::from_tcp(listener)
.unwrap()
.serve(app.into_make_service())
.await
.unwrap();
let listener: tokio::net::TcpListener = tokio::net::TcpListener::from_std(listener).unwrap();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you should use the listener from the line 69, I think creating a new listener on the same port makes the tests stuck.

Copy link
Contributor Author

@dsgallups dsgallups Dec 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved by making the std listener non-blocking...the integration tests should no longer hang! Much appreciated for figuring this out!

Reference: tokio docs

axum::serve(listener, app.into_make_service()).await.unwrap();
});

url
Expand Down Expand Up @@ -130,7 +126,8 @@ fn init_test() {
}

async fn make_proteced_request(app: &mut Router, bearer: &str) -> Response {
app.ready()
app.as_service()
.ready()
.await
.unwrap()
.call(
Expand All @@ -145,7 +142,8 @@ async fn make_proteced_request(app: &mut Router, bearer: &str) -> Response {
}

async fn make_public_request(app: &mut Router) -> Response {
app.ready()
app.as_service()
.ready()
.await
.unwrap()
.call(Request::builder().uri("/public").body(Body::empty()).unwrap())
Copy link
Owner

@cduvray cduvray Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I think you could call call() directly on as_service() (type RouterAsService).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll squash several of the prior commits and push this change when the tonic upgrade is complete.

Expand Down
Loading
Loading