From ff67b0b12e94d6ddbc8e3398c1782b2142bada20 Mon Sep 17 00:00:00 2001 From: kazk Date: Mon, 13 Sep 2021 13:37:05 -0700 Subject: [PATCH] Add `SecWebsocketExtensions` --- src/common/mod.rs | 2 + src/common/sec_websocket_extensions.rs | 399 +++++++++++++++++++++++++ src/util/mod.rs | 2 +- src/util/value_string.rs | 9 + 4 files changed, 411 insertions(+), 1 deletion(-) create mode 100644 src/common/sec_websocket_extensions.rs diff --git a/src/common/mod.rs b/src/common/mod.rs index 3a1e9c0f..915051a5 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -56,6 +56,7 @@ pub use self::referer::Referer; pub use self::referrer_policy::ReferrerPolicy; pub use self::retry_after::RetryAfter; pub use self::sec_websocket_accept::SecWebsocketAccept; +pub use self::sec_websocket_extensions::{SecWebsocketExtensions, WebsocketExtension}; pub use self::sec_websocket_key::SecWebsocketKey; pub use self::sec_websocket_version::SecWebsocketVersion; pub use self::server::Server; @@ -175,6 +176,7 @@ mod referer; mod referrer_policy; mod retry_after; mod sec_websocket_accept; +mod sec_websocket_extensions; mod sec_websocket_key; mod sec_websocket_version; mod server; diff --git a/src/common/sec_websocket_extensions.rs b/src/common/sec_websocket_extensions.rs new file mode 100644 index 00000000..a9218606 --- /dev/null +++ b/src/common/sec_websocket_extensions.rs @@ -0,0 +1,399 @@ +use std::convert::TryFrom; + +use bytes::BytesMut; +use http::header::SEC_WEBSOCKET_EXTENSIONS; + +use util::{Comma, FlatCsv, HeaderValueString, SemiColon}; +use {Error, Header, HeaderValue}; + +/// `Sec-WebSocket-Extensions` header, defined in [RFC6455][RFC6455_11.3.2] +/// +/// The `Sec-WebSocket-Extensions` header field is used in the WebSocket +/// opening handshake. It is initially sent from the client to the +/// server, and then subsequently sent from the server to the client, to +/// agree on a set of protocol-level extensions to use for the duration +/// of the connection. +/// +/// ## ABNF +/// +/// ```text +/// Sec-WebSocket-Extensions = extension-list +/// extension-list = 1#extension +/// extension = extension-token *( ";" extension-param ) +/// extension-token = registered-token +/// registered-token = token +/// extension-param = token [ "=" (token | quoted-string) ] +/// ``` +/// +/// ## Example Values +/// +/// * `permessage-deflate` (defined in [RFC7692][RFC7692_7]) +/// * `permessage-deflate; server_max_window_bits=10` +/// * `permessage-deflate; server_max_window_bits=10, permessage-deflate` +/// +/// ## Example +/// +/// ```rust +/// # extern crate headers; +/// use headers::SecWebsocketExtensions; +/// +/// let extensions = SecWebsocketExtensions::from_static("permessage-deflate"); +/// ``` +/// +/// ## Splitting and Combining +/// +/// Note that `Sec-WebSocket-Extensions` may be split or combined across multiple headers. +/// The following are equivalent: +/// ```text +/// Sec-WebSocket-Extensions: foo +/// Sec-WebSocket-Extensions: bar; baz=2 +/// ``` +/// ```text +/// Sec-WebSocket-Extensions: foo, bar; baz=2 +/// ``` +/// +/// `SecWebsocketExtensions` splits extensions when decoding and combines them into a single +/// value when encoding. +/// +/// [RFC6455_11.3.2]: https://tools.ietf.org/html/rfc6455#section-11.3.2 +/// [RFC7692_7]: https://tools.ietf.org/html/rfc7692#section-7 +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct SecWebsocketExtensions(Vec); + +impl Header for SecWebsocketExtensions { + fn name() -> &'static ::HeaderName { + &SEC_WEBSOCKET_EXTENSIONS + } + + fn decode<'i, I: Iterator>(values: &mut I) -> Result { + let extensions = values + .cloned() + .flat_map(|v| { + FlatCsv::::from(v) + .iter() + .map(WebsocketExtension::try_from) + .collect::>() + }) + .collect::, _>>()?; + if extensions.is_empty() { + Err(Error::invalid()) + } else { + Ok(SecWebsocketExtensions(extensions)) + } + } + + fn encode>(&self, values: &mut E) { + if !self.0.is_empty() { + values.extend(std::iter::once(self.to_value())); + } + } +} + +impl SecWebsocketExtensions { + /// Construct a `SecWebSocketExtensions` from a static string. + /// + /// ## Panic + /// + /// Panics if the static string is not a valid extensions valie. + pub fn from_static(s: &'static str) -> Self { + let value = HeaderValue::from_static(s); + Self::try_from(&value).expect("valid static string") + } + + /// Convert this `SecWebsocketExtensions` to a single `HeaderValue`. + pub fn to_value(&self) -> HeaderValue { + let values = self.0.iter().map(HeaderValue::from).collect::(); + HeaderValue::from(&values) + } + + /// An iterator over the `WebsocketExtension`s in `SecWebsocketExtensions` header(s). + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + /// Get the number of extensions. + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns `true` if headers contain no extensions. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Get a shared reference to the extensions. + pub fn extensions(&self) -> &Vec { + &self.0 + } + + /// Get a mutable reference to the extensions. + pub fn extensions_mut(&mut self) -> &mut Vec { + self.0.as_mut() + } +} + +impl TryFrom<&str> for SecWebsocketExtensions { + type Error = Error; + + fn try_from(value: &str) -> Result { + let value = HeaderValue::from_str(value).map_err(|_| Error::invalid())?; + Self::try_from(&value) + } +} + +impl TryFrom<&HeaderValue> for SecWebsocketExtensions { + type Error = Error; + + fn try_from(value: &HeaderValue) -> Result { + let mut values = std::iter::once(value); + SecWebsocketExtensions::decode(&mut values) + } +} + +/// A WebSocket extension containing the name and parameters. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct WebsocketExtension { + name: HeaderValueString, + params: Vec<(HeaderValueString, Option)>, +} + +impl WebsocketExtension { + /// Construct a `WebSocketExtension` from a static string. + /// + /// ## Panics + /// + /// This function panics if the argument is invalid. + pub fn from_static(src: &'static str) -> Self { + WebsocketExtension::try_from(HeaderValue::from_static(src)).expect("valid static value") + } + + /// Get the name of the extension. + pub fn name(&self) -> &str { + self.name.as_str() + } + + /// Get the paramaters of the extension. + pub fn parameters(&self) -> Vec<(&str, Option<&str>)> { + self.params + .iter() + .map(|(k, v)| (k.as_str(), v.as_ref().map(|v| v.as_str()))) + .collect() + } +} + +impl TryFrom<&str> for WebsocketExtension { + type Error = Error; + + fn try_from(value: &str) -> Result { + if value.is_empty() { + Err(Error::invalid()) + } else { + let value = HeaderValue::from_str(value).map_err(|_| Error::invalid())?; + WebsocketExtension::try_from(value) + } + } +} + +impl TryFrom for WebsocketExtension { + type Error = Error; + + fn try_from(value: HeaderValue) -> Result { + let csv = FlatCsv::::from(value); + // More than one extension was found + if csv.iter().count() > 1 { + return Err(Error::invalid()); + } + + let params = FlatCsv::::from(csv.value); + let mut params_iter = params.iter(); + let name = params_iter + .next() + .ok_or_else(Error::invalid) + .and_then(HeaderValueString::from_str)?; + let params = params_iter + .map(|p| { + let mut kv = p.splitn(2, '='); + let key = kv + .next() + .ok_or_else(Error::invalid) + .and_then(HeaderValueString::from_str)?; + let val = kv + .next() + .map(|v| HeaderValueString::from_str(v.trim_matches('"'))) + .transpose()?; + Ok((key, val)) + }) + .collect::, _>>()?; + Ok(WebsocketExtension { name, params }) + } +} + +impl From<&WebsocketExtension> for HeaderValue { + fn from(extension: &WebsocketExtension) -> Self { + let mut buf = BytesMut::from(extension.name.as_str().as_bytes()); + for (key, val) in &extension.params { + buf.extend_from_slice(b"; "); + buf.extend_from_slice(key.as_str().as_bytes()); + if let Some(val) = val { + buf.extend_from_slice(b"="); + buf.extend_from_slice(val.as_str().as_bytes()); + } + } + + HeaderValue::from_maybe_shared(buf.freeze()) + .expect("semicolon separated HeaderValueStrings are valid") + } +} + +#[cfg(test)] +mod tests { + use super::super::{test_decode, test_encode}; + use super::*; + + #[test] + fn extensions_decode() { + let extensions = + test_decode::(&["key1; val1", "key2; val2"]).unwrap(); + assert_eq!(extensions.0.len(), 2); + assert_eq!( + extensions.0[0], + WebsocketExtension::try_from("key1; val1").unwrap() + ); + assert_eq!( + extensions.0[1], + WebsocketExtension::try_from("key2; val2").unwrap() + ); + + assert_eq!(test_decode::(&[""]), None); + } + + #[test] + fn extensions_decode_split() { + // Split each extension into separate headers + let extensions = + test_decode::(&["key1; val1, key2; val2", "key3; val3"]) + .unwrap(); + assert_eq!(extensions.0.len(), 3); + assert_eq!( + extensions.0[0], + WebsocketExtension::try_from("key1; val1").unwrap() + ); + assert_eq!( + extensions.0[1], + WebsocketExtension::try_from("key2; val2").unwrap() + ); + assert_eq!( + extensions.0[2], + WebsocketExtension::try_from("key3; val3").unwrap() + ); + } + + #[test] + fn extensions_encode() { + let extensions = + SecWebsocketExtensions(vec![WebsocketExtension::from_static("foo; bar; baz=1")]); + let headers = test_encode(extensions); + let mut vals = headers.get_all(SEC_WEBSOCKET_EXTENSIONS).into_iter(); + assert_eq!(vals.next().unwrap(), "foo; bar; baz=1"); + assert_eq!(vals.next(), None); + + let extensions = SecWebsocketExtensions(vec![]); + let headers = test_encode(extensions); + let mut vals = headers.get_all(SEC_WEBSOCKET_EXTENSIONS).into_iter(); + assert_eq!(vals.next(), None); + } + + #[test] + fn extensions_encode_combine() { + // Multiple extensions are combined into a single header + let extensions = SecWebsocketExtensions(vec![ + WebsocketExtension::from_static("foo1; bar"), + WebsocketExtension::from_static("foo2; bar"), + WebsocketExtension::from_static("baz; quux"), + ]); + let headers = test_encode(extensions); + let mut vals = headers.get_all(SEC_WEBSOCKET_EXTENSIONS).into_iter(); + assert_eq!(vals.next().unwrap(), "foo1; bar, foo2; bar, baz; quux"); + assert_eq!(vals.next(), None); + } + + #[test] + fn extensions_iter() { + let extensions = SecWebsocketExtensions(vec![ + WebsocketExtension::from_static("foo; bar1; bar2=3"), + WebsocketExtension::from_static("baz; quux"), + ]); + assert_eq!(extensions.len(), 2); + + let mut iter = extensions.iter(); + let extension = iter.next().unwrap(); + assert_eq!(extension.name(), "foo"); + let mut params = extension.parameters().into_iter(); + assert_eq!(params.next(), Some(("bar1", None))); + assert_eq!(params.next(), Some(("bar2", Some("3")))); + assert!(params.next().is_none()); + + let extension = iter.next().unwrap(); + assert_eq!(extension.name(), "baz"); + let mut params = extension.parameters().into_iter(); + assert_eq!(params.next(), Some(("quux", None))); + assert!(params.next().is_none()); + + assert!(iter.next().is_none()); + } + + #[test] + fn extensions_get_extensions() { + let ext1 = WebsocketExtension::from_static("foo; bar1; bar2=3"); + let ext2 = WebsocketExtension::from_static("baz; quux"); + let exts = vec![ext1, ext2]; + let extensions = SecWebsocketExtensions(exts.clone()); + assert_eq!(extensions.extensions(), &exts); + } + + #[test] + fn extensions_get_extensions_mut() { + let ext1 = WebsocketExtension::from_static("foo; bar1; bar2=3"); + let ext2 = WebsocketExtension::from_static("baz; quux"); + let mut extensions = SecWebsocketExtensions(vec![ext1, ext2]); + assert_eq!(extensions.len(), 2); + let exts = extensions.extensions_mut(); + exts.push(WebsocketExtension::from_static("baz; quux")); + assert_eq!(extensions.len(), 3); + } + + #[test] + fn extension_try_from_str_ok() { + let ext = WebsocketExtension::try_from("permessage-deflate").unwrap(); + assert_eq!(ext.name(), "permessage-deflate"); + assert_eq!(ext.parameters(), vec![]); + + let ext = + WebsocketExtension::try_from("permessage-deflate; client_max_window_bits").unwrap(); + assert_eq!(ext.name(), "permessage-deflate"); + assert_eq!(ext.parameters(), vec![("client_max_window_bits", None)]); + + let ext = + WebsocketExtension::try_from("permessage-deflate; server_max_window_bits=10").unwrap(); + assert_eq!(ext.name(), "permessage-deflate"); + assert_eq!( + ext.parameters(), + vec![("server_max_window_bits", Some("10"))] + ); + + let ext = WebsocketExtension::try_from("permessage-deflate; server_max_window_bits=\"10\"") + .unwrap(); + assert_eq!(ext.name(), "permessage-deflate"); + assert_eq!( + ext.parameters(), + vec![("server_max_window_bits", Some("10"))] + ); + } + + #[test] + fn extension_try_from_str_err() { + assert!(WebsocketExtension::try_from("").is_err()); + // Only single extension is allowed + assert!(WebsocketExtension::try_from("permessage-deflate, permessage-snappy").is_err()); + } +} diff --git a/src/util/mod.rs b/src/util/mod.rs index 07fddbfb..4ebb754c 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -3,7 +3,7 @@ use HeaderValue; //pub use self::charset::Charset; //pub use self::encoding::Encoding; pub(crate) use self::entity::{EntityTag, EntityTagRange}; -pub(crate) use self::flat_csv::{FlatCsv, SemiColon}; +pub(crate) use self::flat_csv::{Comma, FlatCsv, SemiColon}; pub(crate) use self::fmt::fmt; pub(crate) use self::http_date::HttpDate; pub(crate) use self::iter::IterExt; diff --git a/src/util/value_string.rs b/src/util/value_string.rs index 865a3558..0e2d797c 100644 --- a/src/util/value_string.rs +++ b/src/util/value_string.rs @@ -26,6 +26,15 @@ impl HeaderValueString { } } + pub(crate) fn from_str(src: &str) -> Result { + let value = HeaderValue::from_str(src).map_err(|_| ::Error::invalid())?; + if value.to_str().is_ok() { + Ok(HeaderValueString { value }) + } else { + Err(::Error::invalid()) + } + } + pub(crate) fn from_string(src: String) -> Option { // A valid `str` (the argument)... let bytes = Bytes::from(src);