From 36a420b3fa4f36e65bf617aa4ef8fccd83cbc8e7 Mon Sep 17 00:00:00 2001 From: Thomas Klapwijk Date: Thu, 26 Sep 2024 10:56:03 +0100 Subject: [PATCH] Improves client header handling (#46) * Improves header handling and adds support for user defined headers for client upgrades * Improves header handling and logging --- ratchet_core/src/errors.rs | 5 +- ratchet_core/src/handshake/client/encoding.rs | 205 ++++++++---------- ratchet_core/src/handshake/client/mod.rs | 12 +- ratchet_core/src/handshake/client/tests.rs | 121 ++++++++++- ratchet_core/src/handshake/io.rs | 4 +- ratchet_core/src/handshake/mod.rs | 19 +- ratchet_core/src/handshake/server/encoding.rs | 45 +++- ratchet_core/src/handshake/server/tests.rs | 47 +++- 8 files changed, 313 insertions(+), 145 deletions(-) diff --git a/ratchet_core/src/errors.rs b/ratchet_core/src/errors.rs index 393055f..5ef0e44 100644 --- a/ratchet_core/src/errors.rs +++ b/ratchet_core/src/errors.rs @@ -160,7 +160,7 @@ pub enum HttpError { Status(StatusCode), /// An invalid HTTP version was received in a request. #[error("Invalid HTTP version: `{0:?}`")] - HttpVersion(Option), + HttpVersion(String), /// A request or response was missing an expected header. #[error("Missing header: `{0}`")] MissingHeader(HeaderName), @@ -176,6 +176,9 @@ pub enum HttpError { /// A provided header was malformatted #[error("A provided header was malformatted")] MalformattedHeader(String), + /// A request was missing the authority. + #[error("Missing authority")] + MissingAuthority, } impl From for Error { diff --git a/ratchet_core/src/handshake/client/encoding.rs b/ratchet_core/src/handshake/client/encoding.rs index 6796119..e5ef65d 100644 --- a/ratchet_core/src/handshake/client/encoding.rs +++ b/ratchet_core/src/handshake/client/encoding.rs @@ -13,10 +13,10 @@ // limitations under the License. use base64::Engine; -use bytes::{BufMut, BytesMut}; -use http::header::{AsHeaderName, HeaderName, IntoHeaderName}; +use bytes::BytesMut; +use http::header::{HOST, SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL}; use http::request::Parts; -use http::{header, HeaderMap, HeaderValue, Method, Request, Version}; +use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Version}; use ratchet_ext::ExtensionProvider; @@ -27,13 +27,13 @@ use crate::handshake::{ }; use base64::engine::general_purpose::STANDARD; +use log::error; pub fn encode_request(dst: &mut BytesMut, request: ValidatedRequest, nonce_buffer: &mut Nonce) { let ValidatedRequest { version, headers, path_and_query, - host, } = request; let nonce = rand::random::<[u8; 16]>(); @@ -49,76 +49,34 @@ pub fn encode_request(dst: &mut BytesMut, request: ValidatedRequest, nonce_buffe let request = format!( "\ GET {path} {version:?}\r\n\ -Host: {host}\r\n\ -Connection: Upgrade\r\n\ -Upgrade: websocket\r\n\ -sec-websocket-version: 13\r\n\ sec-websocket-key: {nonce}", version = version, path = path_and_query, - host = host, nonce = nonce_str ); - // 28 = request terminator + nonce buffer len - let mut len = 28 + request.len(); + extend(dst, request.as_bytes()); - let origin = write_header(&headers, header::ORIGIN); - let protocol = write_header(&headers, header::SEC_WEBSOCKET_PROTOCOL); - let ext = write_header(&headers, header::SEC_WEBSOCKET_EXTENSIONS); - let auth = write_header(&headers, header::AUTHORIZATION); - - if let Some((name, value)) = &origin { - len += name.len() + value.len() + 2; - } - if let Some((name, value)) = &protocol { - len += name.len() + value.len() + 2; - } - if let Some((name, value)) = &ext { - len += name.len() + value.len() + 2; + for (name, value) in &headers { + extend(dst, b"\r\n"); + extend(dst, name.as_str().as_bytes()); + extend(dst, b": "); + extend(dst, value.as_bytes()); } - if let Some((name, value)) = &auth { - len += name.len() + value.len() + 2; - } - - dst.reserve(len); - dst.put_slice(request.as_bytes()); - if let Some((name, value)) = origin { - dst.put_slice(b"\r\n"); - dst.put_slice(name.as_bytes()); - dst.put_slice(value); - } - if let Some((name, value)) = protocol { - dst.put_slice(b"\r\n"); - dst.put_slice(name.as_bytes()); - dst.put_slice(value); - } - if let Some((name, value)) = ext { - dst.put_slice(b"\r\n"); - dst.put_slice(name.as_bytes()); - dst.put_slice(value); - } - if let Some((name, value)) = auth { - dst.put_slice(b"\r\n"); - dst.put_slice(name.as_bytes()); - dst.put_slice(value); - } - - dst.put_slice(b"\r\n\r\n"); + extend(dst, b"\r\n\r\n"); } -fn write_header(headers: &HeaderMap, name: HeaderName) -> Option<(String, &[u8])> { - headers - .get(&name) - .map(|value| (format!("{}: ", name), value.as_bytes())) +#[inline] +fn extend(dst: &mut BytesMut, data: &[u8]) { + dst.extend_from_slice(data); } +#[derive(Debug)] pub struct ValidatedRequest { version: Version, headers: HeaderMap, path_and_query: String, - host: String, } // rfc6455 § 4.2.1 @@ -149,20 +107,46 @@ where if version != Version::HTTP_11 { return Err(Error::with_cause( ErrorKind::Http, - HttpError::HttpVersion(None), + HttpError::HttpVersion(format!("{version:?}")), )); } - let authority = uri - .authority() - .ok_or_else(|| Error::with_cause(ErrorKind::Http, "Missing authority"))? - .as_str() - .to_string(); - validate_or_insert( - &mut headers, - header::HOST, - HeaderValue::from_str(authority.as_ref())?, - )?; + if headers.get(SEC_WEBSOCKET_EXTENSIONS).is_some() { + error!( + "{} should only be set by extensions", + SEC_WEBSOCKET_EXTENSIONS + ); + return Err(Error::with_cause( + ErrorKind::Http, + HttpError::InvalidHeader(SEC_WEBSOCKET_EXTENSIONS), + )); + } + + // Run this first to ensure that the extension doesn't invalidate the headers. + extension.apply_headers(&mut headers); + + match validate_host_header(&headers) { + Ok(()) => { + // The request should only contain *one* 'host' header, and it must be a single value, + // not a comma seperated list. If the request doesn't already have one then derive it + // from the URI if it contains an authority. If it doesn't, then the request is invalid + // and any correct server implementation would reject it - including Ratchet. + let authority = uri + .authority() + .ok_or_else(|| Error::with_cause(ErrorKind::Http, HttpError::MissingAuthority))? + .as_str() + .to_string(); + validate_or_insert( + &mut headers, + header::HOST, + HeaderValue::from_str(authority.as_ref())?, + )?; + } + Err(e) => { + error!("Request should only contain one 'host' header. {e}"); + return Err(e); + } + } validate_or_insert( &mut headers, @@ -180,54 +164,28 @@ where HeaderValue::from_static(WEBSOCKET_VERSION_STR), )?; - if headers.get(header::SEC_WEBSOCKET_EXTENSIONS).is_some() { + if headers.get(SEC_WEBSOCKET_PROTOCOL).is_some() { + error!( + "{} should only be set by extensions", + SEC_WEBSOCKET_PROTOCOL + ); + // WebSocket protocols can only be applied using a ProtocolRegistry return Err(Error::with_cause( ErrorKind::Http, - HttpError::InvalidHeader(header::SEC_WEBSOCKET_EXTENSIONS), + HttpError::InvalidHeader(SEC_WEBSOCKET_PROTOCOL), )); } - extension.apply_headers(&mut headers); + apply_to(subprotocols, &mut headers); - if headers.get(header::SEC_WEBSOCKET_PROTOCOL).is_some() { - // WebSocket protocols can only be applied using a ProtocolRegistry + if headers.get(SEC_WEBSOCKET_KEY).is_some() { + error!("{} should not be set", SEC_WEBSOCKET_KEY); return Err(Error::with_cause( ErrorKind::Http, - HttpError::InvalidHeader(header::SEC_WEBSOCKET_PROTOCOL), + HttpError::InvalidHeader(SEC_WEBSOCKET_KEY), )); } - apply_to(subprotocols, &mut headers); - - let option = headers - .get(header::SEC_WEBSOCKET_KEY) - .map(|head| head.to_str()); - match option { - Some(Ok(version)) if version == WEBSOCKET_VERSION_STR => {} - None => { - headers.insert( - header::SEC_WEBSOCKET_VERSION, - HeaderValue::from_static(WEBSOCKET_VERSION_STR), - ); - } - _ => { - return Err(Error::with_cause( - ErrorKind::Http, - HttpError::InvalidHeader(header::SEC_WEBSOCKET_KEY), - )); - } - } - - let host = uri - .authority() - .ok_or_else(|| { - Error::with_cause( - ErrorKind::Http, - HttpError::MalformattedUri(Some("Missing authority".to_string())), - ) - })? - .to_string(); - let path_and_query = uri .path_and_query() .map(ToString::to_string) @@ -237,25 +195,46 @@ where version, headers, path_and_query, - host, }) } -fn validate_or_insert( +fn validate_or_insert( headers: &mut HeaderMap, - header_name: A, + header_name: HeaderName, expected: HeaderValue, -) -> Result<(), Error> -where - A: AsHeaderName + IntoHeaderName + Clone, -{ +) -> Result<(), HttpError> { if let Some(header_value) = headers.get(header_name.clone()) { match header_value.to_str() { Ok(v) if v.as_bytes().eq_ignore_ascii_case(expected.as_bytes()) => Ok(()), - _ => Err(Error::new(ErrorKind::Http)), + _ => { + error!("Invalid header set: {} -> {:?}", header_name, header_value); + Err(HttpError::InvalidHeader(header_name)) + } } } else { headers.insert(header_name, expected); Ok(()) } } + +/// Validates that 'headers' contains at most one 'host' header and that it is not a seperated list. +fn validate_host_header(headers: &HeaderMap) -> Result<(), Error> { + let len = headers + .iter() + .filter_map(|(name, value)| { + if name.as_str().eq_ignore_ascii_case(HOST.as_str()) { + Some(value.as_bytes().split(|c| c == &b' ' || c == &b',')) + } else { + None + } + }) + .count(); + if len <= 1 { + Ok(()) + } else { + Err(Error::with_cause( + ErrorKind::Http, + HttpError::InvalidHeader(HOST), + )) + } +} diff --git a/ratchet_core/src/handshake/client/mod.rs b/ratchet_core/src/handshake/client/mod.rs index 0a7039e..d2d995e 100644 --- a/ratchet_core/src/handshake/client/mod.rs +++ b/ratchet_core/src/handshake/client/mod.rs @@ -20,7 +20,7 @@ mod encoding; use base64::engine::general_purpose::STANDARD; use base64::Engine; use bytes::BytesMut; -use http::{header, Request, StatusCode}; +use http::{header, Request, StatusCode, Version}; use httparse::{Response, Status}; use log::{error, trace}; use sha1::{Digest, Sha1}; @@ -227,12 +227,14 @@ where subprotocols, } = self; + trace!("Encoding request: {request:?}"); let validated_request = build_request(request, extension, subprotocols)?; encode_request(buffered.buffer, validated_request, nonce); Ok(()) } async fn write(&mut self) -> Result<(), Error> { + trace!("Writing buffered data"); self.buffered.write().await } @@ -285,10 +287,10 @@ fn check_partial_response(response: &Response) -> Result<(), Error> { // httparse sets this to 0 for HTTP/1.0 or 1 for HTTP/1.1 // rfc6455 § 4.2.1.1: must be HTTP/1.1 or higher Some(1) | None => {} - Some(v) => { + Some(_) => { return Err(Error::with_cause( ErrorKind::Http, - HttpError::HttpVersion(Some(v)), + HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)), )) } } @@ -340,10 +342,10 @@ where match response.version { // rfc6455 § 4.2.1.1: must be HTTP/1.1 or higher Some(1) => {} - v => { + _ => { return Err(Error::with_cause( ErrorKind::Http, - HttpError::HttpVersion(v), + HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)), )) } } diff --git a/ratchet_core/src/handshake/client/tests.rs b/ratchet_core/src/handshake/client/tests.rs index 565f7fb..76b5d25 100644 --- a/ratchet_core/src/handshake/client/tests.rs +++ b/ratchet_core/src/handshake/client/tests.rs @@ -14,6 +14,7 @@ use crate::errors::{Error, HttpError}; use crate::ext::NoExt; +use crate::handshake::client::encoding::build_request; use crate::handshake::client::{ClientHandshake, HandshakeResult}; use crate::handshake::{ProtocolRegistry, ACCEPT_KEY, UPGRADE_STR, WEBSOCKET_STR}; use crate::test_fixture::mock; @@ -22,8 +23,11 @@ use base64::engine::{general_purpose::STANDARD, Engine}; use bytes::BytesMut; use futures::future::join; use futures::FutureExt; -use http::header::HeaderName; -use http::{header, HeaderMap, HeaderValue, Request, Response, StatusCode, Version}; +use http::header::{ + HeaderName, CONNECTION, HOST, SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_KEY, + SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE, +}; +use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version}; use httparse::{Header, Status}; use ratchet_ext::{ Extension, ExtensionDecoder, ExtensionEncoder, ExtensionProvider, FrameHeader, @@ -36,6 +40,7 @@ use tokio::io::AsyncReadExt; use tokio::sync::Notify; const TEST_URL: &str = "ws://127.0.0.1:9001/test"; +const ERR: &str = "Expected an error"; #[tokio::test] async fn handshake_sends_valid_request() { @@ -172,8 +177,6 @@ async fn expect_server_error(response: Response<()>, expected_error: HttpError) let handshake_result = machine.read().await; - const ERR: &str = "Expected an error"; - handshake_result .err() .map(|e| { @@ -243,7 +246,11 @@ async fn incorrect_version() { .body(()) .unwrap(); - expect_server_error(response, HttpError::HttpVersion(Some(0))).await; + expect_server_error( + response, + HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)), + ) + .await; } #[tokio::test] @@ -687,3 +694,107 @@ async fn negotiates_no_extension() { ) .await; } + +#[test] +fn fails_to_build_request() { + fn test(request: Request<()>, expected_error: E) { + match build_request(request, &NoExtProvider, &ProtocolRegistry::default()) { + Ok(r) => { + panic!("Expected a test failure of {}. Got {:?}", expected_error, r); + } + Err(e) => { + let error = e.downcast_ref::().expect(ERR); + assert_eq!(error, &expected_error); + } + } + } + + test( + Request::builder() + .method(Method::POST) + .version(Version::HTTP_11) + .uri(TEST_URL) + .body(()) + .unwrap(), + HttpError::HttpMethod(Some(Method::POST.to_string())), + ); + test( + Request::builder() + .method(Method::GET) + .version(Version::HTTP_10) + .uri(TEST_URL) + .body(()) + .unwrap(), + HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)), + ); + + let mut request = Request::builder() + .method(Method::GET) + .version(Version::HTTP_11) + .uri("/doot/doot") + .body(()) + .unwrap(); + request + .headers_mut() + .insert(HOST, HeaderValue::from_static("hosty")); + + test(request, HttpError::MissingAuthority); + + let mut request = Request::builder() + .method(Method::GET) + .version(Version::HTTP_11) + .uri(TEST_URL) + .body(()) + .unwrap(); + request + .headers_mut() + .insert(CONNECTION, HeaderValue::from_static("downgrade")); + request + .headers_mut() + .insert(HOST, HeaderValue::from_static("127.0.0.1:9001")); + + test(request, HttpError::InvalidHeader(CONNECTION)); + + let headers = [ + UPGRADE, + SEC_WEBSOCKET_VERSION, + SEC_WEBSOCKET_EXTENSIONS, + SEC_WEBSOCKET_PROTOCOL, + SEC_WEBSOCKET_KEY, + ]; + + for header in headers { + let mut request = Request::builder() + .method(Method::GET) + .version(Version::HTTP_11) + .uri(TEST_URL) + .body(()) + .unwrap(); + request + .headers_mut() + .insert(UPGRADE, HeaderValue::from_static("websocket")); + request + .headers_mut() + .insert(HOST, HeaderValue::from_static("127.0.0.1:9001")); + request + .headers_mut() + .insert(header.clone(), HeaderValue::from_static("socketweb")); + + test(request, HttpError::InvalidHeader(header)); + } + + let mut request = Request::builder() + .method(Method::GET) + .version(Version::HTTP_11) + .uri(TEST_URL) + .body(()) + .unwrap(); + request + .headers_mut() + .insert(HOST, HeaderValue::from_static("hostedbyhosts")); + request + .headers_mut() + .insert(HOST, HeaderValue::from_static("hostymchostface")); + + test(request, HttpError::InvalidHeader(HOST)); +} diff --git a/ratchet_core/src/handshake/io.rs b/ratchet_core/src/handshake/io.rs index 944b603..e1f084a 100644 --- a/ratchet_core/src/handshake/io.rs +++ b/ratchet_core/src/handshake/io.rs @@ -40,7 +40,7 @@ impl<'s, S> BufferedIo<'s, S> { Ok(()) } - pub async fn read(&mut self) -> Result<(), Error> + pub async fn read(&mut self) -> Result where S: AsyncRead + Unpin, { @@ -51,7 +51,7 @@ impl<'s, S> BufferedIo<'s, S> { let read_count = socket.read(&mut buffer[len..]).await?; buffer.truncate(len + read_count); - Ok(()) + Ok(read_count) } pub fn advance(&mut self, count: usize) { diff --git a/ratchet_core/src/handshake/mod.rs b/ratchet_core/src/handshake/mod.rs index 541782c..faa7cb9 100644 --- a/ratchet_core/src/handshake/mod.rs +++ b/ratchet_core/src/handshake/mod.rs @@ -28,6 +28,7 @@ use bytes::Bytes; use http::header::HeaderName; use http::Uri; use http::{HeaderMap, HeaderValue}; +use log::{error, trace, warn}; use std::str::FromStr; use tokio::io::AsyncRead; use tokio_util::codec::Decoder; @@ -62,15 +63,29 @@ where let StreamingParser { io, mut parser } = self; loop { - io.read().await?; + let n = io.read().await?; + + if n == 0 { + warn!("Received early EOF"); + return Err(Error::with_cause( + ErrorKind::IO, + std::io::Error::from(std::io::ErrorKind::UnexpectedEof), + )); + } else { + trace!("Read {n} bytes. Attempting to decode"); + } match parser.decode(io.buffer) { Ok(Some((out, count))) => { + trace!("Decoded: {count} bytes"); io.advance(count); return Ok(out); } Ok(None) => continue, - Err(e) => return Err(e), + Err(e) => { + error!("Failed to decode response. Error: {e}"); + return Err(e); + } } } } diff --git a/ratchet_core/src/handshake/server/encoding.rs b/ratchet_core/src/handshake/server/encoding.rs index 5953eca..ecc3b81 100644 --- a/ratchet_core/src/handshake/server/encoding.rs +++ b/ratchet_core/src/handshake/server/encoding.rs @@ -15,14 +15,16 @@ use crate::handshake::io::BufferedIo; use crate::handshake::server::HandshakeResult; use crate::handshake::{ - get_header, validate_header, validate_header_any, validate_header_value, ParseResult, - METHOD_GET, UPGRADE_STR, WEBSOCKET_STR, WEBSOCKET_VERSION_STR, + get_header, validate_header_any, validate_header_value, ParseResult, METHOD_GET, UPGRADE_STR, + WEBSOCKET_STR, WEBSOCKET_VERSION_STR, }; use crate::handshake::{negotiate_request, TryMap}; use crate::{Error, ErrorKind, HttpError, ProtocolRegistry}; use bytes::{BufMut, BytesMut}; -use http::{HeaderMap, StatusCode}; -use httparse::Status; +use http::header::HOST; +use http::{HeaderMap, StatusCode, Version}; +use httparse::{Header, Status}; +use log::error; use ratchet_ext::ExtensionProvider; use tokio::io::AsyncWrite; use tokio_util::codec::Decoder; @@ -144,10 +146,10 @@ where pub fn check_partial_request(request: &httparse::Request) -> Result<(), Error> { match request.version { Some(HTTP_VERSION_INT) | None => {} - Some(v) => { + Some(_) => { return Err(Error::with_cause( ErrorKind::Http, - HttpError::HttpVersion(Some(v)), + HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)), )) } } @@ -176,10 +178,10 @@ where { match request.version { Some(HTTP_VERSION_INT) => {} - v => { + _ => { return Err(Error::with_cause( ErrorKind::Http, - HttpError::HttpVersion(v), + HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)), )) } } @@ -203,7 +205,10 @@ where WEBSOCKET_VERSION_STR, )?; - validate_header(headers, http::header::HOST, |_, _| Ok(()))?; + if let Err(e) = validate_host_header(headers) { + error!("Server responded with invalid 'host' headers"); + return Err(e); + } let key = get_header(headers, http::header::SEC_WEBSOCKET_KEY)?; let subprotocol = negotiate_request(subprotocols, request)?; @@ -220,3 +225,25 @@ where extension_header, }) } + +/// Validates that 'headers' contains one 'host' header and that it is not a seperated list. +fn validate_host_header(headers: &[Header]) -> Result<(), Error> { + let len = headers + .iter() + .filter_map(|header| { + if header.name.eq_ignore_ascii_case(HOST.as_str()) { + Some(header.value.split(|c| c == &b' ' || c == &b',')) + } else { + None + } + }) + .count(); + if len == 1 { + Ok(()) + } else { + Err(Error::with_cause( + ErrorKind::Http, + HttpError::MissingHeader(HOST), + )) + } +} diff --git a/ratchet_core/src/handshake/server/tests.rs b/ratchet_core/src/handshake/server/tests.rs index c22e8ca..b57fa9b 100644 --- a/ratchet_core/src/handshake/server/tests.rs +++ b/ratchet_core/src/handshake/server/tests.rs @@ -33,7 +33,10 @@ impl From> for Error { } } -async fn exec_request(request: Request<()>) -> Result, Error> { +async fn exec_request(request: Request<()>, f: F) -> Result, Error> +where + F: FnOnce(HeaderMap), +{ let (mut client, server) = mock(); client.write_request(request).await?; @@ -46,13 +49,15 @@ async fn exec_request(request: Request<()>) -> Result, Error> { ) .await?; + f(upgrader.request.headers().clone()); + let _upgraded = upgrader.upgrade().await?; client.read_response().await.map_err(Into::into) } #[tokio::test] async fn valid_response() { - let response = exec_request(valid_request()).await.unwrap(); + let response = exec_request(valid_request(), |_| {}).await.unwrap(); let expected = Response::builder() .status(101) @@ -72,7 +77,7 @@ async fn valid_response() { #[tokio::test] async fn bad_request() { async fn t(request: Request<()>, name: HeaderName) { - match exec_request(request).await { + match exec_request(request, |_| {}).await { Ok(o) => panic!("Expected a test failure. Got: {:?}", o), Err(e) => match e.downcast_ref::() { Some(err) => { @@ -85,7 +90,6 @@ async fn bad_request() { } } - // request doesn't implement clone t( Request::builder() .uri("/test") @@ -158,11 +162,14 @@ async fn bad_request() { .body(()) .unwrap(); - match exec_request(request).await { + match exec_request(request, |_| {}).await { Ok(o) => panic!("Expected a test failure. Got: {:?}", o), Err(e) => match e.downcast_ref::() { Some(err) => { - assert_eq!(err, &HttpError::HttpVersion(Some(0))); + assert_eq!( + err, + &HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)) + ); } None => { panic!("Expected a HTTP error. Got: {:?}", e) @@ -181,7 +188,7 @@ async fn bad_request() { .body(()) .unwrap(); - match exec_request(request).await { + match exec_request(request, |_| {}).await { Ok(o) => panic!("Expected a test failure. Got: {:?}", o), Err(e) => match e.downcast_ref::() { Some(err) => { @@ -332,7 +339,7 @@ async fn multiple_connection_headers() { .body(()) .unwrap(); - let response = exec_request(request).await.unwrap(); + let response = exec_request(request, |_| {}).await.unwrap(); let expected = Response::builder() .status(101) @@ -348,3 +355,27 @@ async fn multiple_connection_headers() { assert_response_eq(response, expected); } + +#[tokio::test] +async fn user_defined_headers() { + let mut req = valid_request(); + req.headers_mut() + .insert(http::header::ACCEPT_RANGES, "aaaaa".parse().unwrap()); + req.headers_mut().insert( + http::header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + "bbbbbb".parse().unwrap(), + ); + + exec_request(req, |headers| { + assert_eq!( + headers.get(http::header::ACCEPT_RANGES), + Some(&HeaderValue::from_static("aaaaa")) + ); + assert_eq!( + headers.get(http::header::ACCESS_CONTROL_ALLOW_CREDENTIALS), + Some(&HeaderValue::from_static("bbbbbb")) + ); + }) + .await + .unwrap(); +}