From caec355c1a8b68984d2d2e3e5a2b6d235524ebef Mon Sep 17 00:00:00 2001 From: SirCipher Date: Tue, 24 Sep 2024 13:29:22 +0100 Subject: [PATCH 01/13] Removes NegotiatedExtension --- ratchet_core/src/ext.rs | 118 ------------------ ratchet_core/src/handshake/client/mod.rs | 16 ++- ratchet_core/src/handshake/client/tests.rs | 4 +- ratchet_core/src/handshake/server/encoding.rs | 8 +- ratchet_core/src/handshake/server/mod.rs | 6 +- ratchet_core/src/lib.rs | 2 +- ratchet_core/src/split/mod.rs | 9 +- ratchet_core/src/split/tests.rs | 30 ++--- ratchet_core/src/ws.rs | 29 ++--- ratchet_ext/src/lib.rs | 79 ++++++++++++ ratchet_fixture/src/lib.rs | 15 ++- 11 files changed, 126 insertions(+), 190 deletions(-) diff --git a/ratchet_core/src/ext.rs b/ratchet_core/src/ext.rs index 55329b7..dc60b8e 100644 --- a/ratchet_core/src/ext.rs +++ b/ratchet_core/src/ext.rs @@ -134,121 +134,3 @@ impl ExtensionDecoder for NoExtDecoder { Ok(()) } } - -#[derive(Debug)] -#[allow(missing_docs)] -pub struct NegotiatedExtension(Option); - -impl NegotiatedExtension -where - E: Extension, -{ - #[allow(missing_docs)] - pub fn take(self) -> Option { - self.0 - } -} - -impl From> for NegotiatedExtension -where - E: Extension, -{ - fn from(ext: Option) -> Self { - NegotiatedExtension(ext) - } -} - -impl From for NegotiatedExtension -where - E: Extension, -{ - fn from(ext: E) -> Self { - NegotiatedExtension::from(Some(ext)) - } -} - -impl ExtensionEncoder for NegotiatedExtension -where - E: ExtensionEncoder, -{ - type Error = E::Error; - - fn encode( - &mut self, - payload: &mut BytesMut, - header: &mut FrameHeader, - ) -> Result<(), Self::Error> { - match &mut self.0 { - Some(ext) => ext.encode(payload, header), - None => Ok(()), - } - } -} - -impl ExtensionDecoder for NegotiatedExtension -where - E: ExtensionDecoder, -{ - type Error = E::Error; - - fn decode( - &mut self, - payload: &mut BytesMut, - header: &mut FrameHeader, - ) -> Result<(), Self::Error> { - match &mut self.0 { - Some(ext) => ext.decode(payload, header), - None => Ok(()), - } - } -} - -impl Extension for NegotiatedExtension -where - E: Extension, -{ - fn bits(&self) -> RsvBits { - match &self.0 { - Some(ext) => ext.bits(), - None => RsvBits { - rsv1: false, - rsv2: false, - rsv3: false, - }, - } - } -} - -impl SplittableExtension for NegotiatedExtension -where - E: SplittableExtension, -{ - type SplitEncoder = NegotiatedExtension; - type SplitDecoder = NegotiatedExtension; - - fn split(self) -> (Self::SplitEncoder, Self::SplitDecoder) { - match self.0 { - Some(ext) => { - let (enc, dec) = ext.split(); - ( - NegotiatedExtension(Some(enc)), - NegotiatedExtension(Some(dec)), - ) - } - None => (NegotiatedExtension(None), NegotiatedExtension(None)), - } - } -} - -impl ReunitableExtension for NegotiatedExtension -where - E: ReunitableExtension, -{ - fn reunite(encoder: Self::SplitEncoder, decoder: Self::SplitDecoder) -> Self { - match (encoder.0, decoder.0) { - (Some(enc), Some(dec)) => NegotiatedExtension(Some(E::reunite(enc, dec))), - (None, None) => NegotiatedExtension(None), - _ => panic!("Illegal state"), - } - } -} diff --git a/ratchet_core/src/handshake/client/mod.rs b/ratchet_core/src/handshake/client/mod.rs index 39cc65c..0a7039e 100644 --- a/ratchet_core/src/handshake/client/mod.rs +++ b/ratchet_core/src/handshake/client/mod.rs @@ -27,7 +27,6 @@ use sha1::{Digest, Sha1}; use std::convert::TryFrom; use crate::errors::{Error, ErrorKind, HttpError}; -use crate::ext::NegotiatedExtension; use crate::handshake::client::encoding::{build_request, encode_request}; use crate::handshake::io::BufferedIo; use crate::handshake::{ @@ -168,13 +167,13 @@ struct ClientHandshake<'s, S, E> { extension: &'s E, } -pub struct ResponseParser<'b, E> { +pub struct StreamingResponseParser<'b, E> { nonce: &'b Nonce, extension: &'b E, subprotocols: &'b mut ProtocolRegistry, } -impl<'b, E> Decoder for ResponseParser<'b, E> +impl<'b, E> Decoder for StreamingResponseParser<'b, E> where E: ExtensionProvider, { @@ -182,14 +181,14 @@ where type Error = Error; fn decode(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { - let ResponseParser { + let StreamingResponseParser { nonce, extension, subprotocols, } = self; let mut headers = [httparse::EMPTY_HEADER; 32]; - let mut response = httparse::Response::new(&mut headers); + let mut response = Response::new(&mut headers); match try_parse_response(buf, &mut response, nonce, extension, subprotocols)? { ParseResult::Complete(result, count) => Ok(Some((result, count))), @@ -251,7 +250,7 @@ where let parser = StreamingParser::new( buffered, - ResponseParser { + StreamingResponseParser { nonce, extension, subprotocols, @@ -276,7 +275,7 @@ where #[derive(Debug)] pub struct HandshakeResult { pub subprotocol: Option, - pub extension: NegotiatedExtension, + pub extension: Option, } /// Quickly checks a partial response in the order of the expected HTTP response declaration to see @@ -399,7 +398,6 @@ where subprotocol: negotiate_response(subprotocols, response)?, extension: extension .negotiate_client(response.headers) - .map_err(|e| Error::with_cause(ErrorKind::Extension, e))? - .into(), + .map_err(|e| Error::with_cause(ErrorKind::Extension, e))?, }) } diff --git a/ratchet_core/src/handshake/client/tests.rs b/ratchet_core/src/handshake/client/tests.rs index 95ced6f..565f7fb 100644 --- a/ratchet_core/src/handshake/client/tests.rs +++ b/ratchet_core/src/handshake/client/tests.rs @@ -661,7 +661,7 @@ async fn negotiates_extension() { ); }, |result| match result { - Ok(handshake_result) => assert!(handshake_result.extension.take().unwrap().0), + Ok(mut handshake_result) => assert!(handshake_result.extension.take().unwrap().0), Err(e) => { panic!("Expected a valid upgrade: {:?}", e) } @@ -679,7 +679,7 @@ async fn negotiates_no_extension() { extension_proxy, |_| {}, |result| match result { - Ok(handshake_result) => assert!(!handshake_result.extension.take().unwrap().0), + Ok(mut handshake_result) => assert!(!handshake_result.extension.take().unwrap().0), Err(e) => { panic!("Expected a valid upgrade: {:?}", e) } diff --git a/ratchet_core/src/handshake/server/encoding.rs b/ratchet_core/src/handshake/server/encoding.rs index 667dcf3..5953eca 100644 --- a/ratchet_core/src/handshake/server/encoding.rs +++ b/ratchet_core/src/handshake/server/encoding.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::ext::NegotiatedExtension; use crate::handshake::io::BufferedIo; use crate::handshake::server::HandshakeResult; use crate::handshake::{ @@ -208,13 +207,10 @@ where let key = get_header(headers, http::header::SEC_WEBSOCKET_KEY)?; let subprotocol = negotiate_request(subprotocols, request)?; - let extension_opt = extension + let (extension, extension_header) = extension .negotiate_server(request.headers) + .map(Option::unzip) .map_err(|e| Error::with_cause(ErrorKind::Extension, e))?; - let (extension, extension_header) = match extension_opt { - Some((extension, header)) => (NegotiatedExtension::from(Some(extension)), Some(header)), - None => (NegotiatedExtension::from(None), None), - }; Ok(HandshakeResult { key, diff --git a/ratchet_core/src/handshake/server/mod.rs b/ratchet_core/src/handshake/server/mod.rs index 7c78339..c031cf1 100644 --- a/ratchet_core/src/handshake/server/mod.rs +++ b/ratchet_core/src/handshake/server/mod.rs @@ -16,7 +16,7 @@ mod encoding; #[cfg(test)] mod tests; -use crate::ext::{NegotiatedExtension, NoExt}; +use crate::ext::NoExt; use crate::handshake::io::BufferedIo; use crate::handshake::server::encoding::{write_response, RequestParser}; use crate::handshake::{StreamingParser, ACCEPT_KEY}; @@ -195,7 +195,7 @@ pub struct WebSocketUpgrader { subprotocol: Option, buf: BytesMut, stream: S, - extension: NegotiatedExtension, + extension: Option, extension_header: Option, config: WebSocketConfig, } @@ -315,7 +315,7 @@ where pub struct HandshakeResult { key: Bytes, subprotocol: Option, - extension: NegotiatedExtension, + extension: Option, request: Request, extension_header: Option, } diff --git a/ratchet_core/src/lib.rs b/ratchet_core/src/lib.rs index 11984c0..f32911f 100644 --- a/ratchet_core/src/lib.rs +++ b/ratchet_core/src/lib.rs @@ -53,7 +53,7 @@ pub mod fixture { pub use builder::{WebSocketClientBuilder, WebSocketServerBuilder}; pub use errors::*; -pub use ext::{NegotiatedExtension, NoExt, NoExtDecoder, NoExtEncoder, NoExtProvider}; +pub use ext::{NoExt, NoExtDecoder, NoExtEncoder, NoExtProvider}; pub use handshake::{ accept, accept_with, subscribe, subscribe_with, ProtocolRegistry, TryIntoRequest, UpgradedClient, UpgradedServer, WebSocketResponse, WebSocketUpgrader, diff --git a/ratchet_core/src/split/mod.rs b/ratchet_core/src/split/mod.rs index 50379b6..3978761 100644 --- a/ratchet_core/src/split/mod.rs +++ b/ratchet_core/src/split/mod.rs @@ -25,7 +25,6 @@ use tokio::io::AsyncWriteExt; use bilock::{bilock, BiLock}; use ratchet_ext::{ExtensionDecoder, ExtensionEncoder, ReunitableExtension, SplittableExtension}; -use crate::ext::NegotiatedExtension; use crate::framed::{ read_next, write_close, write_fragmented, CodecFlags, FramedIoParts, FramedRead, FramedWrite, Item, @@ -61,7 +60,7 @@ const STATE_CLOSED: u8 = 2; pub fn split( framed: framed::FramedIo, control_buffer: BytesMut, - extension: NegotiatedExtension, + extension: Option, ) -> (Sender, Receiver) where S: WebSocketStream, @@ -228,7 +227,7 @@ struct FramedIo { read_half: BiLock, reader: FramedRead, split_writer: BiLock>, - ext_decoder: NegotiatedExtension, + ext_decoder: Option, } /// An owned write half of a WebSocket connection. @@ -237,7 +236,7 @@ pub struct Sender { role: Role, close_state: Arc, split_writer: BiLock>, - ext_encoder: NegotiatedExtension, + ext_encoder: Option, } impl Sender @@ -747,7 +746,7 @@ where Ok(WebSocket::from_parts( framed, control_buffer, - NegotiatedExtension::reunite(ext_encoder, ext_decoder), + Option::::reunite(ext_encoder, ext_decoder), close_state, )) } else { diff --git a/ratchet_core/src/split/tests.rs b/ratchet_core/src/split/tests.rs index 742045f..cedbfe0 100644 --- a/ratchet_core/src/split/tests.rs +++ b/ratchet_core/src/split/tests.rs @@ -17,8 +17,8 @@ use crate::protocol::{ControlCode, DataCode, HeaderFlags, OpCode}; use crate::split::{FramedIo, Receiver, Sender, WriteHalf}; use crate::ws::extension_encode; use crate::{ - CloseCause, CloseCode, CloseReason, Error, Message, NegotiatedExtension, NoExt, NoExtDecoder, - NoExtEncoder, Role, WebSocket, WebSocketConfig, WebSocketStream, + CloseCause, CloseCode, CloseReason, Error, Message, NoExt, NoExtDecoder, NoExtEncoder, Role, + WebSocket, WebSocketConfig, WebSocketStream, }; use bytes::{Bytes, BytesMut}; use ratchet_ext::{ExtensionDecoder, ExtensionEncoder}; @@ -105,24 +105,14 @@ fn fixture() -> (Channel, Channel) { let (server, client) = duplex(512); let config = WebSocketConfig::default(); - let server = WebSocket::from_upgraded( - config, - server, - NegotiatedExtension::from(NoExt), - BytesMut::new(), - Role::Server, - ) - .split() - .unwrap(); - let client = WebSocket::from_upgraded( - config, - client, - NegotiatedExtension::from(NoExt), - BytesMut::new(), - Role::Client, - ) - .split() - .unwrap(); + let server = + WebSocket::from_upgraded(config, server, Some(NoExt), BytesMut::new(), Role::Server) + .split() + .unwrap(); + let client = + WebSocket::from_upgraded(config, client, Some(NoExt), BytesMut::new(), Role::Client) + .split() + .unwrap(); (client, server) } diff --git a/ratchet_core/src/ws.rs b/ratchet_core/src/ws.rs index 94939c5..c235cd7 100644 --- a/ratchet_core/src/ws.rs +++ b/ratchet_core/src/ws.rs @@ -13,7 +13,6 @@ // limitations under the License. use crate::errors::{CloseCause, Error, ErrorKind, ProtocolError}; -use crate::ext::NegotiatedExtension; use crate::framed::{FramedIo, Item}; use crate::protocol::{ CloseReason, ControlCode, DataCode, HeaderFlags, Message, MessageType, OpCode, PayloadType, @@ -79,7 +78,7 @@ type SplitSocket = ( pub struct WebSocket { framed: FramedIo, control_buffer: BytesMut, - extension: NegotiatedExtension, + extension: Option, close_state: CloseState, } @@ -102,7 +101,7 @@ where pub(crate) fn from_parts( framed: FramedIo, control_buffer: BytesMut, - extension: NegotiatedExtension, + extension: Option, close_state: CloseState, ) -> WebSocket { WebSocket { @@ -125,7 +124,7 @@ where pub fn from_upgraded( config: WebSocketConfig, stream: S, - extension: NegotiatedExtension, + extension: Option, read_buffer: BytesMut, role: Role, ) -> WebSocket { @@ -559,8 +558,8 @@ mod tests { use crate::protocol::{ControlCode, DataCode, HeaderFlags, OpCode}; use crate::ws::extension_encode; use crate::{ - CloseCause, CloseCode, CloseReason, Error, Message, NegotiatedExtension, NoExt, Role, - WebSocket, WebSocketConfig, WebSocketStream, + CloseCause, CloseCode, CloseReason, Error, Message, NoExt, Role, WebSocket, + WebSocketConfig, WebSocketStream, }; use bytes::{Bytes, BytesMut}; use ratchet_ext::Extension; @@ -613,20 +612,10 @@ mod tests { let (server, client) = duplex(512); let config = WebSocketConfig::default(); - let server = WebSocket::from_upgraded( - config, - server, - NegotiatedExtension::from(NoExt), - BytesMut::new(), - Role::Server, - ); - let client = WebSocket::from_upgraded( - config, - client, - NegotiatedExtension::from(NoExt), - BytesMut::new(), - Role::Client, - ); + let server = + WebSocket::from_upgraded(config, server, Some(NoExt), BytesMut::new(), Role::Server); + let client = + WebSocket::from_upgraded(config, client, Some(NoExt), BytesMut::new(), Role::Client); (client, server) } diff --git a/ratchet_ext/src/lib.rs b/ratchet_ext/src/lib.rs index 8f6b696..13509d3 100644 --- a/ratchet_ext/src/lib.rs +++ b/ratchet_ext/src/lib.rs @@ -267,3 +267,82 @@ pub trait ReunitableExtension: SplittableExtension { /// Reunite this encoder and decoder back into a single extension. fn reunite(encoder: Self::SplitEncoder, decoder: Self::SplitDecoder) -> Self; } + +impl Extension for Option +where + E: Extension, +{ + fn bits(&self) -> RsvBits { + match self { + Some(ext) => ext.bits(), + None => RsvBits { + rsv1: false, + rsv2: false, + rsv3: false, + }, + } + } +} + +impl ExtensionEncoder for Option +where + E: ExtensionEncoder, +{ + type Error = E::Error; + + fn encode( + &mut self, + payload: &mut BytesMut, + header: &mut FrameHeader, + ) -> Result<(), Self::Error> { + match self { + Some(e) => e.encode(payload, header), + None => Ok(()), + } + } +} + +impl ExtensionDecoder for Option +where + E: ExtensionDecoder, +{ + type Error = E::Error; + + fn decode( + &mut self, + payload: &mut BytesMut, + header: &mut FrameHeader, + ) -> Result<(), Self::Error> { + match self { + Some(e) => e.decode(payload, header), + None => Ok(()), + } + } +} + +impl ReunitableExtension for Option +where + E: ReunitableExtension, +{ + fn reunite(encoder: Self::SplitEncoder, decoder: Self::SplitDecoder) -> Self { + Option::zip(encoder, decoder).map(|(encoder, decoder)| E::reunite(encoder, decoder)) + } +} + +impl SplittableExtension for Option +where + E: SplittableExtension, +{ + type SplitEncoder = Option; + type SplitDecoder = Option; + + fn split(self) -> (Self::SplitEncoder, Self::SplitDecoder) { + match self { + Some(ext) => { + let (encoder, decoder) = ext.split(); + (Some(encoder), (Some(decoder))) + } + None => (None, None), + } + } +} diff --git a/ratchet_fixture/src/lib.rs b/ratchet_fixture/src/lib.rs index 6ee3e41..efbbbb4 100644 --- a/ratchet_fixture/src/lib.rs +++ b/ratchet_fixture/src/lib.rs @@ -14,25 +14,28 @@ pub mod duplex { use bytes::BytesMut; - use ratchet::{Extension, NegotiatedExtension, Role, WebSocketConfig}; + use ratchet::{Extension, Role, WebSocketConfig}; use tokio::io::DuplexStream; pub type MockWebSocket = ratchet::WebSocket; - pub fn make_websocket(stream: DuplexStream, role: Role, ext: E) -> MockWebSocket + pub fn make_websocket(stream: DuplexStream, role: Role, ext: Option) -> MockWebSocket where E: Extension, { ratchet::WebSocket::from_upgraded( WebSocketConfig::default(), stream, - NegotiatedExtension::from(Some(ext)), + ext, BytesMut::new(), role, ) } - pub fn websocket_pair(left_ext: L, right_ext: R) -> (MockWebSocket, MockWebSocket) + pub fn websocket_pair( + left_ext: Option, + right_ext: Option, + ) -> (MockWebSocket, MockWebSocket) where L: Extension, R: Extension, @@ -44,7 +47,7 @@ pub mod duplex { ) } - pub async fn websocket_for(role: Role, ext: E) -> (MockWebSocket, DuplexStream) + pub async fn websocket_for(role: Role, ext: Option) -> (MockWebSocket, DuplexStream) where E: Extension, { @@ -53,7 +56,7 @@ pub mod duplex { ratchet::WebSocket::from_upgraded( WebSocketConfig::default(), tx, - NegotiatedExtension::from(Some(ext)), + ext, BytesMut::new(), role, ), From dca92d5ac4d0162294d293325f851ea17b4065d3 Mon Sep 17 00:00:00 2001 From: SirCipher Date: Tue, 24 Sep 2024 16:48:33 +0100 Subject: [PATCH 02/13] Initial restructure to expose request parsing --- ratchet_core/src/errors.rs | 5 +- ratchet_core/src/ext.rs | 5 +- ratchet_core/src/handshake/client/mod.rs | 86 +++++----- ratchet_core/src/handshake/client/tests.rs | 27 +-- ratchet_core/src/handshake/mod.rs | 90 ++++++---- ratchet_core/src/handshake/server/encoding.rs | 79 ++++----- ratchet_core/src/handshake/server/tests.rs | 13 +- ratchet_core/src/handshake/subprotocols.rs | 22 +-- ratchet_core/src/handshake/tests.rs | 82 ++++------ ratchet_deflate/src/handshake.rs | 25 ++- ratchet_deflate/src/lib.rs | 11 +- ratchet_deflate/src/tests.rs | 154 +++++++----------- ratchet_ext/src/lib.rs | 19 ++- 13 files changed, 301 insertions(+), 317 deletions(-) diff --git a/ratchet_core/src/errors.rs b/ratchet_core/src/errors.rs index 393055f..d0c015f 100644 --- a/ratchet_core/src/errors.rs +++ b/ratchet_core/src/errors.rs @@ -16,7 +16,6 @@ use crate::protocol::{CloseCodeParseErr, OpCodeParseErr}; use http::header::{HeaderName, InvalidHeaderValue}; use http::status::InvalidStatusCode; use http::uri::InvalidUri; -use http::StatusCode; use std::any::Any; use std::error::Error as StdError; use std::fmt::{Display, Formatter}; @@ -156,8 +155,8 @@ pub enum HttpError { #[error("Redirected: `{0}`")] Redirected(String), /// The peer returned with a status code other than 101. - #[error("Status code: `{0}`")] - Status(StatusCode), + #[error("Status code: `{0:?}`")] + Status(Option), /// An invalid HTTP version was received in a request. #[error("Invalid HTTP version: `{0:?}`")] HttpVersion(Option), diff --git a/ratchet_core/src/ext.rs b/ratchet_core/src/ext.rs index dc60b8e..c416edc 100644 --- a/ratchet_core/src/ext.rs +++ b/ratchet_core/src/ext.rs @@ -15,7 +15,6 @@ use crate::Error; use bytes::BytesMut; use http::{HeaderMap, HeaderValue}; -use httparse::Header; use ratchet_ext::{ Extension, ExtensionDecoder, ExtensionEncoder, ExtensionProvider, FrameHeader, ReunitableExtension, RsvBits, SplittableExtension, @@ -61,14 +60,14 @@ impl ExtensionProvider for NoExtProvider { fn negotiate_client( &self, - _headers: &[Header], + _headers: &HeaderMap, ) -> Result, Self::Error> { Ok(None) } fn negotiate_server( &self, - _headers: &[Header], + _headers: &HeaderMap, ) -> Result, Self::Error> { Ok(None) } diff --git a/ratchet_core/src/handshake/client/mod.rs b/ratchet_core/src/handshake/client/mod.rs index 0a7039e..65299d7 100644 --- a/ratchet_core/src/handshake/client/mod.rs +++ b/ratchet_core/src/handshake/client/mod.rs @@ -17,26 +17,26 @@ mod tests; mod encoding; -use base64::engine::general_purpose::STANDARD; -use base64::Engine; -use bytes::BytesMut; -use http::{header, Request, StatusCode}; -use httparse::{Response, Status}; -use log::{error, trace}; -use sha1::{Digest, Sha1}; -use std::convert::TryFrom; - use crate::errors::{Error, ErrorKind, HttpError}; use crate::handshake::client::encoding::{build_request, encode_request}; use crate::handshake::io::BufferedIo; use crate::handshake::{ negotiate_response, validate_header, validate_header_value, ParseResult, ProtocolRegistry, - StreamingParser, ACCEPT_KEY, BAD_STATUS_CODE, UPGRADE_STR, WEBSOCKET_STR, + StreamingParser, TryMap, ACCEPT_KEY, BAD_STATUS_CODE, UPGRADE_STR, WEBSOCKET_STR, }; use crate::{ NoExt, NoExtProvider, Role, TryIntoRequest, WebSocket, WebSocketConfig, WebSocketStream, }; +use base64::engine::general_purpose::STANDARD; +use base64::Engine; +use bytes::BytesMut; +use http::header::LOCATION; +use http::{header, Request, StatusCode, Version}; +use httparse::{Response, Status}; +use log::{error, trace}; use ratchet_ext::ExtensionProvider; +use sha1::{Digest, Sha1}; +use std::convert::TryFrom; use tokio_util::codec::Decoder; type Nonce = [u8; 24]; @@ -188,11 +188,11 @@ where } = self; let mut headers = [httparse::EMPTY_HEADER; 32]; - let mut response = Response::new(&mut headers); + let response = Response::new(&mut headers); - match try_parse_response(buf, &mut response, nonce, extension, subprotocols)? { + match try_parse_response(buf, response, nonce, extension, subprotocols)? { ParseResult::Complete(result, count) => Ok(Some((result, count))), - ParseResult::Partial => { + ParseResult::Partial(response) => { check_partial_response(&response)?; Ok(None) } @@ -301,35 +301,39 @@ fn check_partial_response(response: &Response) -> Result<(), Error> { Ok(()) } Some(code) => match StatusCode::try_from(code) { - Ok(code) => Err(Error::with_cause(ErrorKind::Http, HttpError::Status(code))), + Ok(code) => Err(Error::with_cause( + ErrorKind::Http, + HttpError::Status(Some(code.as_u16())), + )), Err(_) => Err(Error::with_cause(ErrorKind::Http, BAD_STATUS_CODE)), }, None => Ok(()), } } -fn try_parse_response<'l, E>( - buffer: &'l [u8], - response: &mut Response<'_, 'l>, +fn try_parse_response<'h, 'b, E>( + buffer: &'h [u8], + mut response: Response<'h, 'b>, expected_nonce: &Nonce, extension: E, subprotocols: &mut ProtocolRegistry, -) -> Result>, Error> +) -> Result, HandshakeResult>, Error> where + 'h: 'b, E: ExtensionProvider, { match response.parse(buffer) { Ok(Status::Complete(count)) => { - parse_response(response, expected_nonce, extension, subprotocols) + parse_response(response.try_map()?, expected_nonce, extension, subprotocols) .map(|r| ParseResult::Complete(r, count)) } - Ok(Status::Partial) => Ok(ParseResult::Partial), + Ok(Status::Partial) => Ok(ParseResult::Partial(response)), Err(e) => Err(e.into()), } } fn parse_response( - response: &Response, + response: http::Response<()>, expected_nonce: &Nonce, extension: E, subprotocols: &mut ProtocolRegistry, @@ -337,48 +341,48 @@ fn parse_response( where E: ExtensionProvider, { - match response.version { + if response.version() < Version::HTTP_11 { // rfc6455 § 4.2.1.1: must be HTTP/1.1 or higher - Some(1) => {} - v => { - return Err(Error::with_cause( - ErrorKind::Http, - HttpError::HttpVersion(v), - )) - } + // this will implicitly be 0 as httparse only parses HTTP/1.x and 1.1 is 0. + return Err(Error::with_cause( + ErrorKind::Http, + HttpError::HttpVersion(Some(0)), + )); } - let raw_status_code = response.code.ok_or_else(|| Error::new(ErrorKind::Http))?; - let status_code = StatusCode::from_u16(raw_status_code)?; + let status_code = response.status(); match status_code { c if c == StatusCode::SWITCHING_PROTOCOLS => {} c if c.is_redirection() => { - return match response.headers.iter().find(|h| h.name == header::LOCATION) { - Some(header) => { + return match response.headers().get(LOCATION) { + Some(value) => { // the value _should_ be valid UTF-8 - let location = String::from_utf8(header.value.to_vec()) + let location = String::from_utf8(value.as_bytes().to_vec()) .map_err(|_| Error::new(ErrorKind::Http))?; Err(Error::with_cause( ErrorKind::Http, HttpError::Redirected(location), )) } - None => Err(Error::with_cause(ErrorKind::Http, HttpError::Status(c))), + None => Err(Error::with_cause( + ErrorKind::Http, + HttpError::Status(Some(c.as_u16())), + )), }; } status_code => { return Err(Error::with_cause( ErrorKind::Http, - HttpError::Status(status_code), + HttpError::Status(Some(status_code.as_u16())), )) } } - validate_header_value(response.headers, header::UPGRADE, WEBSOCKET_STR)?; - validate_header_value(response.headers, header::CONNECTION, UPGRADE_STR)?; + validate_header_value(response.headers(), header::UPGRADE, WEBSOCKET_STR)?; + validate_header_value(response.headers(), header::CONNECTION, UPGRADE_STR)?; validate_header( - response.headers, + response.headers(), header::SEC_WEBSOCKET_ACCEPT, |_name, actual| { let mut digest = Sha1::new(); @@ -395,9 +399,9 @@ where )?; Ok(HandshakeResult { - subprotocol: negotiate_response(subprotocols, response)?, + subprotocol: negotiate_response(subprotocols, response.headers())?, extension: extension - .negotiate_client(response.headers) + .negotiate_client(response.headers()) .map_err(|e| Error::with_cause(ErrorKind::Extension, e))?, }) } diff --git a/ratchet_core/src/handshake/client/tests.rs b/ratchet_core/src/handshake/client/tests.rs index 565f7fb..50b32fa 100644 --- a/ratchet_core/src/handshake/client/tests.rs +++ b/ratchet_core/src/handshake/client/tests.rs @@ -233,7 +233,11 @@ async fn bad_status_code() { .body(()) .unwrap(); - expect_server_error(response, HttpError::Status(StatusCode::IM_A_TEAPOT)).await; + expect_server_error( + response, + HttpError::Status(Some(StatusCode::IM_A_TEAPOT.as_u16())), + ) + .await; } #[tokio::test] @@ -479,11 +483,11 @@ impl From for Error { struct MockExtensionProxy(&'static [(HeaderName, &'static str)], R) where - R: for<'h> Fn(&'h [Header]) -> Result, ExtHandshakeErr>; + R: for<'h> Fn(&'h HeaderMap) -> Result, ExtHandshakeErr>; impl ExtensionProvider for MockExtensionProxy where - R: for<'h> Fn(&'h [Header]) -> Result, ExtHandshakeErr>, + R: for<'h> Fn(&'h HeaderMap) -> Result, ExtHandshakeErr>, { type Extension = MockExtension; type Error = ExtHandshakeErr; @@ -494,13 +498,16 @@ where } } - fn negotiate_client(&self, headers: &[Header]) -> Result, Self::Error> { - (self.1)(headers) + fn negotiate_client( + &self, + headers: &HeaderMap, + ) -> Result, Self::Error> { + self.1(headers) } fn negotiate_server( &self, - _headers: &[Header], + _headers: &HeaderMap, ) -> Result, ExtHandshakeErr> { panic!("Unexpected server-side extension negotiation") } @@ -629,13 +636,13 @@ async fn negotiates_extension() { const HEADERS: &[(HeaderName, &str)] = &[(header::SEC_WEBSOCKET_EXTENSIONS, EXT)]; let extension_proxy = MockExtensionProxy(HEADERS, |headers| { - let ext = headers.iter().find(|h| { - h.name + let ext = headers.iter().find(|(name, _value)| { + name.as_str() .eq_ignore_ascii_case(header::SEC_WEBSOCKET_EXTENSIONS.as_str()) }); match ext { - Some(header) => { - let value = String::from_utf8(header.value.to_vec()) + Some((_name, value)) => { + let value = String::from_utf8(value.as_bytes().to_vec()) .expect("Server returned invalid UTF-8"); if value == EXT { Ok(Some(MockExtension(true))) diff --git a/ratchet_core/src/handshake/mod.rs b/ratchet_core/src/handshake/mod.rs index 541782c..f506117 100644 --- a/ratchet_core/src/handshake/mod.rs +++ b/ratchet_core/src/handshake/mod.rs @@ -24,10 +24,9 @@ use crate::errors::Error; use crate::errors::{ErrorKind, HttpError}; use crate::handshake::io::BufferedIo; use crate::{InvalidHeader, Request}; -use bytes::Bytes; use http::header::HeaderName; -use http::Uri; -use http::{HeaderMap, HeaderValue}; +use http::{HeaderMap, HeaderValue, Method, Version}; +use http::{Response, StatusCode, Uri}; use std::str::FromStr; use tokio::io::AsyncRead; use tokio_util::codec::Decoder; @@ -76,9 +75,9 @@ where } } -pub enum ParseResult { +pub enum ParseResult { Complete(O, usize), - Partial, + Partial(R), } /// A trait for creating a request from a type. @@ -136,12 +135,12 @@ impl TryIntoRequest for Request { } fn validate_header_value( - headers: &[httparse::Header], + headers: &HeaderMap, name: HeaderName, expected: &str, ) -> Result<(), Error> { validate_header(headers, name, |name, actual| { - if actual.eq_ignore_ascii_case(expected.as_bytes()) { + if actual.as_bytes().eq_ignore_ascii_case(expected.as_bytes()) { Ok(()) } else { Err(Error::with_cause( @@ -152,15 +151,12 @@ fn validate_header_value( }) } -fn validate_header(headers: &[httparse::Header], name: HeaderName, f: F) -> Result<(), Error> +fn validate_header(headers: &HeaderMap, name: HeaderName, f: F) -> Result<(), Error> where - F: Fn(HeaderName, &[u8]) -> Result<(), Error>, + F: Fn(HeaderName, &HeaderValue) -> Result<(), Error>, { - match headers - .iter() - .find(|h| h.name.eq_ignore_ascii_case(name.as_str())) - { - Some(header) => f(name, header.value), + match headers.get(&name) { + Some(value) => f(name, value), None => Err(Error::with_cause( ErrorKind::Http, HttpError::MissingHeader(name), @@ -168,13 +164,10 @@ where } } -fn validate_header_any( - headers: &[httparse::Header], - name: HeaderName, - expected: &str, -) -> Result<(), Error> { +fn validate_header_any(headers: &HeaderMap, name: HeaderName, expected: &str) -> Result<(), Error> { validate_header(headers, name, |name, actual| { if actual + .as_bytes() .split(|c| c == &b' ' || c == &b',') .any(|v| v.eq_ignore_ascii_case(expected.as_bytes())) { @@ -188,19 +181,6 @@ fn validate_header_any( }) } -fn get_header(headers: &[httparse::Header], name: HeaderName) -> Result { - match headers - .iter() - .find(|h| h.name.eq_ignore_ascii_case(name.as_str())) - { - Some(header) => Ok(Bytes::from(header.value.to_vec())), - None => Err(Error::with_cause( - ErrorKind::Http, - HttpError::MissingHeader(name), - )), - } -} - /// Local replacement for TryInto that can be implemented for httparse::Header and httparse::Request pub trait TryMap { /// Error type returned if the mapping fails @@ -245,11 +225,57 @@ impl<'l, 'h, 'buf: 'h> TryMap for &'l httparse::Request<'h, 'buf> { ))) } }; + let method = match self.method { + Some(m) => { + Method::from_str(m).map_err(|_| HttpError::HttpMethod(Some(m.to_string())))? + } + None => return Err(HttpError::HttpMethod(None)), + }; + let version = match self.version { + Some(v) => match v { + 0 => Version::HTTP_10, + 1 => Version::HTTP_11, + n => return Err(HttpError::HttpVersion(Some(n))), + }, + None => return Err(HttpError::HttpVersion(None)), + }; let headers = &self.headers; *request.headers_mut() = headers.try_map()?; *request.uri_mut() = path; + *request.version_mut() = version; + *request.method_mut() = method; Ok(request) } } + +impl<'l, 'h, 'buf: 'h> TryMap> for &'l httparse::Response<'h, 'buf> { + type Error = HttpError; + + fn try_map(self) -> Result, Self::Error> { + let mut response = Response::new(()); + let code = match self.code { + Some(c) => match StatusCode::from_u16(c) { + Ok(status) => status, + Err(_) => return Err(HttpError::Status(Some(c))), + }, + None => return Err(HttpError::Status(None)), + }; + let version = match self.version { + Some(v) => match v { + 0 => Version::HTTP_10, + 1 => Version::HTTP_11, + n => return Err(HttpError::HttpVersion(Some(n))), + }, + None => return Err(HttpError::HttpVersion(None)), + }; + let headers = &self.headers; + + *response.headers_mut() = headers.try_map()?; + *response.status_mut() = code; + *response.version_mut() = version; + + Ok(response) + } +} diff --git a/ratchet_core/src/handshake/server/encoding.rs b/ratchet_core/src/handshake/server/encoding.rs index 5953eca..c368d6d 100644 --- a/ratchet_core/src/handshake/server/encoding.rs +++ b/ratchet_core/src/handshake/server/encoding.rs @@ -14,14 +14,15 @@ use crate::handshake::io::BufferedIo; use crate::handshake::server::HandshakeResult; +use crate::handshake::{negotiate_request, TryMap}; use crate::handshake::{ - get_header, validate_header, validate_header_any, validate_header_value, ParseResult, - METHOD_GET, UPGRADE_STR, WEBSOCKET_STR, WEBSOCKET_VERSION_STR, + validate_header, validate_header_any, validate_header_value, ParseResult, METHOD_GET, + UPGRADE_STR, WEBSOCKET_STR, WEBSOCKET_VERSION_STR, }; -use crate::handshake::{negotiate_request, TryMap}; use crate::{Error, ErrorKind, HttpError, ProtocolRegistry}; -use bytes::{BufMut, BytesMut}; -use http::{HeaderMap, StatusCode}; +use bytes::{BufMut, Bytes, BytesMut}; +use http::header::SEC_WEBSOCKET_KEY; +use http::{HeaderMap, Method, StatusCode, Version}; use httparse::Status; use ratchet_ext::ExtensionProvider; use tokio::io::AsyncWrite; @@ -29,11 +30,12 @@ use tokio_util::codec::Decoder; /// The maximum number of headers that will be parsed. const MAX_HEADERS: usize = 32; -const HTTP_VERSION: &[u8] = b"HTTP/1.1 "; +const HTTP_VERSION_STR: &[u8] = b"HTTP/1.1 "; const STATUS_TERMINATOR_LEN: usize = 2; const TERMINATOR_NO_HEADERS: &[u8] = b"\r\n\r\n"; const TERMINATOR_WITH_HEADER: &[u8] = b"\r\n"; const HTTP_VERSION_INT: u8 = 1; +const HTTP_VERSION: Version = Version::HTTP_11; pub struct RequestParser { pub subprotocols: ProtocolRegistry, @@ -53,11 +55,11 @@ where extension, } = self; let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; - let mut request = httparse::Request::new(&mut headers); + let request = httparse::Request::new(&mut headers); - match try_parse_request(buf, &mut request, extension, subprotocols)? { + match try_parse_request(buf, request, extension, subprotocols)? { ParseResult::Complete(result, count) => Ok(Some((result, count))), - ParseResult::Partial => { + ParseResult::Partial(request) => { check_partial_request(&request)?; Ok(None) } @@ -77,7 +79,7 @@ where { buf.clear(); - let version_count = HTTP_VERSION.len(); + let version_count = HTTP_VERSION_STR.len(); let status_bytes = status.as_str().as_bytes(); let reason_len = status .canonical_reason() @@ -95,7 +97,7 @@ where buf.reserve(version_count + status_bytes.len() + reason_len + headers_len + terminator_len); - buf.put_slice(HTTP_VERSION); + buf.put_slice(HTTP_VERSION_STR); buf.put_slice(status.as_str().as_bytes()); match status.canonical_reason() { @@ -123,20 +125,21 @@ where buffered.write().await } -pub fn try_parse_request<'l, E>( - buffer: &'l [u8], - request: &mut httparse::Request<'_, 'l>, +pub fn try_parse_request<'h, 'b, E>( + buffer: &'b [u8], + mut request: httparse::Request<'h, 'b>, extension: E, subprotocols: &mut ProtocolRegistry, -) -> Result>, Error> +) -> Result, HandshakeResult>, Error> where E: ExtensionProvider, { match request.parse(buffer) { Ok(Status::Complete(count)) => { + let request = request.try_map()?; parse_request(request, extension, subprotocols).map(|r| ParseResult::Complete(r, count)) } - Ok(Status::Partial) => Ok(ParseResult::Partial), + Ok(Status::Partial) => Ok(ParseResult::Partial(request)), Err(e) => Err(e.into()), } } @@ -167,34 +170,29 @@ pub fn check_partial_request(request: &httparse::Request) -> Result<(), Error> { } pub fn parse_request( - request: &mut httparse::Request<'_, '_>, + request: http::Request<()>, extension: E, subprotocols: &mut ProtocolRegistry, ) -> Result, Error> where E: ExtensionProvider, { - match request.version { - Some(HTTP_VERSION_INT) => {} - v => { - return Err(Error::with_cause( - ErrorKind::Http, - HttpError::HttpVersion(v), - )) - } + if request.version() < HTTP_VERSION { + // this will implicitly be 0 as httparse only parses HTTP/1.x and 1.1 is 0. + return Err(Error::with_cause( + ErrorKind::Http, + HttpError::HttpVersion(Some(0)), + )); } - match request.method { - Some(m) if m.eq_ignore_ascii_case(METHOD_GET) => {} - m => { - return Err(Error::with_cause( - ErrorKind::Http, - HttpError::HttpMethod(m.map(ToString::to_string)), - )); - } + if request.method() != Method::GET { + return Err(Error::with_cause( + ErrorKind::Http, + HttpError::HttpMethod(Some(request.method().to_string())), + )); } - let headers = &request.headers; + let headers = request.headers(); validate_header_any(headers, http::header::CONNECTION, UPGRADE_STR)?; validate_header_value(headers, http::header::UPGRADE, WEBSOCKET_STR)?; validate_header_value( @@ -205,16 +203,21 @@ where validate_header(headers, http::header::HOST, |_, _| Ok(()))?; - let key = get_header(headers, http::header::SEC_WEBSOCKET_KEY)?; - let subprotocol = negotiate_request(subprotocols, request)?; + let key = headers + .get(SEC_WEBSOCKET_KEY) + .map(|v| Bytes::from(v.as_bytes().to_vec())) + .ok_or_else(|| { + Error::with_cause(ErrorKind::Http, HttpError::MissingHeader(SEC_WEBSOCKET_KEY)) + })?; + let subprotocol = negotiate_request(subprotocols, headers)?; let (extension, extension_header) = extension - .negotiate_server(request.headers) + .negotiate_server(headers) .map(Option::unzip) .map_err(|e| Error::with_cause(ErrorKind::Extension, e))?; Ok(HandshakeResult { key, - request: request.try_map()?, + request, extension, subprotocol, extension_header, diff --git a/ratchet_core/src/handshake/server/tests.rs b/ratchet_core/src/handshake/server/tests.rs index c22e8ca..b84c578 100644 --- a/ratchet_core/src/handshake/server/tests.rs +++ b/ratchet_core/src/handshake/server/tests.rs @@ -19,8 +19,7 @@ use crate::{ }; use bytes::BytesMut; use http::header::HeaderName; -use http::{HeaderMap, HeaderValue, Request, Response, Version}; -use httparse::Header; +use http::{HeaderMap, HeaderValue, Method, Request, Response, Version}; use ratchet_ext::{ Extension, ExtensionDecoder, ExtensionEncoder, ExtensionProvider, FrameHeader, ReunitableExtension, RsvBits, SplittableExtension, @@ -172,7 +171,7 @@ async fn bad_request() { let request = Request::builder() .uri("/test") - .method("post") + .method(Method::POST) .header(http::header::CONNECTION, UPGRADE_STR) .header(http::header::UPGRADE, WEBSOCKET_STR) .header(http::header::SEC_WEBSOCKET_VERSION, WEBSOCKET_VERSION_STR) @@ -185,7 +184,7 @@ async fn bad_request() { Ok(o) => panic!("Expected a test failure. Got: {:?}", o), Err(e) => match e.downcast_ref::() { Some(err) => { - assert_eq!(err, &HttpError::HttpMethod(Some("post".to_string()))); + assert_eq!(err, &HttpError::HttpMethod(Some("POST".to_string()))); } None => { panic!("Expected a HTTP error. Got: {:?}", e) @@ -218,14 +217,14 @@ impl ExtensionProvider for BadExtProvider { fn negotiate_client( &self, - _headers: &[Header], + _headers: &HeaderMap, ) -> Result, Self::Error> { - panic!("Unexpected client negotitation request") + panic!("Unexpected client negotiation request") } fn negotiate_server( &self, - _headers: &[Header], + _headers: &HeaderMap, ) -> Result, Self::Error> { Err(ExtErr) } diff --git a/ratchet_core/src/handshake/subprotocols.rs b/ratchet_core/src/handshake/subprotocols.rs index f8b7371..9408be9 100644 --- a/ratchet_core/src/handshake/subprotocols.rs +++ b/ratchet_core/src/handshake/subprotocols.rs @@ -16,7 +16,6 @@ use crate::{Error, ErrorKind, HttpError, ProtocolError}; use fnv::FnvHashSet; use http::header::SEC_WEBSOCKET_PROTOCOL; use http::{HeaderMap, HeaderValue}; -use httparse::Header; use std::borrow::Cow; /// A subprotocol registry that is used for negotiating a possible subprotocol to use for a @@ -65,11 +64,10 @@ fn negotiate<'h, I>( bias: Bias, ) -> Result, ProtocolError> where - I: Iterator>, + I: Iterator, { for header in headers { - let value = - String::from_utf8(header.value.to_vec()).map_err(|_| ProtocolError::Encoding)?; + let value = std::str::from_utf8(header.as_bytes()).map_err(|_| ProtocolError::Encoding)?; let protocols = value .split(',') .map(|s| s.trim().into()) @@ -103,25 +101,17 @@ where pub fn negotiate_response( registry: &ProtocolRegistry, - response: &httparse::Response, + header_map: &HeaderMap, ) -> Result, ProtocolError> { - let it = response - .headers - .iter() - .filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL.as_str())); - + let it = header_map.get_all(SEC_WEBSOCKET_PROTOCOL).into_iter(); negotiate(registry, it, Bias::Client) } pub fn negotiate_request( registry: &ProtocolRegistry, - request: &httparse::Request, + header_map: &HeaderMap, ) -> Result, ProtocolError> { - let it = request - .headers - .iter() - .filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL.as_str())); - + let it = header_map.get_all(SEC_WEBSOCKET_PROTOCOL).into_iter(); negotiate(registry, it, Bias::Server) } diff --git a/ratchet_core/src/handshake/tests.rs b/ratchet_core/src/handshake/tests.rs index 121e591..243a872 100644 --- a/ratchet_core/src/handshake/tests.rs +++ b/ratchet_core/src/handshake/tests.rs @@ -15,92 +15,72 @@ use crate::handshake::{negotiate_request, ProtocolRegistry}; use crate::ProtocolError; use http::header::SEC_WEBSOCKET_PROTOCOL; +use http::{HeaderMap, HeaderValue}; #[test] fn selects_protocol_ok() { - let mut headers = [httparse::Header { - name: SEC_WEBSOCKET_PROTOCOL.as_str(), - value: b"warp, warps", - }]; - let request = httparse::Request::new(&mut headers); - + let headers = HeaderMap::from_iter([( + SEC_WEBSOCKET_PROTOCOL, + HeaderValue::from_static("warp, warps"), + )]); let registry = ProtocolRegistry::new(vec!["warps", "warp"]).unwrap(); assert_eq!( - negotiate_request(®istry, &request), + negotiate_request(®istry, &headers), Ok(Some("warp".to_string())) ); } #[test] fn multiple_headers() { - let mut headers = [ - httparse::Header { - name: SEC_WEBSOCKET_PROTOCOL.as_str(), - value: b"warp", - }, - httparse::Header { - name: SEC_WEBSOCKET_PROTOCOL.as_str(), - value: b"warps", - }, - ]; - let request = httparse::Request::new(&mut headers); - + let headers = HeaderMap::from_iter([ + (SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("warp")), + (SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("warps")), + ]); let registry = ProtocolRegistry::new(vec!["warps", "warp"]).unwrap(); + assert_eq!( - negotiate_request(®istry, &request), + negotiate_request(®istry, &headers), Ok(Some("warp".to_string())) ); } #[test] fn mixed_headers() { - let mut headers = [ - httparse::Header { - name: SEC_WEBSOCKET_PROTOCOL.as_str(), - value: b"warp1.0", - }, - httparse::Header { - name: SEC_WEBSOCKET_PROTOCOL.as_str(), - value: b"warps2.0,warp3.0", - }, - httparse::Header { - name: SEC_WEBSOCKET_PROTOCOL.as_str(), - value: b"warps4.0", - }, - ]; - let request = httparse::Request::new(&mut headers); - + let headers = HeaderMap::from_iter([ + (SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("warp1.0")), + ( + SEC_WEBSOCKET_PROTOCOL, + HeaderValue::from_static("warps2.0,warp3.0"), + ), + (SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("warps4.0")), + ]); let registry = ProtocolRegistry::new(vec!["warps", "warp", "warps2.0"]).unwrap(); + assert_eq!( - negotiate_request(®istry, &request), + negotiate_request(®istry, &headers), Ok(Some("warps2.0".to_string())) ); } #[test] fn malformatted() { - let mut headers = [httparse::Header { - name: SEC_WEBSOCKET_PROTOCOL.as_str(), - value: &[255, 255, 255, 255], - }]; - let request = httparse::Request::new(&mut headers); - + let headers = HeaderMap::from_iter([(SEC_WEBSOCKET_PROTOCOL, unsafe { + HeaderValue::from_maybe_shared_unchecked([255, 255, 255, 255]) + })]); let registry = ProtocolRegistry::new(vec!["warps", "warp", "warps2.0"]).unwrap(); + assert_eq!( - negotiate_request(®istry, &request), + negotiate_request(®istry, &headers), Err(ProtocolError::Encoding) ); } #[test] fn no_match() { - let mut headers = [httparse::Header { - name: SEC_WEBSOCKET_PROTOCOL.as_str(), - value: b"a,b,c", - }]; - let request = httparse::Request::new(&mut headers); - + let headers = + HeaderMap::from_iter([(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("a,b,c"))]); let registry = ProtocolRegistry::new(vec!["d"]).unwrap(); - assert_eq!(negotiate_request(®istry, &request), Ok(None)); + + assert_eq!(negotiate_request(®istry, &headers), Ok(None)); } diff --git a/ratchet_deflate/src/handshake.rs b/ratchet_deflate/src/handshake.rs index b912149..f0195b2 100644 --- a/ratchet_deflate/src/handshake.rs +++ b/ratchet_deflate/src/handshake.rs @@ -18,7 +18,6 @@ use bytes::BytesMut; use flate2::Compression; use http::header::SEC_WEBSOCKET_EXTENSIONS; use http::{HeaderMap, HeaderValue}; -use ratchet_ext::Header; use std::fmt::Write; use std::str::Utf8Error; @@ -152,7 +151,7 @@ pub fn apply_headers(header_map: &mut HeaderMap, config: &DeflateConfig) { } pub fn negotiate_client( - headers: &[Header], + headers: &HeaderMap, config: &DeflateConfig, ) -> Result, DeflateExtensionError> { match on_response(headers, config) { @@ -163,7 +162,7 @@ pub fn negotiate_client( } pub fn negotiate_server( - headers: &[Header], + headers: &HeaderMap, config: &DeflateConfig, ) -> Result, DeflateExtensionError> { match on_request(headers, config) { @@ -177,17 +176,17 @@ pub fn negotiate_server( } pub fn on_request( - headers: &[Header], + headers: &HeaderMap, config: &DeflateConfig, ) -> Result<(InitialisedDeflateConfig, HeaderValue), NegotiationErr> { - let header_iter = headers.iter().filter(|h| { - h.name + let header_iter = headers.iter().filter(|(name, _value)| { + name.as_str() .eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS.as_str()) }); - for header in header_iter { + for (_, value) in header_iter { let header_value = - std::str::from_utf8(header.value).map_err(DeflateExtensionError::from)?; + std::str::from_utf8(value.as_bytes()).map_err(DeflateExtensionError::from)?; for part in header_value.split(',') { match validate_request_header(part, config) { @@ -319,7 +318,7 @@ impl From for NegotiationErr { } pub fn on_response( - headers: &[Header], + headers: &HeaderMap, config: &DeflateConfig, ) -> Result { let mut seen_extension_name = false; @@ -335,13 +334,13 @@ pub fn on_response( let mut client_max_window_bits = config.client_max_window_bits; let accept_no_context_takeover = config.accept_no_context_takeover; - let header_iter = headers.iter().filter(|h| { - h.name + let header_iter = headers.iter().filter(|(name, _value)| { + name.as_str() .eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS.as_str()) }); - for header in header_iter { - let header_value = std::str::from_utf8(header.value)?; + for (_name, value) in header_iter { + let header_value = std::str::from_utf8(value.as_bytes())?; let mut param_iter = header_value.split(';'); if let Some(param) = param_iter.next() { diff --git a/ratchet_deflate/src/lib.rs b/ratchet_deflate/src/lib.rs index 53dafd1..b8a6b1b 100644 --- a/ratchet_deflate/src/lib.rs +++ b/ratchet_deflate/src/lib.rs @@ -29,8 +29,8 @@ use thiserror::Error; pub use error::DeflateExtensionError; use ratchet_ext::{ - Extension, ExtensionDecoder, ExtensionEncoder, ExtensionProvider, FrameHeader, Header, - HeaderMap, HeaderValue, OpCode, ReunitableExtension, RsvBits, SplittableExtension, + Extension, ExtensionDecoder, ExtensionEncoder, ExtensionProvider, FrameHeader, HeaderMap, + HeaderValue, OpCode, ReunitableExtension, RsvBits, SplittableExtension, }; use crate::codec::{BufCompress, BufDecompress}; @@ -81,13 +81,16 @@ impl ExtensionProvider for DeflateExtProvider { apply_headers(headers, &self.config); } - fn negotiate_client(&self, headers: &[Header]) -> Result, Self::Error> { + fn negotiate_client( + &self, + headers: &HeaderMap, + ) -> Result, Self::Error> { negotiate_client(headers, &self.config) } fn negotiate_server( &self, - headers: &[Header], + headers: &HeaderMap, ) -> Result, Self::Error> { negotiate_server(headers, &self.config) } diff --git a/ratchet_deflate/src/tests.rs b/ratchet_deflate/src/tests.rs index 28df6bf..3ebe1e3 100644 --- a/ratchet_deflate/src/tests.rs +++ b/ratchet_deflate/src/tests.rs @@ -17,8 +17,7 @@ use crate::handshake::{apply_headers, on_request, on_response, NegotiationErr}; use crate::{DeflateConfig, InitialisedDeflateConfig, WindowBits}; use flate2::Compression; use http::header::SEC_WEBSOCKET_EXTENSIONS; -use http::HeaderMap; -use ratchet_ext::Header; +use http::{HeaderMap, HeaderValue}; fn test_headers(config: DeflateConfig, expected: &str) { let mut header_map = HeaderMap::new(); @@ -96,14 +95,14 @@ fn applies_headers() { #[test] fn request_negotiates_nothing() { - match on_request(&[], &DeflateConfig::default()) { + match on_request(&HeaderMap::new(), &DeflateConfig::default()) { Err(NegotiationErr::Failed) => {} _ => panic!("Expected no extension"), } } -fn request_test_valid_default(headers: &[Header]) { - match on_request(headers, &DeflateConfig::default()) { +fn request_test_valid_default(headers: HeaderMap) { + match on_request(&headers, &DeflateConfig::default()) { Ok((config, header)) => { let value = header.to_str().expect("Malformatted header produced"); assert_eq!( @@ -128,36 +127,27 @@ fn request_test_valid_default(headers: &[Header]) { #[test] fn request_negotiates_default_spaces() { request_test_valid_default( - &[Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: b"permessage-deflate; client_max_window_bits; server_no_context_takeover; client_no_context_takeover", - }] + HeaderMap::from_iter([(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static("permessage-deflate; client_max_window_bits; server_no_context_takeover; client_no_context_takeover"))]) ); request_test_valid_default( - &[Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: b"permessage-deflate; client_max_window_bits ; server_no_context_takeover ; client_no_context_takeover", - }] + HeaderMap::from_iter([(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static("permessage-deflate; client_max_window_bits ; server_no_context_takeover ; client_no_context_takeover"))]) ); } #[test] fn request_negotiates_no_spaces() { request_test_valid_default( - &[Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: b"permessage-deflate;client_max_window_bits;server_no_context_takeover;client_no_context_takeover", - }] + HeaderMap::from_iter([(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static("permessage-deflate;client_max_window_bits;server_no_context_takeover;client_no_context_takeover"))]) ); } #[test] fn request_unknown_header() { match on_request( - &[Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: b"permessage-bzip", - }], + &HeaderMap::from_iter([( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_static("permessage-bzip"), + )]), &DeflateConfig::default(), ) { Err(NegotiationErr::Failed) => {} @@ -167,55 +157,34 @@ fn request_unknown_header() { #[test] fn request_mixed_headers_with_unknown() { - let headers = &[ - Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: b"permessage-bzip", - }, - Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: b"permessage-deflate; client_max_window_bits; server_no_context_takeover; client_no_context_takeover", - } - ]; - + let headers = HeaderMap::from_iter([ + (SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static("permessage-bzip")), + (SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static("permessage-deflate; client_max_window_bits; server_no_context_takeover; client_no_context_takeover")) + ]); request_test_valid_default(headers); } #[test] fn request_mixed_headers_with_unnegotiable() { - let headers = &[ - Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: b"permessage-bzip", - }, - Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: b"permessage-deflate; client_max_window_bits=7; server_max_window_bits=8; server_no_context_takeover; client_no_context_takeover", - }, - Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: b"permessage-deflate; client_max_window_bits; server_no_context_takeover; client_no_context_takeover", - } - ]; + let headers = HeaderMap::from_iter([ + (SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static("permessage-bzip")), + (SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static("permessage-deflate; client_max_window_bits=7; server_max_window_bits=8; server_no_context_takeover; client_no_context_takeover")), + (SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static("permessage-deflate; client_max_window_bits; server_no_context_takeover; client_no_context_takeover")), + ]); request_test_valid_default(headers); } #[test] fn request_truncated_headers() { - request_test_valid_default(&[ - Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: b"permessage-deflate; client_max_window_bits=7; server_max_window_bits=8; server_no_context_takeover; client_no_context_takeover, permessage-deflate; client_max_window_bits; server_no_context_takeover; client_no_context_takeover", } - ]) + request_test_valid_default(HeaderMap::from_iter([(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static("permessage-deflate; client_max_window_bits=7; server_max_window_bits=8; server_no_context_takeover; client_no_context_takeover, permessage-deflate; client_max_window_bits; server_no_context_takeover; client_no_context_takeover"))])) } #[test] fn request_no_accept_no_context_takeover() { - let header = Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: b"permessage-deflate; client_max_window_bits=7; server_max_window_bits=8; server_no_context_takeover; client_no_context_takeover, permessage-deflate; client_max_window_bits; server_no_context_takeover; client_no_context_takeover", - }; + let headers = HeaderMap::from_iter([ + (SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static("permessage-deflate; client_max_window_bits=7; server_max_window_bits=8; server_no_context_takeover; client_no_context_takeover, permessage-deflate; client_max_window_bits; server_no_context_takeover; client_no_context_takeover")) + ]); let config = DeflateConfig { server_max_window_bits: WindowBits::fifteen(), client_max_window_bits: WindowBits::fifteen(), @@ -225,7 +194,7 @@ fn request_no_accept_no_context_takeover() { compression_level: Compression::fast(), }; - match on_request(&[header], &config) { + match on_request(&headers, &config) { Ok((config, header)) => { let value = header.to_str().expect("Malformatted header produced"); assert_eq!(value, "permessage-deflate; client_no_context_takeover"); @@ -244,8 +213,8 @@ fn request_no_accept_no_context_takeover() { } } -fn request_test_malformatted_default(headers: &[Header], expected: DeflateExtensionError) { - match on_request(headers, &DeflateConfig::default()) { +fn request_test_malformatted_default(headers: HeaderMap, expected: DeflateExtensionError) { + match on_request(&headers, &DeflateConfig::default()) { Err(NegotiationErr::Err(e)) => assert_eq!(e.to_string(), expected.to_string()), e => panic!("Expected: `{:?}`. Got: {:?}", expected, e), } @@ -254,19 +223,16 @@ fn request_test_malformatted_default(headers: &[Header], expected: DeflateExtens #[test] fn request_malformatted_window_bits() { request_test_malformatted_default( - &[Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: - b"permessage-deflate; client_max_window_bits=2.71828; server_max_window_bits=3.14159", - }], + HeaderMap::from_iter([(SEC_WEBSOCKET_EXTENSIONS, HeaderValue::from_static("permessage-deflate; client_max_window_bits=2.71828; server_max_window_bits=3.14159"))]), DeflateExtensionError::InvalidMaxWindowBits, ); request_test_malformatted_default( - &[Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: - b"permessage-deflate; client_max_window_bits=666; server_max_window_bits=3.14159", - }], + HeaderMap::from_iter([( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_static( + "permessage-deflate; client_max_window_bits=666; server_max_window_bits=3.14159", + ), + )]), DeflateExtensionError::InvalidMaxWindowBits, ) } @@ -274,10 +240,10 @@ fn request_malformatted_window_bits() { #[test] fn request_unknown_parameter() { request_test_malformatted_default( - &[Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: b"permessage-deflate; peer_max_window_bits", - }], + HeaderMap::from_iter([( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_static("permessage-deflate; peer_max_window_bits"), + )]), DeflateExtensionError::NegotiationError( "Unknown permessage-deflate parameter: peer_max_window_bits".to_string(), ), @@ -286,7 +252,7 @@ fn request_unknown_parameter() { #[test] fn response_no_ext() { - match on_response(&[], &DeflateConfig::default()) { + match on_response(&HeaderMap::new(), &DeflateConfig::default()) { Err(NegotiationErr::Failed) => {} _ => panic!("Expected no extension"), } @@ -295,10 +261,10 @@ fn response_no_ext() { #[test] fn response_unknown_ext() { match on_response( - &[Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: b"permessage-bzip", - }], + &HeaderMap::from_iter([( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_static("permessage-bzip"), + )]), &DeflateConfig::default(), ) { Err(NegotiationErr::Failed) => {} @@ -309,10 +275,12 @@ fn response_unknown_ext() { #[test] fn response_duplicate_param() { match on_response( - &[Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: b"permessage-deflate;server_no_context_takeover;server_no_context_takeover", - }], + &HeaderMap::from_iter([( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_static( + "permessage-deflate;server_no_context_takeover;server_no_context_takeover", + ), + )]), &DeflateConfig::default(), ) { Err(NegotiationErr::Err(DeflateExtensionError::NegotiationError(s))) @@ -325,10 +293,10 @@ fn response_duplicate_param() { #[test] fn response_invalid_max_bits() { match on_response( - &[Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: b"permessage-deflate;server_max_window_bits=666", - }], + &HeaderMap::from_iter([( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_static("permessage-deflate;server_max_window_bits=666"), + )]), &DeflateConfig::default(), ) { Err(NegotiationErr::Err(DeflateExtensionError::InvalidMaxWindowBits)) => {} @@ -339,10 +307,10 @@ fn response_invalid_max_bits() { #[test] fn response_unknown_param() { match on_response( - &[Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: b"permessage-deflate;invalid=param", - }], + &HeaderMap::from_iter([( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_static("permessage-deflate;invalid=param"), + )]), &DeflateConfig::default(), ) { Err(NegotiationErr::Err(DeflateExtensionError::NegotiationError(s))) @@ -360,10 +328,10 @@ fn response_no_context_takeover() { }; match on_response( - &[Header { - name: SEC_WEBSOCKET_EXTENSIONS.as_str(), - value: b"permessage-deflate;client_no_context_takeover", - }], + &HeaderMap::from_iter([( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_static("permessage-deflate;client_no_context_takeover"), + )]), &config, ) { Err(NegotiationErr::Err(DeflateExtensionError::NegotiationError(s))) diff --git a/ratchet_ext/src/lib.rs b/ratchet_ext/src/lib.rs index 13509d3..16e5e74 100644 --- a/ratchet_ext/src/lib.rs +++ b/ratchet_ext/src/lib.rs @@ -65,7 +65,8 @@ pub trait ExtensionProvider { /// /// Returning `Err` from this will *fail* the connection with the reason being the error's /// `to_string()` value. - fn negotiate_client(&self, headers: &[Header]) -> Result, Self::Error>; + fn negotiate_client(&self, headers: &HeaderMap) + -> Result, Self::Error>; /// Negotiate the headers that a client has sent. /// @@ -80,7 +81,7 @@ pub trait ExtensionProvider { /// `to_string()` value. fn negotiate_server( &self, - headers: &[Header], + headers: &HeaderMap, ) -> Result, Self::Error>; } @@ -95,13 +96,16 @@ where E::apply_headers(self, headers) } - fn negotiate_client(&self, headers: &[Header]) -> Result, Self::Error> { + fn negotiate_client( + &self, + headers: &HeaderMap, + ) -> Result, Self::Error> { E::negotiate_client(self, headers) } fn negotiate_server( &self, - headers: &[Header], + headers: &HeaderMap, ) -> Result, Self::Error> { E::negotiate_server(self, headers) } @@ -118,13 +122,16 @@ where E::apply_headers(self, headers) } - fn negotiate_client(&self, headers: &[Header]) -> Result, Self::Error> { + fn negotiate_client( + &self, + headers: &HeaderMap, + ) -> Result, Self::Error> { E::negotiate_client(self, headers) } fn negotiate_server( &self, - headers: &[Header], + headers: &HeaderMap, ) -> Result, Self::Error> { E::negotiate_server(self, headers) } From 43a3662261b86bef0e2484bb5aa454d5cd3b27d5 Mon Sep 17 00:00:00 2001 From: SirCipher Date: Wed, 25 Sep 2024 15:44:35 +0100 Subject: [PATCH 03/13] Extracts core WebSocket upgrade logic for servers --- .github/workflows/ci.yml | 2 +- Cargo.toml | 4 +- ratchet_core/src/builder.rs | 19 +- ratchet_core/src/errors.rs | 8 +- ratchet_core/src/handshake/client/encoding.rs | 8 +- ratchet_core/src/handshake/client/mod.rs | 24 +-- ratchet_core/src/handshake/client/tests.rs | 21 +- ratchet_core/src/handshake/mod.rs | 5 +- ratchet_core/src/handshake/server/encoding.rs | 57 +++-- ratchet_core/src/handshake/server/mod.rs | 202 ++++++++++++++++-- ratchet_core/src/handshake/server/tests.rs | 6 +- ratchet_core/src/handshake/subprotocols.rs | 150 ++++++------- ratchet_core/src/handshake/tests.rs | 22 +- ratchet_core/src/lib.rs | 12 +- ratchet_rs/examples/autobahn-client.rs | 4 +- ratchet_rs/examples/autobahn-server.rs | 4 +- ratchet_rs/examples/autobahn-split-client.rs | 4 +- ratchet_rs/examples/autobahn-split-server.rs | 4 +- ratchet_rs/examples/server.rs | 5 +- ratchet_rs/src/lib.rs | 8 +- 20 files changed, 396 insertions(+), 173 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0bc381d..8aec7d7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: - name: Build Documentation run: cargo doc --lib --no-deps --all-features --workspace env: - RUSTDOCFLAGS: --cfg docsrs -Dwarnings + RUSTDOCFLAGS="--cfg docsrs -Dwarnings" cargo doc --lib --no-deps --all-features --workspace testmsrv: name: Test Suite Latest diff --git a/Cargo.toml b/Cargo.toml index d35d747..2900c22 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,8 +24,8 @@ license = "Apache-2.0" url = "2.1.1" http = "1.1.0" tokio = "=1.38.1" -tokio-util = "=0.7.1" -tokio-stream = "=0.1.15" +tokio-util = "0.7.4" +tokio-stream = "0.1.11" futures = "0.3.4" futures-util = "0.3.4" derive_more = "0.99.14" diff --git a/ratchet_core/src/builder.rs b/ratchet_core/src/builder.rs index 21c3dfa..fa9f9eb 100644 --- a/ratchet_core/src/builder.rs +++ b/ratchet_core/src/builder.rs @@ -14,10 +14,9 @@ use crate::errors::Error; use crate::ext::NoExtProvider; -use crate::handshake::{ProtocolRegistry, UpgradedServer}; +use crate::handshake::{SubprotocolRegistry, UpgradedServer}; use crate::{subscribe_with, TryIntoRequest, UpgradedClient, WebSocketConfig, WebSocketStream}; use ratchet_ext::ExtensionProvider; -use std::borrow::Cow; /// A builder to construct WebSocket clients. /// @@ -28,7 +27,7 @@ use std::borrow::Cow; pub struct WebSocketClientBuilder { config: Option, extension: E, - subprotocols: ProtocolRegistry, + subprotocols: SubprotocolRegistry, } impl Default for WebSocketClientBuilder { @@ -36,7 +35,7 @@ impl Default for WebSocketClientBuilder { WebSocketClientBuilder { config: None, extension: NoExtProvider, - subprotocols: ProtocolRegistry::default(), + subprotocols: SubprotocolRegistry::default(), } } } @@ -95,9 +94,9 @@ impl WebSocketClientBuilder { pub fn subprotocols(mut self, subprotocols: I) -> Result where I: IntoIterator, - I::Item: Into>, + I::Item: Into, { - self.subprotocols = ProtocolRegistry::new(subprotocols)?; + self.subprotocols = SubprotocolRegistry::new(subprotocols)?; Ok(self) } } @@ -110,7 +109,7 @@ impl WebSocketClientBuilder { #[derive(Debug)] pub struct WebSocketServerBuilder { config: Option, - subprotocols: ProtocolRegistry, + subprotocols: SubprotocolRegistry, extension: E, } @@ -119,7 +118,7 @@ impl Default for WebSocketServerBuilder { WebSocketServerBuilder { config: None, extension: NoExtProvider, - subprotocols: ProtocolRegistry::default(), + subprotocols: SubprotocolRegistry::default(), } } } @@ -168,9 +167,9 @@ impl WebSocketServerBuilder { pub fn subprotocols(mut self, subprotocols: I) -> Result where I: IntoIterator, - I::Item: Into>, + I::Item: Into, { - self.subprotocols = ProtocolRegistry::new(subprotocols)?; + self.subprotocols = SubprotocolRegistry::new(subprotocols)?; Ok(self) } } diff --git a/ratchet_core/src/errors.rs b/ratchet_core/src/errors.rs index d0c015f..8ed358a 100644 --- a/ratchet_core/src/errors.rs +++ b/ratchet_core/src/errors.rs @@ -268,14 +268,11 @@ pub enum CloseCause { } /// WebSocket protocol errors. -#[derive(Copy, Clone, Debug, Eq, PartialEq, Error)] +#[derive(Clone, Debug, Eq, PartialEq, Error)] pub enum ProtocolError { /// Invalid encoding was received. #[error("Not valid UTF-8 encoding")] Encoding, - /// A peer selected a protocol that was not sent. - #[error("Received an unknown subprotocol")] - UnknownProtocol, /// An invalid OpCode was received. #[error("Bad OpCode: `{0}`")] OpCode(OpCodeParseErr), @@ -309,6 +306,9 @@ pub enum ProtocolError { /// An invalid control frame was received. #[error("Received an invalid control frame")] InvalidControlFrame, + /// Failed to build subprotocol header. + #[error("Invalid subprotocol header: `{0}`")] + InvalidSubprotocolHeader(String), } impl From for Error { diff --git a/ratchet_core/src/handshake/client/encoding.rs b/ratchet_core/src/handshake/client/encoding.rs index 6796119..52b23c7 100644 --- a/ratchet_core/src/handshake/client/encoding.rs +++ b/ratchet_core/src/handshake/client/encoding.rs @@ -22,9 +22,7 @@ use ratchet_ext::ExtensionProvider; use crate::errors::{Error, ErrorKind, HttpError}; use crate::handshake::client::Nonce; -use crate::handshake::{ - apply_to, ProtocolRegistry, UPGRADE_STR, WEBSOCKET_STR, WEBSOCKET_VERSION_STR, -}; +use crate::handshake::{SubprotocolRegistry, UPGRADE_STR, WEBSOCKET_STR, WEBSOCKET_VERSION_STR}; use base64::engine::general_purpose::STANDARD; @@ -125,7 +123,7 @@ pub struct ValidatedRequest { pub fn build_request( request: Request<()>, extension: &E, - subprotocols: &ProtocolRegistry, + subprotocols: &SubprotocolRegistry, ) -> Result where E: ExtensionProvider, @@ -197,7 +195,7 @@ where )); } - apply_to(subprotocols, &mut headers); + subprotocols.apply_to(&mut headers); let option = headers .get(header::SEC_WEBSOCKET_KEY) diff --git a/ratchet_core/src/handshake/client/mod.rs b/ratchet_core/src/handshake/client/mod.rs index 65299d7..92b3836 100644 --- a/ratchet_core/src/handshake/client/mod.rs +++ b/ratchet_core/src/handshake/client/mod.rs @@ -21,8 +21,8 @@ use crate::errors::{Error, ErrorKind, HttpError}; use crate::handshake::client::encoding::{build_request, encode_request}; use crate::handshake::io::BufferedIo; use crate::handshake::{ - negotiate_response, validate_header, validate_header_value, ParseResult, ProtocolRegistry, - StreamingParser, TryMap, ACCEPT_KEY, BAD_STATUS_CODE, UPGRADE_STR, WEBSOCKET_STR, + validate_header, validate_header_value, ParseResult, StreamingParser, SubprotocolRegistry, + TryMap, ACCEPT_KEY, BAD_STATUS_CODE, UPGRADE_STR, WEBSOCKET_STR, }; use crate::{ NoExt, NoExtProvider, Role, TryIntoRequest, WebSocket, WebSocketConfig, WebSocketStream, @@ -80,7 +80,7 @@ where &mut stream, request.try_into_request()?, NoExtProvider, - ProtocolRegistry::default(), + SubprotocolRegistry::default(), &mut read_buffer, ) .await?; @@ -98,7 +98,7 @@ pub async fn subscribe_with( mut stream: S, request: R, extension: E, - subprotocols: ProtocolRegistry, + subprotocols: SubprotocolRegistry, ) -> Result, Error> where S: WebSocketStream, @@ -128,7 +128,7 @@ async fn exec_client_handshake( stream: &mut S, request: Request<()>, extension: E, - subprotocols: ProtocolRegistry, + subprotocols: SubprotocolRegistry, buf: &mut BytesMut, ) -> Result, Error> where @@ -163,14 +163,14 @@ where struct ClientHandshake<'s, S, E> { buffered: BufferedIo<'s, S>, nonce: Nonce, - subprotocols: ProtocolRegistry, + subprotocols: SubprotocolRegistry, extension: &'s E, } pub struct StreamingResponseParser<'b, E> { nonce: &'b Nonce, extension: &'b E, - subprotocols: &'b mut ProtocolRegistry, + subprotocols: &'b mut SubprotocolRegistry, } impl<'b, E> Decoder for StreamingResponseParser<'b, E> @@ -207,7 +207,7 @@ where { pub fn new( socket: &'s mut S, - subprotocols: ProtocolRegistry, + subprotocols: SubprotocolRegistry, extension: &'s E, buf: &'s mut BytesMut, ) -> ClientHandshake<'s, S, E> { @@ -316,7 +316,7 @@ fn try_parse_response<'h, 'b, E>( mut response: Response<'h, 'b>, expected_nonce: &Nonce, extension: E, - subprotocols: &mut ProtocolRegistry, + subprotocols: &mut SubprotocolRegistry, ) -> Result, HandshakeResult>, Error> where 'h: 'b, @@ -336,14 +336,14 @@ fn parse_response( response: http::Response<()>, expected_nonce: &Nonce, extension: E, - subprotocols: &mut ProtocolRegistry, + subprotocols: &SubprotocolRegistry, ) -> Result, Error> where E: ExtensionProvider, { if response.version() < Version::HTTP_11 { // rfc6455 § 4.2.1.1: must be HTTP/1.1 or higher - // this will implicitly be 0 as httparse only parses HTTP/1.x and 1.1 is 0. + // this will implicitly be 0 as httparse only parses HTTP/1.x and 1.0 is 0. return Err(Error::with_cause( ErrorKind::Http, HttpError::HttpVersion(Some(0)), @@ -399,7 +399,7 @@ where )?; Ok(HandshakeResult { - subprotocol: negotiate_response(subprotocols, response.headers())?, + subprotocol: subprotocols.validate_accepted_subprotocol(response.headers())?, extension: extension .negotiate_client(response.headers()) .map_err(|e| Error::with_cause(ErrorKind::Extension, e))?, diff --git a/ratchet_core/src/handshake/client/tests.rs b/ratchet_core/src/handshake/client/tests.rs index 50b32fa..8208540 100644 --- a/ratchet_core/src/handshake/client/tests.rs +++ b/ratchet_core/src/handshake/client/tests.rs @@ -15,7 +15,7 @@ use crate::errors::{Error, HttpError}; use crate::ext::NoExt; use crate::handshake::client::{ClientHandshake, HandshakeResult}; -use crate::handshake::{ProtocolRegistry, ACCEPT_KEY, UPGRADE_STR, WEBSOCKET_STR}; +use crate::handshake::{SubprotocolRegistry, ACCEPT_KEY, UPGRADE_STR, WEBSOCKET_STR}; use crate::test_fixture::mock; use crate::{ErrorKind, NoExtProvider, ProtocolError, TryIntoRequest}; use base64::engine::{general_purpose::STANDARD, Engine}; @@ -44,7 +44,7 @@ async fn handshake_sends_valid_request() { let mut buf = BytesMut::new(); let mut machine = ClientHandshake::new( &mut stream, - ProtocolRegistry::new(vec!["warp"]).unwrap(), + SubprotocolRegistry::new(vec!["warp"]).unwrap(), &NoExtProvider, &mut buf, ); @@ -84,7 +84,7 @@ async fn handshake_invalid_requests() { let mut buf = BytesMut::new(); let mut machine = ClientHandshake::new( &mut stream, - ProtocolRegistry::default(), + SubprotocolRegistry::default(), &NoExtProvider, &mut buf, ); @@ -157,7 +157,7 @@ async fn expect_server_error(response: Response<()>, expected_error: HttpError) let mut buf = BytesMut::new(); let mut machine = ClientHandshake::new( &mut stream, - ProtocolRegistry::default(), + SubprotocolRegistry::default(), &NoExtProvider, &mut buf, ); @@ -261,7 +261,7 @@ async fn ok_nonce() { let mut buf = BytesMut::new(); let mut machine = ClientHandshake::new( &mut stream, - ProtocolRegistry::default(), + SubprotocolRegistry::default(), &NoExtProvider, &mut buf, ); @@ -331,7 +331,7 @@ async fn redirection() { let mut buf = BytesMut::new(); let mut machine = ClientHandshake::new( &mut stream, - ProtocolRegistry::default(), + SubprotocolRegistry::default(), &NoExtProvider, &mut buf, ); @@ -391,7 +391,7 @@ where let mut machine = ClientHandshake::new( &mut stream, - ProtocolRegistry::new(registry).unwrap(), + SubprotocolRegistry::new(registry).unwrap(), &NoExtProvider, &mut buf, ); @@ -458,7 +458,10 @@ async fn invalid_subprotocol() { let protocol_error = err .downcast_ref::() .expect("Expected a protocol error"); - assert_eq!(protocol_error, &ProtocolError::UnknownProtocol); + assert_eq!( + protocol_error, + &ProtocolError::InvalidSubprotocolHeader("warpy".to_string()) + ); }) .await; } @@ -579,7 +582,7 @@ where let client_task = async move { let mut buf = BytesMut::new(); let mut machine = - ClientHandshake::new(&mut stream, ProtocolRegistry::default(), &ext, &mut buf); + ClientHandshake::new(&mut stream, SubprotocolRegistry::default(), &ext, &mut buf); machine .encode(Request::get(TEST_URL).body(()).unwrap()) .unwrap(); diff --git a/ratchet_core/src/handshake/mod.rs b/ratchet_core/src/handshake/mod.rs index f506117..327168a 100644 --- a/ratchet_core/src/handshake/mod.rs +++ b/ratchet_core/src/handshake/mod.rs @@ -33,7 +33,10 @@ use tokio_util::codec::Decoder; use url::Url; pub use client::{subscribe, subscribe_with, UpgradedClient}; -pub use server::{accept, accept_with, UpgradedServer, WebSocketResponse, WebSocketUpgrader}; +pub use server::{ + accept, accept_with, build_response, handshake, parse_request, UpgradeRequest, UpgradedServer, + WebSocketResponse, WebSocketUpgrader, +}; pub use subprotocols::*; const WEBSOCKET_STR: &str = "websocket"; diff --git a/ratchet_core/src/handshake/server/encoding.rs b/ratchet_core/src/handshake/server/encoding.rs index c368d6d..774f01d 100644 --- a/ratchet_core/src/handshake/server/encoding.rs +++ b/ratchet_core/src/handshake/server/encoding.rs @@ -13,13 +13,13 @@ // limitations under the License. use crate::handshake::io::BufferedIo; -use crate::handshake::server::HandshakeResult; -use crate::handshake::{negotiate_request, TryMap}; +use crate::handshake::server::UpgradeRequest; +use crate::handshake::TryMap; use crate::handshake::{ validate_header, validate_header_any, validate_header_value, ParseResult, METHOD_GET, UPGRADE_STR, WEBSOCKET_STR, WEBSOCKET_VERSION_STR, }; -use crate::{Error, ErrorKind, HttpError, ProtocolRegistry}; +use crate::{Error, ErrorKind, HttpError, SubprotocolRegistry}; use bytes::{BufMut, Bytes, BytesMut}; use http::header::SEC_WEBSOCKET_KEY; use http::{HeaderMap, Method, StatusCode, Version}; @@ -38,7 +38,7 @@ const HTTP_VERSION_INT: u8 = 1; const HTTP_VERSION: Version = Version::HTTP_11; pub struct RequestParser { - pub subprotocols: ProtocolRegistry, + pub subprotocols: SubprotocolRegistry, pub extension: E, } @@ -46,7 +46,7 @@ impl Decoder for RequestParser where E: ExtensionProvider, { - type Item = (HandshakeResult, usize); + type Item = (UpgradeRequest, usize); type Error = Error; fn decode(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { @@ -129,8 +129,8 @@ pub fn try_parse_request<'h, 'b, E>( buffer: &'b [u8], mut request: httparse::Request<'h, 'b>, extension: E, - subprotocols: &mut ProtocolRegistry, -) -> Result, HandshakeResult>, Error> + subprotocols: &mut SubprotocolRegistry, +) -> Result, UpgradeRequest>, Error> where E: ExtensionProvider, { @@ -169,16 +169,43 @@ pub fn check_partial_request(request: &httparse::Request) -> Result<(), Error> { Ok(()) } -pub fn parse_request( - request: http::Request<()>, +/// Parses an HTTP request to extract WebSocket upgrade information. +/// +/// This function validates and processes an incoming HTTP request to ensure it meets the +/// requirements for a WebSocket upgrade. It checks the HTTP version, method, and necessary headers +/// to determine if the request can be successfully upgraded to a WebSocket connection. It also +/// negotiates the subprotocols and extensions specified in the request. +/// +/// # Arguments +/// - `request`: An `http::Request` representing the incoming HTTP request from the client, which +/// is expected to contain WebSocket-specific headers. While it is discouraged for GET requests to +/// have a body it is not technically incorrect and the use of this function is lowering the +/// guardrails to allow for Ratchet to be more easily integrated into other libraries. It is the +/// implementors responsibility to perform any validation on the body. +/// - `extension`: An instance of a type that implements the `ExtensionProvider` +/// trait. This object is responsible for negotiating any server-supported +/// extensions requested by the client. +/// - `subprotocols`: A `SubprotocolRegistry`, which manages the supported subprotocols and attempts +/// to negotiate one with the client. +/// +/// # Returns +/// This function returns a `Result, Error>`, where: +/// - `Ok(UpgradeRequest)`: Contains the parsed information needed for the WebSocket +/// handshake, including the WebSocket key, negotiated subprotocol, optional +/// extensions, and the original HTTP request. +/// - `Err(Error)`: Contains an error if the request is invalid or cannot be parsed. +/// This could include issues such as unsupported HTTP versions, invalid methods, +/// missing required headers, or failed negotiations for subprotocols or extensions. +pub fn parse_request( + request: http::Request, extension: E, - subprotocols: &mut ProtocolRegistry, -) -> Result, Error> + subprotocols: &SubprotocolRegistry, +) -> Result, Error> where E: ExtensionProvider, { if request.version() < HTTP_VERSION { - // this will implicitly be 0 as httparse only parses HTTP/1.x and 1.1 is 0. + // this will implicitly be 0 as httparse only parses HTTP/1.x and 1.0 is 0. return Err(Error::with_cause( ErrorKind::Http, HttpError::HttpVersion(Some(0)), @@ -209,17 +236,17 @@ where .ok_or_else(|| { Error::with_cause(ErrorKind::Http, HttpError::MissingHeader(SEC_WEBSOCKET_KEY)) })?; - let subprotocol = negotiate_request(subprotocols, headers)?; + let subprotocol = subprotocols.negotiate_client(headers)?; let (extension, extension_header) = extension .negotiate_server(headers) .map(Option::unzip) .map_err(|e| Error::with_cause(ErrorKind::Extension, e))?; - Ok(HandshakeResult { + Ok(UpgradeRequest { key, - request, extension, subprotocol, extension_header, + request, }) } diff --git a/ratchet_core/src/handshake/server/mod.rs b/ratchet_core/src/handshake/server/mod.rs index c031cf1..8aad5a2 100644 --- a/ratchet_core/src/handshake/server/mod.rs +++ b/ratchet_core/src/handshake/server/mod.rs @@ -16,20 +16,22 @@ mod encoding; #[cfg(test)] mod tests; -use crate::ext::NoExt; -use crate::handshake::io::BufferedIo; -use crate::handshake::server::encoding::{write_response, RequestParser}; -use crate::handshake::{StreamingParser, ACCEPT_KEY}; -use crate::handshake::{UPGRADE_STR, WEBSOCKET_STR}; -use crate::protocol::Role; +pub use encoding::parse_request; + use crate::{ - Error, HttpError, NoExtProvider, ProtocolRegistry, Request, WebSocket, WebSocketConfig, + ext::NoExt, + handshake::io::BufferedIo, + handshake::server::encoding::{write_response, RequestParser}, + handshake::{StreamingParser, ACCEPT_KEY}, + handshake::{UPGRADE_STR, WEBSOCKET_STR}, + protocol::Role, + Error, HttpError, NoExtProvider, Request, SubprotocolRegistry, WebSocket, WebSocketConfig, WebSocketStream, }; use base64::engine::{general_purpose::STANDARD, Engine}; use bytes::{Bytes, BytesMut}; use http::status::InvalidStatusCode; -use http::{HeaderMap, HeaderValue, StatusCode, Uri}; +use http::{HeaderMap, HeaderValue, StatusCode, Uri, Version}; use log::{error, trace}; use ratchet_ext::{Extension, ExtensionProvider}; use sha1::{Digest, Sha1}; @@ -72,7 +74,13 @@ pub async fn accept( where S: WebSocketStream, { - accept_with(stream, config, NoExtProvider, ProtocolRegistry::default()).await + accept_with( + stream, + config, + NoExtProvider, + SubprotocolRegistry::default(), + ) + .await } /// Execute a server handshake on the provided stream. An attempt will be made to negotiate the @@ -85,7 +93,7 @@ pub async fn accept_with( mut stream: S, config: WebSocketConfig, extension: E, - subprotocols: ProtocolRegistry, + subprotocols: SubprotocolRegistry, ) -> Result, Error> where S: WebSocketStream, @@ -102,14 +110,14 @@ where ); match parser.parse().await { - Ok(result) => { - let HandshakeResult { + Ok(request) => { + let UpgradeRequest { key, subprotocol, extension, request, extension_header, - } = result; + } = request; trace!( "{}for: {}. Selected subprotocol: {:?} and extension: {:?}", @@ -311,11 +319,173 @@ where } } +/// Represents a parsed WebSocket connection upgrade HTTP request. #[derive(Debug)] -pub struct HandshakeResult { +#[non_exhaustive] +pub struct UpgradeRequest { + /// The security key provided by the client during the WebSocket handshake. + /// + /// This key is used by the server to generate a response key, confirming that the server + /// accepts the WebSocket upgrade request. + pub key: Bytes, + + /// The optional WebSocket subprotocol agreed upon during the handshake. + /// + /// The subprotocol is used to define the application-specific communication on top of the + /// WebSocket connection, such as `wamp` or `graphql-ws`. If no subprotocol is requested or + /// agreed upon, this will be `None`. + pub subprotocol: Option, + + /// The optional WebSocket extension negotiated during the handshake. + /// + /// Extensions allow WebSocket connections to have additional functionality, such as compression + /// or multiplexing. This field represents any such negotiated extension, or `None` if no + /// extensions were negotiated. + pub extension: Option, + + /// The original HTTP request that initiated the WebSocket upgrade. + pub request: http::Request, + + /// The optional `Sec-WebSocket-Extensions` header value from the HTTP request. + /// + /// This header may contain the raw extension details sent by the client during the handshake. + /// If no extension was requested, this field will be `None`. + pub extension_header: Option, +} + +/// Builds an HTTP response to a WebSocket connection upgrade request. +/// +/// No validation is performed by this function and it is only guaranteed to be correct if the +/// arguments are derived by previously calling [`parse_request`]. +/// +/// # Arguments +/// +/// - `key`: The WebSocket security key provided by the client in the handshake request. +/// - `subprotocol`: An optional subprotocol that the server and client agree upon, if any. +/// This header is added only if a subprotocol is provided. +/// - `extension_header`: An optional WebSocket extension header. If present, the header +/// is included to specify any negotiated WebSocket extensions. +/// +/// # Returns +/// +/// This function returns an `http::Response<()>` that represents a valid HTTP 101 Switching +/// Protocols response required to establish a WebSocket connection. The response includes: +/// +/// - `Sec-WebSocket-Accept`: The hashed and encoded value of the provided WebSocket `key` +/// according to the WebSocket protocol specification. +/// - `Upgrade`: Set to `websocket`, indicating that the connection is being upgraded to +/// the WebSocket protocol. +/// - `Connection`: Set to `Upgrade`, as required by the HTTP upgrade process. +/// +/// Optionally, the response may also include: +/// +/// - `Sec-WebSocket-Protocol`: The negotiated subprotocol, if provided. +/// - `Sec-WebSocket-Extensions`: The WebSocket extension header, if an extension was negotiated. +/// +/// # Errors +/// +/// This function returns an `Error` if: +/// - There is an issue converting the headers (like subprotocol or extension) to +/// valid `HeaderValue` types. +/// - There is an issue building the HTTP response. +pub fn build_response( key: Bytes, subprotocol: Option, - extension: Option, - request: Request, extension_header: Option, +) -> Result, Error> { + let mut digest = Sha1::new(); + Digest::update(&mut digest, key); + Digest::update(&mut digest, ACCEPT_KEY); + + let sec_websocket_accept = STANDARD.encode(digest.finalize()); + + let mut response = http::Response::builder() + .version(Version::HTTP_11) + .header( + http::header::SEC_WEBSOCKET_ACCEPT, + HeaderValue::try_from(sec_websocket_accept)?, + ) + .header( + http::header::UPGRADE, + HeaderValue::from_static(WEBSOCKET_STR), + ) + .header( + http::header::CONNECTION, + HeaderValue::from_static(UPGRADE_STR), + ); + + if let Some(subprotocol) = &subprotocol { + response = response.header( + http::header::SEC_WEBSOCKET_PROTOCOL, + HeaderValue::try_from(subprotocol)?, + ); + } + if let Some(extension_header) = extension_header { + response = response.header(http::header::SEC_WEBSOCKET_EXTENSIONS, extension_header); + } + + Ok(response.body(())?) +} + +/// Processes a WebSocket handshake request and generates the appropriate response. +/// +/// This function handles the server-side part of a WebSocket handshake. It parses the incoming HTTP +/// request that seeks to upgrade the connection to WebSocket, negotiates extensions and +/// subprotocols, and constructs an appropriate HTTP response +/// to complete the WebSocket handshake. +/// +/// # Arguments +/// +/// - `request`: The incoming HTTP request from the client, which contains headers related to the +/// WebSocket upgrade request. +/// - `extension`: An extension that may be negotiated for the connection. +/// - `subprotocols`: A `SubprotocolRegistry`, which will be used to attempt to negotiate a +/// subprotocol. +/// +/// # Returns +/// +/// This function returns a `Result` containing: +/// - A tuple consisting of: +/// - An `http::Response<()>`, which represents the WebSocket handshake response. +/// The response includes headers such as `Sec-WebSocket-Accept` to confirm the upgrade. +/// - An optional `E::Extension`, which represents the negotiated extension, if any. +/// +/// If the handshake fails, an `Error` is returned, which may be caused by invalid +/// requests, issues parsing headers, or problems negotiating the WebSocket subprotocols +/// or extensions. +/// +/// # Type Parameters +/// +/// - `E`: The type of the extension provider, which must implement the `ExtensionProvider` +/// trait. This defines how WebSocket extensions (like compression) are handled. +/// - `B`: The body type of the HTTP request. While it is discouraged for GET requests to have a body +/// it is not technically incorrect and the use of this function is lowering the guardrails to +/// allow for Ratchet to be more easily integrated into other libraries. It is the implementors +/// responsibility to perform any validation on the body. +/// +/// # Errors +/// +/// The function returns an `Error` in cases such as: +/// - Failure to parse the WebSocket upgrade request. +/// - Issues building the response, such as invalid subprotocol or extension headers. +/// - Failure to negotiate the WebSocket extensions or subprotocols. +pub fn handshake( + request: http::Request, + extension: &E, + subprotocols: &SubprotocolRegistry, +) -> Result<(http::Response<()>, Option), Error> +where + E: ExtensionProvider, +{ + let UpgradeRequest { + key, + subprotocol, + extension, + extension_header, + .. + } = parse_request(request, extension, subprotocols)?; + Ok(( + build_response(key, subprotocol, extension_header)?, + extension, + )) } diff --git a/ratchet_core/src/handshake/server/tests.rs b/ratchet_core/src/handshake/server/tests.rs index b84c578..1440c1c 100644 --- a/ratchet_core/src/handshake/server/tests.rs +++ b/ratchet_core/src/handshake/server/tests.rs @@ -15,7 +15,7 @@ use crate::handshake::{UPGRADE_STR, WEBSOCKET_STR, WEBSOCKET_VERSION_STR}; use crate::test_fixture::{mock, ReadError}; use crate::{ - accept_with, Error, ErrorKind, HttpError, NoExtProvider, ProtocolRegistry, WebSocketConfig, + accept_with, Error, ErrorKind, HttpError, NoExtProvider, SubprotocolRegistry, WebSocketConfig, }; use bytes::BytesMut; use http::header::HeaderName; @@ -41,7 +41,7 @@ async fn exec_request(request: Request<()>) -> Result, Error> { server, WebSocketConfig::default(), NoExtProvider, - ProtocolRegistry::default(), + SubprotocolRegistry::default(), ) .await?; @@ -303,7 +303,7 @@ async fn bad_extension() { server, WebSocketConfig::default(), BadExtProvider, - ProtocolRegistry::default(), + SubprotocolRegistry::default(), ) .await; diff --git a/ratchet_core/src/handshake/subprotocols.rs b/ratchet_core/src/handshake/subprotocols.rs index 9408be9..bc421e0 100644 --- a/ratchet_core/src/handshake/subprotocols.rs +++ b/ratchet_core/src/handshake/subprotocols.rs @@ -16,107 +16,113 @@ use crate::{Error, ErrorKind, HttpError, ProtocolError}; use fnv::FnvHashSet; use http::header::SEC_WEBSOCKET_PROTOCOL; use http::{HeaderMap, HeaderValue}; -use std::borrow::Cow; +use std::sync::Arc; /// A subprotocol registry that is used for negotiating a possible subprotocol to use for a /// connection. #[derive(Default, Debug, Clone)] -pub struct ProtocolRegistry { - registrants: FnvHashSet>, +pub struct SubprotocolRegistry { + inner: Arc, +} + +#[derive(Debug, Default)] +struct Inner { + registrants: FnvHashSet, header: Option, } -impl ProtocolRegistry { - /// Construct a new protocol registry that will allow the provided items. - pub fn new(i: I) -> Result +impl SubprotocolRegistry { + /// Construct a new protocol registry that will allow the provided subprotocols. The priority + /// of the subprotocols is specified by the order that the iterator yields items. + pub fn new(i: I) -> Result where I: IntoIterator, - I::Item: Into>, + I::Item: Into, { - let registrants = i - .into_iter() - .map(Into::into) - .collect::>>(); + let registrants = i.into_iter().map(Into::into).collect::>(); let header_str = registrants .clone() .into_iter() .collect::>() .join(", "); let header = HeaderValue::from_str(&header_str).map_err(|_| { - crate::Error::with_cause(ErrorKind::Http, HttpError::MalformattedHeader(header_str)) + Error::with_cause(ErrorKind::Http, HttpError::MalformattedHeader(header_str)) })?; - Ok(ProtocolRegistry { - registrants, - header: Some(header), + Ok(SubprotocolRegistry { + inner: Arc::new(Inner { + registrants, + header: Some(header), + }), }) } -} -enum Bias { - Client, - Server, -} + /// Attempts to negotiate a subprotocol offered by a client. + /// + /// # Returns + /// The subprotocol that was negotiated if one was offered. Or an error if the client send a + /// malformed header. + pub fn negotiate_client( + &self, + header_map: &HeaderMap, + ) -> Result, ProtocolError> { + let SubprotocolRegistry { inner } = self; -fn negotiate<'h, I>( - registry: &ProtocolRegistry, - headers: I, - bias: Bias, -) -> Result, ProtocolError> -where - I: Iterator, -{ - for header in headers { - let value = std::str::from_utf8(header.as_bytes()).map_err(|_| ProtocolError::Encoding)?; - let protocols = value - .split(',') - .map(|s| s.trim().into()) - .collect::>(); - - let selected = match bias { - Bias::Client => { - if !registry.registrants.is_superset(&protocols) { - return Err(ProtocolError::UnknownProtocol); + for header in header_map.get_all(SEC_WEBSOCKET_PROTOCOL) { + let header_str = header.to_str().map_err(|_| ProtocolError::Encoding)?; + + for protocol in header_str.split(',') { + if let Some(supported_protocol) = inner.registrants.get(protocol.trim()) { + return Ok(Some(supported_protocol.clone())); } - protocols - .intersection(®istry.registrants) - .next() - .map(|s| s.to_string()) } - Bias::Server => registry - .registrants - .intersection(&protocols) - .next() - .map(|s| s.to_string()), - }; - - match selected { - Some(selected) => return Ok(Some(selected)), - None => continue, } + + Ok(None) } - Ok(None) -} + /// Validate a server's response for SEC_WEBSOCKET_PROTOCOL. A server may send at most one + /// sec-websocket-protocol header, and it must contain a subprotocol that was offered by the + /// client. + /// + /// # Returns + /// The subprotocol that was accepted by the server if one was offered. Or an error if the + /// server responded with a malformed header. + pub fn validate_accepted_subprotocol( + &self, + header_map: &HeaderMap, + ) -> Result, ProtocolError> { + let SubprotocolRegistry { inner } = self; -pub fn negotiate_response( - registry: &ProtocolRegistry, - header_map: &HeaderMap, -) -> Result, ProtocolError> { - let it = header_map.get_all(SEC_WEBSOCKET_PROTOCOL).into_iter(); - negotiate(registry, it, Bias::Client) -} + let protocols: Vec<_> = header_map.get_all(SEC_WEBSOCKET_PROTOCOL).iter().collect(); -pub fn negotiate_request( - registry: &ProtocolRegistry, - header_map: &HeaderMap, -) -> Result, ProtocolError> { - let it = header_map.get_all(SEC_WEBSOCKET_PROTOCOL).into_iter(); - negotiate(registry, it, Bias::Server) -} + if protocols.len() > 1 { + return Err(ProtocolError::InvalidSubprotocolHeader( + "Server returned too many subprotocols".to_string(), + )); + } + + if protocols.is_empty() { + return Ok(None); + } -pub fn apply_to(registry: &ProtocolRegistry, target: &mut HeaderMap) { - if let Some(header) = ®istry.header { - target.insert(SEC_WEBSOCKET_PROTOCOL, header.clone()); + let server_protocol = protocols[0].to_str().map_err(|_| ProtocolError::Encoding)?; + + if inner.registrants.contains(server_protocol) { + Ok(Some(server_protocol.to_string())) + } else { + Err(ProtocolError::InvalidSubprotocolHeader( + server_protocol.to_string(), + )) + } + } + + /// Applies a sec-websocket-protocol header to `target` if one has been registered. + pub fn apply_to(&self, target: &mut HeaderMap) { + let SubprotocolRegistry { inner } = self; + + if let Some(header) = &inner.header { + target.insert(SEC_WEBSOCKET_PROTOCOL, header.clone()); + } } } diff --git a/ratchet_core/src/handshake/tests.rs b/ratchet_core/src/handshake/tests.rs index 243a872..f8f5922 100644 --- a/ratchet_core/src/handshake/tests.rs +++ b/ratchet_core/src/handshake/tests.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::handshake::{negotiate_request, ProtocolRegistry}; +use crate::handshake::SubprotocolRegistry; use crate::ProtocolError; use http::header::SEC_WEBSOCKET_PROTOCOL; use http::{HeaderMap, HeaderValue}; @@ -23,10 +23,10 @@ fn selects_protocol_ok() { SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("warp, warps"), )]); - let registry = ProtocolRegistry::new(vec!["warps", "warp"]).unwrap(); + let registry = SubprotocolRegistry::new(vec!["warps", "warp"]).unwrap(); assert_eq!( - negotiate_request(®istry, &headers), + registry.negotiate_client(&headers), Ok(Some("warp".to_string())) ); } @@ -37,10 +37,10 @@ fn multiple_headers() { (SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("warp")), (SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("warps")), ]); - let registry = ProtocolRegistry::new(vec!["warps", "warp"]).unwrap(); + let registry = SubprotocolRegistry::new(vec!["warps", "warp"]).unwrap(); assert_eq!( - negotiate_request(®istry, &headers), + registry.negotiate_client(&headers), Ok(Some("warp".to_string())) ); } @@ -55,10 +55,10 @@ fn mixed_headers() { ), (SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("warps4.0")), ]); - let registry = ProtocolRegistry::new(vec!["warps", "warp", "warps2.0"]).unwrap(); + let registry = SubprotocolRegistry::new(vec!["warps", "warp", "warps2.0"]).unwrap(); assert_eq!( - negotiate_request(®istry, &headers), + registry.negotiate_client(&headers), Ok(Some("warps2.0".to_string())) ); } @@ -68,10 +68,10 @@ fn malformatted() { let headers = HeaderMap::from_iter([(SEC_WEBSOCKET_PROTOCOL, unsafe { HeaderValue::from_maybe_shared_unchecked([255, 255, 255, 255]) })]); - let registry = ProtocolRegistry::new(vec!["warps", "warp", "warps2.0"]).unwrap(); + let registry = SubprotocolRegistry::new(vec!["warps", "warp", "warps2.0"]).unwrap(); assert_eq!( - negotiate_request(®istry, &headers), + registry.negotiate_client(&headers), Err(ProtocolError::Encoding) ); } @@ -80,7 +80,7 @@ fn malformatted() { fn no_match() { let headers = HeaderMap::from_iter([(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("a,b,c"))]); - let registry = ProtocolRegistry::new(vec!["d"]).unwrap(); + let registry = SubprotocolRegistry::new(vec!["d"]).unwrap(); - assert_eq!(negotiate_request(®istry, &headers), Ok(None)); + assert_eq!(registry.negotiate_client(&headers), Ok(None)); } diff --git a/ratchet_core/src/lib.rs b/ratchet_core/src/lib.rs index f32911f..645f3f2 100644 --- a/ratchet_core/src/lib.rs +++ b/ratchet_core/src/lib.rs @@ -55,7 +55,7 @@ pub use builder::{WebSocketClientBuilder, WebSocketServerBuilder}; pub use errors::*; pub use ext::{NoExt, NoExtDecoder, NoExtEncoder, NoExtProvider}; pub use handshake::{ - accept, accept_with, subscribe, subscribe_with, ProtocolRegistry, TryIntoRequest, + accept, accept_with, subscribe, subscribe_with, SubprotocolRegistry, TryIntoRequest, UpgradedClient, UpgradedServer, WebSocketResponse, WebSocketUpgrader, }; pub use protocol::{ @@ -70,3 +70,13 @@ pub(crate) type Request = http::Request<()>; /// A stream representing a WebSocket connection. pub trait WebSocketStream: AsyncRead + AsyncWrite + Send + Unpin + 'static {} impl WebSocketStream for S where S: AsyncRead + AsyncWrite + Send + Unpin + 'static {} + +/// Provides utilities for handling WebSocket handshakes on the server side. +/// +/// This module includes the necessary components to parse, negotiate, and respond to WebSocket +/// connection upgrade requests from clients. +/// +/// It should generally not be required unless integrating Ratchet into other libraries. +pub mod server { + pub use crate::handshake::{build_response, handshake, parse_request, UpgradeRequest}; +} diff --git a/ratchet_rs/examples/autobahn-client.rs b/ratchet_rs/examples/autobahn-client.rs index 48ce60b..7575b81 100644 --- a/ratchet_rs/examples/autobahn-client.rs +++ b/ratchet_rs/examples/autobahn-client.rs @@ -15,7 +15,7 @@ use bytes::BytesMut; use ratchet_deflate::{Deflate, DeflateExtProvider}; use ratchet_rs::UpgradedClient; -use ratchet_rs::{Error, Message, PayloadType, ProtocolRegistry, WebSocketConfig}; +use ratchet_rs::{Error, Message, PayloadType, SubprotocolRegistry, WebSocketConfig}; use tokio::io::{BufReader, BufWriter}; use tokio::net::TcpStream; @@ -32,7 +32,7 @@ async fn subscribe( BufReader::new(BufWriter::new(stream)), url, &DeflateExtProvider::default(), - ProtocolRegistry::default(), + SubprotocolRegistry::default(), ) .await } diff --git a/ratchet_rs/examples/autobahn-server.rs b/ratchet_rs/examples/autobahn-server.rs index ce53fac..1e87734 100644 --- a/ratchet_rs/examples/autobahn-server.rs +++ b/ratchet_rs/examples/autobahn-server.rs @@ -15,7 +15,7 @@ use bytes::BytesMut; use log::trace; use ratchet_rs::deflate::DeflateExtProvider; -use ratchet_rs::{Error, Message, PayloadType, ProtocolRegistry, WebSocketConfig}; +use ratchet_rs::{Error, Message, PayloadType, SubprotocolRegistry, WebSocketConfig}; use tokio::io::{BufReader, BufWriter}; use tokio::net::{TcpListener, TcpStream}; @@ -34,7 +34,7 @@ async fn run(stream: TcpStream) -> Result<(), Error> { BufReader::new(BufWriter::new(stream)), WebSocketConfig::default(), DeflateExtProvider::default(), - ProtocolRegistry::default(), + SubprotocolRegistry::default(), ) .await .unwrap() diff --git a/ratchet_rs/examples/autobahn-split-client.rs b/ratchet_rs/examples/autobahn-split-client.rs index 9da25d0..08752f2 100644 --- a/ratchet_rs/examples/autobahn-split-client.rs +++ b/ratchet_rs/examples/autobahn-split-client.rs @@ -15,7 +15,7 @@ use bytes::BytesMut; use ratchet_deflate::{Deflate, DeflateExtProvider}; use ratchet_rs::UpgradedClient; -use ratchet_rs::{Error, Message, PayloadType, ProtocolRegistry, WebSocketConfig}; +use ratchet_rs::{Error, Message, PayloadType, SubprotocolRegistry, WebSocketConfig}; use tokio::io::{BufReader, BufWriter}; use tokio::net::TcpStream; @@ -32,7 +32,7 @@ async fn subscribe( BufReader::new(BufWriter::new(stream)), url, &DeflateExtProvider::default(), - ProtocolRegistry::default(), + SubprotocolRegistry::default(), ) .await } diff --git a/ratchet_rs/examples/autobahn-split-server.rs b/ratchet_rs/examples/autobahn-split-server.rs index 5b1b8c6..3a0ea65 100644 --- a/ratchet_rs/examples/autobahn-split-server.rs +++ b/ratchet_rs/examples/autobahn-split-server.rs @@ -14,7 +14,7 @@ use bytes::BytesMut; use ratchet_rs::deflate::DeflateExtProvider; -use ratchet_rs::{Error, Message, PayloadType, ProtocolRegistry, WebSocketConfig}; +use ratchet_rs::{Error, Message, PayloadType, SubprotocolRegistry, WebSocketConfig}; use tokio::io::{BufReader, BufWriter}; use tokio::net::{TcpListener, TcpStream}; @@ -33,7 +33,7 @@ async fn run(stream: TcpStream) -> Result<(), Error> { BufReader::new(BufWriter::new(stream)), WebSocketConfig::default(), DeflateExtProvider::default(), - ProtocolRegistry::default(), + SubprotocolRegistry::default(), ) .await .unwrap() diff --git a/ratchet_rs/examples/server.rs b/ratchet_rs/examples/server.rs index f1c84dd..298cd76 100644 --- a/ratchet_rs/examples/server.rs +++ b/ratchet_rs/examples/server.rs @@ -14,7 +14,8 @@ use bytes::BytesMut; use ratchet_rs::{ - Error, Message, NoExtProvider, PayloadType, ProtocolRegistry, UpgradedServer, WebSocketConfig, + Error, Message, NoExtProvider, PayloadType, SubprotocolRegistry, UpgradedServer, + WebSocketConfig, }; use tokio::net::TcpListener; use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; @@ -31,7 +32,7 @@ async fn main() -> Result<(), Error> { socket, WebSocketConfig::default(), NoExtProvider, - ProtocolRegistry::default(), + SubprotocolRegistry::default(), ) .await?; diff --git a/ratchet_rs/src/lib.rs b/ratchet_rs/src/lib.rs index f2af7fa..c1d2f9e 100644 --- a/ratchet_rs/src/lib.rs +++ b/ratchet_rs/src/lib.rs @@ -44,7 +44,13 @@ unused_import_braces )] -pub use ratchet_core::{self, *}; +pub use ratchet_core::{ + accept, accept_with, subscribe, subscribe_with, CloseCode, CloseReason, CloseState, Error, + ErrorKind, HttpError, Message, MessageType, NoExt, NoExtDecoder, NoExtEncoder, NoExtProvider, + PayloadType, ProtocolError, Receiver, ReuniteError, Role, Sender, SubprotocolRegistry, + TryIntoRequest, UpgradedClient, UpgradedServer, WebSocket, WebSocketClientBuilder, + WebSocketConfig, WebSocketResponse, WebSocketServerBuilder, WebSocketStream, WebSocketUpgrader, +}; pub use ratchet_ext::{self, *}; /// Per-message deflate. From 1788aa3b1ca7ea7422da2ff33eb419d2e357cfba Mon Sep 17 00:00:00 2001 From: SirCipher Date: Wed, 25 Sep 2024 15:45:07 +0100 Subject: [PATCH 04/13] Rolls back CI changes --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8aec7d7..0bc381d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: - name: Build Documentation run: cargo doc --lib --no-deps --all-features --workspace env: - RUSTDOCFLAGS="--cfg docsrs -Dwarnings" cargo doc --lib --no-deps --all-features --workspace + RUSTDOCFLAGS: --cfg docsrs -Dwarnings testmsrv: name: Test Suite Latest From e1aa61927602fd3365b834cfa23f15a80dcf1d1b Mon Sep 17 00:00:00 2001 From: SirCipher Date: Wed, 25 Sep 2024 15:56:29 +0100 Subject: [PATCH 05/13] Resolves incorrect response building --- ratchet_core/src/handshake/mod.rs | 4 +++- ratchet_core/src/handshake/server/mod.rs | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/ratchet_core/src/handshake/mod.rs b/ratchet_core/src/handshake/mod.rs index 327168a..cae7699 100644 --- a/ratchet_core/src/handshake/mod.rs +++ b/ratchet_core/src/handshake/mod.rs @@ -261,7 +261,9 @@ impl<'l, 'h, 'buf: 'h> TryMap> for &'l httparse::Response<'h, 'buf> let code = match self.code { Some(c) => match StatusCode::from_u16(c) { Ok(status) => status, - Err(_) => return Err(HttpError::Status(Some(c))), + Err(_) => { + return Err(HttpError::Status(Some(c))); + } }, None => return Err(HttpError::Status(None)), }; diff --git a/ratchet_core/src/handshake/server/mod.rs b/ratchet_core/src/handshake/server/mod.rs index fa513e6..c5a76c0 100644 --- a/ratchet_core/src/handshake/server/mod.rs +++ b/ratchet_core/src/handshake/server/mod.rs @@ -401,6 +401,7 @@ pub fn build_response( let mut response = http::Response::builder() .version(Version::HTTP_11) + .status(StatusCode::SWITCHING_PROTOCOLS) .header( http::header::SEC_WEBSOCKET_ACCEPT, HeaderValue::try_from(sec_websocket_accept)?, From d583a04ab54541cbb7517820cc9c221e2e01bfab Mon Sep 17 00:00:00 2001 From: SirCipher Date: Wed, 25 Sep 2024 18:16:18 +0100 Subject: [PATCH 06/13] Resolves PR comments --- ratchet_core/src/handshake/client/mod.rs | 24 +++-- ratchet_core/src/handshake/mod.rs | 100 +++++++++--------- ratchet_core/src/handshake/server/encoding.rs | 17 ++- 3 files changed, 70 insertions(+), 71 deletions(-) diff --git a/ratchet_core/src/handshake/client/mod.rs b/ratchet_core/src/handshake/client/mod.rs index 92b3836..7c33cb6 100644 --- a/ratchet_core/src/handshake/client/mod.rs +++ b/ratchet_core/src/handshake/client/mod.rs @@ -22,7 +22,7 @@ use crate::handshake::client::encoding::{build_request, encode_request}; use crate::handshake::io::BufferedIo; use crate::handshake::{ validate_header, validate_header_value, ParseResult, StreamingParser, SubprotocolRegistry, - TryMap, ACCEPT_KEY, BAD_STATUS_CODE, UPGRADE_STR, WEBSOCKET_STR, + TryFromWrapper, ACCEPT_KEY, BAD_STATUS_CODE, UPGRADE_STR, WEBSOCKET_STR, }; use crate::{ NoExt, NoExtProvider, Role, TryIntoRequest, WebSocket, WebSocketConfig, WebSocketStream, @@ -311,22 +311,24 @@ fn check_partial_response(response: &Response) -> Result<(), Error> { } } -fn try_parse_response<'h, 'b, E>( - buffer: &'h [u8], - mut response: Response<'h, 'b>, +fn try_parse_response<'b, E>( + buffer: &'b [u8], + mut response: Response<'b, 'b>, expected_nonce: &Nonce, extension: E, subprotocols: &mut SubprotocolRegistry, -) -> Result, HandshakeResult>, Error> +) -> Result, HandshakeResult>, Error> where - 'h: 'b, E: ExtensionProvider, { match response.parse(buffer) { - Ok(Status::Complete(count)) => { - parse_response(response.try_map()?, expected_nonce, extension, subprotocols) - .map(|r| ParseResult::Complete(r, count)) - } + Ok(Status::Complete(count)) => parse_response( + TryFromWrapper(response).try_into()?, + expected_nonce, + extension, + subprotocols, + ) + .map(|r| ParseResult::Complete(r, count)), Ok(Status::Partial) => Ok(ParseResult::Partial(response)), Err(e) => Err(e.into()), } @@ -343,7 +345,7 @@ where { if response.version() < Version::HTTP_11 { // rfc6455 § 4.2.1.1: must be HTTP/1.1 or higher - // this will implicitly be 0 as httparse only parses HTTP/1.x and 1.0 is 0. + // this will always be 0 as httparse only parses HTTP/1.x and 1.0 is 0. return Err(Error::with_cause( ErrorKind::Http, HttpError::HttpVersion(Some(0)), diff --git a/ratchet_core/src/handshake/mod.rs b/ratchet_core/src/handshake/mod.rs index cae7699..43d023e 100644 --- a/ratchet_core/src/handshake/mod.rs +++ b/ratchet_core/src/handshake/mod.rs @@ -27,6 +27,7 @@ use crate::{InvalidHeader, Request}; use http::header::HeaderName; use http::{HeaderMap, HeaderValue, Method, Version}; use http::{Response, StatusCode, Uri}; +use httparse::Header; use std::str::FromStr; use tokio::io::AsyncRead; use tokio_util::codec::Decoder; @@ -184,21 +185,16 @@ fn validate_header_any(headers: &HeaderMap, name: HeaderName, expected: &str) -> }) } -/// Local replacement for TryInto that can be implemented for httparse::Header and httparse::Request -pub trait TryMap { - /// Error type returned if the mapping fails - type Error: Into; +struct TryFromWrapper(pub T); - /// Try and map this into `Target` - fn try_map(self) -> Result; -} +impl<'h> TryFrom]>> for HeaderMap { + type Error = HttpError; -impl<'h> TryMap for &'h [httparse::Header<'h>] { - type Error = InvalidHeader; + fn try_from(value: TryFromWrapper<&'h mut [Header<'h>]>) -> Result { + let parsed_headers = value.0; - fn try_map(self) -> Result { - let mut header_map = HeaderMap::with_capacity(self.len()); - for header in self { + let mut header_map = HeaderMap::with_capacity(parsed_headers.len()); + for header in parsed_headers { let header_string = || { let value = String::from_utf8_lossy(header.value); format!("{}: {}", header.name, value) @@ -215,12 +211,47 @@ impl<'h> TryMap for &'h [httparse::Header<'h>] { } } -impl<'l, 'h, 'buf: 'h> TryMap for &'l httparse::Request<'h, 'buf> { +impl<'b> TryFrom>> for Response<()> { + type Error = HttpError; + + fn try_from(value: TryFromWrapper>) -> Result { + let parsed_response = value.0; + + let mut response = Response::new(()); + let code = match parsed_response.code { + Some(c) => match StatusCode::from_u16(c) { + Ok(status) => status, + Err(_) => { + return Err(HttpError::Status(Some(c))); + } + }, + None => return Err(HttpError::Status(None)), + }; + let version = match parsed_response.version { + Some(v) => match v { + 0 => Version::HTTP_10, + 1 => Version::HTTP_11, + n => return Err(HttpError::HttpVersion(Some(n))), + }, + None => return Err(HttpError::HttpVersion(None)), + }; + + *response.headers_mut() = HeaderMap::try_from(TryFromWrapper(parsed_response.headers))?; + *response.status_mut() = code; + *response.version_mut() = version; + + Ok(response) + } +} + +impl<'b> TryFrom>> for Request { type Error = HttpError; - fn try_map(self) -> Result { + fn try_from(value: TryFromWrapper>) -> Result { + let parsed_request = value.0; + let mut request = Request::new(()); - let path = match self.path { + let path = match parsed_request.path { Some(path) => path.parse()?, None => { return Err(HttpError::MalformattedUri(Some( @@ -228,13 +259,13 @@ impl<'l, 'h, 'buf: 'h> TryMap for &'l httparse::Request<'h, 'buf> { ))) } }; - let method = match self.method { + let method = match parsed_request.method { Some(m) => { Method::from_str(m).map_err(|_| HttpError::HttpMethod(Some(m.to_string())))? } None => return Err(HttpError::HttpMethod(None)), }; - let version = match self.version { + let version = match parsed_request.version { Some(v) => match v { 0 => Version::HTTP_10, 1 => Version::HTTP_11, @@ -242,9 +273,8 @@ impl<'l, 'h, 'buf: 'h> TryMap for &'l httparse::Request<'h, 'buf> { }, None => return Err(HttpError::HttpVersion(None)), }; - let headers = &self.headers; - *request.headers_mut() = headers.try_map()?; + *request.headers_mut() = HeaderMap::try_from(TryFromWrapper(parsed_request.headers))?; *request.uri_mut() = path; *request.version_mut() = version; *request.method_mut() = method; @@ -252,35 +282,3 @@ impl<'l, 'h, 'buf: 'h> TryMap for &'l httparse::Request<'h, 'buf> { Ok(request) } } - -impl<'l, 'h, 'buf: 'h> TryMap> for &'l httparse::Response<'h, 'buf> { - type Error = HttpError; - - fn try_map(self) -> Result, Self::Error> { - let mut response = Response::new(()); - let code = match self.code { - Some(c) => match StatusCode::from_u16(c) { - Ok(status) => status, - Err(_) => { - return Err(HttpError::Status(Some(c))); - } - }, - None => return Err(HttpError::Status(None)), - }; - let version = match self.version { - Some(v) => match v { - 0 => Version::HTTP_10, - 1 => Version::HTTP_11, - n => return Err(HttpError::HttpVersion(Some(n))), - }, - None => return Err(HttpError::HttpVersion(None)), - }; - let headers = &self.headers; - - *response.headers_mut() = headers.try_map()?; - *response.status_mut() = code; - *response.version_mut() = version; - - Ok(response) - } -} diff --git a/ratchet_core/src/handshake/server/encoding.rs b/ratchet_core/src/handshake/server/encoding.rs index 774f01d..4f7168d 100644 --- a/ratchet_core/src/handshake/server/encoding.rs +++ b/ratchet_core/src/handshake/server/encoding.rs @@ -14,15 +14,14 @@ use crate::handshake::io::BufferedIo; use crate::handshake::server::UpgradeRequest; -use crate::handshake::TryMap; use crate::handshake::{ - validate_header, validate_header_any, validate_header_value, ParseResult, METHOD_GET, - UPGRADE_STR, WEBSOCKET_STR, WEBSOCKET_VERSION_STR, + validate_header, validate_header_any, validate_header_value, ParseResult, TryFromWrapper, + METHOD_GET, UPGRADE_STR, WEBSOCKET_STR, WEBSOCKET_VERSION_STR, }; use crate::{Error, ErrorKind, HttpError, SubprotocolRegistry}; use bytes::{BufMut, Bytes, BytesMut}; use http::header::SEC_WEBSOCKET_KEY; -use http::{HeaderMap, Method, StatusCode, Version}; +use http::{HeaderMap, Method, Request, StatusCode, Version}; use httparse::Status; use ratchet_ext::ExtensionProvider; use tokio::io::AsyncWrite; @@ -125,18 +124,18 @@ where buffered.write().await } -pub fn try_parse_request<'h, 'b, E>( +pub fn try_parse_request<'b, E>( buffer: &'b [u8], - mut request: httparse::Request<'h, 'b>, + mut request: httparse::Request<'b, 'b>, extension: E, subprotocols: &mut SubprotocolRegistry, -) -> Result, UpgradeRequest>, Error> +) -> Result, UpgradeRequest>, Error> where E: ExtensionProvider, { match request.parse(buffer) { Ok(Status::Complete(count)) => { - let request = request.try_map()?; + let request = Request::try_from(TryFromWrapper(request))?; parse_request(request, extension, subprotocols).map(|r| ParseResult::Complete(r, count)) } Ok(Status::Partial) => Ok(ParseResult::Partial(request)), @@ -205,7 +204,7 @@ where E: ExtensionProvider, { if request.version() < HTTP_VERSION { - // this will implicitly be 0 as httparse only parses HTTP/1.x and 1.0 is 0. + // this will always be 0 as httparse only parses HTTP/1.x and 1.0 is 0. return Err(Error::with_cause( ErrorKind::Http, HttpError::HttpVersion(Some(0)), From c426fc72ecb2605a5ee5499bb78d7365de6ba21e Mon Sep 17 00:00:00 2001 From: SirCipher Date: Thu, 26 Sep 2024 10:43:45 +0100 Subject: [PATCH 07/13] Resolves PR comments --- ratchet_core/src/errors.rs | 5 ++++- ratchet_core/src/handshake/client/mod.rs | 17 ++++++++++------- ratchet_core/src/handshake/client/tests.rs | 2 +- ratchet_core/src/handshake/mod.rs | 4 ++-- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/ratchet_core/src/errors.rs b/ratchet_core/src/errors.rs index 8ed358a..f15fae9 100644 --- a/ratchet_core/src/errors.rs +++ b/ratchet_core/src/errors.rs @@ -156,7 +156,10 @@ pub enum HttpError { Redirected(String), /// The peer returned with a status code other than 101. #[error("Status code: `{0:?}`")] - Status(Option), + Status(u16), + /// A request was missing its status code + #[error("Missing status code")] + MissingStatus, /// An invalid HTTP version was received in a request. #[error("Invalid HTTP version: `{0:?}`")] HttpVersion(Option), diff --git a/ratchet_core/src/handshake/client/mod.rs b/ratchet_core/src/handshake/client/mod.rs index 7c33cb6..c32c35c 100644 --- a/ratchet_core/src/handshake/client/mod.rs +++ b/ratchet_core/src/handshake/client/mod.rs @@ -33,7 +33,7 @@ use bytes::BytesMut; use http::header::LOCATION; use http::{header, Request, StatusCode, Version}; use httparse::{Response, Status}; -use log::{error, trace}; +use log::{error, trace, warn}; use ratchet_ext::ExtensionProvider; use sha1::{Digest, Sha1}; use std::convert::TryFrom; @@ -303,7 +303,7 @@ fn check_partial_response(response: &Response) -> Result<(), Error> { Some(code) => match StatusCode::try_from(code) { Ok(code) => Err(Error::with_cause( ErrorKind::Http, - HttpError::Status(Some(code.as_u16())), + HttpError::Status(code.as_u16()), )), Err(_) => Err(Error::with_cause(ErrorKind::Http, BAD_STATUS_CODE)), }, @@ -366,16 +366,19 @@ where HttpError::Redirected(location), )) } - None => Err(Error::with_cause( - ErrorKind::Http, - HttpError::Status(Some(c.as_u16())), - )), + None => { + warn!("Received a redirection status code with no location header"); + Err(Error::with_cause( + ErrorKind::Http, + HttpError::Status(c.as_u16()), + )) + } }; } status_code => { return Err(Error::with_cause( ErrorKind::Http, - HttpError::Status(Some(status_code.as_u16())), + HttpError::Status(status_code.as_u16()), )) } } diff --git a/ratchet_core/src/handshake/client/tests.rs b/ratchet_core/src/handshake/client/tests.rs index 8208540..7ddcdb6 100644 --- a/ratchet_core/src/handshake/client/tests.rs +++ b/ratchet_core/src/handshake/client/tests.rs @@ -235,7 +235,7 @@ async fn bad_status_code() { expect_server_error( response, - HttpError::Status(Some(StatusCode::IM_A_TEAPOT.as_u16())), + HttpError::Status(StatusCode::IM_A_TEAPOT.as_u16()), ) .await; } diff --git a/ratchet_core/src/handshake/mod.rs b/ratchet_core/src/handshake/mod.rs index 43d023e..42304da 100644 --- a/ratchet_core/src/handshake/mod.rs +++ b/ratchet_core/src/handshake/mod.rs @@ -222,10 +222,10 @@ impl<'b> TryFrom>> for Response<()> { Some(c) => match StatusCode::from_u16(c) { Ok(status) => status, Err(_) => { - return Err(HttpError::Status(Some(c))); + return Err(HttpError::Status(c)); } }, - None => return Err(HttpError::Status(None)), + None => return Err(HttpError::MissingStatus), }; let version = match parsed_response.version { Some(v) => match v { From 2e626758aa09b1fc3001a67b0b7fd9b4af13b467 Mon Sep 17 00:00:00 2001 From: SirCipher Date: Thu, 26 Sep 2024 10:45:14 +0100 Subject: [PATCH 08/13] Fixes incorrect HTTP Error documentation --- ratchet_core/src/errors.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ratchet_core/src/errors.rs b/ratchet_core/src/errors.rs index f15fae9..1c71200 100644 --- a/ratchet_core/src/errors.rs +++ b/ratchet_core/src/errors.rs @@ -157,10 +157,10 @@ pub enum HttpError { /// The peer returned with a status code other than 101. #[error("Status code: `{0:?}`")] Status(u16), - /// A request was missing its status code + /// A request or response was missing its status code. #[error("Missing status code")] MissingStatus, - /// An invalid HTTP version was received in a request. + /// An invalid HTTP version was received in a request or response. #[error("Invalid HTTP version: `{0:?}`")] HttpVersion(Option), /// A request or response was missing an expected header. From 63c1ce44048d328d7d5560b7e4973901a441ea1e Mon Sep 17 00:00:00 2001 From: SirCipher Date: Thu, 26 Sep 2024 11:24:34 +0100 Subject: [PATCH 09/13] Updates to new API --- ratchet_core/src/handshake/client/mod.rs | 20 +++-------- ratchet_core/src/handshake/client/tests.rs | 2 +- ratchet_core/src/handshake/mod.rs | 36 +++++++++---------- ratchet_core/src/handshake/server/encoding.rs | 16 ++++----- 4 files changed, 30 insertions(+), 44 deletions(-) diff --git a/ratchet_core/src/handshake/client/mod.rs b/ratchet_core/src/handshake/client/mod.rs index 239cb9f..741c77b 100644 --- a/ratchet_core/src/handshake/client/mod.rs +++ b/ratchet_core/src/handshake/client/mod.rs @@ -36,16 +36,9 @@ use crate::handshake::{ use crate::{ NoExt, NoExtProvider, Role, TryIntoRequest, WebSocket, WebSocketConfig, WebSocketStream, }; -use base64::engine::general_purpose::STANDARD; -use base64::Engine; -use bytes::BytesMut; use http::header::LOCATION; -use http::{header, Request, StatusCode, Version}; -use httparse::{Response, Status}; -use log::{error, trace, warn}; +use log::warn; use ratchet_ext::ExtensionProvider; -use sha1::{Digest, Sha1}; -use std::convert::TryFrom; use tokio_util::codec::Decoder; type Nonce = [u8; 24]; @@ -356,13 +349,10 @@ where { if response.version() < Version::HTTP_11 { // rfc6455 § 4.2.1.1: must be HTTP/1.1 or higher - Some(1) => {} - _ => { - return Err(Error::with_cause( - ErrorKind::Http, - HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)), - )) - } + return Err(Error::with_cause( + ErrorKind::Http, + HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)), + )); } let status_code = response.status(); diff --git a/ratchet_core/src/handshake/client/tests.rs b/ratchet_core/src/handshake/client/tests.rs index e9c1e24..5398fcd 100644 --- a/ratchet_core/src/handshake/client/tests.rs +++ b/ratchet_core/src/handshake/client/tests.rs @@ -708,7 +708,7 @@ async fn negotiates_no_extension() { #[test] fn fails_to_build_request() { fn test(request: Request<()>, expected_error: E) { - match build_request(request, &NoExtProvider, &ProtocolRegistry::default()) { + match build_request(request, &NoExtProvider, &SubprotocolRegistry::default()) { Ok(r) => { panic!("Expected a test failure of {}. Got {:?}", expected_error, r); } diff --git a/ratchet_core/src/handshake/mod.rs b/ratchet_core/src/handshake/mod.rs index e445719..ddcd272 100644 --- a/ratchet_core/src/handshake/mod.rs +++ b/ratchet_core/src/handshake/mod.rs @@ -25,8 +25,9 @@ use crate::errors::{ErrorKind, HttpError}; use crate::handshake::io::BufferedIo; use crate::{InvalidHeader, Request}; use http::header::HeaderName; -use http::{StatusCode, Uri}; -use http::{HeaderMap, HeaderValue}; +use http::{HeaderMap, HeaderValue, Method, Response}; +use http::{StatusCode, Uri, Version}; +use httparse::Header; use log::{error, trace, warn}; use std::str::FromStr; use tokio::io::AsyncRead; @@ -241,18 +242,10 @@ impl<'b> TryFrom>> for Response<()> { }, None => return Err(HttpError::MissingStatus), }; - let version = match parsed_response.version { - Some(v) => match v { - 0 => Version::HTTP_10, - 1 => Version::HTTP_11, - n => return Err(HttpError::HttpVersion(Some(n))), - }, - None => return Err(HttpError::HttpVersion(None)), - }; *response.headers_mut() = HeaderMap::try_from(TryFromWrapper(parsed_response.headers))?; *response.status_mut() = code; - *response.version_mut() = version; + *response.version_mut() = parse_version(parsed_response.version)?; Ok(response) } @@ -279,20 +272,23 @@ impl<'b> TryFrom>> for Request { } None => return Err(HttpError::HttpMethod(None)), }; - let version = match parsed_request.version { - Some(v) => match v { - 0 => Version::HTTP_10, - 1 => Version::HTTP_11, - n => return Err(HttpError::HttpVersion(Some(n))), - }, - None => return Err(HttpError::HttpVersion(None)), - }; *request.headers_mut() = HeaderMap::try_from(TryFromWrapper(parsed_request.headers))?; *request.uri_mut() = path; - *request.version_mut() = version; + *request.version_mut() = parse_version(parsed_request.version)?; *request.method_mut() = method; Ok(request) } } + +fn parse_version(version: Option) -> Result { + match version { + Some(v) => match v { + 0 => Ok(Version::HTTP_10), + 1 => Ok(Version::HTTP_11), + n => Err(HttpError::HttpVersion(n.to_string())), + }, + None => Err(HttpError::HttpVersion("Missing HTTP version".to_string())), + } +} diff --git a/ratchet_core/src/handshake/server/encoding.rs b/ratchet_core/src/handshake/server/encoding.rs index 894e17a..07b65fc 100644 --- a/ratchet_core/src/handshake/server/encoding.rs +++ b/ratchet_core/src/handshake/server/encoding.rs @@ -15,14 +15,14 @@ use crate::handshake::io::BufferedIo; use crate::handshake::server::UpgradeRequest; use crate::handshake::{ - validate_header, validate_header_any, validate_header_value, ParseResult, TryFromWrapper, - METHOD_GET, UPGRADE_STR, WEBSOCKET_STR, WEBSOCKET_VERSION_STR, + validate_header_any, validate_header_value, ParseResult, TryFromWrapper, METHOD_GET, + UPGRADE_STR, WEBSOCKET_STR, WEBSOCKET_VERSION_STR, }; use crate::{Error, ErrorKind, HttpError, SubprotocolRegistry}; use bytes::{BufMut, Bytes, BytesMut}; use http::header::{HOST, SEC_WEBSOCKET_KEY}; use http::{HeaderMap, Method, Request, StatusCode, Version}; -use httparse::{Header, Status}; +use httparse::Status; use log::error; use ratchet_ext::ExtensionProvider; use tokio::io::AsyncWrite; @@ -208,7 +208,7 @@ where return Err(Error::with_cause( ErrorKind::Http, HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)), - )) + )); } if request.method() != Method::GET { @@ -254,12 +254,12 @@ where } /// Validates that 'headers' contains one 'host' header and that it is not a seperated list. -fn validate_host_header(headers: &[Header]) -> Result<(), Error> { +fn validate_host_header(headers: &HeaderMap) -> Result<(), Error> { let len = headers .iter() - .filter_map(|header| { - if header.name.eq_ignore_ascii_case(HOST.as_str()) { - Some(header.value.split(|c| c == &b' ' || c == &b',')) + .filter_map(|(name, value)| { + if name.as_str().eq_ignore_ascii_case(HOST.as_str()) { + Some(value.as_bytes().split(|c| c == &b' ' || c == &b',')) } else { None } From 41e8cd2f497fb4f655b9cd9076fe03cb6f2146ed Mon Sep 17 00:00:00 2001 From: SirCipher Date: Thu, 26 Sep 2024 17:30:48 +0100 Subject: [PATCH 10/13] Adds more server upgrade functions --- ratchet_core/src/handshake/mod.rs | 5 +- ratchet_core/src/handshake/server/encoding.rs | 47 +++- ratchet_core/src/handshake/server/mod.rs | 203 +++++++++++++++--- ratchet_core/src/lib.rs | 5 +- 4 files changed, 215 insertions(+), 45 deletions(-) diff --git a/ratchet_core/src/handshake/mod.rs b/ratchet_core/src/handshake/mod.rs index ddcd272..88e926c 100644 --- a/ratchet_core/src/handshake/mod.rs +++ b/ratchet_core/src/handshake/mod.rs @@ -36,8 +36,9 @@ use url::Url; pub use client::{subscribe, subscribe_with, UpgradedClient}; pub use server::{ - accept, accept_with, build_response, handshake, parse_request, UpgradeRequest, UpgradedServer, - WebSocketResponse, WebSocketUpgrader, + accept, accept_with, build_response, build_response_headers, handshake, handshake_from_parts, + parse_request, UpgradeRequest, UpgradeRequestParts, UpgradedServer, WebSocketResponse, + WebSocketUpgrader, }; pub use subprotocols::*; diff --git a/ratchet_core/src/handshake/server/encoding.rs b/ratchet_core/src/handshake/server/encoding.rs index 07b65fc..8f22735 100644 --- a/ratchet_core/src/handshake/server/encoding.rs +++ b/ratchet_core/src/handshake/server/encoding.rs @@ -13,7 +13,7 @@ // limitations under the License. use crate::handshake::io::BufferedIo; -use crate::handshake::server::UpgradeRequest; +use crate::handshake::server::{UpgradeRequest, UpgradeRequestParts}; use crate::handshake::{ validate_header_any, validate_header_value, ParseResult, TryFromWrapper, METHOD_GET, UPGRADE_STR, WEBSOCKET_STR, WEBSOCKET_VERSION_STR, @@ -21,6 +21,7 @@ use crate::handshake::{ use crate::{Error, ErrorKind, HttpError, SubprotocolRegistry}; use bytes::{BufMut, Bytes, BytesMut}; use http::header::{HOST, SEC_WEBSOCKET_KEY}; +use http::request::Parts; use http::{HeaderMap, Method, Request, StatusCode, Version}; use httparse::Status; use log::error; @@ -137,7 +138,31 @@ where match request.parse(buffer) { Ok(Status::Complete(count)) => { let request = Request::try_from(TryFromWrapper(request))?; - parse_request(request, extension, subprotocols).map(|r| ParseResult::Complete(r, count)) + let (parts, body) = request.into_parts(); + let Parts { + method, + version, + headers, + .. + } = &parts; + + let UpgradeRequestParts { + key, + subprotocol, + extension, + extension_header, + } = parse_request(*version, method, headers, extension, subprotocols)?; + + Ok(ParseResult::Complete( + UpgradeRequest { + key, + subprotocol, + extension, + request: Request::from_parts(parts, body), + extension_header, + }, + count, + )) } Ok(Status::Partial) => Ok(ParseResult::Partial(request)), Err(e) => Err(e.into()), @@ -196,29 +221,30 @@ pub fn check_partial_request(request: &httparse::Request) -> Result<(), Error> { /// - `Err(Error)`: Contains an error if the request is invalid or cannot be parsed. /// This could include issues such as unsupported HTTP versions, invalid methods, /// missing required headers, or failed negotiations for subprotocols or extensions. -pub fn parse_request( - request: http::Request, +pub fn parse_request( + version: Version, + method: &Method, + headers: &HeaderMap, extension: E, subprotocols: &SubprotocolRegistry, -) -> Result, Error> +) -> Result, Error> where E: ExtensionProvider, { - if request.version() < HTTP_VERSION { + if version < HTTP_VERSION { return Err(Error::with_cause( ErrorKind::Http, HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)), )); } - if request.method() != Method::GET { + if method != Method::GET { return Err(Error::with_cause( ErrorKind::Http, - HttpError::HttpMethod(Some(request.method().to_string())), + HttpError::HttpMethod(Some(method.to_string())), )); } - let headers = request.headers(); validate_header_any(headers, http::header::CONNECTION, UPGRADE_STR)?; validate_header_value(headers, http::header::UPGRADE, WEBSOCKET_STR)?; validate_header_value( @@ -244,12 +270,11 @@ where .map(Option::unzip) .map_err(|e| Error::with_cause(ErrorKind::Extension, e))?; - Ok(UpgradeRequest { + Ok(UpgradeRequestParts { key, extension, subprotocol, extension_header, - request, }) } diff --git a/ratchet_core/src/handshake/server/mod.rs b/ratchet_core/src/handshake/server/mod.rs index c5a76c0..6fcbd05 100644 --- a/ratchet_core/src/handshake/server/mod.rs +++ b/ratchet_core/src/handshake/server/mod.rs @@ -30,8 +30,9 @@ use crate::{ }; use base64::engine::{general_purpose::STANDARD, Engine}; use bytes::{Bytes, BytesMut}; +use http::request::Parts; use http::status::InvalidStatusCode; -use http::{HeaderMap, HeaderValue, StatusCode, Uri, Version}; +use http::{HeaderMap, HeaderValue, Method, Response, StatusCode, Uri, Version}; use log::{error, trace}; use ratchet_ext::{Extension, ExtensionProvider}; use sha1::{Digest, Sha1}; @@ -319,6 +320,38 @@ where } } +/// Represents a parsed WebSocket connection upgrade HTTP request without the context of the +/// request that it is responding to. +#[derive(Debug)] +#[non_exhaustive] +pub struct UpgradeRequestParts { + /// The security key provided by the client during the WebSocket handshake. + /// + /// This key is used by the server to generate a response key, confirming that the server + /// accepts the WebSocket upgrade request. + pub key: Bytes, + + /// The optional WebSocket subprotocol agreed upon during the handshake. + /// + /// The subprotocol is used to define the application-specific communication on top of the + /// WebSocket connection, such as `wamp` or `graphql-ws`. If no subprotocol is requested or + /// agreed upon, this will be `None`. + pub subprotocol: Option, + + /// The optional WebSocket extension negotiated during the handshake. + /// + /// Extensions allow WebSocket connections to have additional functionality, such as compression + /// or multiplexing. This field represents any such negotiated extension, or `None` if no + /// extensions were negotiated. + pub extension: Option, + + /// The optional `Sec-WebSocket-Extensions` header value from the HTTP request. + /// + /// This header may contain the raw extension details sent by the client during the handshake. + /// If no extension was requested, this field will be `None`. + pub extension_header: Option, +} + /// Represents a parsed WebSocket connection upgrade HTTP request. #[derive(Debug)] #[non_exhaustive] @@ -393,36 +426,12 @@ pub fn build_response( subprotocol: Option, extension_header: Option, ) -> Result, Error> { - let mut digest = Sha1::new(); - Digest::update(&mut digest, key); - Digest::update(&mut digest, ACCEPT_KEY); - - let sec_websocket_accept = STANDARD.encode(digest.finalize()); - let mut response = http::Response::builder() .version(Version::HTTP_11) - .status(StatusCode::SWITCHING_PROTOCOLS) - .header( - http::header::SEC_WEBSOCKET_ACCEPT, - HeaderValue::try_from(sec_websocket_accept)?, - ) - .header( - http::header::UPGRADE, - HeaderValue::from_static(WEBSOCKET_STR), - ) - .header( - http::header::CONNECTION, - HeaderValue::from_static(UPGRADE_STR), - ); + .status(StatusCode::SWITCHING_PROTOCOLS); - if let Some(subprotocol) = &subprotocol { - response = response.header( - http::header::SEC_WEBSOCKET_PROTOCOL, - HeaderValue::try_from(subprotocol)?, - ); - } - if let Some(extension_header) = extension_header { - response = response.header(http::header::SEC_WEBSOCKET_EXTENSIONS, extension_header); + if let Some(headers) = response.headers_mut() { + *headers = build_response_headers(key, subprotocol, extension_header)?; } Ok(response.body(())?) @@ -474,19 +483,151 @@ pub fn handshake( request: http::Request, extension: &E, subprotocols: &SubprotocolRegistry, -) -> Result<(http::Response<()>, Option), Error> +) -> Result<(Response<()>, Option), Error> where E: ExtensionProvider, { - let UpgradeRequest { + let (parts, _body) = request.into_parts(); + let Parts { + method, + version, + headers, + .. + } = parts; + let UpgradeRequestParts { key, subprotocol, extension, extension_header, .. - } = parse_request(request, extension, subprotocols)?; + } = parse_request(version, &method, &headers, extension, subprotocols)?; Ok(( build_response(key, subprotocol, extension_header)?, extension, )) } + +/// Processes a WebSocket handshake request from its parts and generates the appropriate response. +/// +/// This function handles the server-side part of a WebSocket handshake. It parses the incoming HTTP +/// request parts that seeks to upgrade the connection to WebSocket, negotiates extensions and +/// subprotocols, and constructs an appropriate HTTP response to complete the WebSocket handshake. +/// +/// # Arguments +/// +/// - `version`: The HTTP `Version` of the request. +/// - `method`: The HTTP `Method` of the request, which must be `GET` for WebSocket handshakes. +/// - `headers`: A reference to the request's `HeaderMap` containing the HTTP headers. These headers +/// must include the necessary WebSocket headers such as `Sec-WebSocket-Key` and `Upgrade`. +/// - `extension`: An extension that may be negotiated for the connection. +/// - `subprotocols`: A `SubprotocolRegistry`, which will be used to attempt to negotiate a +/// subprotocol. +/// +/// # Returns +/// +/// This function returns a `Result` containing: +/// - A tuple consisting of: +/// - An `http::Response<()>`, which represents the WebSocket handshake response. +/// The response includes headers such as `Sec-WebSocket-Accept` to confirm the upgrade. +/// - An optional `E::Extension`, which represents the negotiated extension, if any. +/// +/// If the handshake fails, an `Error` is returned, which may be caused by invalid +/// requests, issues parsing headers, or problems negotiating the WebSocket subprotocols +/// or extensions. +/// +/// # Type Parameters +/// +/// - `E`: The type of the extension provider, which must implement the `ExtensionProvider` +/// trait. This defines how WebSocket extensions (like compression) are handled. +/// - `B`: The body type of the HTTP request. While it is discouraged for GET requests to have a body +/// it is not technically incorrect and the use of this function is lowering the guardrails to +/// allow for Ratchet to be more easily integrated into other libraries. It is the implementors +/// responsibility to perform any validation on the body. +/// +/// # Errors +/// +/// The function returns an `Error` in cases such as: +/// - Failure to parse the WebSocket upgrade request. +/// - Issues building the response, such as invalid subprotocol or extension headers. +/// - Failure to negotiate the WebSocket extensions or subprotocols. +pub fn handshake_from_parts( + version: Version, + method: &Method, + headers: &HeaderMap, + extension: &E, + subprotocols: &SubprotocolRegistry, +) -> Result<(Response<()>, Option), Error> +where + E: ExtensionProvider, +{ + let UpgradeRequestParts { + key, + subprotocol, + extension, + extension_header, + .. + } = parse_request(version, method, headers, extension, subprotocols)?; + + let mut response = http::Response::builder() + .version(Version::HTTP_11) + .status(StatusCode::SWITCHING_PROTOCOLS) + .body(())?; + *response.headers_mut() = build_response_headers(key, subprotocol, extension_header)?; + + Ok((response, extension)) +} + +/// Constructs the HTTP response headers for a WebSocket handshake response. +/// +/// This function builds the necessary headers for completing a WebSocket handshake, including +/// the `Sec-WebSocket-Accept` header, which is derived by hashing the client's WebSocket key +/// and appending the WebSocket GUID. Additionally, it adds optional headers for subprotocols and +/// extensions if they were negotiated during the handshake. +/// +/// # Arguments +/// - `key`: The WebSocket key provided by the client as part of the handshake request (`Sec-WebSocket-Key`). +/// - `subprotocol`: An optional `String` that represents the WebSocket subprotocol negotiated between +/// the client and server. If provided, it will be included in the `Sec-WebSocket-Protocol` header. +/// - `extension_header`: An optional `HeaderValue` representing any WebSocket extensions negotiated +/// during the handshake. If provided, it will be added to the `Sec-WebSocket-Extensions` header. +/// +/// # Returns +/// - `Result`: A result that contains either a `HeaderMap` with the constructed +/// headers or an `Error` if an issue occurs while creating the headers. +pub fn build_response_headers( + key: Bytes, + subprotocol: Option, + extension_header: Option, +) -> Result { + let mut digest = Sha1::new(); + Digest::update(&mut digest, key); + Digest::update(&mut digest, ACCEPT_KEY); + + let sec_websocket_accept = STANDARD.encode(digest.finalize()); + + let mut map = HeaderMap::default(); + + map.insert( + http::header::SEC_WEBSOCKET_ACCEPT, + HeaderValue::try_from(sec_websocket_accept)?, + ); + map.insert( + http::header::UPGRADE, + HeaderValue::from_static(WEBSOCKET_STR), + ); + map.insert( + http::header::CONNECTION, + HeaderValue::from_static(UPGRADE_STR), + ); + if let Some(subprotocol) = subprotocol { + map.insert( + http::header::SEC_WEBSOCKET_PROTOCOL, + HeaderValue::try_from(subprotocol)?, + ); + } + if let Some(extension_header) = extension_header { + map.insert(http::header::SEC_WEBSOCKET_EXTENSIONS, extension_header); + } + + Ok(map) +} diff --git a/ratchet_core/src/lib.rs b/ratchet_core/src/lib.rs index 645f3f2..19320e0 100644 --- a/ratchet_core/src/lib.rs +++ b/ratchet_core/src/lib.rs @@ -78,5 +78,8 @@ impl WebSocketStream for S where S: AsyncRead + AsyncWrite + Send + Unpin + ' /// /// It should generally not be required unless integrating Ratchet into other libraries. pub mod server { - pub use crate::handshake::{build_response, handshake, parse_request, UpgradeRequest}; + pub use crate::handshake::{ + build_response, build_response_headers, handshake, handshake_from_parts, parse_request, + UpgradeRequest, UpgradeRequestParts, + }; } From e8f98b3a7b3991b94de064ec26ea7784540186b0 Mon Sep 17 00:00:00 2001 From: SirCipher Date: Thu, 26 Sep 2024 17:53:53 +0100 Subject: [PATCH 11/13] Exposes simple function to build a response from request headers --- ratchet_core/src/handshake/mod.rs | 6 +-- ratchet_core/src/handshake/server/encoding.rs | 37 ++++++++++++------- ratchet_core/src/handshake/server/mod.rs | 33 +++++++---------- ratchet_core/src/lib.rs | 4 +- 4 files changed, 41 insertions(+), 39 deletions(-) diff --git a/ratchet_core/src/handshake/mod.rs b/ratchet_core/src/handshake/mod.rs index 88e926c..6b4db48 100644 --- a/ratchet_core/src/handshake/mod.rs +++ b/ratchet_core/src/handshake/mod.rs @@ -36,9 +36,9 @@ use url::Url; pub use client::{subscribe, subscribe_with, UpgradedClient}; pub use server::{ - accept, accept_with, build_response, build_response_headers, handshake, handshake_from_parts, - parse_request, UpgradeRequest, UpgradeRequestParts, UpgradedServer, WebSocketResponse, - WebSocketUpgrader, + accept, accept_with, build_response, build_response_headers, handshake, parse_request, + response_from_headers, validate_method_and_version, UpgradeRequest, UpgradeRequestParts, + UpgradedServer, WebSocketResponse, WebSocketUpgrader, }; pub use subprotocols::*; diff --git a/ratchet_core/src/handshake/server/encoding.rs b/ratchet_core/src/handshake/server/encoding.rs index 8f22735..477482f 100644 --- a/ratchet_core/src/handshake/server/encoding.rs +++ b/ratchet_core/src/handshake/server/encoding.rs @@ -194,6 +194,28 @@ pub fn check_partial_request(request: &httparse::Request) -> Result<(), Error> { Ok(()) } +/// Validates that `version` and `method` are correct for a WebSocket upgrade. +/// +/// # Returns +/// `Ok(())` if they are correct or `Err(e)` if they are not. +pub fn validate_method_and_version(version: Version, method: &Method) -> Result<(), Error> { + if version < HTTP_VERSION { + return Err(Error::with_cause( + ErrorKind::Http, + HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)), + )); + } + + if method != Method::GET { + return Err(Error::with_cause( + ErrorKind::Http, + HttpError::HttpMethod(Some(method.to_string())), + )); + } + + Ok(()) +} + /// Parses an HTTP request to extract WebSocket upgrade information. /// /// This function validates and processes an incoming HTTP request to ensure it meets the @@ -231,20 +253,7 @@ pub fn parse_request( where E: ExtensionProvider, { - if version < HTTP_VERSION { - return Err(Error::with_cause( - ErrorKind::Http, - HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)), - )); - } - - if method != Method::GET { - return Err(Error::with_cause( - ErrorKind::Http, - HttpError::HttpMethod(Some(method.to_string())), - )); - } - + validate_method_and_version(version, method)?; validate_header_any(headers, http::header::CONNECTION, UPGRADE_STR)?; validate_header_value(headers, http::header::UPGRADE, WEBSOCKET_STR)?; validate_header_value( diff --git a/ratchet_core/src/handshake/server/mod.rs b/ratchet_core/src/handshake/server/mod.rs index 6fcbd05..325880d 100644 --- a/ratchet_core/src/handshake/server/mod.rs +++ b/ratchet_core/src/handshake/server/mod.rs @@ -16,7 +16,7 @@ mod encoding; #[cfg(test)] mod tests; -pub use encoding::parse_request; +pub use encoding::{parse_request, validate_method_and_version}; use crate::{ ext::NoExt, @@ -410,7 +410,7 @@ pub struct UpgradeRequest { /// the WebSocket protocol. /// - `Connection`: Set to `Upgrade`, as required by the HTTP upgrade process. /// -/// Optionally, the response may also include: +/// Optionally, the response headers may also include: /// /// - `Sec-WebSocket-Protocol`: The negotiated subprotocol, if provided. /// - `Sec-WebSocket-Extensions`: The WebSocket extension header, if an extension was negotiated. @@ -425,7 +425,7 @@ pub fn build_response( key: Bytes, subprotocol: Option, extension_header: Option, -) -> Result, Error> { +) -> Result, Error> { let mut response = http::Response::builder() .version(Version::HTTP_11) .status(StatusCode::SWITCHING_PROTOCOLS); @@ -507,16 +507,10 @@ where )) } -/// Processes a WebSocket handshake request from its parts and generates the appropriate response. -/// -/// This function handles the server-side part of a WebSocket handshake. It parses the incoming HTTP -/// request parts that seeks to upgrade the connection to WebSocket, negotiates extensions and -/// subprotocols, and constructs an appropriate HTTP response to complete the WebSocket handshake. +/// Generates a WebSocket upgrade response from the provided headers. /// /// # Arguments /// -/// - `version`: The HTTP `Version` of the request. -/// - `method`: The HTTP `Method` of the request, which must be `GET` for WebSocket handshakes. /// - `headers`: A reference to the request's `HeaderMap` containing the HTTP headers. These headers /// must include the necessary WebSocket headers such as `Sec-WebSocket-Key` and `Upgrade`. /// - `extension`: An extension that may be negotiated for the connection. @@ -539,10 +533,6 @@ where /// /// - `E`: The type of the extension provider, which must implement the `ExtensionProvider` /// trait. This defines how WebSocket extensions (like compression) are handled. -/// - `B`: The body type of the HTTP request. While it is discouraged for GET requests to have a body -/// it is not technically incorrect and the use of this function is lowering the guardrails to -/// allow for Ratchet to be more easily integrated into other libraries. It is the implementors -/// responsibility to perform any validation on the body. /// /// # Errors /// @@ -550,11 +540,9 @@ where /// - Failure to parse the WebSocket upgrade request. /// - Issues building the response, such as invalid subprotocol or extension headers. /// - Failure to negotiate the WebSocket extensions or subprotocols. -pub fn handshake_from_parts( - version: Version, - method: &Method, +pub fn response_from_headers( headers: &HeaderMap, - extension: &E, + extension: E, subprotocols: &SubprotocolRegistry, ) -> Result<(Response<()>, Option), Error> where @@ -565,8 +553,13 @@ where subprotocol, extension, extension_header, - .. - } = parse_request(version, method, headers, extension, subprotocols)?; + } = parse_request( + Version::HTTP_11, + &Method::GET, + headers, + extension, + subprotocols, + )?; let mut response = http::Response::builder() .version(Version::HTTP_11) diff --git a/ratchet_core/src/lib.rs b/ratchet_core/src/lib.rs index 19320e0..b1ba829 100644 --- a/ratchet_core/src/lib.rs +++ b/ratchet_core/src/lib.rs @@ -79,7 +79,7 @@ impl WebSocketStream for S where S: AsyncRead + AsyncWrite + Send + Unpin + ' /// It should generally not be required unless integrating Ratchet into other libraries. pub mod server { pub use crate::handshake::{ - build_response, build_response_headers, handshake, handshake_from_parts, parse_request, - UpgradeRequest, UpgradeRequestParts, + build_response, build_response_headers, handshake, parse_request, response_from_headers, + validate_method_and_version, UpgradeRequest, UpgradeRequestParts, }; } From 2035b147046f756bc7d04ad2d953315538e13633 Mon Sep 17 00:00:00 2001 From: SirCipher Date: Fri, 27 Sep 2024 09:18:36 +0100 Subject: [PATCH 12/13] Restructures server handshake parsing functions --- ratchet_core/src/handshake/mod.rs | 2 +- ratchet_core/src/handshake/server/encoding.rs | 160 +----------------- ratchet_core/src/handshake/server/mod.rs | 157 ++++++++++++++++- ratchet_core/src/lib.rs | 4 +- 4 files changed, 160 insertions(+), 163 deletions(-) diff --git a/ratchet_core/src/handshake/mod.rs b/ratchet_core/src/handshake/mod.rs index 6b4db48..de1b544 100644 --- a/ratchet_core/src/handshake/mod.rs +++ b/ratchet_core/src/handshake/mod.rs @@ -36,7 +36,7 @@ use url::Url; pub use client::{subscribe, subscribe_with, UpgradedClient}; pub use server::{ - accept, accept_with, build_response, build_response_headers, handshake, parse_request, + accept, accept_with, build_response, build_response_headers, handshake, parse_request_parts, response_from_headers, validate_method_and_version, UpgradeRequest, UpgradeRequestParts, UpgradedServer, WebSocketResponse, WebSocketUpgrader, }; diff --git a/ratchet_core/src/handshake/server/encoding.rs b/ratchet_core/src/handshake/server/encoding.rs index 477482f..e9143ee 100644 --- a/ratchet_core/src/handshake/server/encoding.rs +++ b/ratchet_core/src/handshake/server/encoding.rs @@ -13,18 +13,14 @@ // limitations under the License. use crate::handshake::io::BufferedIo; -use crate::handshake::server::{UpgradeRequest, UpgradeRequestParts}; -use crate::handshake::{ - validate_header_any, validate_header_value, ParseResult, TryFromWrapper, METHOD_GET, - UPGRADE_STR, WEBSOCKET_STR, WEBSOCKET_VERSION_STR, -}; -use crate::{Error, ErrorKind, HttpError, SubprotocolRegistry}; -use bytes::{BufMut, Bytes, BytesMut}; -use http::header::{HOST, SEC_WEBSOCKET_KEY}; +use crate::handshake::server::{check_partial_request, UpgradeRequest, UpgradeRequestParts}; +use crate::handshake::{ParseResult, TryFromWrapper}; +use crate::server::parse_request_parts; +use crate::{Error, SubprotocolRegistry}; +use bytes::{BufMut, BytesMut}; use http::request::Parts; -use http::{HeaderMap, Method, Request, StatusCode, Version}; +use http::{HeaderMap, Request, StatusCode}; use httparse::Status; -use log::error; use ratchet_ext::ExtensionProvider; use tokio::io::AsyncWrite; use tokio_util::codec::Decoder; @@ -35,8 +31,6 @@ const HTTP_VERSION_STR: &[u8] = b"HTTP/1.1 "; const STATUS_TERMINATOR_LEN: usize = 2; const TERMINATOR_NO_HEADERS: &[u8] = b"\r\n\r\n"; const TERMINATOR_WITH_HEADER: &[u8] = b"\r\n"; -const HTTP_VERSION_INT: u8 = 1; -const HTTP_VERSION: Version = Version::HTTP_11; pub struct RequestParser { pub subprotocols: SubprotocolRegistry, @@ -151,7 +145,7 @@ where subprotocol, extension, extension_header, - } = parse_request(*version, method, headers, extension, subprotocols)?; + } = parse_request_parts(*version, method, headers, extension, subprotocols)?; Ok(ParseResult::Complete( UpgradeRequest { @@ -168,143 +162,3 @@ where Err(e) => Err(e.into()), } } - -pub fn check_partial_request(request: &httparse::Request) -> Result<(), Error> { - match request.version { - Some(HTTP_VERSION_INT) | None => {} - Some(_) => { - return Err(Error::with_cause( - ErrorKind::Http, - HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)), - )) - } - } - - match request.method { - Some(m) if m.eq_ignore_ascii_case(METHOD_GET) => {} - None => {} - m => { - return Err(Error::with_cause( - ErrorKind::Http, - HttpError::HttpMethod(m.map(ToString::to_string)), - )); - } - } - - Ok(()) -} - -/// Validates that `version` and `method` are correct for a WebSocket upgrade. -/// -/// # Returns -/// `Ok(())` if they are correct or `Err(e)` if they are not. -pub fn validate_method_and_version(version: Version, method: &Method) -> Result<(), Error> { - if version < HTTP_VERSION { - return Err(Error::with_cause( - ErrorKind::Http, - HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)), - )); - } - - if method != Method::GET { - return Err(Error::with_cause( - ErrorKind::Http, - HttpError::HttpMethod(Some(method.to_string())), - )); - } - - Ok(()) -} - -/// Parses an HTTP request to extract WebSocket upgrade information. -/// -/// This function validates and processes an incoming HTTP request to ensure it meets the -/// requirements for a WebSocket upgrade. It checks the HTTP version, method, and necessary headers -/// to determine if the request can be successfully upgraded to a WebSocket connection. It also -/// negotiates the subprotocols and extensions specified in the request. -/// -/// # Arguments -/// - `request`: An `http::Request` representing the incoming HTTP request from the client, which -/// is expected to contain WebSocket-specific headers. While it is discouraged for GET requests to -/// have a body it is not technically incorrect and the use of this function is lowering the -/// guardrails to allow for Ratchet to be more easily integrated into other libraries. It is the -/// implementors responsibility to perform any validation on the body. -/// - `extension`: An instance of a type that implements the `ExtensionProvider` -/// trait. This object is responsible for negotiating any server-supported -/// extensions requested by the client. -/// - `subprotocols`: A `SubprotocolRegistry`, which manages the supported subprotocols and attempts -/// to negotiate one with the client. -/// -/// # Returns -/// This function returns a `Result, Error>`, where: -/// - `Ok(UpgradeRequest)`: Contains the parsed information needed for the WebSocket -/// handshake, including the WebSocket key, negotiated subprotocol, optional -/// extensions, and the original HTTP request. -/// - `Err(Error)`: Contains an error if the request is invalid or cannot be parsed. -/// This could include issues such as unsupported HTTP versions, invalid methods, -/// missing required headers, or failed negotiations for subprotocols or extensions. -pub fn parse_request( - version: Version, - method: &Method, - headers: &HeaderMap, - extension: E, - subprotocols: &SubprotocolRegistry, -) -> Result, Error> -where - E: ExtensionProvider, -{ - validate_method_and_version(version, method)?; - validate_header_any(headers, http::header::CONNECTION, UPGRADE_STR)?; - validate_header_value(headers, http::header::UPGRADE, WEBSOCKET_STR)?; - validate_header_value( - headers, - http::header::SEC_WEBSOCKET_VERSION, - WEBSOCKET_VERSION_STR, - )?; - - if let Err(e) = validate_host_header(headers) { - error!("Server responded with invalid 'host' headers"); - return Err(e); - } - - let key = headers - .get(SEC_WEBSOCKET_KEY) - .map(|v| Bytes::from(v.as_bytes().to_vec())) - .ok_or_else(|| { - Error::with_cause(ErrorKind::Http, HttpError::MissingHeader(SEC_WEBSOCKET_KEY)) - })?; - let subprotocol = subprotocols.negotiate_client(headers)?; - let (extension, extension_header) = extension - .negotiate_server(headers) - .map(Option::unzip) - .map_err(|e| Error::with_cause(ErrorKind::Extension, e))?; - - Ok(UpgradeRequestParts { - key, - extension, - subprotocol, - extension_header, - }) -} - -/// Validates that 'headers' contains one 'host' header and that it is not a seperated list. -fn validate_host_header(headers: &HeaderMap) -> Result<(), Error> { - let len = headers - .iter() - .filter_map(|(name, value)| { - if name.as_str().eq_ignore_ascii_case(HOST.as_str()) { - Some(value.as_bytes().split(|c| c == &b' ' || c == &b',')) - } else { - None - } - }) - .count(); - if len == 1 { - Ok(()) - } else { - Err(Error::with_cause( - ErrorKind::Http, - HttpError::MissingHeader(HOST), - )) - } -} diff --git a/ratchet_core/src/handshake/server/mod.rs b/ratchet_core/src/handshake/server/mod.rs index 325880d..96506d1 100644 --- a/ratchet_core/src/handshake/server/mod.rs +++ b/ratchet_core/src/handshake/server/mod.rs @@ -16,8 +16,9 @@ mod encoding; #[cfg(test)] mod tests; -pub use encoding::{parse_request, validate_method_and_version}; - +use crate::handshake::{ + validate_header_any, validate_header_value, METHOD_GET, WEBSOCKET_VERSION_STR, +}; use crate::{ ext::NoExt, handshake::io::BufferedIo, @@ -25,11 +26,12 @@ use crate::{ handshake::{StreamingParser, ACCEPT_KEY}, handshake::{UPGRADE_STR, WEBSOCKET_STR}, protocol::Role, - Error, HttpError, NoExtProvider, Request, SubprotocolRegistry, WebSocket, WebSocketConfig, - WebSocketStream, + Error, ErrorKind, HttpError, NoExtProvider, Request, SubprotocolRegistry, WebSocket, + WebSocketConfig, WebSocketStream, }; use base64::engine::{general_purpose::STANDARD, Engine}; use bytes::{Bytes, BytesMut}; +use http::header::{HOST, SEC_WEBSOCKET_KEY}; use http::request::Parts; use http::status::InvalidStatusCode; use http::{HeaderMap, HeaderValue, Method, Response, StatusCode, Uri, Version}; @@ -43,6 +45,7 @@ const MSG_HANDSHAKE_COMPLETED: &str = "Server handshake completed"; const MSG_HANDSHAKE_FAILED: &str = "Server handshake failed"; const UPGRADED_MSG: &str = "Upgraded connection"; const REJECT_MSG: &str = "Rejected connection"; +const HTTP_VERSION_INT: u8 = 1; /// A structure representing an upgraded WebSocket session and an optional subprotocol that was /// negotiated during the upgrade. @@ -389,7 +392,7 @@ pub struct UpgradeRequest { /// Builds an HTTP response to a WebSocket connection upgrade request. /// /// No validation is performed by this function and it is only guaranteed to be correct if the -/// arguments are derived by previously calling [`parse_request`]. +/// arguments are derived by previously calling [`parse_request_parts`]. /// /// # Arguments /// @@ -500,7 +503,7 @@ where extension, extension_header, .. - } = parse_request(version, &method, &headers, extension, subprotocols)?; + } = parse_request_parts(version, &method, &headers, extension, subprotocols)?; Ok(( build_response(key, subprotocol, extension_header)?, extension, @@ -553,7 +556,7 @@ where subprotocol, extension, extension_header, - } = parse_request( + } = parse_request_parts( Version::HTTP_11, &Method::GET, headers, @@ -624,3 +627,143 @@ pub fn build_response_headers( Ok(map) } + +/// Parses an HTTP request from its parts to extract WebSocket upgrade information. +/// +/// This function validates and processes an incoming HTTP request to ensure it meets the +/// requirements for a WebSocket upgrade. It checks the HTTP version, method, and necessary headers +/// to determine if the request can be successfully upgraded to a WebSocket connection. It also +/// negotiates the subprotocols and extensions specified in the request. +/// +/// # Arguments +/// - `request`: An `http::Request` representing the incoming HTTP request from the client, which +/// is expected to contain WebSocket-specific headers. While it is discouraged for GET requests to +/// have a body it is not technically incorrect and the use of this function is lowering the +/// guardrails to allow for Ratchet to be more easily integrated into other libraries. It is the +/// implementors responsibility to perform any validation on the body. +/// - `extension`: An instance of a type that implements the `ExtensionProvider` +/// trait. This object is responsible for negotiating any server-supported +/// extensions requested by the client. +/// - `subprotocols`: A `SubprotocolRegistry`, which manages the supported subprotocols and attempts +/// to negotiate one with the client. +/// +/// # Returns +/// This function returns a `Result, Error>`, where: +/// - `Ok(UpgradeRequest)`: Contains the parsed information needed for the WebSocket +/// handshake, including the WebSocket key, negotiated subprotocol, optional +/// extensions, and the original HTTP request. +/// - `Err(Error)`: Contains an error if the request is invalid or cannot be parsed. +/// This could include issues such as unsupported HTTP versions, invalid methods, +/// missing required headers, or failed negotiations for subprotocols or extensions. +pub fn parse_request_parts( + version: Version, + method: &Method, + headers: &HeaderMap, + extension: E, + subprotocols: &SubprotocolRegistry, +) -> Result, Error> +where + E: ExtensionProvider, +{ + validate_method_and_version(version, method)?; + validate_header_any(headers, http::header::CONNECTION, UPGRADE_STR)?; + validate_header_value(headers, http::header::UPGRADE, WEBSOCKET_STR)?; + validate_header_value( + headers, + http::header::SEC_WEBSOCKET_VERSION, + WEBSOCKET_VERSION_STR, + )?; + + if let Err(e) = validate_host_header(headers) { + error!("Server responded with invalid 'host' headers"); + return Err(e); + } + + let key = headers + .get(SEC_WEBSOCKET_KEY) + .map(|v| Bytes::from(v.as_bytes().to_vec())) + .ok_or_else(|| { + Error::with_cause(ErrorKind::Http, HttpError::MissingHeader(SEC_WEBSOCKET_KEY)) + })?; + let subprotocol = subprotocols.negotiate_client(headers)?; + let (extension, extension_header) = extension + .negotiate_server(headers) + .map(Option::unzip) + .map_err(|e| Error::with_cause(ErrorKind::Extension, e))?; + + Ok(UpgradeRequestParts { + key, + extension, + subprotocol, + extension_header, + }) +} + +fn check_partial_request(request: &httparse::Request) -> Result<(), Error> { + match request.version { + Some(HTTP_VERSION_INT) | None => {} + Some(_) => { + return Err(Error::with_cause( + ErrorKind::Http, + HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)), + )) + } + } + + match request.method { + Some(m) if m.eq_ignore_ascii_case(METHOD_GET) => {} + None => {} + m => { + return Err(Error::with_cause( + ErrorKind::Http, + HttpError::HttpMethod(m.map(ToString::to_string)), + )); + } + } + + Ok(()) +} + +/// Validates that `version` and `method` are correct for a WebSocket upgrade. +/// +/// # Returns +/// `Ok(())` if they are correct or `Err(e)` if they are not. +pub fn validate_method_and_version(version: Version, method: &Method) -> Result<(), Error> { + if version < Version::HTTP_11 { + return Err(Error::with_cause( + ErrorKind::Http, + HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)), + )); + } + + if method != Method::GET { + return Err(Error::with_cause( + ErrorKind::Http, + HttpError::HttpMethod(Some(method.to_string())), + )); + } + + Ok(()) +} + +/// Validates that 'headers' contains one 'host' header and that it is not a seperated list. +fn validate_host_header(headers: &HeaderMap) -> Result<(), Error> { + let len = headers + .iter() + .filter_map(|(name, value)| { + if name.as_str().eq_ignore_ascii_case(HOST.as_str()) { + Some(value.as_bytes().split(|c| c == &b' ' || c == &b',')) + } else { + None + } + }) + .count(); + if len == 1 { + Ok(()) + } else { + Err(Error::with_cause( + ErrorKind::Http, + HttpError::MissingHeader(HOST), + )) + } +} diff --git a/ratchet_core/src/lib.rs b/ratchet_core/src/lib.rs index b1ba829..96bd5ae 100644 --- a/ratchet_core/src/lib.rs +++ b/ratchet_core/src/lib.rs @@ -79,7 +79,7 @@ impl WebSocketStream for S where S: AsyncRead + AsyncWrite + Send + Unpin + ' /// It should generally not be required unless integrating Ratchet into other libraries. pub mod server { pub use crate::handshake::{ - build_response, build_response_headers, handshake, parse_request, response_from_headers, - validate_method_and_version, UpgradeRequest, UpgradeRequestParts, + build_response, build_response_headers, handshake, parse_request_parts, + response_from_headers, validate_method_and_version, UpgradeRequest, UpgradeRequestParts, }; } From ab6406f789186c56d49100934bd24544d675ce67 Mon Sep 17 00:00:00 2001 From: SirCipher Date: Fri, 27 Sep 2024 09:21:18 +0100 Subject: [PATCH 13/13] Updates incorrect documentation --- ratchet_core/src/handshake/server/mod.rs | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/ratchet_core/src/handshake/server/mod.rs b/ratchet_core/src/handshake/server/mod.rs index 96506d1..7f58c76 100644 --- a/ratchet_core/src/handshake/server/mod.rs +++ b/ratchet_core/src/handshake/server/mod.rs @@ -636,11 +636,10 @@ pub fn build_response_headers( /// negotiates the subprotocols and extensions specified in the request. /// /// # Arguments -/// - `request`: An `http::Request` representing the incoming HTTP request from the client, which -/// is expected to contain WebSocket-specific headers. While it is discouraged for GET requests to -/// have a body it is not technically incorrect and the use of this function is lowering the -/// guardrails to allow for Ratchet to be more easily integrated into other libraries. It is the -/// implementors responsibility to perform any validation on the body. +/// - `version`: The HTTP version of the request. +/// - `method`: The HTTP method of the request. +/// - `headers`: A reference to the request's `HeaderMap` containing the HTTP headers. These headers +/// must include the necessary WebSocket headers such as `Sec-WebSocket-Key` and `Upgrade`. /// - `extension`: An instance of a type that implements the `ExtensionProvider` /// trait. This object is responsible for negotiating any server-supported /// extensions requested by the client. @@ -648,11 +647,10 @@ pub fn build_response_headers( /// to negotiate one with the client. /// /// # Returns -/// This function returns a `Result, Error>`, where: -/// - `Ok(UpgradeRequest)`: Contains the parsed information needed for the WebSocket -/// handshake, including the WebSocket key, negotiated subprotocol, optional -/// extensions, and the original HTTP request. -/// - `Err(Error)`: Contains an error if the request is invalid or cannot be parsed. +/// This function returns a `Result, Error>`, where: +/// - `Ok(UpgradeRequestParts)`: Contains the parsed information needed for the WebSocket +/// handshake, including the WebSocket key, negotiated subprotocol, and an optional extension. +/// - `Err(Error)`: Contains an error if the request parts are invalid or cannot be parsed. /// This could include issues such as unsupported HTTP versions, invalid methods, /// missing required headers, or failed negotiations for subprotocols or extensions. pub fn parse_request_parts(