Skip to content

Commit

Permalink
fix: avoid setting content-length before middleware
Browse files Browse the repository at this point in the history
Axum currently sets content-length automatically inside Route. However,
if this occurs before middleware is run then should the middleware add a
body to the request, Axum will avoid overwriting the content-length and
so the user is stuck with an incorrect content-length, leading to
panics in Hyper.
  • Loading branch information
SabrinaJewson committed Sep 1, 2024
1 parent 1ac617a commit cb69cde
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 65 deletions.
8 changes: 3 additions & 5 deletions axum/src/routing/method_routing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1034,13 +1034,11 @@ where
match $svc {
MethodEndpoint::None => {}
MethodEndpoint::Route(route) => {
return RouteFuture::from_future(route.clone().oneshot_inner($req))
.strip_body($method == Method::HEAD);
return route.clone().oneshot_inner($req);
}
MethodEndpoint::BoxedHandler(handler) => {
let route = handler.clone().into_route(state);
return RouteFuture::from_future(route.clone().oneshot_inner($req))
.strip_body($method == Method::HEAD);
let mut route = handler.clone().into_route(state);
return route.oneshot_inner($req);
}
}
}
Expand Down
6 changes: 2 additions & 4 deletions axum/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -658,12 +658,10 @@ where

fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture<E> {
match self {
Fallback::Default(route) | Fallback::Service(route) => {
RouteFuture::from_future(route.oneshot_inner(req))
}
Fallback::Default(route) | Fallback::Service(route) => route.oneshot_inner(req),
Fallback::BoxedHandler(handler) => {
let mut route = handler.clone().into_route(state);
RouteFuture::from_future(route.oneshot_inner(req))
route.oneshot_inner(req)
}
}
}
Expand Down
82 changes: 26 additions & 56 deletions axum/src/routing/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@ use axum_core::{extract::Request, response::IntoResponse};
use bytes::Bytes;
use http::{
header::{self, CONTENT_LENGTH},
HeaderMap, HeaderValue,
HeaderMap, HeaderValue, Method,
};
use pin_project_lite::pin_project;
use std::{
convert::Infallible,
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
task::{ready, Context, Poll},
};
use tower::{
util::{BoxCloneService, MapErrLayer, MapRequestLayer, MapResponseLayer, Oneshot},
util::{BoxCloneService, MapErrLayer, MapResponseLayer, Oneshot},
ServiceExt,
};
use tower_layer::Layer;
Expand All @@ -42,11 +42,9 @@ impl<E> Route<E> {
)))
}

pub(crate) fn oneshot_inner(
&mut self,
req: Request,
) -> Oneshot<BoxCloneService<Request, Response, E>, Request> {
self.0.get_mut().unwrap().clone().oneshot(req)
pub(crate) fn oneshot_inner(&mut self, req: Request) -> RouteFuture<E> {
let method = req.method().clone();
RouteFuture::new(method, self.0.get_mut().unwrap().clone().oneshot(req))
}

pub(crate) fn layer<L, NewError>(self, layer: L) -> Route<NewError>
Expand All @@ -59,7 +57,6 @@ impl<E> Route<E> {
NewError: 'static,
{
let layer = (
MapRequestLayer::new(|req: Request<_>| req.map(Body::new)),
MapErrLayer::new(Into::into),
MapResponseLayer::new(IntoResponse::into_response),
layer,
Expand Down Expand Up @@ -98,55 +95,38 @@ where

#[inline]
fn call(&mut self, req: Request<B>) -> Self::Future {
let req = req.map(Body::new);
RouteFuture::from_future(self.oneshot_inner(req))
self.oneshot_inner(req.map(Body::new)).not_top_level()
}
}

pin_project! {
/// Response future for [`Route`].
pub struct RouteFuture<E> {
#[pin]
kind: RouteFutureKind<E>,
inner: Oneshot<BoxCloneService<Request, Response, E>, Request>,
strip_body: bool,
allow_header: Option<Bytes>,
}
}

pin_project! {
#[project = RouteFutureKindProj]
enum RouteFutureKind<E> {
Future {
#[pin]
future: Oneshot<
BoxCloneService<Request, Response, E>,
Request,
>,
},
Response {
response: Option<Response>,
}
top_level: bool,
}
}

impl<E> RouteFuture<E> {
pub(crate) fn from_future(
future: Oneshot<BoxCloneService<Request, Response, E>, Request>,
) -> Self {
fn new(method: Method, inner: Oneshot<BoxCloneService<Request, Response, E>, Request>) -> Self {
Self {
kind: RouteFutureKind::Future { future },
strip_body: false,
inner,
strip_body: method == Method::HEAD,
allow_header: None,
top_level: true,
}
}

pub(crate) fn strip_body(mut self, strip_body: bool) -> Self {
self.strip_body = strip_body;
pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self {
self.allow_header = Some(allow_header);
self
}

pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self {
self.allow_header = Some(allow_header);
pub(crate) fn not_top_level(mut self) -> Self {
self.top_level = false;
self
}
}
Expand All @@ -157,28 +137,18 @@ impl<E> Future for RouteFuture<E> {
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut res = ready!(this.inner.poll(cx))?;

let mut res = match this.kind.project() {
RouteFutureKindProj::Future { future } => match future.poll(cx) {
Poll::Ready(Ok(res)) => res,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
},
RouteFutureKindProj::Response { response } => {
response.take().expect("future polled after completion")
}
};

set_allow_header(res.headers_mut(), this.allow_header);
if *this.top_level {
set_allow_header(res.headers_mut(), this.allow_header);

// make sure to set content-length before removing the body
set_content_length(res.size_hint(), res.headers_mut());
// make sure to set content-length before removing the body
set_content_length(res.size_hint(), res.headers_mut());

let res = if *this.strip_body {
res.map(|_| Body::empty())
} else {
res
};
if *this.strip_body {
*res.body_mut() = Body::empty();
}
}

Poll::Ready(Ok(res))
}
Expand Down
19 changes: 19 additions & 0 deletions axum/src/routing/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1099,3 +1099,22 @@ async fn locks_mutex_very_little() {
assert_eq!(num, 1);
}
}

#[crate::test]
async fn middleware_adding_body() {
let app = Router::new()
.route("/", get(()))
.layer(MapResponseLayer::new(|mut res: Response| -> Response {
// If there is a content-length header, its value will be zero and Axum will avoid
// overwriting it. But this means our content-length doesn’t match the length of the
// body, which leads to panics in Hyper. Thus we have to ensure that Axum doesn’t add
// on content-length headers until after middleware has been run.
assert!(!res.headers().contains_key("content-length"));
*res.body_mut() = "…".into();
res
}));

let client = TestClient::new(app);
let res = client.get("/").await;
assert_eq!(res.text().await, "…");
}

0 comments on commit cb69cde

Please sign in to comment.