Skip to content

Commit

Permalink
http: don't use bufreader for socket
Browse files Browse the repository at this point in the history
  • Loading branch information
2bc4 committed Feb 17, 2024
1 parent 43006aa commit 1cfa15b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 49 deletions.
55 changes: 28 additions & 27 deletions src/http/decoder.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
use std::io::{self, BufReader, Read};
use std::io::{self, Read};

use anyhow::{bail, Result};
use chunked_transfer::Decoder as ChunkDecoder;
use flate2::read::GzDecoder;
use log::debug;

use super::request::Transport;

enum Encoding<'a> {
Unencoded(&'a mut BufReader<Transport>, u64),
Chunked(ChunkDecoder<&'a mut BufReader<Transport>>),
ChunkedGzip(GzDecoder<ChunkDecoder<&'a mut BufReader<Transport>>>),
Gzip(GzDecoder<&'a mut BufReader<Transport>>),
enum Encoding<T>
where
T: Read,
{
Unencoded(T, u64),
Chunked(ChunkDecoder<T>),
ChunkedGzip(GzDecoder<ChunkDecoder<T>>),
Gzip(GzDecoder<T>),
}

pub struct Decoder<'a> {
kind: Encoding<'a>,
pub struct Decoder<T: Read> {
kind: Encoding<T>,
consumed: u64,
}

impl Read for Decoder<'_> {
impl<T: Read> Read for Decoder<T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match &mut self.kind {
Encoding::Unencoded(stream, length) => {
let consumed = stream.take(*length - self.consumed).read(buf)?;
Encoding::Unencoded(reader, length) => {
let consumed = reader.take(*length - self.consumed).read(buf)?;
self.consumed += consumed as u64;

Ok(consumed)
Expand All @@ -43,8 +44,8 @@ impl Read for Decoder<'_> {
}
}

impl<'a> Decoder<'a> {
pub fn new(stream: &'a mut BufReader<Transport>, headers: &str) -> Result<Decoder<'a>> {
impl<T: Read> Decoder<T> {
pub fn new(reader: T, headers: &str) -> Result<Decoder<T>> {
let headers = headers.to_lowercase();
let content_length = headers
.lines()
Expand All @@ -58,35 +59,35 @@ impl<'a> Decoder<'a> {
(true, true) => {
debug!("Body is chunked and gzipped");

return Ok(Self {
kind: Encoding::ChunkedGzip(GzDecoder::new(ChunkDecoder::new(stream))),
Ok(Self {
kind: Encoding::ChunkedGzip(GzDecoder::new(ChunkDecoder::new(reader))),
consumed: u64::default(),
});
})
}
(true, false) => {
debug!("Body is chunked");

return Ok(Self {
kind: Encoding::Chunked(ChunkDecoder::new(stream)),
Ok(Self {
kind: Encoding::Chunked(ChunkDecoder::new(reader)),
consumed: u64::default(),
});
})
}
(false, true) => {
debug!("Body is gzipped");

return Ok(Self {
kind: Encoding::Gzip(GzDecoder::new(stream)),
Ok(Self {
kind: Encoding::Gzip(GzDecoder::new(reader)),
consumed: u64::default(),
});
})
}
_ => match content_length {
Some(length) => {
debug!("Content length: {length}");

return Ok(Self {
kind: Encoding::Unencoded(stream, length),
Ok(Self {
kind: Encoding::Unencoded(reader, length),
consumed: u64::default(),
});
})
}
_ => bail!("Could not resolve encoding of HTTP response"),
},
Expand Down
48 changes: 26 additions & 22 deletions src/http/request.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{
io::{
self, BufRead, BufReader,
self,
ErrorKind::{InvalidInput, Other, UnexpectedEof},
Read, Write,
},
Expand Down Expand Up @@ -169,7 +169,7 @@ struct Request<T>
where
T: Write,
{
stream: BufReader<Transport>,
stream: Transport,
handler: Handler<T>,
raw: String,

Expand All @@ -184,7 +184,7 @@ where
impl<T: Write> Request<T> {
fn new(writer: T, method: Method, url: Url, data: String, agent: Agent) -> Result<Self> {
let mut request = Self {
stream: BufReader::new(Transport::new(&url, agent.clone())?),
stream: Transport::new(&url, agent.clone())?,
handler: Handler::new(writer),
raw: String::default(),

Expand Down Expand Up @@ -265,27 +265,31 @@ impl<T: Write> Request<T> {
}

fn do_request(&mut self) -> Result<()> {
//Will break if server sends more than this in headers, but protects against OOM
const MAX_HEADERS_SIZE: usize = 2048;
//Read only \r\n
const HEADERS_END_SIZE: usize = 2;
const BUF_SIZE: usize = 2048;

debug!("Request:\n{}", self.raw);
self.stream.get_mut().write_all(self.raw.as_bytes())?;

let mut headers = String::default();
let mut consumed = 0;
while consumed != HEADERS_END_SIZE {
if self.stream.fill_buf()?.is_empty() {
self.stream.write_all(self.raw.as_bytes())?;
self.stream.flush()?;

//Read into buf and search for the header terminator string,
//then split buf there and feed remaining half into decoder
let mut buf = [0u8; BUF_SIZE];
let mut written = 0;
let (headers, remaining) = loop {
let consumed = self.stream.read(&mut buf[written..])?;
if consumed == 0 {
return Err(io::Error::from(UnexpectedEof).into());
}

consumed = self
.stream
.by_ref()
.take(MAX_HEADERS_SIZE as u64)
.read_line(&mut headers)?;
}
written += consumed;

if let Some(mut headers_end) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
headers_end += 4; //pass \r\n\r\n
break (
String::from_utf8_lossy(&buf[..headers_end]),
&buf[headers_end..written],
);
}
};
debug!("Response:\n{headers}");

let code = headers
Expand All @@ -302,12 +306,12 @@ impl<T: Write> Request<T> {
}

match io::copy(
&mut Decoder::new(&mut self.stream, &headers)?,
&mut Decoder::new(remaining.chain(&mut self.stream), &headers)?,
&mut self.handler,
) {
Ok(_) => Ok(()),
//Chunk decoder returns InvalidInput on some segment servers, can be ignored
Err(e) if matches!(e.kind(), InvalidInput) => Ok(()),
Err(e) if e.kind() == InvalidInput => Ok(()),
Err(e) => Err(e.into()),
}
}
Expand Down

0 comments on commit 1cfa15b

Please sign in to comment.