Skip to content

Commit

Permalink
feat: add Limited body (#61)
Browse files Browse the repository at this point in the history
* feat: add `Limited` body

* fix: correct size_hint, remove const generic

* chore: use boxed error, pin project

Co-authored-by: Programatik <[email protected]>

Co-authored-by: Programatik <[email protected]>
  • Loading branch information
neoeinstein and programatik29 authored May 20, 2022
1 parent 730e9bd commit e17465c
Show file tree
Hide file tree
Showing 2 changed files with 301 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
mod empty;
mod full;
mod limited;
mod next;
mod size_hint;

pub mod combinators;

pub use self::empty::Empty;
pub use self::full::Full;
pub use self::limited::{LengthLimitError, Limited};
pub use self::next::{Data, Trailers};
pub use self::size_hint::SizeHint;

Expand Down
299 changes: 299 additions & 0 deletions src/limited.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
use crate::{Body, SizeHint};
use bytes::Buf;
use http::HeaderMap;
use pin_project_lite::pin_project;
use std::error::Error;
use std::fmt;
use std::pin::Pin;
use std::task::{Context, Poll};

pin_project! {
/// A length limited body.
///
/// This body will return an error if more than the configured number
/// of bytes are returned on polling the wrapped body.
#[derive(Clone, Copy, Debug)]
pub struct Limited<B> {
remaining: usize,
#[pin]
inner: B,
}
}

impl<B> Limited<B> {
/// Create a new `Limited`.
pub fn new(inner: B, limit: usize) -> Self {
Self {
remaining: limit,
inner,
}
}
}

impl<B> Body for Limited<B>
where
B: Body,
B::Error: Into<Box<dyn Error + Send + Sync>>,
{
type Data = B::Data;
type Error = Box<dyn Error + Send + Sync>;

fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
let this = self.project();
let res = match this.inner.poll_data(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => None,
Poll::Ready(Some(Ok(data))) => {
if data.remaining() > *this.remaining {
*this.remaining = 0;
Some(Err(LengthLimitError.into()))
} else {
*this.remaining -= data.remaining();
Some(Ok(data))
}
}
Poll::Ready(Some(Err(err))) => Some(Err(err.into())),
};

Poll::Ready(res)
}

fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
let this = self.project();
let res = match this.inner.poll_trailers(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(data)) => Ok(data),
Poll::Ready(Err(err)) => Err(err.into()),
};

Poll::Ready(res)
}

fn is_end_stream(&self) -> bool {
self.inner.is_end_stream()
}

fn size_hint(&self) -> SizeHint {
use std::convert::TryFrom;
match u64::try_from(self.remaining) {
Ok(n) => {
let mut hint = self.inner.size_hint();
if hint.lower() >= n {
hint.set_exact(n)
} else if let Some(max) = hint.upper() {
hint.set_upper(n.min(max))
} else {
hint.set_upper(n)
}
hint
}
Err(_) => self.inner.size_hint(),
}
}
}

/// An error returned when body length exceeds the configured limit.
#[derive(Debug)]
#[non_exhaustive]
pub struct LengthLimitError;

impl fmt::Display for LengthLimitError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("length limit exceeded")
}
}

impl Error for LengthLimitError {}

#[cfg(test)]
mod tests {
use super::*;
use crate::Full;
use bytes::Bytes;
use std::convert::Infallible;

#[tokio::test]
async fn read_for_body_under_limit_returns_data() {
const DATA: &[u8] = b"testing";
let inner = Full::new(Bytes::from(DATA));
let body = &mut Limited::new(inner, 8);

let mut hint = SizeHint::new();
hint.set_upper(7);
assert_eq!(body.size_hint().upper(), hint.upper());

let data = body.data().await.unwrap().unwrap();
assert_eq!(data, DATA);
hint.set_upper(0);
assert_eq!(body.size_hint().upper(), hint.upper());

assert!(matches!(body.data().await, None));
}

#[tokio::test]
async fn read_for_body_over_limit_returns_error() {
const DATA: &[u8] = b"testing a string that is too long";
let inner = Full::new(Bytes::from(DATA));
let body = &mut Limited::new(inner, 8);

let mut hint = SizeHint::new();
hint.set_upper(8);
assert_eq!(body.size_hint().upper(), hint.upper());

let error = body.data().await.unwrap().unwrap_err();
assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
}

struct Chunky(&'static [&'static [u8]]);

