From 208610dfceadc7cc1c6b90369ac836f289bda0c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1niel=20Buga?= Date: Wed, 18 Oct 2023 10:24:31 +0200 Subject: [PATCH] Implement BufRead for BodyReader --- src/response.rs | 353 +++++++++++++++++++++++++++++++----------------- 1 file changed, 228 insertions(+), 125 deletions(-) diff --git a/src/response.rs b/src/response.rs index c5e9793..2f465d9 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,5 +1,5 @@ use embedded_io::{Error as _, ErrorType}; -use embedded_io_async::Read; +use embedded_io_async::{BufRead, Read}; use heapless::Vec; use crate::headers::{ContentType, KeepAlive, TransferEncoding}; @@ -217,8 +217,7 @@ where }), ReaderHint::Chunked => BodyReader::Chunked(ChunkedBodyReader { raw_body, - chunk_remaining: 0, - empty_chunk_received: false, + chunk_remaining: ChunkState::NoChunk, }), ReaderHint::ToEnd => BodyReader::ToEnd(raw_body), } @@ -228,7 +227,7 @@ where impl<'buf, 'conn, C> ResponseBody<'buf, 'conn, C> where C: Read, - BufferingReader<'buf, &'conn mut C>: Read, + BufferingReader<'buf, &'conn mut C>: BufRead + Read, { /// Read the entire body into the buffer originally provided [`Response::read()`]. /// This requires that this original buffer is large enough to contain the entire body. @@ -288,7 +287,7 @@ pub enum BodyReader { impl BodyReader where - B: Read, + B: BufRead + Read, { /// Read the entire body pub async fn read_to_end(&mut self, buf: &mut [u8]) -> Result { @@ -304,7 +303,7 @@ where let is_done = match self { BodyReader::Empty => true, BodyReader::FixedLength(reader) => reader.remaining == 0, - BodyReader::Chunked(reader) => reader.empty_chunk_received, + BodyReader::Chunked(reader) => reader.chunk_remaining == ChunkState::Empty, BodyReader::ToEnd(_) => true, }; @@ -318,28 +317,26 @@ where async fn discard(&mut self) -> Result { let mut body_len = 0; loop { - let mut trash = [0; 256]; - let len = self.read(&mut trash).await?; - if len == 0 { + let buf = self.fill_buf().await?; + if buf.is_empty() { break; } - body_len += len; + let buf_len = buf.len(); + body_len += buf_len; + self.consume(buf_len); } Ok(body_len) } } -impl ErrorType for BodyReader -where - B: Read, -{ +impl ErrorType for BodyReader { type Error = Error; } impl Read for BodyReader where - B: Read, + B: BufRead + Read, { async fn read(&mut self, buf: &mut [u8]) -> Result { match self { @@ -351,28 +348,107 @@ where } } +impl BufRead for BodyReader +where + B: BufRead + Read, +{ + async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> { + match self { + BodyReader::Empty => Ok(&[]), + BodyReader::FixedLength(reader) => reader.fill_buf().await, + BodyReader::Chunked(reader) => reader.fill_buf().await, + BodyReader::ToEnd(conn) => conn.fill_buf().await.map_err(|e| Error::Network(e.kind())), + } + } + + fn consume(&mut self, amt: usize) { + match self { + BodyReader::Empty => {} + BodyReader::FixedLength(reader) => reader.consume(amt), + BodyReader::Chunked(reader) => reader.consume(amt), + BodyReader::ToEnd(conn) => conn.consume(amt), + } + } +} + /// Fixed length response body reader pub struct FixedLengthBodyReader { raw_body: B, remaining: usize, } -impl ErrorType for FixedLengthBodyReader { +impl ErrorType for FixedLengthBodyReader { type Error = Error; } -impl Read for FixedLengthBodyReader { +impl Read for FixedLengthBodyReader +where + C: BufRead + Read, +{ async fn read(&mut self, buf: &mut [u8]) -> Result { + let loaded = self.fill_buf().await?; + let len = loaded.len().min(buf.len()); + + buf[..len].copy_from_slice(&loaded[..len]); + self.consume(len); + + Ok(len) + } +} + +impl BufRead for FixedLengthBodyReader +where + C: BufRead + Read, +{ + async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> { if self.remaining == 0 { - return Ok(0); + return Ok(&[]); } - let to_read = usize::min(self.remaining, buf.len()); - let len = self.raw_body.read(&mut buf[..to_read]).await.map_err(|e| e.kind())?; - if len > 0 { - self.remaining -= len; - Ok(len) + + let loaded = self + .raw_body + .fill_buf() + .await + .map_err(|e| Error::Network(e.kind())) + .map(|data| &data[..data.len().min(self.remaining)])?; + + if loaded.is_empty() { + return Err(Error::ConnectionClosed); + } + + Ok(loaded) + } + + fn consume(&mut self, amt: usize) { + let amt = amt.min(self.remaining); + self.remaining -= amt; + self.raw_body.consume(amt) + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum ChunkState { + NoChunk, + NotEmpty(u32), + Empty, +} + +impl ChunkState { + fn consume(&mut self, amt: usize) -> usize { + if let ChunkState::NotEmpty(remaining) = self { + let consumed = (amt as u32).min(*remaining); + *remaining -= consumed; + consumed as usize } else { - Err(Error::ConnectionClosed) + 0 + } + } + + fn len(self) -> usize { + if let ChunkState::NotEmpty(len) = self { + len as usize + } else { + 0 } } } @@ -380,11 +456,63 @@ impl Read for FixedLengthBodyReader { /// Chunked response body reader pub struct ChunkedBodyReader { raw_body: B, - chunk_remaining: u32, - empty_chunk_received: bool, + chunk_remaining: ChunkState, } -impl ChunkedBodyReader { +impl ChunkedBodyReader +where + C: BufRead + Read, +{ + async fn read_next_chunk_length(&mut self) -> Result<(), Error> { + let mut header_buf = [0; 8 + 2]; // 32 bit hex + \r + \n + let mut total_read = 0; + + 'read_size: loop { + let buf = self.raw_body.fill_buf().await.map_err(|e| e.kind())?; + for (i, byte) in buf.iter().enumerate() { + if *byte != b'\n' { + header_buf[total_read] = *byte; + total_read += 1; + + if total_read == header_buf.len() { + self.raw_body.consume(i + 1); + return Err(Error::Codec); + } + } else { + self.raw_body.consume(i + 1); + break 'read_size; + } + } + + let consumed = buf.len(); + self.raw_body.consume(consumed); + } + + if header_buf[total_read - 1] != b'\r' { + return Err(Error::Codec); + } + + let hex_digits = total_read - 1; + + // Prepend hex with zeros + let mut hex = [b'0'; 8]; + hex[8 - hex_digits..].copy_from_slice(&header_buf[..hex_digits]); + + let mut bytes = [0; 4]; + hex::decode_to_slice(hex, &mut bytes).map_err(|_| Error::Codec)?; + + let chunk_length = u32::from_be_bytes(bytes); + + debug!("Chunk length: {}", chunk_length); + + self.chunk_remaining = match chunk_length { + 0 => ChunkState::Empty, + other => ChunkState::NotEmpty(other), + }; + + Ok(()) + } + async fn read_chunk_end(&mut self) -> Result<(), Error> { // All chunks are terminated with a \r\n let mut newline_buf = [0; 2]; @@ -397,106 +525,65 @@ impl ChunkedBodyReader { } } -impl ErrorType for ChunkedBodyReader { +impl ErrorType for ChunkedBodyReader { type Error = Error; } -impl Read for ChunkedBodyReader { +impl Read for ChunkedBodyReader +where + C: BufRead + Read, +{ async fn read(&mut self, buf: &mut [u8]) -> Result { - if buf.is_empty() || self.empty_chunk_received { + if buf.is_empty() { return Ok(0); } - if self.chunk_remaining == 0 { - // The current chunk is currently empty, advance into a new chunk... - - let mut header_buf = [0; 8 + 2]; // 32 bit hex + \r + \n - let mut total_read = 0; - - // For now, limit the number of bytes that we can read to avoid reading into a header after the current - let mut max_read = 3; // Single hex digit + \r + \n - loop { - let read = self - .raw_body - .read(&mut header_buf[total_read..max_read]) - .await - .map_err(|e| e.kind())?; - if read == 0 { - return Err(Error::ConnectionClosed); - } - total_read += read; + // If we receive an empty buffer here, the body includes an empty chunk. + // `fill_buf` will return an Err if the connection is closed. + let loaded = self.fill_buf().await?; - // Decode the chunked header - let header_and_body = &header_buf[..total_read]; - if let Some(nl) = header_and_body.iter().position(|x| *x == b'\n') { - let header = &header_and_body[..nl + 1]; - if nl == 0 || header[nl - 1] != b'\r' { - return Err(Error::Codec); - } - let hex_digits = nl - 1; - // Prepend hex with zeros - let mut hex = [b'0'; 8]; - hex[8 - hex_digits..].copy_from_slice(&header[..hex_digits]); - let mut bytes = [0; 4]; - hex::decode_to_slice(hex, &mut bytes).map_err(|_| Error::Codec)?; - self.chunk_remaining = u32::from_be_bytes(bytes); - - if self.chunk_remaining == 0 { - self.empty_chunk_received = true; - } + let len = loaded.len().min(buf.len()); - // Return the excess body bytes read during the header, if any - let excess_body_read = header_and_body.len() - header.len(); - if excess_body_read > 0 { - if excess_body_read > self.chunk_remaining as usize { - // We have read chunk bytes that exceed the size of the chunk - return Err(Error::Codec); - } - - buf[..excess_body_read].copy_from_slice(&header_and_body[header.len()..]); - self.chunk_remaining -= excess_body_read as u32; - return Ok(excess_body_read); - } + buf[..len].copy_from_slice(&loaded[..len]); + self.consume(len); - break; - } + Ok(len) + } +} - if total_read >= 3 { - // At least three bytes were read and a \n was not found - // This means that the chunk length is at least double-digit hex - // which in turn means that it is impossible for another header to - // be present within the 10 bytes header buffer. - // 10 is the length of the max header "ffffffff\r\n". - // For example, 10\r\nXXXXXXYYYYYYYYYY is more than 10 bytes - // - 10\r\n is the header - // - XXXXXX are the excess body 6 bytes that we may read - // - YYYYYYYYYY are the remaining unread chunk bytes. - // However, for reading these excess bytes into the actual chunk payload, - // the user buffer must be large enough to actually contain the excess read bytes. - // A \n was not found, and we can read that + buf.len(). - max_read = core::cmp::min(total_read + 1 + buf.len(), 10); - } - } - } +impl BufRead for ChunkedBodyReader +where + C: BufRead + Read, +{ + async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> { + match self.chunk_remaining { + ChunkState::NoChunk => self.read_next_chunk_length().await?, - if self.empty_chunk_received { - self.read_chunk_end().await?; - Ok(0) - } else { - let max_len = usize::min(self.chunk_remaining as usize, buf.len()); - let len = self.raw_body.read(&mut buf[..max_len]).await.map_err(|e| e.kind())?; - if len == 0 { - return Err(Error::ConnectionClosed); + ChunkState::NotEmpty(0) => { + // The current chunk is currently empty, advance into a new chunk... + self.read_chunk_end().await?; + self.read_next_chunk_length().await?; } - self.chunk_remaining -= len as u32; + ChunkState::NotEmpty(_) => {} - if self.chunk_remaining == 0 { - self.read_chunk_end().await?; - } + ChunkState::Empty => return Ok(&[]), + } - Ok(len) + let remaining = self.chunk_remaining.len(); + + let buf = self.raw_body.fill_buf().await.map_err(|e| Error::Network(e.kind()))?; + if buf.is_empty() { + return Err(Error::ConnectionClosed); } + + let len = buf.len().min(remaining); + Ok(&buf[..len]) + } + + fn consume(&mut self, amt: usize) { + let consumed = self.chunk_remaining.consume(amt); + self.raw_body.consume(consumed); } } @@ -595,13 +682,14 @@ impl From for Status { #[cfg(test)] mod tests { - use embedded_io::ErrorType; - use embedded_io_async::{Read, Write}; + use embedded_io::{ErrorKind, ErrorType}; + use embedded_io_async::{BufRead, Read, Write}; use crate::{ client::HttpConnection, + reader::BufferingReader, request::Method, - response::{ChunkedBodyReader, Response}, + response::{ChunkState, ChunkedBodyReader, Response}, }; struct Buffer { @@ -614,18 +702,33 @@ mod tests { } impl ErrorType for Buffer { - type Error = embedded_io::ErrorKind; + type Error = ErrorKind; } impl Read for Buffer { async fn read(&mut self, buf: &mut [u8]) -> Result { - let len = buf.len().min(self.b.len()); - buf[..len].copy_from_slice(&self.b[..len]); - self.b.drain(..len); + let loaded = self.fill_buf().await?; + + let len = loaded.len().min(buf.len()); + + buf[..len].copy_from_slice(&loaded[..len]); + self.consume(len); + Ok(len) } } + impl BufRead for Buffer { + async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> { + Ok(self.b.as_slice()) + } + + fn consume(&mut self, amt: usize) { + let len = amt.min(self.b.len()); + self.b.drain(..len); + } + } + impl Write for Buffer { async fn write(&mut self, buf: &[u8]) -> Result { self.b.extend_from_slice(buf); @@ -750,11 +853,11 @@ mod tests { #[tokio::test] async fn chunked_body_reader_can_read_with_large_buffer() { - let raw_body = HttpConnection::Plain(Buffer::from(b"1\r\nX\r\n10\r\nYYYYYYYYYYYYYYYY\r\n0\r\n\r\n")); + let mut raw_body = HttpConnection::Plain(Buffer::from(b"1\r\nX\r\n10\r\nYYYYYYYYYYYYYYYY\r\n0\r\n\r\n")); + let mut read_buffer = [0; 128]; let mut reader = ChunkedBodyReader { - raw_body, - chunk_remaining: 0, - empty_chunk_received: false, + raw_body: BufferingReader::new(&mut read_buffer, 0, &mut raw_body), + chunk_remaining: ChunkState::NoChunk, }; let mut body = [0; 17]; @@ -767,11 +870,11 @@ mod tests { #[tokio::test] async fn chunked_body_reader_can_read_with_tiny_buffer() { - let raw_body = HttpConnection::Plain(Buffer::from(b"1\r\nX\r\n10\r\nYYYYYYYYYYYYYYYY\r\n0\r\n\r\n")); + let mut raw_body = HttpConnection::Plain(Buffer::from(b"1\r\nX\r\n10\r\nYYYYYYYYYYYYYYYY\r\n0\r\n\r\n")); + let mut read_buffer = [0; 128]; let mut reader = ChunkedBodyReader { - raw_body, - chunk_remaining: 0, - empty_chunk_received: false, + raw_body: BufferingReader::new(&mut read_buffer, 0, &mut raw_body), + chunk_remaining: ChunkState::NoChunk, }; let mut body = heapless::Vec::::new();