diff --git a/CHANGELOG.md b/CHANGELOG.md index a61ee91d..3c17e4f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # Unreleased -- Implement `AsRef<[u8]>` for `Payload`, `AsRef<[u8]>`, `AsRef` for `Utf8Payload`. -- Implement `&str`-like `PartialEq` for `Utf8Payload`. +- Simplify `Message` to use `Bytes` payload directly with simpler `Utf8Bytes` for text. +- Change `CloseFrame` to use `Utf8Bytes` for `reason`. +- Re-export `Bytes`. # 0.25.0 diff --git a/benches/read.rs b/benches/read.rs index 252ef3e3..75b405bf 100644 --- a/benches/read.rs +++ b/benches/read.rs @@ -64,7 +64,7 @@ fn benchmark(c: &mut Criterion) { while sum != expected_sum { match ws.read().unwrap() { Message::Binary(v) => { - let a: &[u8; 8] = v.as_slice().try_into().unwrap(); + let a: &[u8; 8] = v.as_ref().try_into().unwrap(); sum += u64::from_le_bytes(*a); } Message::Text(msg) => { diff --git a/src/lib.rs b/src/lib.rs index 6b79d5be..be169d40 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,8 +37,10 @@ type ReadBuffer = buffer::ReadBuffer; pub use crate::{ error::{Error, Result}, - protocol::{Message, WebSocket}, + protocol::{frame::Utf8Bytes, Message, WebSocket}, }; +// re-export bytes since used in `Message` API. +pub use bytes::Bytes; #[cfg(feature = "handshake")] pub use crate::{ diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 11cfe7fd..31a4c004 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -1,44 +1,35 @@ +use byteorder::{NetworkEndian, ReadBytesExt}; +use log::*; use std::{ - borrow::Cow, default::Default, fmt, io::{Cursor, ErrorKind, Read, Write}, + mem, result::Result as StdResult, str::Utf8Error, string::String, }; -use byteorder::{NetworkEndian, ReadBytesExt}; -use log::*; - use super::{ coding::{CloseCode, Control, Data, OpCode}, mask::{apply_mask, generate_mask}, - Payload, }; use crate::{ error::{Error, ProtocolError, Result}, - protocol::frame::Utf8Payload, + protocol::frame::Utf8Bytes, }; -use bytes::{Buf, BytesMut}; +use bytes::{Bytes, BytesMut}; /// A struct representing the close command. #[derive(Debug, Clone, Eq, PartialEq)] -pub struct CloseFrame<'t> { +pub struct CloseFrame { /// The reason as a code. pub code: CloseCode, /// The reason as text string. - pub reason: Cow<'t, str>, + pub reason: Utf8Bytes, } -impl CloseFrame<'_> { - /// Convert into a owned string. - pub fn into_owned(self) -> CloseFrame<'static> { - CloseFrame { code: self.code, reason: self.reason.into_owned().into() } - } -} - -impl fmt::Display for CloseFrame<'_> { +impl fmt::Display for CloseFrame { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{} ({})", self.reason, self.code) } @@ -213,7 +204,7 @@ impl FrameHeader { #[derive(Debug, Clone, Eq, PartialEq)] pub struct Frame { header: FrameHeader, - payload: Payload, + payload: Bytes, } impl Frame { @@ -246,13 +237,7 @@ impl Frame { /// Get a reference to the frame's payload. #[inline] pub fn payload(&self) -> &[u8] { - self.payload.as_slice() - } - - /// Get a mutable reference to the frame's payload. - #[inline] - pub fn payload_mut(&mut self) -> &mut [u8] { - self.payload.as_mut_slice() + &self.payload } /// Test whether the frame is masked. @@ -270,52 +255,41 @@ impl Frame { self.header.set_random_mask(); } - /// This method unmasks the payload and should only be called on frames that are actually - /// masked. In other words, those frames that have just been received from a client endpoint. - #[inline] - pub(crate) fn apply_mask(&mut self) { - if let Some(mask) = self.header.mask.take() { - apply_mask(self.payload.as_mut_slice(), mask); - } - } - /// Consume the frame into its payload as string. #[inline] - pub fn into_text(self) -> StdResult { - self.payload.into_text() + pub fn into_text(self) -> StdResult { + self.payload.try_into() } /// Consume the frame into its payload. #[inline] - pub fn into_payload(self) -> Payload { + pub fn into_payload(self) -> Bytes { self.payload } /// Get frame payload as `&str`. #[inline] pub fn to_text(&self) -> Result<&str, Utf8Error> { - std::str::from_utf8(self.payload.as_slice()) + std::str::from_utf8(&self.payload) } /// Consume the frame into a closing frame. #[inline] - pub(crate) fn into_close(self) -> Result>> { + pub(crate) fn into_close(self) -> Result> { match self.payload.len() { 0 => Ok(None), 1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)), _ => { - let mut data = self.payload.as_slice(); - let code = u16::from_be_bytes([data[0], data[1]]).into(); - data.advance(2); - let text = String::from_utf8(data.to_vec())?; - Ok(Some(CloseFrame { code, reason: text.into() })) + let code = u16::from_be_bytes([self.payload[0], self.payload[1]]).into(); + let reason = Utf8Bytes::try_from(self.payload.slice(2..))?; + Ok(Some(CloseFrame { code, reason })) } } } /// Create a new data frame. #[inline] - pub fn message(data: impl Into, opcode: OpCode, is_final: bool) -> Frame { + pub fn message(data: impl Into, opcode: OpCode, is_final: bool) -> Frame { debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame."); Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, @@ -325,7 +299,7 @@ impl Frame { /// Create a new Pong control frame. #[inline] - pub fn pong(data: impl Into) -> Frame { + pub fn pong(data: impl Into) -> Frame { Frame { header: FrameHeader { opcode: OpCode::Control(Control::Pong), @@ -337,7 +311,7 @@ impl Frame { /// Create a new Ping control frame. #[inline] - pub fn ping(data: impl Into) -> Frame { + pub fn ping(data: impl Into) -> Frame { Frame { header: FrameHeader { opcode: OpCode::Control(Control::Ping), @@ -359,19 +333,39 @@ impl Frame { <_>::default() }; - Frame { header: FrameHeader::default(), payload: Payload::Owned(payload) } + Frame { header: FrameHeader::default(), payload: payload.into() } } /// Create a frame from given header and data. - pub fn from_payload(header: FrameHeader, payload: Payload) -> Self { + pub fn from_payload(header: FrameHeader, payload: Bytes) -> Self { Frame { header, payload } } /// Write a frame out to a buffer pub fn format(mut self, output: &mut impl Write) -> Result<()> { self.header.format(self.payload.len() as u64, output)?; - self.apply_mask(); - output.write_all(self.payload())?; + + if let Some(mask) = self.header.mask.take() { + let mut data = Vec::from(mem::take(&mut self.payload)); + apply_mask(&mut data, mask); + output.write_all(&data)?; + } else { + output.write_all(&self.payload)?; + } + + Ok(()) + } + + pub(crate) fn format_into_buf(mut self, buf: &mut Vec) -> Result<()> { + self.header.format(self.payload.len() as u64, buf)?; + + let len = buf.len(); + buf.extend_from_slice(&self.payload); + + if let Some(mask) = self.header.mask.take() { + apply_mask(&mut buf[len..], mask); + } + Ok(()) } } @@ -399,7 +393,7 @@ payload: 0x{} // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), self.len(), self.payload.len(), - self.payload.as_slice().iter().fold(String::new(), |mut output, byte| { + self.payload.iter().fold(String::new(), |mut output, byte| { _ = write!(output, "{byte:02x}"); output }) @@ -474,7 +468,7 @@ mod tests { let mut payload = Vec::new(); raw.read_to_end(&mut payload).unwrap(); let frame = Frame::from_payload(header, payload.into()); - assert_eq!(frame.into_payload(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); + assert_eq!(frame.into_payload(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07][..]); } #[test] @@ -485,9 +479,17 @@ mod tests { assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]); } + #[test] + fn format_into_buf() { + let frame = Frame::ping(vec![0x01, 0x02]); + let mut buf = Vec::with_capacity(frame.len()); + frame.format_into_buf(&mut buf).unwrap(); + assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]); + } + #[test] fn display() { - let f = Frame::message(Payload::from_static(b"hi there"), OpCode::Data(Data::Text), true); + let f = Frame::message(Bytes::from_static(b"hi there"), OpCode::Data(Data::Text), true); let view = format!("{f}"); assert!(view.contains("payload:")); } diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index fd3b4524..2d76830d 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -5,15 +5,16 @@ pub mod coding; #[allow(clippy::module_inception)] mod frame; mod mask; -mod payload; +mod utf8; pub use self::{ frame::{CloseFrame, Frame, FrameHeader}, - payload::{Payload, Utf8Payload}, + utf8::Utf8Bytes, }; use crate::{ - error::{CapacityError, Error, Result}, + error::{CapacityError, Error, ProtocolError, Result}, + protocol::frame::mask::apply_mask, Message, }; use bytes::BytesMut; @@ -65,7 +66,7 @@ where { /// Read a frame from stream. pub fn read(&mut self, max_size: Option) -> Result> { - self.codec.read_frame(&mut self.stream, max_size) + self.codec.read_frame(&mut self.stream, max_size, false, true) } } @@ -158,13 +159,15 @@ impl FrameCodec { &mut self, stream: &mut Stream, max_size: Option, + unmask: bool, + accept_unmasked: bool, ) -> Result> where Stream: Read, { let max_size = max_size.unwrap_or_else(usize::max_value); - let payload = loop { + let mut payload = loop { { if self.header.is_none() { let mut cursor = Cursor::new(&mut self.in_buffer); @@ -205,9 +208,24 @@ impl FrameCodec { } }; - let (header, length) = self.header.take().expect("Bug: no frame header"); + let (mut header, length) = self.header.take().expect("Bug: no frame header"); debug_assert_eq!(payload.len() as u64, length); - let frame = Frame::from_payload(header, Payload::Owned(payload)); + + if unmask { + if let Some(mask) = header.mask.take() { + // A server MUST remove masking for data frames received from a client + // as described in Section 5.3. (RFC 6455) + apply_mask(&mut payload, mask); + } else if !accept_unmasked { + // The server MUST close the connection upon receiving a + // frame that is not masked. (RFC 6455) + // The only exception here is if the user explicitly accepts given + // stream by setting WebSocketConfig.accept_unmasked_frames to true + return Err(Error::Protocol(ProtocolError::UnmaskedFrameFromClient)); + } + } + + let frame = Frame::from_payload(header, payload.freeze()); trace!("received frame {frame}"); Ok(Some(frame)) } @@ -230,7 +248,7 @@ impl FrameCodec { trace!("writing frame {frame}"); self.out_buffer.reserve(frame.len()); - frame.format(&mut self.out_buffer).expect("Bug: can't write to vector"); + frame.format_into_buf(&mut self.out_buffer).expect("Bug: can't write to vector"); if self.out_buffer.len() > self.out_buffer_write_len { self.write_out_buffer(stream) @@ -284,9 +302,9 @@ mod tests { assert_eq!( sock.read(None).unwrap().unwrap().into_payload(), - &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] + &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07][..] ); - assert_eq!(sock.read(None).unwrap().unwrap().into_payload(), &[0x03, 0x02, 0x01]); + assert_eq!(sock.read(None).unwrap().unwrap().into_payload(), &[0x03, 0x02, 0x01][..]); assert!(sock.read(None).unwrap().is_none()); let (_, rest) = sock.into_inner(); @@ -299,7 +317,7 @@ mod tests { let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]); assert_eq!( sock.read(None).unwrap().unwrap().into_payload(), - &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] + &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07][..] ); } diff --git a/src/protocol/frame/payload.rs b/src/protocol/frame/payload.rs deleted file mode 100644 index 5c70e5f0..00000000 --- a/src/protocol/frame/payload.rs +++ /dev/null @@ -1,306 +0,0 @@ -use bytes::{Bytes, BytesMut}; -use core::str; -use std::{fmt::Display, mem}; - -/// Utf8 payload. -#[derive(Debug, Default, Clone, Eq, PartialEq)] -pub struct Utf8Payload(Payload); - -impl Utf8Payload { - /// Creates from a static str. - #[inline] - pub const fn from_static(str: &'static str) -> Self { - Self(Payload::Shared(Bytes::from_static(str.as_bytes()))) - } - - /// Returns a slice of the payload. - #[inline] - pub fn as_slice(&self) -> &[u8] { - self.0.as_slice() - } - - /// Returns as a string slice. - #[inline] - pub fn as_str(&self) -> &str { - // safety: is valid uft8 - unsafe { str::from_utf8_unchecked(self.as_slice()) } - } - - /// Returns length in bytes. - #[inline] - pub fn len(&self) -> usize { - self.as_slice().len() - } - - /// Returns true if the length is 0. - #[inline] - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// If owned converts into [`Bytes`] internals & then clones (cheaply). - #[inline] - pub fn share(&mut self) -> Self { - Self(self.0.share()) - } -} - -impl TryFrom for Utf8Payload { - type Error = str::Utf8Error; - - #[inline] - fn try_from(payload: Payload) -> Result { - str::from_utf8(payload.as_slice())?; - Ok(Self(payload)) - } -} - -impl TryFrom for Utf8Payload { - type Error = str::Utf8Error; - - #[inline] - fn try_from(bytes: Bytes) -> Result { - Payload::from(bytes).try_into() - } -} - -impl TryFrom for Utf8Payload { - type Error = str::Utf8Error; - - #[inline] - fn try_from(bytes: BytesMut) -> Result { - Payload::from(bytes).try_into() - } -} - -impl TryFrom> for Utf8Payload { - type Error = str::Utf8Error; - - #[inline] - fn try_from(bytes: Vec) -> Result { - Payload::from(bytes).try_into() - } -} - -impl From for Utf8Payload { - #[inline] - fn from(s: String) -> Self { - Self(s.into()) - } -} - -impl From<&str> for Utf8Payload { - #[inline] - fn from(s: &str) -> Self { - Self(Payload::Owned(s.as_bytes().into())) - } -} - -impl From<&String> for Utf8Payload { - #[inline] - fn from(s: &String) -> Self { - s.as_str().into() - } -} - -impl From for Payload { - #[inline] - fn from(Utf8Payload(payload): Utf8Payload) -> Self { - payload - } -} - -impl Display for Utf8Payload { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.as_str()) - } -} - -impl AsRef for Utf8Payload { - #[inline] - fn as_ref(&self) -> &str { - self.as_str() - } -} - -impl AsRef<[u8]> for Utf8Payload { - #[inline] - fn as_ref(&self) -> &[u8] { - self.as_slice() - } -} - -impl PartialEq for Utf8Payload -where - for<'a> &'a str: PartialEq, -{ - /// ``` - /// use tungstenite::protocol::frame::Utf8Payload; - /// let payload = Utf8Payload::from_static("foo123"); - /// assert_eq!(payload, "foo123"); - /// assert_eq!(payload, "foo123".to_string()); - /// assert_eq!(payload, &"foo123".to_string()); - /// assert_eq!(payload, std::borrow::Cow::from("foo123")); - /// ``` - #[inline] - fn eq(&self, other: &T) -> bool { - self.as_str() == *other - } -} - -/// A payload of a WebSocket frame. -#[derive(Debug, Clone)] -pub enum Payload { - /// Owned data with unique ownership. - Owned(BytesMut), - /// Shared data with shared ownership. - Shared(Bytes), - /// Owned vec data. - Vec(Vec), -} - -impl Payload { - /// Creates from static bytes. - #[inline] - pub const fn from_static(bytes: &'static [u8]) -> Self { - Self::Shared(Bytes::from_static(bytes)) - } - - /// Converts into [`Bytes`] internals & then clones (cheaply). - pub fn share(&mut self) -> Self { - match self { - Self::Owned(data) => { - *self = Self::Shared(mem::take(data).freeze()); - } - Self::Vec(data) => { - *self = Self::Shared(Bytes::from(mem::take(data))); - } - Self::Shared(_) => {} - } - self.clone() - } - - /// Returns a slice of the payload. - #[inline] - pub fn as_slice(&self) -> &[u8] { - match self { - Payload::Owned(v) => v, - Payload::Shared(v) => v, - Payload::Vec(v) => v, - } - } - - /// Returns a mutable slice of the payload. - /// - /// Note that this will internally allocate if the payload is shared - /// and there are other references to the same data. No allocation - /// would happen if the payload is owned or if there is only one - /// `Bytes` instance referencing the data. - #[inline] - pub fn as_mut_slice(&mut self) -> &mut [u8] { - match self { - Payload::Owned(v) => &mut *v, - Payload::Vec(v) => &mut *v, - Payload::Shared(v) => { - // Using `Bytes::to_vec()` or `Vec::from(bytes.as_ref())` would mean making a copy. - // `Bytes::into()` would not make a copy if our `Bytes` instance is the only one. - let data = mem::take(v).into(); - *self = Payload::Owned(data); - match self { - Payload::Owned(v) => v, - _ => unreachable!(), - } - } - } - } - - /// Returns the length of the payload. - #[inline] - pub fn len(&self) -> usize { - self.as_slice().len() - } - - /// Returns true if the payload has a length of 0. - #[inline] - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Consumes the payload and returns the underlying data as a string. - #[inline] - pub fn into_text(self) -> Result { - self.try_into() - } -} - -impl Default for Payload { - #[inline] - fn default() -> Self { - Self::Owned(<_>::default()) - } -} - -impl From> for Payload { - #[inline] - fn from(v: Vec) -> Self { - Payload::Vec(v) - } -} - -impl From for Payload { - #[inline] - fn from(v: String) -> Self { - v.into_bytes().into() - } -} - -impl From for Payload { - #[inline] - fn from(v: Bytes) -> Self { - Payload::Shared(v) - } -} - -impl From for Payload { - #[inline] - fn from(v: BytesMut) -> Self { - Payload::Owned(v) - } -} - -impl From<&[u8]> for Payload { - #[inline] - fn from(v: &[u8]) -> Self { - Self::Owned(v.into()) - } -} - -impl PartialEq for Payload { - #[inline] - fn eq(&self, other: &Payload) -> bool { - self.as_slice() == other.as_slice() - } -} - -impl Eq for Payload {} - -impl PartialEq<[u8]> for Payload { - #[inline] - fn eq(&self, other: &[u8]) -> bool { - self.as_slice() == other - } -} - -impl PartialEq<&[u8; N]> for Payload { - #[inline] - fn eq(&self, other: &&[u8; N]) -> bool { - self.as_slice() == *other - } -} - -impl AsRef<[u8]> for Payload { - #[inline] - fn as_ref(&self) -> &[u8] { - self.as_slice() - } -} diff --git a/src/protocol/frame/utf8.rs b/src/protocol/frame/utf8.rs new file mode 100644 index 00000000..845c96bf --- /dev/null +++ b/src/protocol/frame/utf8.rs @@ -0,0 +1,123 @@ +use bytes::{Bytes, BytesMut}; +use core::str; +use std::fmt::Display; + +/// Utf8 payload. +#[derive(Debug, Default, Clone, Eq, PartialEq)] +pub struct Utf8Bytes(Bytes); + +impl Utf8Bytes { + /// Creates from a static str. + #[inline] + pub const fn from_static(str: &'static str) -> Self { + Self(Bytes::from_static(str.as_bytes())) + } + + /// Returns as a string slice. + #[inline] + pub fn as_str(&self) -> &str { + // SAFETY: is valid uft8 + unsafe { str::from_utf8_unchecked(&self.0) } + } +} + +impl std::ops::Deref for Utf8Bytes { + type Target = str; + + /// ``` + /// /// Example fn that takes a str slice + /// fn a(s: &str) {} + /// + /// let data = tungstenite::Utf8Bytes::from_static("foo123"); + /// + /// // auto-deref as arg + /// a(&data); + /// + /// // deref to str methods + /// assert_eq!(data.len(), 6); + /// ``` + #[inline] + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + +impl PartialEq for Utf8Bytes +where + for<'a> &'a str: PartialEq, +{ + /// ``` + /// let payload = tungstenite::Utf8Bytes::from_static("foo123"); + /// assert_eq!(payload, "foo123"); + /// assert_eq!(payload, "foo123".to_string()); + /// assert_eq!(payload, &"foo123".to_string()); + /// assert_eq!(payload, std::borrow::Cow::from("foo123")); + /// ``` + #[inline] + fn eq(&self, other: &T) -> bool { + self.as_str() == *other + } +} + +impl Display for Utf8Bytes { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +impl TryFrom for Utf8Bytes { + type Error = str::Utf8Error; + + #[inline] + fn try_from(bytes: Bytes) -> Result { + str::from_utf8(&bytes)?; + Ok(Self(bytes)) + } +} + +impl TryFrom for Utf8Bytes { + type Error = str::Utf8Error; + + #[inline] + fn try_from(bytes: BytesMut) -> Result { + bytes.freeze().try_into() + } +} + +impl TryFrom> for Utf8Bytes { + type Error = str::Utf8Error; + + #[inline] + fn try_from(v: Vec) -> Result { + Bytes::from(v).try_into() + } +} + +impl From for Utf8Bytes { + #[inline] + fn from(s: String) -> Self { + Self(s.into()) + } +} + +impl From<&str> for Utf8Bytes { + #[inline] + fn from(s: &str) -> Self { + Self(Bytes::copy_from_slice(s.as_bytes())) + } +} + +impl From<&String> for Utf8Bytes { + #[inline] + fn from(s: &String) -> Self { + s.as_str().into() + } +} + +impl From for Bytes { + #[inline] + fn from(Utf8Bytes(bytes): Utf8Bytes) -> Self { + bytes + } +} diff --git a/src/protocol/message.rs b/src/protocol/message.rs index ee098c09..57539e09 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -1,7 +1,9 @@ -use std::{borrow::Cow, fmt, result::Result as StdResult, str}; - -use super::frame::{CloseFrame, Frame, Payload, Utf8Payload}; -use crate::error::{CapacityError, Error, Result}; +use super::frame::{CloseFrame, Frame}; +use crate::{ + error::{CapacityError, Error, Result}, + protocol::frame::Utf8Bytes, +}; +use std::{fmt, result::Result as StdResult, str}; mod string_collect { use utf8::DecodeError; @@ -74,6 +76,7 @@ mod string_collect { } use self::string_collect::StringCollector; +use bytes::Bytes; /// A struct representing the incomplete message. #[derive(Debug)] @@ -154,19 +157,19 @@ pub enum IncompleteMessageType { #[derive(Debug, Eq, PartialEq, Clone)] pub enum Message { /// A text WebSocket message - Text(Utf8Payload), + Text(Utf8Bytes), /// A binary WebSocket message - Binary(Payload), + Binary(Bytes), /// A ping message with the specified payload /// /// The payload here must have a length less than 125 bytes - Ping(Payload), + Ping(Bytes), /// A pong message with the specified payload /// /// The payload here must have a length less than 125 bytes - Pong(Payload), + Pong(Bytes), /// A close message with the optional close frame. - Close(Option>), + Close(Option), /// Raw frame. Note, that you're not going to get this value while reading the message. Frame(Frame), } @@ -175,7 +178,7 @@ impl Message { /// Create a new text WebSocket message from a stringable. pub fn text(string: S) -> Message where - S: Into, + S: Into, { Message::Text(string.into()) } @@ -183,7 +186,7 @@ impl Message { /// Create a new binary WebSocket message by converting to `Vec`. pub fn binary(bin: B) -> Message where - B: Into, + B: Into, { Message::Binary(bin.into()) } @@ -232,31 +235,25 @@ impl Message { } /// Consume the WebSocket and return it as binary data. - pub fn into_data(self) -> Payload { + pub fn into_data(self) -> Bytes { match self { - Message::Text(string) => string.into(), + Message::Text(utf8) => utf8.into(), Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data, Message::Close(None) => <_>::default(), - Message::Close(Some(frame)) => match frame.reason { - Cow::Borrowed(s) => Payload::from_static(s.as_bytes()), - Cow::Owned(s) => s.into(), - }, + Message::Close(Some(frame)) => frame.reason.into(), Message::Frame(frame) => frame.into_payload(), } } /// Attempt to consume the WebSocket message and convert it to a String. - pub fn into_text(self) -> Result { + pub fn into_text(self) -> Result { match self { Message::Text(txt) => Ok(txt), Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => { Ok(data.try_into()?) } Message::Close(None) => Ok(<_>::default()), - Message::Close(Some(frame)) => Ok(match frame.reason { - Cow::Borrowed(s) => Utf8Payload::from_static(s), - Cow::Owned(s) => s.into(), - }), + Message::Close(Some(frame)) => Ok(frame.reason), Message::Frame(frame) => Ok(frame.into_text()?), } } @@ -267,7 +264,7 @@ impl Message { match *self { Message::Text(ref string) => Ok(string.as_str()), Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => { - Ok(str::from_utf8(data.as_slice())?) + Ok(str::from_utf8(data)?) } Message::Close(None) => Ok(""), Message::Close(Some(ref frame)) => Ok(&frame.reason), @@ -293,7 +290,7 @@ impl<'s> From<&'s str> for Message { impl<'b> From<&'b [u8]> for Message { #[inline] fn from(data: &'b [u8]) -> Self { - Message::binary(data) + Message::binary(Bytes::copy_from_slice(data)) } } @@ -304,10 +301,10 @@ impl From> for Message { } } -impl From for Vec { +impl From for Bytes { #[inline] fn from(message: Message) -> Self { - message.into_data().as_slice().into() + message.into_data() } } @@ -351,11 +348,11 @@ mod tests { } #[test] - fn binary_convert_into_vec() { + fn binary_convert_into_bytes() { let bin = vec![6u8, 7, 8, 9, 10, 241]; let bin_copy = bin.clone(); let msg = Message::from(bin); - let serialized: Vec = msg.into(); + let serialized: Bytes = msg.into(); assert_eq!(bin_copy, serialized); } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index d057121e..6d316941 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -13,7 +13,10 @@ use self::{ }, message::{IncompleteMessage, IncompleteMessageType}, }; -use crate::error::{CapacityError, Error, ProtocolError, Result}; +use crate::{ + error::{CapacityError, Error, ProtocolError, Result}, + protocol::frame::Utf8Bytes, +}; use log::*; use std::{ io::{self, Read, Write}, @@ -586,13 +589,15 @@ impl WebSocketContext { } /// Try to decode one message frame. May return None. - fn read_message_frame(&mut self, stream: &mut Stream) -> Result> - where - Stream: Read + Write, - { - if let Some(mut frame) = self + fn read_message_frame(&mut self, stream: &mut impl Read) -> Result> { + if let Some(frame) = self .frame - .read_frame(stream, self.config.max_frame_size) + .read_frame( + stream, + self.config.max_frame_size, + matches!(self.role, Role::Server), + self.config.accept_unmasked_frames, + ) .check_connection_reset(self.state)? { if !self.state.can_read() { @@ -610,26 +615,9 @@ impl WebSocketContext { } } - match self.role { - Role::Server => { - if frame.is_masked() { - // A server MUST remove masking for data frames received from a client - // as described in Section 5.3. (RFC 6455) - frame.apply_mask(); - } else if !self.config.accept_unmasked_frames { - // The server MUST close the connection upon receiving a - // frame that is not masked. (RFC 6455) - // The only exception here is if the user explicitly accepts given - // stream by setting WebSocketConfig.accept_unmasked_frames to true - return Err(Error::Protocol(ProtocolError::UnmaskedFrameFromClient)); - } - } - Role::Client => { - if frame.is_masked() { - // A client MUST close a connection if it detects a masked frame. (RFC 6455) - return Err(Error::Protocol(ProtocolError::MaskedFrameFromServer)); - } - } + if self.role == Role::Client && frame.is_masked() { + // A client MUST close a connection if it detects a masked frame. (RFC 6455) + return Err(Error::Protocol(ProtocolError::MaskedFrameFromServer)); } match frame.header().opcode { @@ -648,10 +636,10 @@ impl WebSocketContext { Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i))) } OpCtl::Ping => { - let mut data = frame.into_payload(); + let data = frame.into_payload(); // No ping processing after we sent a close frame. if self.state.is_active() { - self.set_additional(Frame::pong(data.share())); + self.set_additional(Frame::pong(data.clone())); } Ok(Some(Message::Ping(data))) } @@ -718,7 +706,7 @@ impl WebSocketContext { /// Received a close frame. Tells if we need to return a close frame to the user. #[allow(clippy::option_option)] - fn do_close<'t>(&mut self, close: Option>) -> Option>> { + fn do_close(&mut self, close: Option) -> Option> { debug!("Received close frame: {close:?}"); match self.state { WebSocketState::Active => { @@ -728,7 +716,7 @@ impl WebSocketContext { if !frame.code.is_allowed() { CloseFrame { code: CloseCode::Protocol, - reason: "Protocol violation".into(), + reason: Utf8Bytes::from_static("Protocol violation"), } } else { frame diff --git a/tests/client_headers.rs b/tests/client_headers.rs index f943f8ae..2f9693b4 100644 --- a/tests/client_headers.rs +++ b/tests/client_headers.rs @@ -80,7 +80,7 @@ fn test_headers() { // This read should succeed even though we already initiated a close let message = client_handler.read().unwrap(); - assert_eq!(message.into_data(), b"Hello WebSocket"); + assert_eq!(message.into_data(), b"Hello WebSocket"[..]); assert!(client_handler.read().unwrap().is_close()); // receive acknowledgement diff --git a/tests/connection_reset.rs b/tests/connection_reset.rs index 8c55d7e6..5357eedc 100644 --- a/tests/connection_reset.rs +++ b/tests/connection_reset.rs @@ -64,7 +64,7 @@ fn test_server_close() { }, |mut srv_sock| { let message = srv_sock.read().unwrap(); - assert_eq!(message.into_data(), b"Hello WebSocket"); + assert_eq!(message.into_data(), b"Hello WebSocket"[..]); srv_sock.close(None).unwrap(); // send close to client @@ -100,7 +100,7 @@ fn test_evil_server_close() { }, |mut srv_sock| { let message = srv_sock.read().unwrap(); - assert_eq!(message.into_data(), b"Hello WebSocket"); + assert_eq!(message.into_data(), b"Hello WebSocket"[..]); srv_sock.close(None).unwrap(); // send close to client @@ -121,7 +121,7 @@ fn test_client_close() { cli_sock.send(Message::Text("Hello WebSocket".into())).unwrap(); let message = cli_sock.read().unwrap(); // receive answer from server - assert_eq!(message.into_data(), b"From Server"); + assert_eq!(message.into_data(), b"From Server"[..]); cli_sock.close(None).unwrap(); // send close to server @@ -136,7 +136,7 @@ fn test_client_close() { }, |mut srv_sock| { let message = srv_sock.read().unwrap(); - assert_eq!(message.into_data(), b"Hello WebSocket"); + assert_eq!(message.into_data(), b"Hello WebSocket"[..]); srv_sock.send(Message::Text("From Server".into())).unwrap(); diff --git a/tests/receive_after_init_close.rs b/tests/receive_after_init_close.rs index 3dd0fb02..dfc1db4c 100644 --- a/tests/receive_after_init_close.rs +++ b/tests/receive_after_init_close.rs @@ -46,7 +46,7 @@ fn test_receive_after_init_close() { // This read should succeed even though we already initiated a close let message = client_handler.read().unwrap(); - assert_eq!(message.into_data(), b"Hello WebSocket"); + assert_eq!(message.into_data(), b"Hello WebSocket"[..]); assert!(client_handler.read().unwrap().is_close()); // receive acknowledgement