Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework as Bytes-only payload #473

Merged
merged 10 commits into from
Dec 17, 2024
2 changes: 1 addition & 1 deletion benches/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ fn benchmark(c: &mut Criterion) {
while sum != expected_sum {
match ws.read().unwrap() {
Message::Binary(v) => {
let a: &[u8; 8] = v.as_slice().try_into().unwrap();
let a: &[u8; 8] = v.as_ref().try_into().unwrap();
sum += u64::from_le_bytes(*a);
}
Message::Text(msg) => {
Expand Down
83 changes: 48 additions & 35 deletions src/protocol/frame/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{
default::Default,
fmt,
io::{Cursor, ErrorKind, Read, Write},
mem,
result::Result as StdResult,
str::Utf8Error,
string::String,
Expand All @@ -14,13 +15,12 @@ use log::*;
use super::{
coding::{CloseCode, Control, Data, OpCode},
mask::{apply_mask, generate_mask},
Payload,
};
use crate::{
error::{Error, ProtocolError, Result},
protocol::frame::Utf8Payload,
protocol::frame::Utf8Bytes,
};
use bytes::{Buf, BytesMut};
use bytes::{Buf, Bytes, BytesMut};

/// A struct representing the close command.
#[derive(Debug, Clone, Eq, PartialEq)]
Expand Down Expand Up @@ -213,7 +213,7 @@ impl FrameHeader {
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Frame {
header: FrameHeader,
payload: Payload,
payload: Bytes,
}

impl Frame {
Expand Down Expand Up @@ -246,13 +246,7 @@ impl Frame {
/// Get a reference to the frame's payload.
#[inline]
pub fn payload(&self) -> &[u8] {
self.payload.as_slice()
}

/// Get a mutable reference to the frame's payload.
#[inline]
pub fn payload_mut(&mut self) -> &mut [u8] {
self.payload.as_mut_slice()
&self.payload
}

/// Test whether the frame is masked.
Expand All @@ -270,31 +264,22 @@ impl Frame {
self.header.set_random_mask();
}

/// This method unmasks the payload and should only be called on frames that are actually
/// masked. In other words, those frames that have just been received from a client endpoint.
#[inline]
pub(crate) fn apply_mask(&mut self) {
if let Some(mask) = self.header.mask.take() {
apply_mask(self.payload.as_mut_slice(), mask);
}
}

/// Consume the frame into its payload as string.
#[inline]
pub fn into_text(self) -> StdResult<Utf8Payload, Utf8Error> {
self.payload.into_text()
pub fn into_text(self) -> StdResult<Utf8Bytes, Utf8Error> {
self.payload.try_into()
}

/// Consume the frame into its payload.
#[inline]
pub fn into_payload(self) -> Payload {
pub fn into_payload(self) -> Bytes {
self.payload
}

/// Get frame payload as `&str`.
#[inline]
pub fn to_text(&self) -> Result<&str, Utf8Error> {
std::str::from_utf8(self.payload.as_slice())
std::str::from_utf8(&self.payload)
}

/// Consume the frame into a closing frame.
Expand All @@ -304,7 +289,7 @@ impl Frame {
0 => Ok(None),
1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)),
_ => {
let mut data = self.payload.as_slice();
let mut data = BytesMut::from(self.payload);
alexheretic marked this conversation as resolved.
Show resolved Hide resolved
let code = u16::from_be_bytes([data[0], data[1]]).into();
data.advance(2);
let text = String::from_utf8(data.to_vec())?;
Expand All @@ -315,7 +300,7 @@ impl Frame {

/// Create a new data frame.
#[inline]
pub fn message(data: impl Into<Payload>, opcode: OpCode, is_final: bool) -> Frame {
pub fn message(data: impl Into<Bytes>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
Frame {
header: FrameHeader { is_final, opcode, ..FrameHeader::default() },
Expand All @@ -325,7 +310,7 @@ impl Frame {

/// Create a new Pong control frame.
#[inline]
pub fn pong(data: impl Into<Payload>) -> Frame {
pub fn pong(data: impl Into<Bytes>) -> Frame {
Frame {
header: FrameHeader {
opcode: OpCode::Control(Control::Pong),
Expand All @@ -337,7 +322,7 @@ impl Frame {

/// Create a new Ping control frame.
#[inline]
pub fn ping(data: impl Into<Payload>) -> Frame {
pub fn ping(data: impl Into<Bytes>) -> Frame {
Frame {
header: FrameHeader {
opcode: OpCode::Control(Control::Ping),
Expand All @@ -359,19 +344,39 @@ impl Frame {
<_>::default()
};

Frame { header: FrameHeader::default(), payload: Payload::Owned(payload) }
Frame { header: FrameHeader::default(), payload: payload.into() }
}

/// Create a frame from given header and data.
pub fn from_payload(header: FrameHeader, payload: Payload) -> Self {
pub fn from_payload(header: FrameHeader, payload: Bytes) -> Self {
Frame { header, payload }
}

/// Write a frame out to a buffer
pub fn format(mut self, output: &mut impl Write) -> Result<()> {
self.header.format(self.payload.len() as u64, output)?;
self.apply_mask();
output.write_all(self.payload())?;

if let Some(mask) = self.header.mask.take() {
let mut data = Vec::from(mem::take(&mut self.payload));
apply_mask(&mut data, mask);
output.write_all(&data)?;
} else {
output.write_all(&self.payload)?;
}

Ok(())
}

pub(crate) fn format_into_buf(mut self, buf: &mut Vec<u8>) -> Result<()> {
self.header.format(self.payload.len() as u64, buf)?;

let len = buf.len();
buf.extend_from_slice(&self.payload);

if let Some(mask) = self.header.mask.take() {
apply_mask(&mut buf[len..], mask);
}

Ok(())
}
}
Expand Down Expand Up @@ -399,7 +404,7 @@ payload: 0x{}
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
self.len(),
self.payload.len(),
self.payload.as_slice().iter().fold(String::new(), |mut output, byte| {
self.payload.iter().fold(String::new(), |mut output, byte| {
_ = write!(output, "{byte:02x}");
output
})
Expand Down Expand Up @@ -474,7 +479,7 @@ mod tests {
let mut payload = Vec::new();
raw.read_to_end(&mut payload).unwrap();
let frame = Frame::from_payload(header, payload.into());
assert_eq!(frame.into_payload(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
assert_eq!(frame.into_payload(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07][..]);
}

#[test]
Expand All @@ -485,9 +490,17 @@ mod tests {
assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]);
}

#[test]
fn format_into_buf() {
let frame = Frame::ping(vec![0x01, 0x02]);
let mut buf = Vec::with_capacity(frame.len());
frame.format_into_buf(&mut buf).unwrap();
assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]);
}

#[test]
fn display() {
let f = Frame::message(Payload::from_static(b"hi there"), OpCode::Data(Data::Text), true);
let f = Frame::message(Bytes::from_static(b"hi there"), OpCode::Data(Data::Text), true);
let view = format!("{f}");
assert!(view.contains("payload:"));
}
Expand Down
40 changes: 29 additions & 11 deletions src/protocol/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ pub mod coding;
#[allow(clippy::module_inception)]
mod frame;
mod mask;
mod payload;
mod utf8;

pub use self::{
frame::{CloseFrame, Frame, FrameHeader},
payload::{Payload, Utf8Payload},
utf8::Utf8Bytes,
};

use crate::{
error::{CapacityError, Error, Result},
error::{CapacityError, Error, ProtocolError, Result},
protocol::frame::mask::apply_mask,
Message,
};
use bytes::BytesMut;
Expand Down Expand Up @@ -65,7 +66,7 @@ where
{
/// Read a frame from stream.
pub fn read(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> {
self.codec.read_frame(&mut self.stream, max_size)
self.codec.read_frame(&mut self.stream, max_size, false, true)
}
}

Expand Down Expand Up @@ -158,13 +159,15 @@ impl FrameCodec {
&mut self,
stream: &mut Stream,
max_size: Option<usize>,
unmask: bool,
accept_unmasked: bool,
) -> Result<Option<Frame>>
where
Stream: Read,
{
let max_size = max_size.unwrap_or_else(usize::max_value);

let payload = loop {
let mut payload = loop {
{
if self.header.is_none() {
let mut cursor = Cursor::new(&mut self.in_buffer);
Expand Down Expand Up @@ -205,9 +208,24 @@ impl FrameCodec {
}
};

let (header, length) = self.header.take().expect("Bug: no frame header");
let (mut header, length) = self.header.take().expect("Bug: no frame header");
debug_assert_eq!(payload.len() as u64, length);
let frame = Frame::from_payload(header, Payload::Owned(payload));

if unmask {
if let Some(mask) = header.mask.take() {
// A server MUST remove masking for data frames received from a client
// as described in Section 5.3. (RFC 6455)
apply_mask(&mut payload, mask);
} else if !accept_unmasked {
// The server MUST close the connection upon receiving a
// frame that is not masked. (RFC 6455)
// The only exception here is if the user explicitly accepts given
// stream by setting WebSocketConfig.accept_unmasked_frames to true
return Err(Error::Protocol(ProtocolError::UnmaskedFrameFromClient));
}
}

let frame = Frame::from_payload(header, payload.freeze());
trace!("received frame {frame}");
Ok(Some(frame))
}
Expand All @@ -230,7 +248,7 @@ impl FrameCodec {
trace!("writing frame {frame}");

self.out_buffer.reserve(frame.len());
frame.format(&mut self.out_buffer).expect("Bug: can't write to vector");
frame.format_into_buf(&mut self.out_buffer).expect("Bug: can't write to vector");

if self.out_buffer.len() > self.out_buffer_write_len {
self.write_out_buffer(stream)
Expand Down Expand Up @@ -284,9 +302,9 @@ mod tests {

assert_eq!(
sock.read(None).unwrap().unwrap().into_payload(),
&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07][..]
);
assert_eq!(sock.read(None).unwrap().unwrap().into_payload(), &[0x03, 0x02, 0x01]);
assert_eq!(sock.read(None).unwrap().unwrap().into_payload(), &[0x03, 0x02, 0x01][..]);
assert!(sock.read(None).unwrap().is_none());

let (_, rest) = sock.into_inner();
Expand All @@ -299,7 +317,7 @@ mod tests {
let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
assert_eq!(
sock.read(None).unwrap().unwrap().into_payload(),
&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07][..]
);
}

Expand Down
Loading
Loading