Skip to content

axum-extra: implement FromRequest for Either* #3323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 162 additions & 2 deletions axum-extra/src/either.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! `Either*` types for combining extractors or responses into a single type.
//!
//! # As an extractor
//! # As an `FromRequestParts` extractor
//!
//! ```
//! use axum_extra::either::Either3;
Expand Down Expand Up @@ -54,6 +54,42 @@
//! Note that if all the inner extractors reject the request, the rejection from the last
//! extractor will be returned. For the example above that would be [`BytesRejection`].
//!
//! # As an `FromRequest` extractor
//!
//! In the following example, we can first try to deserialize the payload as JSON, if that fails try
//! to interpret it as a UTF-8 string, and lastly just take the raw bytes.
//!
//! It might be preferable to instead extract `Bytes` directly and then fallibly convert them to
//! `String` and then deserialize the data inside the handler.
//!
//! ```
//! use axum_extra::either::Either3;
//! use axum::{
//! body::Bytes,
//! Json,
//! Router,
//! routing::get,
//! extract::FromRequestParts,
//! };
//!
//! #[derive(serde::Deserialize)]
//! struct Payload {
//! user: String,
//! request_id: u32,
//! }
//!
//! async fn handler(
//! body: Either3<Json<Payload>, String, Bytes>,
//! ) {
//! match body {
//! Either3::E1(json) => { /* ... */ }
//! Either3::E2(string) => { /* ... */ }
//! Either3::E3(bytes) => { /* ... */ }
//! }
//! }
//! #
//! # let _: axum::routing::MethodRouter = axum::routing::get(handler);
//! ```
//! # As a response
//!
//! ```
Expand Down Expand Up @@ -93,9 +129,10 @@
use std::task::{Context, Poll};

use axum::{
extract::FromRequestParts,
extract::{rejection::BytesRejection, FromRequest, FromRequestParts, Request},
response::{IntoResponse, Response},
};
use bytes::Bytes;
use http::request::Parts;
use tower_layer::Layer;
use tower_service::Service;
Expand Down Expand Up @@ -226,6 +263,28 @@ pub enum Either8<E1, E2, E3, E4, E5, E6, E7, E8> {
E8(E8),
}

/// Rejection used for [`Either`], [`Either3`], etc.
///
/// Contains one variant for a case when the whole request could not be loaded and one variant
/// containing the rejection of the last variant if all extractors failed..
#[derive(Debug)]
pub enum EitherRejection<E> {
/// Buffering of the request body failed.
Bytes(BytesRejection),

/// All extractors failed. This contains the error returned by the last extractor.
LastRejection(E),
}

impl<E: IntoResponse> IntoResponse for EitherRejection<E> {
fn into_response(self) -> Response {
match self {
EitherRejection::Bytes(rejection) => rejection.into_response(),
EitherRejection::LastRejection(rejection) => rejection.into_response(),
}
}
}

macro_rules! impl_traits_for_either {
(
$either:ident =>
Expand All @@ -251,6 +310,43 @@ macro_rules! impl_traits_for_either {
}
}

impl<S, $($ident),*, $last> FromRequest<S> for $either<$($ident),*, $last>
where
S: Send + Sync,
$($ident: FromRequest<S>),*,
$last: FromRequest<S>,
$($ident::Rejection: Send),*,
$last::Rejection: IntoResponse + Send,
{
type Rejection = EitherRejection<$last::Rejection>;

async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let (parts, body) = req.into_parts();
let bytes = Bytes::from_request(Request::from_parts(parts.clone(), body), state)
.await
.map_err(EitherRejection::Bytes)?;

$(
let req = Request::from_parts(
parts.clone(),
axum::body::Body::new(http_body_util::Full::new(bytes.clone())),
);
if let Ok(extracted) = $ident::from_request(req, state).await {
return Ok(Self::$ident(extracted));
}
)*

let req = Request::from_parts(
parts.clone(),
axum::body::Body::new(http_body_util::Full::new(bytes.clone())),
);
match $last::from_request(req, state).await {
Ok(extracted) => Ok(Self::$last(extracted)),
Err(error) => Err(EitherRejection::LastRejection(error)),
}
}
}

impl<$($ident),*, $last> IntoResponse for $either<$($ident),*, $last>
where
$($ident: IntoResponse),*,
Expand Down Expand Up @@ -312,3 +408,67 @@ where
}
}
}

#[cfg(test)]
mod tests {
use std::future::Future;

use axum::body::Body;
use axum::extract::rejection::StringRejection;
use axum::extract::{FromRequest, Request, State};
use bytes::Bytes;
use http_body_util::Full;

use super::*;

struct False;

impl<S> FromRequestParts<S> for False {
type Rejection = ();

fn from_request_parts(
_parts: &mut Parts,
_state: &S,
) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
std::future::ready(Err(()))
}
}

#[tokio::test]
async fn either_from_request() {
// The body is by design not valid UTF-8.
let request = Request::new(Body::new(Full::new(Bytes::from_static(&[255]))));

let either = Either4::<String, String, Request, Bytes>::from_request(request, &())
.await
.unwrap();

assert!(matches!(either, Either4::E3(_)));
}

#[tokio::test]
async fn either_from_request_rejection() {
// The body is by design not valid UTF-8.
let request = Request::new(Body::new(Full::new(Bytes::from_static(&[255]))));

let either = Either::<String, String>::from_request(request, &())
.await
.unwrap_err();

assert!(matches!(
either,
EitherRejection::LastRejection(StringRejection::InvalidUtf8(_))
));
}

#[tokio::test]
async fn either_from_request_parts() {
let (mut parts, _) = Request::new(Body::empty()).into_parts();

let either = Either3::<False, False, State<()>>::from_request_parts(&mut parts, &())
.await
.unwrap();

assert!(matches!(either, Either3::E3(State(()))));
}
}
Loading