diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index 02d8a4a9ec..ac494b9387 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -40,6 +40,7 @@ const MAX_BUF_LIST_BUFFERS: usize = 16; pub(crate) struct Buffered { flush_pipeline: bool, io: T, + partial_len: Option, read_blocked: bool, read_buf: BytesMut, read_buf_strategy: ReadStrategy, @@ -73,6 +74,7 @@ where Buffered { flush_pipeline: false, io, + partial_len: None, read_blocked: false, read_buf: BytesMut::with_capacity(0), read_buf_strategy: ReadStrategy::default(), @@ -184,6 +186,7 @@ where loop { match super::role::parse_headers::( &mut self.read_buf, + self.partial_len, ParseContext { cached_headers: parse_ctx.cached_headers, req_method: parse_ctx.req_method, @@ -220,11 +223,13 @@ where .reset(Instant::now() + Duration::from_secs(30 * 24 * 60 * 60)); } } + self.partial_len = None; return Poll::Ready(Ok(msg)); } None => { let max = self.read_buf_strategy.max(); - if self.read_buf.len() >= max { + let curr_len = self.read_buf.len(); + if curr_len >= max { debug!("max_buf_size ({}) reached, closing", max); return Poll::Ready(Err(crate::Error::new_too_large())); } @@ -242,6 +247,9 @@ where } } } + if curr_len > 0 { + self.partial_len = Some(curr_len); + } } } if ready!(self.poll_read_from_io(cx)).map_err(crate::Error::new_io)? == 0 { diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index 7a4544d989..1c00d7445d 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -62,6 +62,7 @@ macro_rules! maybe_panic { pub(super) fn parse_headers( bytes: &mut BytesMut, + prev_len: Option, ctx: ParseContext<'_>, ) -> ParseResult where @@ -97,9 +98,37 @@ where let span = trace_span!("parse_headers"); let _s = span.enter(); + if let Some(prev_len) = prev_len { + if !is_complete_fast(bytes, prev_len) { + return Ok(None); + } + } + T::parse(bytes, ctx) } +/// A fast scan for the end of a message. +/// Used when there was a partial read, to skip full parsing on a +/// a slow connection. +fn is_complete_fast(bytes: &[u8], prev_len: usize) -> bool { + let start = if prev_len < 3 { 0 } else { prev_len - 3 }; + let bytes = &bytes[start..]; + + for (i, b) in bytes.iter().copied().enumerate() { + if b == b'\r' { + if bytes[i + 1..].chunks(3).next() == Some(&b"\n\r\n"[..]) { + return true; + } + } else if b == b'\n' { + if bytes.get(i + 1) == Some(&b'\n') { + return true; + } + } + } + + false +} + pub(super) fn encode_headers( enc: Encode<'_, T::Outgoing>, dst: &mut Vec, @@ -2635,6 +2664,28 @@ mod tests { assert_eq!(parsed.head.headers["server"], "hello\tworld"); } + #[test] + fn test_is_complete_fast() { + let s = b"GET / HTTP/1.1\r\na: b\r\n\r\n"; + for n in 0..s.len() { + assert!(is_complete_fast(s, n), "{:?}; {}", s, n); + } + let s = b"GET / HTTP/1.1\na: b\n\n"; + for n in 0..s.len() { + assert!(is_complete_fast(s, n)); + } + + // Not + let s = b"GET / HTTP/1.1\r\na: b\r\n\r"; + for n in 0..s.len() { + assert!(!is_complete_fast(s, n)); + } + let s = b"GET / HTTP/1.1\na: b\n"; + for n in 0..s.len() { + assert!(!is_complete_fast(s, n)); + } + } + #[test] fn test_write_headers_orig_case_empty_value() { let mut headers = HeaderMap::new();