From afa3a667ac6639dc8535ccffb6551991cc5dcb6f Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Fri, 5 Apr 2024 00:19:39 +0530 Subject: [PATCH] protocol: Check for illegal packet escape --- src/io/aio.rs | 7 +------ src/io/sync.rs | 7 +------ src/protocol/pipe.rs | 47 ++++++++++++++++++++++---------------------- 3 files changed, 26 insertions(+), 35 deletions(-) diff --git a/src/io/aio.rs b/src/io/aio.rs index efdb216..9491e24 100644 --- a/src/io/aio.rs +++ b/src/io/aio.rs @@ -161,7 +161,6 @@ impl TcpConnection { self.con.write_all(pipeline.buf()).await?; self.buf.clear(); // read - let mut expected = Decoder::MIN_READBACK; let mut cursor = 0; let mut state = MRespState::default(); loop { @@ -170,15 +169,11 @@ impl TcpConnection { if n == 0 { return Err(Error::IoError(std::io::ErrorKind::ConnectionReset.into())); } - if n < expected { - continue; - } self.buf.extend_from_slice(&buf[..n]); let mut decoder = Decoder::new(&self.buf, cursor); - match decoder.validate_pipe(cursor == 0, state) { + match decoder.validate_pipe(pipeline.query_count(), state) { PipelineResult::Completed(r) => return Ok(r), PipelineResult::Pending(_state) => { - expected = 1; cursor = decoder.position(); state = _state; } diff --git a/src/io/sync.rs b/src/io/sync.rs index e47d4f4..68603c2 100644 --- a/src/io/sync.rs +++ b/src/io/sync.rs @@ -155,7 +155,6 @@ impl TcpConnection { self.con.write_all(pipeline.buf())?; self.buf.clear(); // read - let mut expected = Decoder::MIN_READBACK; let mut cursor = 0; let mut state = MRespState::default(); loop { @@ -164,15 +163,11 @@ impl TcpConnection { if n == 0 { return Err(Error::IoError(std::io::ErrorKind::ConnectionReset.into())); } - if n < expected { - continue; - } self.buf.extend_from_slice(&buf[..n]); let mut decoder = Decoder::new(&self.buf, cursor); - match decoder.validate_pipe(cursor == 0, state) { + match decoder.validate_pipe(pipeline.query_count(), state) { PipelineResult::Completed(r) => return Ok(r), PipelineResult::Pending(_state) => { - expected = 1; cursor = decoder.position(); state = _state; } diff --git a/src/protocol/pipe.rs b/src/protocol/pipe.rs index 646f706..7c44961 100644 --- a/src/protocol/pipe.rs +++ b/src/protocol/pipe.rs @@ -16,17 +16,18 @@ use { super::{ - state::{DecodeState, MetaState, RState, ResponseState}, + state::{DecodeState, RState, ResponseState}, Decoder, ProtocolError, }, crate::response::Response, }; +const PIPELINE_EXCEPTION: u8 = 0xFF; + #[derive(Debug, PartialEq, Default)] pub(crate) struct MRespState { processed: Vec, pending: Option, - expected: MetaState, } #[derive(Debug, PartialEq)] @@ -37,19 +38,18 @@ pub(crate) enum PipelineResult { } impl MRespState { - fn step(mut self, decoder: &mut Decoder) -> PipelineResult { - match self.expected.finished(decoder) { - Ok(true) => {} - Ok(false) => return PipelineResult::Pending(self), - Err(e) => return PipelineResult::Error(e), - } + #[cold] + fn except() -> PipelineResult { + PipelineResult::Error(ProtocolError::InvalidPacket) + } + fn step(mut self, decoder: &mut Decoder, expected: usize) -> PipelineResult { loop { - if self.processed.len() as u64 == self.expected.val() { - return PipelineResult::Completed(self.processed); - } if decoder._cursor_eof() { return PipelineResult::Pending(self); } + if decoder._cursor_value() == PIPELINE_EXCEPTION { + return Self::except(); + } match decoder.validate_response(RState( self.pending.take().unwrap_or(ResponseState::Initial), )) { @@ -57,7 +57,12 @@ impl MRespState { self.pending = Some(s); return PipelineResult::Pending(self); } - DecodeState::Completed(c) => self.processed.push(c), + DecodeState::Completed(c) => { + self.processed.push(c); + if self.processed.len() == expected { + return PipelineResult::Completed(self.processed); + } + } DecodeState::Error(e) => return PipelineResult::Error(e), } } @@ -65,24 +70,20 @@ impl MRespState { } impl<'a> Decoder<'a> { - pub fn validate_pipe(&mut self, first: bool, state: MRespState) -> PipelineResult { - if first && self._cursor_next() != b'P' { - PipelineResult::Error(ProtocolError::InvalidPacket) - } else { - state.step(self) - } + pub fn validate_pipe(&mut self, expected: usize, state: MRespState) -> PipelineResult { + state.step(self, expected) } } #[cfg(test)] -const QUERY: &[u8] = b"P5\n\x12\x10\xFF\xFF\x115\n\x00\x01\x01\x0D5\nsayan\x0220\n\x0E0\n\x115\n\x00\x01\x01\x0D5\nelana\x0221\n\x0E0\n\x115\n\x00\x01\x01\x0D5\nemily\x0222\n\x0E0\n"; +const QUERY: &[u8] = b"\x12\x10\xFF\xFF\x115\n\x00\x01\x01\x0D5\nsayan\x0220\n\x0E0\n\x115\n\x00\x01\x01\x0D5\nelana\x0221\n\x0E0\n\x115\n\x00\x01\x01\x0D5\nemily\x0222\n\x0E0\n"; #[test] fn t_pipe() { use crate::response::{Response, Row, Value}; let mut decoder = Decoder::new(QUERY, 0); assert_eq!( - decoder.validate_pipe(true, MRespState::default()), + decoder.validate_pipe(5, MRespState::default()), PipelineResult::Completed(vec![ Response::Empty, Response::Error(u16::MAX), @@ -117,13 +118,13 @@ fn t_pipe_staged() { let mut dec = Decoder::new(&QUERY[..i], 0); if i < 3 { assert!(matches!( - dec.validate_pipe(true, MRespState::default()), + dec.validate_pipe(5, MRespState::default()), PipelineResult::Pending(_) )); } else { assert!(matches!( - dec.validate_pipe(true, MRespState::default()), - PipelineResult::Pending(p) if p.expected.val() == 5 + dec.validate_pipe(5, MRespState::default()), + PipelineResult::Pending(_) )); } }