Skip to content

Commit

Permalink
axum/routing: Merge fallbacks with the rest of the router
Browse files Browse the repository at this point in the history
  • Loading branch information
mladedav committed Jan 7, 2025
1 parent f84105a commit 550562f
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 102 deletions.
82 changes: 48 additions & 34 deletions axum/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
use crate::{
body::{Body, HttpBody},
boxed::BoxedIntoRoute,
extract::MatchedPath,
handler::Handler,
util::try_downcast,
};
Expand All @@ -20,7 +21,8 @@ use std::{
sync::Arc,
task::{Context, Poll},
};
use tower_layer::Layer;
use tower::service_fn;
use tower_layer::{layer_fn, Layer};
use tower_service::Service;

pub mod future;
Expand Down Expand Up @@ -72,8 +74,7 @@ impl<S> Clone for Router<S> {
}

struct RouterInner<S> {
path_router: PathRouter<S, false>,
fallback_router: PathRouter<S, true>,
path_router: PathRouter<S>,
default_fallback: bool,
catch_all_fallback: Fallback<S>,
}
Expand All @@ -91,7 +92,6 @@ impl<S> fmt::Debug for Router<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Router")
.field("path_router", &self.inner.path_router)
.field("fallback_router", &self.inner.fallback_router)
.field("default_fallback", &self.inner.default_fallback)
.field("catch_all_fallback", &self.inner.catch_all_fallback)
.finish()
Expand Down Expand Up @@ -141,7 +141,6 @@ where
Self {
inner: Arc::new(RouterInner {
path_router: Default::default(),
fallback_router: PathRouter::new_fallback(),
default_fallback: true,
catch_all_fallback: Fallback::Default(Route::new(NotFound)),
}),
Expand All @@ -153,7 +152,6 @@ where
Ok(inner) => inner,
Err(arc) => RouterInner {
path_router: arc.path_router.clone(),
fallback_router: arc.fallback_router.clone(),
default_fallback: arc.default_fallback,
catch_all_fallback: arc.catch_all_fallback.clone(),
},
Expand Down Expand Up @@ -207,8 +205,7 @@ where

let RouterInner {
path_router,
fallback_router,
default_fallback,
default_fallback: _,
// we don't need to inherit the catch-all fallback. It is only used for CONNECT
// requests with an empty path. If we were to inherit the catch-all fallback
// it would end up matching `/{path}/*` which doesn't match empty paths.
Expand All @@ -217,10 +214,6 @@ where

tap_inner!(self, mut this => {
panic_on_err!(this.path_router.nest(path, path_router));

if !default_fallback {
panic_on_err!(this.fallback_router.nest(path, fallback_router));
}
})
}

Expand All @@ -247,43 +240,33 @@ where
where
R: Into<Router<S>>,
{
const PANIC_MSG: &str =
"Failed to merge fallbacks. This is a bug in axum. Please file an issue";

let other: Router<S> = other.into();
let RouterInner {
path_router,
fallback_router: mut other_fallback,
default_fallback,
catch_all_fallback,
} = other.into_inner();

map_inner!(self, mut this => {
panic_on_err!(this.path_router.merge(path_router));

match (this.default_fallback, default_fallback) {
// both have the default fallback
// use the one from other
(true, true) => {
this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
}
(true, true) => {}
// this has default fallback, other has a custom fallback
(true, false) => {
this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
this.default_fallback = false;
}
// this has a custom fallback, other has a default
(false, true) => {
let fallback_router = std::mem::take(&mut this.fallback_router);
other_fallback.merge(fallback_router).expect(PANIC_MSG);
this.fallback_router = other_fallback;
}
// both have a custom fallback, not allowed
(false, false) => {
panic!("Cannot merge two `Router`s that both have a fallback")
}
};

panic_on_err!(this.path_router.merge(path_router));

this.catch_all_fallback = this
.catch_all_fallback
.merge(catch_all_fallback)
Expand All @@ -304,7 +287,6 @@ where
{
map_inner!(self, this => RouterInner {
path_router: this.path_router.layer(layer.clone()),
fallback_router: this.fallback_router.layer(layer.clone()),
default_fallback: this.default_fallback,
catch_all_fallback: this.catch_all_fallback.map(|route| route.layer(layer)),
})
Expand All @@ -322,7 +304,6 @@ where
{
map_inner!(self, this => RouterInner {
path_router: this.path_router.route_layer(layer),
fallback_router: this.fallback_router,
default_fallback: this.default_fallback,
catch_all_fallback: this.catch_all_fallback,
})
Expand Down Expand Up @@ -376,8 +357,47 @@ where
}

fn fallback_endpoint(self, endpoint: Endpoint<S>) -> Self {
// TODO make this better, get rid of the `unwrap`s.
// We need the returned `Service` to be `Clone` and the function inside `service_fn` to be
// `FnMut` so instead of just using the owned service, we do this trick with `Option`. We
// know this will be called just once so it's fine. We're doing that so that we avoid one
// clone inside `oneshot_inner` so that the `Router` and subsequently the `State` is not
// cloned too much.
tap_inner!(self, mut this => {
this.fallback_router.set_fallback(endpoint);
_ = this.path_router.route_endpoint(
"/",
endpoint.clone().layer(
layer_fn(
|service: Route| {
let mut service = Some(service);
service_fn(
move |mut request: Request| {
request.extensions_mut().remove::<MatchedPath>();
service.take().unwrap().oneshot_inner_owned(request)
}
)
}
)
)
);

_ = this.path_router.route_endpoint(
FALLBACK_PARAM_PATH,
endpoint.layer(
layer_fn(
|service: Route| {
let mut service = Some(service);
service_fn(
move |mut request: Request| {
request.extensions_mut().remove::<MatchedPath>();
service.take().unwrap().oneshot_inner_owned(request)
}
)
}
)
)
);

this.default_fallback = false;
})
}
Expand All @@ -386,7 +406,6 @@ where
pub fn with_state<S2>(self, state: S) -> Router<S2> {
map_inner!(self, this => RouterInner {
path_router: this.path_router.with_state(state.clone()),
fallback_router: this.fallback_router.with_state(state.clone()),
default_fallback: this.default_fallback,
catch_all_fallback: this.catch_all_fallback.with_state(state),
})
Expand All @@ -398,11 +417,6 @@ where
Err((req, state)) => (req, state),
};

let (req, state) = match self.inner.fallback_router.call_with_state(req, state) {
Ok(future) => return future,
Err((req, state)) => (req, state),
};

self.inner
.catch_all_fallback
.clone()
Expand Down
88 changes: 20 additions & 68 deletions axum/src/routing/path_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,17 @@ use tower_layer::Layer;
use tower_service::Service;

use super::{
future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint,
MethodRouter, Route, RouteId, FALLBACK_PARAM_PATH, NEST_TAIL_PARAM,
future::RouteFuture, strip_prefix::StripPrefix, url_params, Endpoint, MethodRouter, Route,
RouteId, NEST_TAIL_PARAM,
};

pub(super) struct PathRouter<S, const IS_FALLBACK: bool> {
pub(super) struct PathRouter<S> {
routes: HashMap<RouteId, Endpoint<S>>,
node: Arc<Node>,
prev_route_id: RouteId,
v7_checks: bool,
}

impl<S> PathRouter<S, true>
where
S: Clone + Send + Sync + 'static,
{
pub(super) fn new_fallback() -> Self {
let mut this = Self::default();
this.set_fallback(Endpoint::Route(Route::new(NotFound)));
this
}

pub(super) fn set_fallback(&mut self, endpoint: Endpoint<S>) {
self.replace_endpoint("/", endpoint.clone());
self.replace_endpoint(FALLBACK_PARAM_PATH, endpoint);
}
}

fn validate_path(v7_checks: bool, path: &str) -> Result<(), &'static str> {
if path.is_empty() {
return Err("Paths must start with a `/`. Use \"/\" for root routes");
Expand Down Expand Up @@ -72,7 +56,7 @@ fn validate_v07_paths(path: &str) -> Result<(), &'static str> {
.unwrap_or(Ok(()))
}

impl<S, const IS_FALLBACK: bool> PathRouter<S, IS_FALLBACK>
impl<S> PathRouter<S>
where
S: Clone + Send + Sync + 'static,
{
Expand Down Expand Up @@ -159,10 +143,7 @@ where
.map_err(|err| format!("Invalid route {path:?}: {err}"))
}

pub(super) fn merge(
&mut self,
other: PathRouter<S, IS_FALLBACK>,
) -> Result<(), Cow<'static, str>> {
pub(super) fn merge(&mut self, other: PathRouter<S>) -> Result<(), Cow<'static, str>> {
let PathRouter {
routes,
node,
Expand All @@ -179,24 +160,9 @@ where
.get(&id)
.expect("no path for route id. This is a bug in axum. Please file an issue");

if IS_FALLBACK && (&**path == "/" || &**path == FALLBACK_PARAM_PATH) {
// when merging two routers it doesn't matter if you do `a.merge(b)` or
// `b.merge(a)`. This must also be true for fallbacks.
//
// However all fallback routers will have routes for `/` and `/*` so when merging
// we have to ignore the top level fallbacks on one side otherwise we get
// conflicts.
//
// `Router::merge` makes sure that when merging fallbacks `other` always has the
// fallback we want to keep. It panics if both routers have a custom fallback. Thus
// it is always okay to ignore one fallback and `Router::merge` also makes sure the
// one we can ignore is that of `self`.
self.replace_endpoint(path, route);
} else {
match route {
Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
Endpoint::Route(route) => self.route_service(path, route)?,
}
match route {
Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
Endpoint::Route(route) => self.route_service(path, route)?,
}
}

Expand All @@ -206,7 +172,7 @@ where
pub(super) fn nest(
&mut self,
path_to_nest_at: &str,
router: PathRouter<S, IS_FALLBACK>,
router: PathRouter<S>,
) -> Result<(), Cow<'static, str>> {
let prefix = validate_nest_path(self.v7_checks, path_to_nest_at);

Expand Down Expand Up @@ -282,7 +248,7 @@ where
Ok(())
}

pub(super) fn layer<L>(self, layer: L) -> PathRouter<S, IS_FALLBACK>
pub(super) fn layer<L>(self, layer: L) -> PathRouter<S>
where
L: Layer<Route> + Clone + Send + Sync + 'static,
L::Service: Service<Request> + Clone + Send + Sync + 'static,
Expand Down Expand Up @@ -344,7 +310,7 @@ where
!self.routes.is_empty()
}

pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2, IS_FALLBACK> {
pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2> {
let routes = self
.routes
.into_iter()
Expand Down Expand Up @@ -388,14 +354,12 @@ where
Ok(match_) => {
let id = *match_.value;

if !IS_FALLBACK {
#[cfg(feature = "matched-path")]
crate::extract::matched_path::set_matched_path_for_request(
id,
&self.node.route_id_to_path,
&mut parts.extensions,
);
}
#[cfg(feature = "matched-path")]
crate::extract::matched_path::set_matched_path_for_request(
id,
&self.node.route_id_to_path,
&mut parts.extensions,
);

url_params::insert_url_params(&mut parts.extensions, match_.params);

Expand All @@ -418,18 +382,6 @@ where
}
}

pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint<S>) {
match self.node.at(path) {
Ok(match_) => {
let id = *match_.value;
self.routes.insert(id, endpoint);
}
Err(_) => self
.route_endpoint(path, endpoint)
.expect("path wasn't matched so endpoint shouldn't exist"),
}
}

fn next_route_id(&mut self) -> RouteId {
let next_id = self
.prev_route_id
Expand All @@ -441,7 +393,7 @@ where
}
}

impl<S, const IS_FALLBACK: bool> Default for PathRouter<S, IS_FALLBACK> {
impl<S> Default for PathRouter<S> {
fn default() -> Self {
Self {
routes: Default::default(),
Expand All @@ -452,7 +404,7 @@ impl<S, const IS_FALLBACK: bool> Default for PathRouter<S, IS_FALLBACK> {
}
}

impl<S, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, IS_FALLBACK> {
impl<S> fmt::Debug for PathRouter<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PathRouter")
.field("routes", &self.routes)
Expand All @@ -461,7 +413,7 @@ impl<S, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, IS_FALLBACK> {
}
}

impl<S, const IS_FALLBACK: bool> Clone for PathRouter<S, IS_FALLBACK> {
impl<S> Clone for PathRouter<S> {
fn clone(&self) -> Self {
Self {
routes: self.routes.clone(),
Expand Down
Loading

0 comments on commit 550562f

Please sign in to comment.