Skip to content

Commit

Permalink
protocol: Check for illegal packet escape
Browse files Browse the repository at this point in the history
  • Loading branch information
ohsayan committed Apr 4, 2024
1 parent d3c0b90 commit afa3a66
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 35 deletions.
7 changes: 1 addition & 6 deletions src/io/aio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ impl<C: AsyncWriteExt + AsyncReadExt + Unpin> TcpConnection<C> {
self.con.write_all(pipeline.buf()).await?;
self.buf.clear();
// read
let mut expected = Decoder::MIN_READBACK;
let mut cursor = 0;
let mut state = MRespState::default();
loop {
Expand All @@ -170,15 +169,11 @@ impl<C: AsyncWriteExt + AsyncReadExt + Unpin> TcpConnection<C> {
if n == 0 {
return Err(Error::IoError(std::io::ErrorKind::ConnectionReset.into()));
}
if n < expected {
continue;
}
self.buf.extend_from_slice(&buf[..n]);
let mut decoder = Decoder::new(&self.buf, cursor);
match decoder.validate_pipe(cursor == 0, state) {
match decoder.validate_pipe(pipeline.query_count(), state) {
PipelineResult::Completed(r) => return Ok(r),
PipelineResult::Pending(_state) => {
expected = 1;
cursor = decoder.position();
state = _state;
}
Expand Down
7 changes: 1 addition & 6 deletions src/io/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ impl<C: Write + Read> TcpConnection<C> {
self.con.write_all(pipeline.buf())?;
self.buf.clear();
// read
let mut expected = Decoder::MIN_READBACK;
let mut cursor = 0;
let mut state = MRespState::default();
loop {
Expand All @@ -164,15 +163,11 @@ impl<C: Write + Read> TcpConnection<C> {
if n == 0 {
return Err(Error::IoError(std::io::ErrorKind::ConnectionReset.into()));
}
if n < expected {
continue;
}
self.buf.extend_from_slice(&buf[..n]);
let mut decoder = Decoder::new(&self.buf, cursor);
match decoder.validate_pipe(cursor == 0, state) {
match decoder.validate_pipe(pipeline.query_count(), state) {
PipelineResult::Completed(r) => return Ok(r),
PipelineResult::Pending(_state) => {
expected = 1;
cursor = decoder.position();
state = _state;
}
Expand Down
47 changes: 24 additions & 23 deletions src/protocol/pipe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@

use {
super::{
state::{DecodeState, MetaState, RState, ResponseState},
state::{DecodeState, RState, ResponseState},
Decoder, ProtocolError,
},
crate::response::Response,
};

const PIPELINE_EXCEPTION: u8 = 0xFF;

#[derive(Debug, PartialEq, Default)]
pub(crate) struct MRespState {
processed: Vec<Response>,
pending: Option<ResponseState>,
expected: MetaState,
}

#[derive(Debug, PartialEq)]
Expand All @@ -37,52 +38,52 @@ pub(crate) enum PipelineResult {
}

impl MRespState {
fn step(mut self, decoder: &mut Decoder) -> PipelineResult {
match self.expected.finished(decoder) {
Ok(true) => {}
Ok(false) => return PipelineResult::Pending(self),
Err(e) => return PipelineResult::Error(e),
}
#[cold]
fn except() -> PipelineResult {
PipelineResult::Error(ProtocolError::InvalidPacket)
}
fn step(mut self, decoder: &mut Decoder, expected: usize) -> PipelineResult {
loop {
if self.processed.len() as u64 == self.expected.val() {
return PipelineResult::Completed(self.processed);
}
if decoder._cursor_eof() {
return PipelineResult::Pending(self);
}
if decoder._cursor_value() == PIPELINE_EXCEPTION {
return Self::except();
}
match decoder.validate_response(RState(
self.pending.take().unwrap_or(ResponseState::Initial),
)) {
DecodeState::ChangeState(RState(s)) => {
self.pending = Some(s);
return PipelineResult::Pending(self);
}
DecodeState::Completed(c) => self.processed.push(c),
DecodeState::Completed(c) => {
self.processed.push(c);
if self.processed.len() == expected {
return PipelineResult::Completed(self.processed);
}
}
DecodeState::Error(e) => return PipelineResult::Error(e),
}
}
}
}

impl<'a> Decoder<'a> {
pub fn validate_pipe(&mut self, first: bool, state: MRespState) -> PipelineResult {
if first && self._cursor_next() != b'P' {
PipelineResult::Error(ProtocolError::InvalidPacket)
} else {
state.step(self)
}
pub fn validate_pipe(&mut self, expected: usize, state: MRespState) -> PipelineResult {
state.step(self, expected)
}
}

#[cfg(test)]
const QUERY: &[u8] = b"P5\n\x12\x10\xFF\xFF\x115\n\x00\x01\x01\x0D5\nsayan\x0220\n\x0E0\n\x115\n\x00\x01\x01\x0D5\nelana\x0221\n\x0E0\n\x115\n\x00\x01\x01\x0D5\nemily\x0222\n\x0E0\n";
const QUERY: &[u8] = b"\x12\x10\xFF\xFF\x115\n\x00\x01\x01\x0D5\nsayan\x0220\n\x0E0\n\x115\n\x00\x01\x01\x0D5\nelana\x0221\n\x0E0\n\x115\n\x00\x01\x01\x0D5\nemily\x0222\n\x0E0\n";

#[test]
fn t_pipe() {
use crate::response::{Response, Row, Value};
let mut decoder = Decoder::new(QUERY, 0);
assert_eq!(
decoder.validate_pipe(true, MRespState::default()),
decoder.validate_pipe(5, MRespState::default()),
PipelineResult::Completed(vec![
Response::Empty,
Response::Error(u16::MAX),
Expand Down Expand Up @@ -117,13 +118,13 @@ fn t_pipe_staged() {
let mut dec = Decoder::new(&QUERY[..i], 0);
if i < 3 {
assert!(matches!(
dec.validate_pipe(true, MRespState::default()),
dec.validate_pipe(5, MRespState::default()),
PipelineResult::Pending(_)
));
} else {
assert!(matches!(
dec.validate_pipe(true, MRespState::default()),
PipelineResult::Pending(p) if p.expected.val() == 5
dec.validate_pipe(5, MRespState::default()),
PipelineResult::Pending(_)
));
}
}
Expand Down

0 comments on commit afa3a66

Please sign in to comment.