From b2449d788be10239dcb465407a2858c70269577e Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Tue, 17 Dec 2024 10:58:55 +0000 Subject: [PATCH 01/10] Rework Payload to only use Bytes --- src/protocol/frame/frame.rs | 14 ++--- src/protocol/frame/mod.rs | 2 +- src/protocol/frame/payload.rs | 96 +++++++++++++++++++---------------- 3 files changed, 60 insertions(+), 52 deletions(-) diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 11cfe7fd..c23092c5 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -249,11 +249,11 @@ impl Frame { self.payload.as_slice() } - /// Get a mutable reference to the frame's payload. - #[inline] - pub fn payload_mut(&mut self) -> &mut [u8] { - self.payload.as_mut_slice() - } + // /// Get a mutable reference to the frame's payload. + // #[inline] + // pub fn payload_mut(&mut self) -> &mut [u8] { + // self.payload.as_mut_slice() + // } /// Test whether the frame is masked. #[inline] @@ -275,7 +275,7 @@ impl Frame { #[inline] pub(crate) fn apply_mask(&mut self) { if let Some(mask) = self.header.mask.take() { - apply_mask(self.payload.as_mut_slice(), mask); + self.payload.mutate(|data| apply_mask(data, mask)); } } @@ -359,7 +359,7 @@ impl Frame { <_>::default() }; - Frame { header: FrameHeader::default(), payload: Payload::Owned(payload) } + Frame { header: FrameHeader::default(), payload: payload.into() } } /// Create a frame from given header and data. diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index fd3b4524..87bec71c 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -207,7 +207,7 @@ impl FrameCodec { let (header, length) = self.header.take().expect("Bug: no frame header"); debug_assert_eq!(payload.len() as u64, length); - let frame = Frame::from_payload(header, Payload::Owned(payload)); + let frame = Frame::from_payload(header, payload.into()); trace!("received frame {frame}"); Ok(Some(frame)) } diff --git a/src/protocol/frame/payload.rs b/src/protocol/frame/payload.rs index 5c70e5f0..a1b85ee3 100644 --- a/src/protocol/frame/payload.rs +++ b/src/protocol/frame/payload.rs @@ -1,6 +1,6 @@ use bytes::{Bytes, BytesMut}; use core::str; -use std::{fmt::Display, mem}; +use std::fmt::Display; /// Utf8 payload. #[derive(Debug, Default, Clone, Eq, PartialEq)] @@ -92,7 +92,7 @@ impl From for Utf8Payload { impl From<&str> for Utf8Payload { #[inline] fn from(s: &str) -> Self { - Self(Payload::Owned(s.as_bytes().into())) + Self(s.as_bytes().into()) } } @@ -151,12 +151,12 @@ where /// A payload of a WebSocket frame. #[derive(Debug, Clone)] pub enum Payload { - /// Owned data with unique ownership. - Owned(BytesMut), + // /// Owned data with unique ownership. + // Owned(BytesMut), /// Shared data with shared ownership. Shared(Bytes), - /// Owned vec data. - Vec(Vec), + // /// Owned vec data. + // Vec(Vec), } impl Payload { @@ -166,17 +166,25 @@ impl Payload { Self::Shared(Bytes::from_static(bytes)) } + #[inline] + pub(crate) fn mutate(&mut self, f: impl FnOnce(&mut [u8])) { + let Self::Shared(bytes) = self; + let mut bytes_mut = BytesMut::from(std::mem::take(bytes)); + f(&mut bytes_mut); + *bytes = bytes_mut.freeze(); + } + /// Converts into [`Bytes`] internals & then clones (cheaply). pub fn share(&mut self) -> Self { - match self { - Self::Owned(data) => { - *self = Self::Shared(mem::take(data).freeze()); - } - Self::Vec(data) => { - *self = Self::Shared(Bytes::from(mem::take(data))); - } - Self::Shared(_) => {} - } + // match self { + // Self::Owned(data) => { + // *self = Self::Shared(mem::take(data).freeze()); + // } + // // Self::Vec(data) => { + // // *self = Self::Shared(Bytes::from(mem::take(data))); + // // } + // Self::Shared(_) => {} + // } self.clone() } @@ -184,35 +192,35 @@ impl Payload { #[inline] pub fn as_slice(&self) -> &[u8] { match self { - Payload::Owned(v) => v, + // Payload::Owned(v) => v, Payload::Shared(v) => v, - Payload::Vec(v) => v, + // Payload::Vec(v) => v, } } - /// Returns a mutable slice of the payload. - /// - /// Note that this will internally allocate if the payload is shared - /// and there are other references to the same data. No allocation - /// would happen if the payload is owned or if there is only one - /// `Bytes` instance referencing the data. - #[inline] - pub fn as_mut_slice(&mut self) -> &mut [u8] { - match self { - Payload::Owned(v) => &mut *v, - Payload::Vec(v) => &mut *v, - Payload::Shared(v) => { - // Using `Bytes::to_vec()` or `Vec::from(bytes.as_ref())` would mean making a copy. - // `Bytes::into()` would not make a copy if our `Bytes` instance is the only one. - let data = mem::take(v).into(); - *self = Payload::Owned(data); - match self { - Payload::Owned(v) => v, - _ => unreachable!(), - } - } - } - } + // /// Returns a mutable slice of the payload. + // /// + // /// Note that this will internally allocate if the payload is shared + // /// and there are other references to the same data. No allocation + // /// would happen if the payload is owned or if there is only one + // /// `Bytes` instance referencing the data. + // #[inline] + // pub fn as_mut_slice(&mut self) -> &mut [u8] { + // match self { + // Payload::Owned(v) => &mut *v, + // // Payload::Vec(v) => &mut *v, + // Payload::Shared(v) => { + // // Using `Bytes::to_vec()` or `Vec::from(bytes.as_ref())` would mean making a copy. + // // `Bytes::into()` would not make a copy if our `Bytes` instance is the only one. + // let data = mem::take(v).into(); + // *self = Payload::Owned(data); + // match self { + // Payload::Owned(v) => v, + // _ => unreachable!(), + // } + // } + // } + // } /// Returns the length of the payload. #[inline] @@ -236,14 +244,14 @@ impl Payload { impl Default for Payload { #[inline] fn default() -> Self { - Self::Owned(<_>::default()) + Self::Shared(<_>::default()) } } impl From> for Payload { #[inline] fn from(v: Vec) -> Self { - Payload::Vec(v) + Payload::Shared(Bytes::from(v)) } } @@ -264,14 +272,14 @@ impl From for Payload { impl From for Payload { #[inline] fn from(v: BytesMut) -> Self { - Payload::Owned(v) + Payload::Shared(v.freeze()) } } impl From<&[u8]> for Payload { #[inline] fn from(v: &[u8]) -> Self { - Self::Owned(v.into()) + Self::Shared(Bytes::copy_from_slice(v)) } } From c117f1a544c2472d1b9ab505518dba079e5c98ca Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Tue, 17 Dec 2024 14:41:07 +0000 Subject: [PATCH 02/10] Use Bytes directly in Message --- benches/read.rs | 2 +- src/protocol/frame/frame.rs | 37 ++-- src/protocol/frame/mod.rs | 38 +++- src/protocol/frame/payload.rs | 314 ------------------------------ src/protocol/frame/utf8.rs | 94 +++++++++ src/protocol/message.rs | 41 ++-- src/protocol/mod.rs | 41 ++-- tests/client_headers.rs | 2 +- tests/connection_reset.rs | 8 +- tests/receive_after_init_close.rs | 2 +- 10 files changed, 183 insertions(+), 396 deletions(-) delete mode 100644 src/protocol/frame/payload.rs create mode 100644 src/protocol/frame/utf8.rs diff --git a/benches/read.rs b/benches/read.rs index 252ef3e3..75b405bf 100644 --- a/benches/read.rs +++ b/benches/read.rs @@ -64,7 +64,7 @@ fn benchmark(c: &mut Criterion) { while sum != expected_sum { match ws.read().unwrap() { Message::Binary(v) => { - let a: &[u8; 8] = v.as_slice().try_into().unwrap(); + let a: &[u8; 8] = v.as_ref().try_into().unwrap(); sum += u64::from_le_bytes(*a); } Message::Text(msg) => { diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index c23092c5..6875cf4b 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -14,13 +14,12 @@ use log::*; use super::{ coding::{CloseCode, Control, Data, OpCode}, mask::{apply_mask, generate_mask}, - Payload, }; use crate::{ error::{Error, ProtocolError, Result}, - protocol::frame::Utf8Payload, + protocol::frame::Utf8Bytes, }; -use bytes::{Buf, BytesMut}; +use bytes::{Buf, Bytes, BytesMut}; /// A struct representing the close command. #[derive(Debug, Clone, Eq, PartialEq)] @@ -213,7 +212,7 @@ impl FrameHeader { #[derive(Debug, Clone, Eq, PartialEq)] pub struct Frame { header: FrameHeader, - payload: Payload, + payload: Bytes, } impl Frame { @@ -246,7 +245,7 @@ impl Frame { /// Get a reference to the frame's payload. #[inline] pub fn payload(&self) -> &[u8] { - self.payload.as_slice() + &self.payload } // /// Get a mutable reference to the frame's payload. @@ -275,26 +274,28 @@ impl Frame { #[inline] pub(crate) fn apply_mask(&mut self) { if let Some(mask) = self.header.mask.take() { - self.payload.mutate(|data| apply_mask(data, mask)); + let mut bytes_mut = BytesMut::from(std::mem::take(&mut self.payload)); + apply_mask(&mut bytes_mut, mask); + self.payload = bytes_mut.freeze(); } } /// Consume the frame into its payload as string. #[inline] - pub fn into_text(self) -> StdResult { - self.payload.into_text() + pub fn into_text(self) -> StdResult { + self.payload.try_into() } /// Consume the frame into its payload. #[inline] - pub fn into_payload(self) -> Payload { + pub fn into_payload(self) -> Bytes { self.payload } /// Get frame payload as `&str`. #[inline] pub fn to_text(&self) -> Result<&str, Utf8Error> { - std::str::from_utf8(self.payload.as_slice()) + std::str::from_utf8(&self.payload) } /// Consume the frame into a closing frame. @@ -304,7 +305,7 @@ impl Frame { 0 => Ok(None), 1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)), _ => { - let mut data = self.payload.as_slice(); + let mut data = BytesMut::from(self.payload); let code = u16::from_be_bytes([data[0], data[1]]).into(); data.advance(2); let text = String::from_utf8(data.to_vec())?; @@ -315,7 +316,7 @@ impl Frame { /// Create a new data frame. #[inline] - pub fn message(data: impl Into, opcode: OpCode, is_final: bool) -> Frame { + pub fn message(data: impl Into, opcode: OpCode, is_final: bool) -> Frame { debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame."); Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, @@ -325,7 +326,7 @@ impl Frame { /// Create a new Pong control frame. #[inline] - pub fn pong(data: impl Into) -> Frame { + pub fn pong(data: impl Into) -> Frame { Frame { header: FrameHeader { opcode: OpCode::Control(Control::Pong), @@ -337,7 +338,7 @@ impl Frame { /// Create a new Ping control frame. #[inline] - pub fn ping(data: impl Into) -> Frame { + pub fn ping(data: impl Into) -> Frame { Frame { header: FrameHeader { opcode: OpCode::Control(Control::Ping), @@ -363,7 +364,7 @@ impl Frame { } /// Create a frame from given header and data. - pub fn from_payload(header: FrameHeader, payload: Payload) -> Self { + pub fn from_payload(header: FrameHeader, payload: Bytes) -> Self { Frame { header, payload } } @@ -399,7 +400,7 @@ payload: 0x{} // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), self.len(), self.payload.len(), - self.payload.as_slice().iter().fold(String::new(), |mut output, byte| { + self.payload.iter().fold(String::new(), |mut output, byte| { _ = write!(output, "{byte:02x}"); output }) @@ -474,7 +475,7 @@ mod tests { let mut payload = Vec::new(); raw.read_to_end(&mut payload).unwrap(); let frame = Frame::from_payload(header, payload.into()); - assert_eq!(frame.into_payload(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); + assert_eq!(frame.into_payload(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07][..]); } #[test] @@ -487,7 +488,7 @@ mod tests { #[test] fn display() { - let f = Frame::message(Payload::from_static(b"hi there"), OpCode::Data(Data::Text), true); + let f = Frame::message(Bytes::from_static(b"hi there"), OpCode::Data(Data::Text), true); let view = format!("{f}"); assert!(view.contains("payload:")); } diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index 87bec71c..d13161b2 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -5,15 +5,16 @@ pub mod coding; #[allow(clippy::module_inception)] mod frame; mod mask; -mod payload; +mod utf8; pub use self::{ frame::{CloseFrame, Frame, FrameHeader}, - payload::{Payload, Utf8Payload}, + utf8::Utf8Bytes, }; use crate::{ - error::{CapacityError, Error, Result}, + error::{CapacityError, Error, ProtocolError, Result}, + protocol::frame::mask::apply_mask, Message, }; use bytes::BytesMut; @@ -65,7 +66,7 @@ where { /// Read a frame from stream. pub fn read(&mut self, max_size: Option) -> Result> { - self.codec.read_frame(&mut self.stream, max_size) + self.codec.read_frame(&mut self.stream, max_size, false, true) } } @@ -158,13 +159,15 @@ impl FrameCodec { &mut self, stream: &mut Stream, max_size: Option, + unmask: bool, + accept_unmasked: bool, ) -> Result> where Stream: Read, { let max_size = max_size.unwrap_or_else(usize::max_value); - let payload = loop { + let mut payload = loop { { if self.header.is_none() { let mut cursor = Cursor::new(&mut self.in_buffer); @@ -205,9 +208,24 @@ impl FrameCodec { } }; - let (header, length) = self.header.take().expect("Bug: no frame header"); + let (mut header, length) = self.header.take().expect("Bug: no frame header"); debug_assert_eq!(payload.len() as u64, length); - let frame = Frame::from_payload(header, payload.into()); + + if unmask { + if let Some(mask) = header.mask.take() { + // A server MUST remove masking for data frames received from a client + // as described in Section 5.3. (RFC 6455) + apply_mask(&mut payload, mask); + } else if !accept_unmasked { + // The server MUST close the connection upon receiving a + // frame that is not masked. (RFC 6455) + // The only exception here is if the user explicitly accepts given + // stream by setting WebSocketConfig.accept_unmasked_frames to true + return Err(Error::Protocol(ProtocolError::UnmaskedFrameFromClient)); + } + } + + let frame = Frame::from_payload(header, payload.freeze()); trace!("received frame {frame}"); Ok(Some(frame)) } @@ -284,9 +302,9 @@ mod tests { assert_eq!( sock.read(None).unwrap().unwrap().into_payload(), - &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] + &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07][..] ); - assert_eq!(sock.read(None).unwrap().unwrap().into_payload(), &[0x03, 0x02, 0x01]); + assert_eq!(sock.read(None).unwrap().unwrap().into_payload(), &[0x03, 0x02, 0x01][..]); assert!(sock.read(None).unwrap().is_none()); let (_, rest) = sock.into_inner(); @@ -299,7 +317,7 @@ mod tests { let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]); assert_eq!( sock.read(None).unwrap().unwrap().into_payload(), - &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] + &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07][..] ); } diff --git a/src/protocol/frame/payload.rs b/src/protocol/frame/payload.rs deleted file mode 100644 index a1b85ee3..00000000 --- a/src/protocol/frame/payload.rs +++ /dev/null @@ -1,314 +0,0 @@ -use bytes::{Bytes, BytesMut}; -use core::str; -use std::fmt::Display; - -/// Utf8 payload. -#[derive(Debug, Default, Clone, Eq, PartialEq)] -pub struct Utf8Payload(Payload); - -impl Utf8Payload { - /// Creates from a static str. - #[inline] - pub const fn from_static(str: &'static str) -> Self { - Self(Payload::Shared(Bytes::from_static(str.as_bytes()))) - } - - /// Returns a slice of the payload. - #[inline] - pub fn as_slice(&self) -> &[u8] { - self.0.as_slice() - } - - /// Returns as a string slice. - #[inline] - pub fn as_str(&self) -> &str { - // safety: is valid uft8 - unsafe { str::from_utf8_unchecked(self.as_slice()) } - } - - /// Returns length in bytes. - #[inline] - pub fn len(&self) -> usize { - self.as_slice().len() - } - - /// Returns true if the length is 0. - #[inline] - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// If owned converts into [`Bytes`] internals & then clones (cheaply). - #[inline] - pub fn share(&mut self) -> Self { - Self(self.0.share()) - } -} - -impl TryFrom for Utf8Payload { - type Error = str::Utf8Error; - - #[inline] - fn try_from(payload: Payload) -> Result { - str::from_utf8(payload.as_slice())?; - Ok(Self(payload)) - } -} - -impl TryFrom for Utf8Payload { - type Error = str::Utf8Error; - - #[inline] - fn try_from(bytes: Bytes) -> Result { - Payload::from(bytes).try_into() - } -} - -impl TryFrom for Utf8Payload { - type Error = str::Utf8Error; - - #[inline] - fn try_from(bytes: BytesMut) -> Result { - Payload::from(bytes).try_into() - } -} - -impl TryFrom> for Utf8Payload { - type Error = str::Utf8Error; - - #[inline] - fn try_from(bytes: Vec) -> Result { - Payload::from(bytes).try_into() - } -} - -impl From for Utf8Payload { - #[inline] - fn from(s: String) -> Self { - Self(s.into()) - } -} - -impl From<&str> for Utf8Payload { - #[inline] - fn from(s: &str) -> Self { - Self(s.as_bytes().into()) - } -} - -impl From<&String> for Utf8Payload { - #[inline] - fn from(s: &String) -> Self { - s.as_str().into() - } -} - -impl From for Payload { - #[inline] - fn from(Utf8Payload(payload): Utf8Payload) -> Self { - payload - } -} - -impl Display for Utf8Payload { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.as_str()) - } -} - -impl AsRef for Utf8Payload { - #[inline] - fn as_ref(&self) -> &str { - self.as_str() - } -} - -impl AsRef<[u8]> for Utf8Payload { - #[inline] - fn as_ref(&self) -> &[u8] { - self.as_slice() - } -} - -impl PartialEq for Utf8Payload -where - for<'a> &'a str: PartialEq, -{ - /// ``` - /// use tungstenite::protocol::frame::Utf8Payload; - /// let payload = Utf8Payload::from_static("foo123"); - /// assert_eq!(payload, "foo123"); - /// assert_eq!(payload, "foo123".to_string()); - /// assert_eq!(payload, &"foo123".to_string()); - /// assert_eq!(payload, std::borrow::Cow::from("foo123")); - /// ``` - #[inline] - fn eq(&self, other: &T) -> bool { - self.as_str() == *other - } -} - -/// A payload of a WebSocket frame. -#[derive(Debug, Clone)] -pub enum Payload { - // /// Owned data with unique ownership. - // Owned(BytesMut), - /// Shared data with shared ownership. - Shared(Bytes), - // /// Owned vec data. - // Vec(Vec), -} - -impl Payload { - /// Creates from static bytes. - #[inline] - pub const fn from_static(bytes: &'static [u8]) -> Self { - Self::Shared(Bytes::from_static(bytes)) - } - - #[inline] - pub(crate) fn mutate(&mut self, f: impl FnOnce(&mut [u8])) { - let Self::Shared(bytes) = self; - let mut bytes_mut = BytesMut::from(std::mem::take(bytes)); - f(&mut bytes_mut); - *bytes = bytes_mut.freeze(); - } - - /// Converts into [`Bytes`] internals & then clones (cheaply). - pub fn share(&mut self) -> Self { - // match self { - // Self::Owned(data) => { - // *self = Self::Shared(mem::take(data).freeze()); - // } - // // Self::Vec(data) => { - // // *self = Self::Shared(Bytes::from(mem::take(data))); - // // } - // Self::Shared(_) => {} - // } - self.clone() - } - - /// Returns a slice of the payload. - #[inline] - pub fn as_slice(&self) -> &[u8] { - match self { - // Payload::Owned(v) => v, - Payload::Shared(v) => v, - // Payload::Vec(v) => v, - } - } - - // /// Returns a mutable slice of the payload. - // /// - // /// Note that this will internally allocate if the payload is shared - // /// and there are other references to the same data. No allocation - // /// would happen if the payload is owned or if there is only one - // /// `Bytes` instance referencing the data. - // #[inline] - // pub fn as_mut_slice(&mut self) -> &mut [u8] { - // match self { - // Payload::Owned(v) => &mut *v, - // // Payload::Vec(v) => &mut *v, - // Payload::Shared(v) => { - // // Using `Bytes::to_vec()` or `Vec::from(bytes.as_ref())` would mean making a copy. - // // `Bytes::into()` would not make a copy if our `Bytes` instance is the only one. - // let data = mem::take(v).into(); - // *self = Payload::Owned(data); - // match self { - // Payload::Owned(v) => v, - // _ => unreachable!(), - // } - // } - // } - // } - - /// Returns the length of the payload. - #[inline] - pub fn len(&self) -> usize { - self.as_slice().len() - } - - /// Returns true if the payload has a length of 0. - #[inline] - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Consumes the payload and returns the underlying data as a string. - #[inline] - pub fn into_text(self) -> Result { - self.try_into() - } -} - -impl Default for Payload { - #[inline] - fn default() -> Self { - Self::Shared(<_>::default()) - } -} - -impl From> for Payload { - #[inline] - fn from(v: Vec) -> Self { - Payload::Shared(Bytes::from(v)) - } -} - -impl From for Payload { - #[inline] - fn from(v: String) -> Self { - v.into_bytes().into() - } -} - -impl From for Payload { - #[inline] - fn from(v: Bytes) -> Self { - Payload::Shared(v) - } -} - -impl From for Payload { - #[inline] - fn from(v: BytesMut) -> Self { - Payload::Shared(v.freeze()) - } -} - -impl From<&[u8]> for Payload { - #[inline] - fn from(v: &[u8]) -> Self { - Self::Shared(Bytes::copy_from_slice(v)) - } -} - -impl PartialEq for Payload { - #[inline] - fn eq(&self, other: &Payload) -> bool { - self.as_slice() == other.as_slice() - } -} - -impl Eq for Payload {} - -impl PartialEq<[u8]> for Payload { - #[inline] - fn eq(&self, other: &[u8]) -> bool { - self.as_slice() == other - } -} - -impl PartialEq<&[u8; N]> for Payload { - #[inline] - fn eq(&self, other: &&[u8; N]) -> bool { - self.as_slice() == *other - } -} - -impl AsRef<[u8]> for Payload { - #[inline] - fn as_ref(&self) -> &[u8] { - self.as_slice() - } -} diff --git a/src/protocol/frame/utf8.rs b/src/protocol/frame/utf8.rs new file mode 100644 index 00000000..9166d979 --- /dev/null +++ b/src/protocol/frame/utf8.rs @@ -0,0 +1,94 @@ +use bytes::{Bytes, BytesMut}; +use core::str; +use std::fmt::Display; + +/// Utf8 payload. +#[derive(Debug, Default, Clone, Eq, PartialEq)] +pub struct Utf8Bytes(Bytes); + +impl Utf8Bytes { + /// Creates from a static str. + #[inline] + pub const fn from_static(str: &'static str) -> Self { + Self(Bytes::from_static(str.as_bytes())) + } + + /// Returns as a string slice. + #[inline] + pub fn as_str(&self) -> &str { + // safety: is valid uft8 + unsafe { str::from_utf8_unchecked(&self.0) } + } +} + +impl std::ops::Deref for Utf8Bytes { + type Target = str; + + #[inline] + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + +impl Display for Utf8Bytes { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +impl TryFrom for Utf8Bytes { + type Error = str::Utf8Error; + + #[inline] + fn try_from(bytes: Bytes) -> Result { + str::from_utf8(&bytes)?; + Ok(Self(bytes)) + } +} + +impl TryFrom for Utf8Bytes { + type Error = str::Utf8Error; + + #[inline] + fn try_from(bytes: BytesMut) -> Result { + bytes.freeze().try_into() + } +} + +impl TryFrom> for Utf8Bytes { + type Error = str::Utf8Error; + + #[inline] + fn try_from(v: Vec) -> Result { + Bytes::from(v).try_into() + } +} + +impl From for Utf8Bytes { + #[inline] + fn from(s: String) -> Self { + Self(s.into()) + } +} + +impl From<&str> for Utf8Bytes { + #[inline] + fn from(s: &str) -> Self { + Self(Bytes::copy_from_slice(s.as_bytes())) + } +} + +impl From<&String> for Utf8Bytes { + #[inline] + fn from(s: &String) -> Self { + s.as_str().into() + } +} + +impl From for Bytes { + #[inline] + fn from(Utf8Bytes(bytes): Utf8Bytes) -> Self { + bytes + } +} diff --git a/src/protocol/message.rs b/src/protocol/message.rs index ee098c09..f43e8c62 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -1,8 +1,10 @@ +use super::frame::{CloseFrame, Frame}; +use crate::{ + error::{CapacityError, Error, Result}, + protocol::frame::Utf8Bytes, +}; use std::{borrow::Cow, fmt, result::Result as StdResult, str}; -use super::frame::{CloseFrame, Frame, Payload, Utf8Payload}; -use crate::error::{CapacityError, Error, Result}; - mod string_collect { use utf8::DecodeError; @@ -74,6 +76,7 @@ mod string_collect { } use self::string_collect::StringCollector; +use bytes::Bytes; /// A struct representing the incomplete message. #[derive(Debug)] @@ -154,17 +157,17 @@ pub enum IncompleteMessageType { #[derive(Debug, Eq, PartialEq, Clone)] pub enum Message { /// A text WebSocket message - Text(Utf8Payload), + Text(Utf8Bytes), /// A binary WebSocket message - Binary(Payload), + Binary(Bytes), /// A ping message with the specified payload /// /// The payload here must have a length less than 125 bytes - Ping(Payload), + Ping(Bytes), /// A pong message with the specified payload /// /// The payload here must have a length less than 125 bytes - Pong(Payload), + Pong(Bytes), /// A close message with the optional close frame. Close(Option>), /// Raw frame. Note, that you're not going to get this value while reading the message. @@ -175,7 +178,7 @@ impl Message { /// Create a new text WebSocket message from a stringable. pub fn text(string: S) -> Message where - S: Into, + S: Into, { Message::Text(string.into()) } @@ -183,7 +186,7 @@ impl Message { /// Create a new binary WebSocket message by converting to `Vec`. pub fn binary(bin: B) -> Message where - B: Into, + B: Into, { Message::Binary(bin.into()) } @@ -232,13 +235,13 @@ impl Message { } /// Consume the WebSocket and return it as binary data. - pub fn into_data(self) -> Payload { + pub fn into_data(self) -> Bytes { match self { Message::Text(string) => string.into(), Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data, Message::Close(None) => <_>::default(), Message::Close(Some(frame)) => match frame.reason { - Cow::Borrowed(s) => Payload::from_static(s.as_bytes()), + Cow::Borrowed(s) => Bytes::from_static(s.as_bytes()), Cow::Owned(s) => s.into(), }, Message::Frame(frame) => frame.into_payload(), @@ -246,7 +249,7 @@ impl Message { } /// Attempt to consume the WebSocket message and convert it to a String. - pub fn into_text(self) -> Result { + pub fn into_text(self) -> Result { match self { Message::Text(txt) => Ok(txt), Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => { @@ -254,7 +257,7 @@ impl Message { } Message::Close(None) => Ok(<_>::default()), Message::Close(Some(frame)) => Ok(match frame.reason { - Cow::Borrowed(s) => Utf8Payload::from_static(s), + Cow::Borrowed(s) => Utf8Bytes::from_static(s), Cow::Owned(s) => s.into(), }), Message::Frame(frame) => Ok(frame.into_text()?), @@ -267,7 +270,7 @@ impl Message { match *self { Message::Text(ref string) => Ok(string.as_str()), Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => { - Ok(str::from_utf8(data.as_slice())?) + Ok(str::from_utf8(data)?) } Message::Close(None) => Ok(""), Message::Close(Some(ref frame)) => Ok(&frame.reason), @@ -293,7 +296,7 @@ impl<'s> From<&'s str> for Message { impl<'b> From<&'b [u8]> for Message { #[inline] fn from(data: &'b [u8]) -> Self { - Message::binary(data) + Message::binary(Bytes::copy_from_slice(data)) } } @@ -304,10 +307,10 @@ impl From> for Message { } } -impl From for Vec { +impl From for Bytes { #[inline] fn from(message: Message) -> Self { - message.into_data().as_slice().into() + message.into_data() } } @@ -351,11 +354,11 @@ mod tests { } #[test] - fn binary_convert_into_vec() { + fn binary_convert_into_bytes() { let bin = vec![6u8, 7, 8, 9, 10, 241]; let bin_copy = bin.clone(); let msg = Message::from(bin); - let serialized: Vec = msg.into(); + let serialized: Bytes = msg.into(); assert_eq!(bin_copy, serialized); } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index d057121e..28d9790d 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -586,13 +586,15 @@ impl WebSocketContext { } /// Try to decode one message frame. May return None. - fn read_message_frame(&mut self, stream: &mut Stream) -> Result> - where - Stream: Read + Write, - { - if let Some(mut frame) = self + fn read_message_frame(&mut self, stream: &mut impl Read) -> Result> { + if let Some(frame) = self .frame - .read_frame(stream, self.config.max_frame_size) + .read_frame( + stream, + self.config.max_frame_size, + matches!(self.role, Role::Server), + self.config.accept_unmasked_frames, + ) .check_connection_reset(self.state)? { if !self.state.can_read() { @@ -610,26 +612,9 @@ impl WebSocketContext { } } - match self.role { - Role::Server => { - if frame.is_masked() { - // A server MUST remove masking for data frames received from a client - // as described in Section 5.3. (RFC 6455) - frame.apply_mask(); - } else if !self.config.accept_unmasked_frames { - // The server MUST close the connection upon receiving a - // frame that is not masked. (RFC 6455) - // The only exception here is if the user explicitly accepts given - // stream by setting WebSocketConfig.accept_unmasked_frames to true - return Err(Error::Protocol(ProtocolError::UnmaskedFrameFromClient)); - } - } - Role::Client => { - if frame.is_masked() { - // A client MUST close a connection if it detects a masked frame. (RFC 6455) - return Err(Error::Protocol(ProtocolError::MaskedFrameFromServer)); - } - } + if self.role == Role::Client && frame.is_masked() { + // A client MUST close a connection if it detects a masked frame. (RFC 6455) + return Err(Error::Protocol(ProtocolError::MaskedFrameFromServer)); } match frame.header().opcode { @@ -648,10 +633,10 @@ impl WebSocketContext { Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i))) } OpCtl::Ping => { - let mut data = frame.into_payload(); + let data = frame.into_payload(); // No ping processing after we sent a close frame. if self.state.is_active() { - self.set_additional(Frame::pong(data.share())); + self.set_additional(Frame::pong(data.clone())); } Ok(Some(Message::Ping(data))) } diff --git a/tests/client_headers.rs b/tests/client_headers.rs index f943f8ae..2f9693b4 100644 --- a/tests/client_headers.rs +++ b/tests/client_headers.rs @@ -80,7 +80,7 @@ fn test_headers() { // This read should succeed even though we already initiated a close let message = client_handler.read().unwrap(); - assert_eq!(message.into_data(), b"Hello WebSocket"); + assert_eq!(message.into_data(), b"Hello WebSocket"[..]); assert!(client_handler.read().unwrap().is_close()); // receive acknowledgement diff --git a/tests/connection_reset.rs b/tests/connection_reset.rs index 8c55d7e6..5357eedc 100644 --- a/tests/connection_reset.rs +++ b/tests/connection_reset.rs @@ -64,7 +64,7 @@ fn test_server_close() { }, |mut srv_sock| { let message = srv_sock.read().unwrap(); - assert_eq!(message.into_data(), b"Hello WebSocket"); + assert_eq!(message.into_data(), b"Hello WebSocket"[..]); srv_sock.close(None).unwrap(); // send close to client @@ -100,7 +100,7 @@ fn test_evil_server_close() { }, |mut srv_sock| { let message = srv_sock.read().unwrap(); - assert_eq!(message.into_data(), b"Hello WebSocket"); + assert_eq!(message.into_data(), b"Hello WebSocket"[..]); srv_sock.close(None).unwrap(); // send close to client @@ -121,7 +121,7 @@ fn test_client_close() { cli_sock.send(Message::Text("Hello WebSocket".into())).unwrap(); let message = cli_sock.read().unwrap(); // receive answer from server - assert_eq!(message.into_data(), b"From Server"); + assert_eq!(message.into_data(), b"From Server"[..]); cli_sock.close(None).unwrap(); // send close to server @@ -136,7 +136,7 @@ fn test_client_close() { }, |mut srv_sock| { let message = srv_sock.read().unwrap(); - assert_eq!(message.into_data(), b"Hello WebSocket"); + assert_eq!(message.into_data(), b"Hello WebSocket"[..]); srv_sock.send(Message::Text("From Server".into())).unwrap(); diff --git a/tests/receive_after_init_close.rs b/tests/receive_after_init_close.rs index 3dd0fb02..dfc1db4c 100644 --- a/tests/receive_after_init_close.rs +++ b/tests/receive_after_init_close.rs @@ -46,7 +46,7 @@ fn test_receive_after_init_close() { // This read should succeed even though we already initiated a close let message = client_handler.read().unwrap(); - assert_eq!(message.into_data(), b"Hello WebSocket"); + assert_eq!(message.into_data(), b"Hello WebSocket"[..]); assert!(client_handler.read().unwrap().is_close()); // receive acknowledgement From 487bd5c8ac957d2eb1279add7ce130f4265650d7 Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Tue, 17 Dec 2024 15:01:34 +0000 Subject: [PATCH 03/10] apply_mask use vec as mut type --- src/protocol/frame/frame.rs | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 6875cf4b..71ab41d6 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -3,6 +3,7 @@ use std::{ default::Default, fmt, io::{Cursor, ErrorKind, Read, Write}, + mem, result::Result as StdResult, str::Utf8Error, string::String, @@ -248,12 +249,6 @@ impl Frame { &self.payload } - // /// Get a mutable reference to the frame's payload. - // #[inline] - // pub fn payload_mut(&mut self) -> &mut [u8] { - // self.payload.as_mut_slice() - // } - /// Test whether the frame is masked. #[inline] pub(crate) fn is_masked(&self) -> bool { @@ -274,9 +269,9 @@ impl Frame { #[inline] pub(crate) fn apply_mask(&mut self) { if let Some(mask) = self.header.mask.take() { - let mut bytes_mut = BytesMut::from(std::mem::take(&mut self.payload)); - apply_mask(&mut bytes_mut, mask); - self.payload = bytes_mut.freeze(); + let mut data = Vec::from(mem::take(&mut self.payload)); + apply_mask(&mut data, mask); + self.payload = data.into(); } } From 4cad6a798d038a7d1bccb9474b177f2aeff063fd Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Tue, 17 Dec 2024 15:26:57 +0000 Subject: [PATCH 04/10] Add specialised format_into_buf --- src/protocol/frame/frame.rs | 43 ++++++++++++++++++++++++++----------- src/protocol/frame/mod.rs | 2 +- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 71ab41d6..d35a4d3f 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -264,17 +264,6 @@ impl Frame { self.header.set_random_mask(); } - /// This method unmasks the payload and should only be called on frames that are actually - /// masked. In other words, those frames that have just been received from a client endpoint. - #[inline] - pub(crate) fn apply_mask(&mut self) { - if let Some(mask) = self.header.mask.take() { - let mut data = Vec::from(mem::take(&mut self.payload)); - apply_mask(&mut data, mask); - self.payload = data.into(); - } - } - /// Consume the frame into its payload as string. #[inline] pub fn into_text(self) -> StdResult { @@ -366,8 +355,28 @@ impl Frame { /// Write a frame out to a buffer pub fn format(mut self, output: &mut impl Write) -> Result<()> { self.header.format(self.payload.len() as u64, output)?; - self.apply_mask(); - output.write_all(self.payload())?; + + if let Some(mask) = self.header.mask.take() { + let mut data = Vec::from(mem::take(&mut self.payload)); + apply_mask(&mut data, mask); + output.write_all(&data)?; + } else { + output.write_all(&self.payload)?; + } + + Ok(()) + } + + pub(crate) fn format_into_buf(mut self, buf: &mut Vec) -> Result<()> { + self.header.format(self.payload.len() as u64, buf)?; + + let len = buf.len(); + buf.extend_from_slice(&self.payload); + + if let Some(mask) = self.header.mask.take() { + apply_mask(&mut buf[len..], mask); + } + Ok(()) } } @@ -481,6 +490,14 @@ mod tests { assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]); } + #[test] + fn format_into_buf() { + let frame = Frame::ping(vec![0x01, 0x02]); + let mut buf = Vec::with_capacity(frame.len()); + frame.format_into_buf(&mut buf).unwrap(); + assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]); + } + #[test] fn display() { let f = Frame::message(Bytes::from_static(b"hi there"), OpCode::Data(Data::Text), true); diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index d13161b2..2d76830d 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -248,7 +248,7 @@ impl FrameCodec { trace!("writing frame {frame}"); self.out_buffer.reserve(frame.len()); - frame.format(&mut self.out_buffer).expect("Bug: can't write to vector"); + frame.format_into_buf(&mut self.out_buffer).expect("Bug: can't write to vector"); if self.out_buffer.len() > self.out_buffer_write_len { self.write_out_buffer(stream) From 46d681e0349ac3e8454930dcf4bfbf56859eb64e Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Tue, 17 Dec 2024 15:53:37 +0000 Subject: [PATCH 05/10] Change `CloseFrame` to use `Utf8Bytes` for `reason` --- CHANGELOG.md | 4 ++-- src/protocol/frame/frame.rs | 31 ++++++++++--------------------- src/protocol/message.rs | 16 +++++----------- src/protocol/mod.rs | 2 +- 4 files changed, 18 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a61ee91d..650c793c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Unreleased -- Implement `AsRef<[u8]>` for `Payload`, `AsRef<[u8]>`, `AsRef` for `Utf8Payload`. -- Implement `&str`-like `PartialEq` for `Utf8Payload`. +- Simplify `Message` to use `Bytes` payload directly with simpler `Utf8Bytes` for text. +- Change `CloseFrame` to use `Utf8Bytes` for `reason`. # 0.25.0 diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index d35a4d3f..31a4c004 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -1,5 +1,6 @@ +use byteorder::{NetworkEndian, ReadBytesExt}; +use log::*; use std::{ - borrow::Cow, default::Default, fmt, io::{Cursor, ErrorKind, Read, Write}, @@ -9,9 +10,6 @@ use std::{ string::String, }; -use byteorder::{NetworkEndian, ReadBytesExt}; -use log::*; - use super::{ coding::{CloseCode, Control, Data, OpCode}, mask::{apply_mask, generate_mask}, @@ -20,25 +18,18 @@ use crate::{ error::{Error, ProtocolError, Result}, protocol::frame::Utf8Bytes, }; -use bytes::{Buf, Bytes, BytesMut}; +use bytes::{Bytes, BytesMut}; /// A struct representing the close command. #[derive(Debug, Clone, Eq, PartialEq)] -pub struct CloseFrame<'t> { +pub struct CloseFrame { /// The reason as a code. pub code: CloseCode, /// The reason as text string. - pub reason: Cow<'t, str>, -} - -impl CloseFrame<'_> { - /// Convert into a owned string. - pub fn into_owned(self) -> CloseFrame<'static> { - CloseFrame { code: self.code, reason: self.reason.into_owned().into() } - } + pub reason: Utf8Bytes, } -impl fmt::Display for CloseFrame<'_> { +impl fmt::Display for CloseFrame { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{} ({})", self.reason, self.code) } @@ -284,16 +275,14 @@ impl Frame { /// Consume the frame into a closing frame. #[inline] - pub(crate) fn into_close(self) -> Result>> { + pub(crate) fn into_close(self) -> Result> { match self.payload.len() { 0 => Ok(None), 1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)), _ => { - let mut data = BytesMut::from(self.payload); - let code = u16::from_be_bytes([data[0], data[1]]).into(); - data.advance(2); - let text = String::from_utf8(data.to_vec())?; - Ok(Some(CloseFrame { code, reason: text.into() })) + let code = u16::from_be_bytes([self.payload[0], self.payload[1]]).into(); + let reason = Utf8Bytes::try_from(self.payload.slice(2..))?; + Ok(Some(CloseFrame { code, reason })) } } } diff --git a/src/protocol/message.rs b/src/protocol/message.rs index f43e8c62..57539e09 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -3,7 +3,7 @@ use crate::{ error::{CapacityError, Error, Result}, protocol::frame::Utf8Bytes, }; -use std::{borrow::Cow, fmt, result::Result as StdResult, str}; +use std::{fmt, result::Result as StdResult, str}; mod string_collect { use utf8::DecodeError; @@ -169,7 +169,7 @@ pub enum Message { /// The payload here must have a length less than 125 bytes Pong(Bytes), /// A close message with the optional close frame. - Close(Option>), + Close(Option), /// Raw frame. Note, that you're not going to get this value while reading the message. Frame(Frame), } @@ -237,13 +237,10 @@ impl Message { /// Consume the WebSocket and return it as binary data. pub fn into_data(self) -> Bytes { match self { - Message::Text(string) => string.into(), + Message::Text(utf8) => utf8.into(), Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data, Message::Close(None) => <_>::default(), - Message::Close(Some(frame)) => match frame.reason { - Cow::Borrowed(s) => Bytes::from_static(s.as_bytes()), - Cow::Owned(s) => s.into(), - }, + Message::Close(Some(frame)) => frame.reason.into(), Message::Frame(frame) => frame.into_payload(), } } @@ -256,10 +253,7 @@ impl Message { Ok(data.try_into()?) } Message::Close(None) => Ok(<_>::default()), - Message::Close(Some(frame)) => Ok(match frame.reason { - Cow::Borrowed(s) => Utf8Bytes::from_static(s), - Cow::Owned(s) => s.into(), - }), + Message::Close(Some(frame)) => Ok(frame.reason), Message::Frame(frame) => Ok(frame.into_text()?), } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 28d9790d..bea86d41 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -703,7 +703,7 @@ impl WebSocketContext { /// Received a close frame. Tells if we need to return a close frame to the user. #[allow(clippy::option_option)] - fn do_close<'t>(&mut self, close: Option>) -> Option>> { + fn do_close(&mut self, close: Option) -> Option> { debug!("Received close frame: {close:?}"); match self.state { WebSocketState::Active => { From 5faacf9617e33522d171c186b3bc31023c74fd99 Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Tue, 17 Dec 2024 15:58:39 +0000 Subject: [PATCH 06/10] Add Utf8Bytes str-like PartialEq impls --- src/protocol/frame/utf8.rs | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/protocol/frame/utf8.rs b/src/protocol/frame/utf8.rs index 9166d979..cd4de58f 100644 --- a/src/protocol/frame/utf8.rs +++ b/src/protocol/frame/utf8.rs @@ -24,12 +24,44 @@ impl Utf8Bytes { impl std::ops::Deref for Utf8Bytes { type Target = str; + /// ``` + /// use tungstenite::protocol::frame::Utf8Bytes; + /// + /// /// Example fn that takes a str slice + /// fn a(s: &str) {} + /// + /// let data = Utf8Bytes::from_static("foo123"); + /// + /// // auto-deref as arg + /// a(&data); + /// + /// // deref to str methods + /// assert_eq!(data.len(), 6); + /// ``` #[inline] fn deref(&self) -> &Self::Target { self.as_str() } } +impl PartialEq for Utf8Bytes +where + for<'a> &'a str: PartialEq, +{ + /// ``` + /// use tungstenite::protocol::frame::Utf8Bytes; + /// let payload = Utf8Bytes::from_static("foo123"); + /// assert_eq!(payload, "foo123"); + /// assert_eq!(payload, "foo123".to_string()); + /// assert_eq!(payload, &"foo123".to_string()); + /// assert_eq!(payload, std::borrow::Cow::from("foo123")); + /// ``` + #[inline] + fn eq(&self, other: &T) -> bool { + self.as_str() == *other + } +} + impl Display for Utf8Bytes { #[inline] fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { From 924bf30e4f5f43e867c54ed8f4975d09356d5920 Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Tue, 17 Dec 2024 16:01:12 +0000 Subject: [PATCH 07/10] cargo fmt --- src/protocol/frame/utf8.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/protocol/frame/utf8.rs b/src/protocol/frame/utf8.rs index cd4de58f..c88f78ae 100644 --- a/src/protocol/frame/utf8.rs +++ b/src/protocol/frame/utf8.rs @@ -26,15 +26,15 @@ impl std::ops::Deref for Utf8Bytes { /// ``` /// use tungstenite::protocol::frame::Utf8Bytes; - /// + /// /// /// Example fn that takes a str slice /// fn a(s: &str) {} - /// + /// /// let data = Utf8Bytes::from_static("foo123"); - /// + /// /// // auto-deref as arg /// a(&data); - /// + /// /// // deref to str methods /// assert_eq!(data.len(), 6); /// ``` From 9f94ff10f79f350900f417fddac1d6209a1e1fe2 Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Tue, 17 Dec 2024 16:01:51 +0000 Subject: [PATCH 08/10] always shout about SAFETY --- src/protocol/frame/utf8.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/protocol/frame/utf8.rs b/src/protocol/frame/utf8.rs index c88f78ae..8c0ff2d3 100644 --- a/src/protocol/frame/utf8.rs +++ b/src/protocol/frame/utf8.rs @@ -16,7 +16,7 @@ impl Utf8Bytes { /// Returns as a string slice. #[inline] pub fn as_str(&self) -> &str { - // safety: is valid uft8 + // SAFETY: is valid uft8 unsafe { str::from_utf8_unchecked(&self.0) } } } From 81f2b33a59d5241f745b2df441c746941c3251bc Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Tue, 17 Dec 2024 16:17:42 +0000 Subject: [PATCH 09/10] re-export Utf8Bytes, Bytes at root --- src/lib.rs | 4 +++- src/protocol/mod.rs | 7 +++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 6b79d5be..be169d40 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,8 +37,10 @@ type ReadBuffer = buffer::ReadBuffer; pub use crate::{ error::{Error, Result}, - protocol::{Message, WebSocket}, + protocol::{frame::Utf8Bytes, Message, WebSocket}, }; +// re-export bytes since used in `Message` API. +pub use bytes::Bytes; #[cfg(feature = "handshake")] pub use crate::{ diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index bea86d41..6d316941 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -13,7 +13,10 @@ use self::{ }, message::{IncompleteMessage, IncompleteMessageType}, }; -use crate::error::{CapacityError, Error, ProtocolError, Result}; +use crate::{ + error::{CapacityError, Error, ProtocolError, Result}, + protocol::frame::Utf8Bytes, +}; use log::*; use std::{ io::{self, Read, Write}, @@ -713,7 +716,7 @@ impl WebSocketContext { if !frame.code.is_allowed() { CloseFrame { code: CloseCode::Protocol, - reason: "Protocol violation".into(), + reason: Utf8Bytes::from_static("Protocol violation"), } } else { frame From ef51ddd7a80296b8f74bc8c008a974caccc2711f Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Tue, 17 Dec 2024 16:20:20 +0000 Subject: [PATCH 10/10] Use shorter import in doc tests --- CHANGELOG.md | 1 + src/protocol/frame/utf8.rs | 7 ++----- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 650c793c..3c17e4f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # Unreleased - Simplify `Message` to use `Bytes` payload directly with simpler `Utf8Bytes` for text. - Change `CloseFrame` to use `Utf8Bytes` for `reason`. +- Re-export `Bytes`. # 0.25.0 diff --git a/src/protocol/frame/utf8.rs b/src/protocol/frame/utf8.rs index 8c0ff2d3..845c96bf 100644 --- a/src/protocol/frame/utf8.rs +++ b/src/protocol/frame/utf8.rs @@ -25,12 +25,10 @@ impl std::ops::Deref for Utf8Bytes { type Target = str; /// ``` - /// use tungstenite::protocol::frame::Utf8Bytes; - /// /// /// Example fn that takes a str slice /// fn a(s: &str) {} /// - /// let data = Utf8Bytes::from_static("foo123"); + /// let data = tungstenite::Utf8Bytes::from_static("foo123"); /// /// // auto-deref as arg /// a(&data); @@ -49,8 +47,7 @@ where for<'a> &'a str: PartialEq, { /// ``` - /// use tungstenite::protocol::frame::Utf8Bytes; - /// let payload = Utf8Bytes::from_static("foo123"); + /// let payload = tungstenite::Utf8Bytes::from_static("foo123"); /// assert_eq!(payload, "foo123"); /// assert_eq!(payload, "foo123".to_string()); /// assert_eq!(payload, &"foo123".to_string());