impl Body for Chunky {
type Data = &'static [u8];
type Error = Infallible;

fn poll_data(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
let mut this = self;
match this.0.split_first().map(|(&head, tail)| (Ok(head), tail)) {
Some((data, new_tail)) => {
this.0 = new_tail;

Poll::Ready(Some(data))
}
None => Poll::Ready(None),
}
}

fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
Poll::Ready(Ok(Some(HeaderMap::new())))
}
}

#[tokio::test]
async fn read_for_chunked_body_around_limit_returns_first_chunk_but_returns_error_on_over_limit_chunk(
) {
const DATA: &[&[u8]] = &[b"testing ", b"a string that is too long"];
let inner = Chunky(DATA);
let body = &mut Limited::new(inner, 8);

let mut hint = SizeHint::new();
hint.set_upper(8);
assert_eq!(body.size_hint().upper(), hint.upper());

let data = body.data().await.unwrap().unwrap();
assert_eq!(data, DATA[0]);
hint.set_upper(0);
assert_eq!(body.size_hint().upper(), hint.upper());

let error = body.data().await.unwrap().unwrap_err();
assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
}

#[tokio::test]
async fn read_for_chunked_body_over_limit_on_first_chunk_returns_error() {
const DATA: &[&[u8]] = &[b"testing a string", b" that is too long"];
let inner = Chunky(DATA);
let body = &mut Limited::new(inner, 8);

let mut hint = SizeHint::new();
hint.set_upper(8);
assert_eq!(body.size_hint().upper(), hint.upper());

let error = body.data().await.unwrap().unwrap_err();
assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
}

#[tokio::test]
async fn read_for_chunked_body_under_limit_is_okay() {
const DATA: &[&[u8]] = &[b"test", b"ing!"];
let inner = Chunky(DATA);
let body = &mut Limited::new(inner, 8);

let mut hint = SizeHint::new();
hint.set_upper(8);
assert_eq!(body.size_hint().upper(), hint.upper());

let data = body.data().await.unwrap().unwrap();
assert_eq!(data, DATA[0]);
hint.set_upper(4);
assert_eq!(body.size_hint().upper(), hint.upper());

let data = body.data().await.unwrap().unwrap();
assert_eq!(data, DATA[1]);
hint.set_upper(0);
assert_eq!(body.size_hint().upper(), hint.upper());

assert!(matches!(body.data().await, None));
}

#[tokio::test]
async fn read_for_trailers_propagates_inner_trailers() {
const DATA: &[&[u8]] = &[b"test", b"ing!"];
let inner = Chunky(DATA);
let body = &mut Limited::new(inner, 8);
let trailers = body.trailers().await.unwrap();
assert_eq!(trailers, Some(HeaderMap::new()))
}

#[derive(Debug)]
enum ErrorBodyError {
Data,
Trailers,
}

impl fmt::Display for ErrorBodyError {
fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
Ok(())
}
}

impl Error for ErrorBodyError {}

struct ErrorBody;

impl Body for ErrorBody {
type Data = &'static [u8];
type Error = ErrorBodyError;

fn poll_data(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
Poll::Ready(Some(Err(ErrorBodyError::Data)))
}

fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
Poll::Ready(Err(ErrorBodyError::Trailers))
}
}

#[tokio::test]
async fn read_for_body_returning_error_propagates_error() {
let body = &mut Limited::new(ErrorBody, 8);
let error = body.data().await.unwrap().unwrap_err();
assert!(matches!(error.downcast_ref(), Some(ErrorBodyError::Data)));
}

#[tokio::test]
async fn trailers_for_body_returning_error_propagates_error() {
let body = &mut Limited::new(ErrorBody, 8);
let error = body.trailers().await.unwrap_err();
assert!(matches!(
error.downcast_ref(),
Some(ErrorBodyError::Trailers)
));
}
}

0 comments on commit e17465c

Please sign in to comment.