From fac5436a2be0fedb6572c462cbe77f8b0ec8d6f1 Mon Sep 17 00:00:00 2001 From: kralverde Date: Thu, 5 Sep 2024 14:28:09 -0400 Subject: [PATCH 1/3] get_string_len extra validity check --- pumpkin-protocol/src/bytebuf/mod.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pumpkin-protocol/src/bytebuf/mod.rs b/pumpkin-protocol/src/bytebuf/mod.rs index d6be909de..9a2cad4c6 100644 --- a/pumpkin-protocol/src/bytebuf/mod.rs +++ b/pumpkin-protocol/src/bytebuf/mod.rs @@ -84,6 +84,14 @@ impl ByteBuffer { "String length is bigger than max size", )); } + + if size as usize > self.buffer.len() { + return Err(Error::new( + ErrorKind::InvalidData, + "String length is bigger than packet", + )); + } + let data = self.buffer.copy_to_bytes(size as usize); if data.len() > max_size { return Err(Error::new( From 2993d6bfb66873046c8595d545c929771c035e64 Mon Sep 17 00:00:00 2001 From: kralverde Date: Thu, 5 Sep 2024 15:21:33 -0400 Subject: [PATCH 2/3] add checks for everything else --- pumpkin-protocol/src/bytebuf/deserializer.rs | 26 +-- pumpkin-protocol/src/bytebuf/mod.rs | 200 ++++++++++++------ pumpkin-protocol/src/server/handshake/mod.rs | 8 +- .../src/server/login/s_encryption_response.rs | 8 +- .../src/server/login/s_login_start.rs | 4 +- .../src/server/login/s_plugin_response.rs | 6 +- .../src/server/play/s_chat_message.rs | 12 +- .../src/server/play/s_interact.rs | 12 +- .../src/server/play/s_player_command.rs | 6 +- pumpkin/src/proxy/velocity.rs | 2 +- 10 files changed, 174 insertions(+), 110 deletions(-) diff --git a/pumpkin-protocol/src/bytebuf/deserializer.rs b/pumpkin-protocol/src/bytebuf/deserializer.rs index c2d96839f..1593d4bd2 100644 --- a/pumpkin-protocol/src/bytebuf/deserializer.rs +++ b/pumpkin-protocol/src/bytebuf/deserializer.rs @@ -45,77 +45,77 @@ impl<'a, 'de> de::Deserializer<'de> for Deserializer<'a> { where V: de::Visitor<'de>, { - visitor.visit_bool(self.inner.get_bool()) + visitor.visit_bool(self.inner.get_bool()?) } fn deserialize_i8(self, visitor: V) -> Result where V: de::Visitor<'de>, { - visitor.visit_i8(self.inner.get_i8()) + visitor.visit_i8(self.inner.get_i8()?) } fn deserialize_i16(self, visitor: V) -> Result where V: de::Visitor<'de>, { - visitor.visit_i16(self.inner.get_i16()) + visitor.visit_i16(self.inner.get_i16()?) } fn deserialize_i32(self, visitor: V) -> Result where V: de::Visitor<'de>, { - visitor.visit_i32(self.inner.get_i32()) + visitor.visit_i32(self.inner.get_i32()?) } fn deserialize_i64(self, visitor: V) -> Result where V: de::Visitor<'de>, { - visitor.visit_i64(self.inner.get_i64()) + visitor.visit_i64(self.inner.get_i64()?) } fn deserialize_u8(self, visitor: V) -> Result where V: de::Visitor<'de>, { - visitor.visit_u8(self.inner.get_u8()) + visitor.visit_u8(self.inner.get_u8()?) } fn deserialize_u16(self, visitor: V) -> Result where V: de::Visitor<'de>, { - visitor.visit_u16(self.inner.get_u16()) + visitor.visit_u16(self.inner.get_u16()?) } fn deserialize_u32(self, visitor: V) -> Result where V: de::Visitor<'de>, { - visitor.visit_u32(self.inner.get_u32()) + visitor.visit_u32(self.inner.get_u32()?) } fn deserialize_u64(self, visitor: V) -> Result where V: de::Visitor<'de>, { - visitor.visit_u64(self.inner.get_u64()) + visitor.visit_u64(self.inner.get_u64()?) } fn deserialize_f32(self, visitor: V) -> Result where V: de::Visitor<'de>, { - visitor.visit_f32(self.inner.get_f32()) + visitor.visit_f32(self.inner.get_f32()?) } fn deserialize_f64(self, visitor: V) -> Result where V: de::Visitor<'de>, { - visitor.visit_f64(self.inner.get_f64()) + visitor.visit_f64(self.inner.get_f64()?) } fn deserialize_char(self, _visitor: V) -> Result @@ -129,7 +129,7 @@ impl<'a, 'de> de::Deserializer<'de> for Deserializer<'a> { where V: de::Visitor<'de>, { - let string = self.inner.get_string().map_err(DeserializerError::Stdio)?; + let string = self.inner.get_string()?; visitor.visit_str(&string) } @@ -137,7 +137,7 @@ impl<'a, 'de> de::Deserializer<'de> for Deserializer<'a> { where V: de::Visitor<'de>, { - let string = self.inner.get_string().map_err(DeserializerError::Stdio)?; + let string = self.inner.get_string()?; visitor.visit_str(&string) } diff --git a/pumpkin-protocol/src/bytebuf/mod.rs b/pumpkin-protocol/src/bytebuf/mod.rs index 9a2cad4c6..7e3c025d5 100644 --- a/pumpkin-protocol/src/bytebuf/mod.rs +++ b/pumpkin-protocol/src/bytebuf/mod.rs @@ -1,7 +1,6 @@ use crate::{BitSet, FixedBitSet, VarInt, VarLongType}; use bytes::{Buf, BufMut, BytesMut}; use core::str; -use std::io::{self, Error, ErrorKind}; mod deserializer; pub use deserializer::DeserializerError; @@ -26,12 +25,12 @@ impl ByteBuffer { Self { buffer } } - pub fn get_var_int(&mut self) -> VarInt { + pub fn get_var_int(&mut self) -> Result { let mut value: i32 = 0; let mut position: i32 = 0; loop { - let read = self.buffer.get_u8(); + let read = self.get_u8()?; value |= ((read & SEGMENT_BITS) as i32) << position; @@ -42,19 +41,19 @@ impl ByteBuffer { position += 7; if position >= 32 { - panic!("VarInt is too big"); + return Err(DeserializerError::Message("VarInt is too big".to_string())); } } - VarInt(value) + Ok(VarInt(value)) } - pub fn get_var_long(&mut self) -> VarLongType { + pub fn get_var_long(&mut self) -> Result { let mut value: i64 = 0; let mut position: i64 = 0; loop { - let read = self.buffer.get_u8(); + let read = self.get_u8()?; value |= ((read & SEGMENT_BITS) as i64) << position; @@ -65,57 +64,48 @@ impl ByteBuffer { position += 7; if position >= 64 { - panic!("VarInt is too big"); + return Err(DeserializerError::Message("VarLong is too big".to_string())); } } - value + Ok(value) } - pub fn get_string(&mut self) -> Result { + pub fn get_string(&mut self) -> Result { self.get_string_len(32767) } - pub fn get_string_len(&mut self, max_size: usize) -> Result { - let size = self.get_var_int().0; + pub fn get_string_len(&mut self, max_size: usize) -> Result { + let size = self.get_var_int()?.0; if size as usize > max_size { - return Err(Error::new( - ErrorKind::InvalidData, - "String length is bigger than max size", + return Err(DeserializerError::Message( + "String length is bigger than max size".to_string(), )); } - if size as usize > self.buffer.len() { - return Err(Error::new( - ErrorKind::InvalidData, - "String length is bigger than packet", - )); - } - - let data = self.buffer.copy_to_bytes(size as usize); + let data = self.copy_to_bytes(size as usize)?; if data.len() > max_size { - return Err(Error::new( - ErrorKind::InvalidData, - "String is bigger than max size", + return Err(DeserializerError::Message( + "String is bigger than max size".to_string(), )); } match str::from_utf8(&data) { Ok(string_result) => Ok(string_result.to_string()), - Err(e) => Err(Error::new(ErrorKind::InvalidData, e)), + Err(e) => Err(DeserializerError::Message(e.to_string())), } } - pub fn get_bool(&mut self) -> bool { - self.buffer.get_u8() != 0 + pub fn get_bool(&mut self) -> Result { + Ok(self.get_u8()? != 0) } - pub fn get_uuid(&mut self) -> uuid::Uuid { + pub fn get_uuid(&mut self) -> Result { let mut bytes = [0u8; 16]; - self.buffer.copy_to_slice(&mut bytes); - uuid::Uuid::from_slice(&bytes).expect("Failed to parse UUID") + self.copy_to_slice(&mut bytes)?; + Ok(uuid::Uuid::from_slice(&bytes).expect("Failed to parse UUID")) } - pub fn get_fixed_bitset(&mut self, bits: usize) -> FixedBitSet { + pub fn get_fixed_bitset(&mut self, bits: usize) -> Result { self.copy_to_bytes(bits.div_ceil(8)) } @@ -169,11 +159,14 @@ impl ByteBuffer { /// Reads a boolean. If true, the closure is called, and the returned value is /// wrapped in Some. Otherwise, this returns None. - pub fn get_option(&mut self, val: impl FnOnce(&mut Self) -> T) -> Option { - if self.get_bool() { - Some(val(self)) + pub fn get_option( + &mut self, + val: impl FnOnce(&mut Self) -> Result, + ) -> Result, DeserializerError> { + if self.get_bool()? { + Ok(Some(val(self)?)) } else { - None + Ok(None) } } /// Writes `true` if the option is Some, or `false` if None. If the option is @@ -185,13 +178,16 @@ impl ByteBuffer { } } - pub fn get_list(&mut self, val: impl Fn(&mut Self) -> T) -> Vec { - let len = self.get_var_int().0 as usize; + pub fn get_list( + &mut self, + val: impl Fn(&mut Self) -> Result, + ) -> Result, DeserializerError> { + let len = self.get_var_int()?.0 as usize; let mut list = Vec::with_capacity(len); for _ in 0..len { - list.push(val(self)); + list.push(val(self)?); } - list + Ok(list) } /// Writes a list to the buffer. pub fn put_list(&mut self, list: &[T], write: impl Fn(&mut Self, &T)) { @@ -219,50 +215,109 @@ impl ByteBuffer { pub fn buf(&mut self) -> &mut BytesMut { &mut self.buffer } -} -// trait -impl ByteBuffer { - pub fn get_u8(&mut self) -> u8 { - self.buffer.get_u8() + // Trait equivalents + pub fn get_u8(&mut self) -> Result { + if self.buffer.has_remaining() { + Ok(self.buffer.get_u8()) + } else { + Err(DeserializerError::Message( + "No bytes left to consume".to_string(), + )) + } } - pub fn get_i8(&mut self) -> i8 { - self.buffer.get_i8() + pub fn get_i8(&mut self) -> Result { + if self.buffer.has_remaining() { + Ok(self.buffer.get_i8()) + } else { + Err(DeserializerError::Message( + "No bytes left to consume".to_string(), + )) + } } - pub fn get_u16(&mut self) -> u16 { - self.buffer.get_u16() + pub fn get_u16(&mut self) -> Result { + if self.buffer.remaining() >= 2 { + Ok(self.buffer.get_u16()) + } else { + Err(DeserializerError::Message( + "Less than 2 bytes left to consume".to_string(), + )) + } } - pub fn get_i16(&mut self) -> i16 { - self.buffer.get_i16() + pub fn get_i16(&mut self) -> Result { + if self.buffer.remaining() >= 2 { + Ok(self.buffer.get_i16()) + } else { + Err(DeserializerError::Message( + "Less than 2 bytes left to consume".to_string(), + )) + } } - pub fn get_u32(&mut self) -> u32 { - self.buffer.get_u32() + pub fn get_u32(&mut self) -> Result { + if self.buffer.remaining() >= 4 { + Ok(self.buffer.get_u32()) + } else { + Err(DeserializerError::Message( + "Less than 4 bytes left to consume".to_string(), + )) + } } - pub fn get_i32(&mut self) -> i32 { - self.buffer.get_i32() + pub fn get_i32(&mut self) -> Result { + if self.buffer.remaining() >= 4 { + Ok(self.buffer.get_i32()) + } else { + Err(DeserializerError::Message( + "Less than 4 bytes left to consume".to_string(), + )) + } } - pub fn get_u64(&mut self) -> u64 { - self.buffer.get_u64() + pub fn get_u64(&mut self) -> Result { + if self.buffer.remaining() >= 8 { + Ok(self.buffer.get_u64()) + } else { + Err(DeserializerError::Message( + "Less than 8 bytes left to consume".to_string(), + )) + } } - pub fn get_i64(&mut self) -> i64 { - self.buffer.get_i64() + pub fn get_i64(&mut self) -> Result { + if self.buffer.remaining() >= 8 { + Ok(self.buffer.get_i64()) + } else { + Err(DeserializerError::Message( + "Less than 8 bytes left to consume".to_string(), + )) + } } - pub fn get_f32(&mut self) -> f32 { - self.buffer.get_f32() + pub fn get_f32(&mut self) -> Result { + if self.buffer.remaining() >= 4 { + Ok(self.buffer.get_f32()) + } else { + Err(DeserializerError::Message( + "Less than 4 bytes left to consume".to_string(), + )) + } } - pub fn get_f64(&mut self) -> f64 { - self.buffer.get_f64() + pub fn get_f64(&mut self) -> Result { + if self.buffer.remaining() >= 8 { + Ok(self.buffer.get_f64()) + } else { + Err(DeserializerError::Message( + "Less than 8 bytes left to consume".to_string(), + )) + } } + // TODO: SerializerError? pub fn put_u8(&mut self, n: u8) { self.buffer.put_u8(n) } @@ -303,12 +358,21 @@ impl ByteBuffer { self.buffer.put_f64(n) } - pub fn copy_to_bytes(&mut self, len: usize) -> bytes::Bytes { - self.buffer.copy_to_bytes(len) + pub fn copy_to_bytes(&mut self, len: usize) -> Result { + if self.buffer.remaining() >= len { + Ok(self.buffer.copy_to_bytes(len)) + } else { + Err(DeserializerError::Message("Unable to copy".to_string())) + } } - pub fn copy_to_slice(&mut self, dst: &mut [u8]) { - self.buffer.copy_to_slice(dst) + pub fn copy_to_slice(&mut self, dst: &mut [u8]) -> Result<(), DeserializerError> { + if self.buffer.remaining() > dst.len() { + self.buffer.copy_to_slice(dst); + Ok(()) + } else { + Err(DeserializerError::Message("Unable to copy".to_string())) + } } pub fn put_slice(&mut self, src: &[u8]) { diff --git a/pumpkin-protocol/src/server/handshake/mod.rs b/pumpkin-protocol/src/server/handshake/mod.rs index 100e67fd0..4935e5d00 100644 --- a/pumpkin-protocol/src/server/handshake/mod.rs +++ b/pumpkin-protocol/src/server/handshake/mod.rs @@ -16,10 +16,10 @@ pub struct SHandShake { impl ServerPacket for SHandShake { fn read(bytebuf: &mut ByteBuffer) -> Result { Ok(Self { - protocol_version: bytebuf.get_var_int(), - server_address: bytebuf.get_string_len(255).unwrap(), - server_port: bytebuf.get_u16(), - next_state: bytebuf.get_var_int().into(), + protocol_version: bytebuf.get_var_int()?, + server_address: bytebuf.get_string_len(255)?, + server_port: bytebuf.get_u16()?, + next_state: bytebuf.get_var_int()?.into(), }) } } diff --git a/pumpkin-protocol/src/server/login/s_encryption_response.rs b/pumpkin-protocol/src/server/login/s_encryption_response.rs index 980e54b20..16d40289e 100644 --- a/pumpkin-protocol/src/server/login/s_encryption_response.rs +++ b/pumpkin-protocol/src/server/login/s_encryption_response.rs @@ -15,10 +15,10 @@ pub struct SEncryptionResponse { impl ServerPacket for SEncryptionResponse { fn read(bytebuf: &mut ByteBuffer) -> Result { - let shared_secret_length = bytebuf.get_var_int(); - let shared_secret = bytebuf.copy_to_bytes(shared_secret_length.0 as usize); - let verify_token_length = bytebuf.get_var_int(); - let verify_token = bytebuf.copy_to_bytes(shared_secret_length.0 as usize); + let shared_secret_length = bytebuf.get_var_int()?; + let shared_secret = bytebuf.copy_to_bytes(shared_secret_length.0 as usize)?; + let verify_token_length = bytebuf.get_var_int()?; + let verify_token = bytebuf.copy_to_bytes(shared_secret_length.0 as usize)?; Ok(Self { shared_secret_length, shared_secret: shared_secret.to_vec(), diff --git a/pumpkin-protocol/src/server/login/s_login_start.rs b/pumpkin-protocol/src/server/login/s_login_start.rs index 96bbbac70..f9f0cb7ab 100644 --- a/pumpkin-protocol/src/server/login/s_login_start.rs +++ b/pumpkin-protocol/src/server/login/s_login_start.rs @@ -14,8 +14,8 @@ pub struct SLoginStart { impl ServerPacket for SLoginStart { fn read(bytebuf: &mut ByteBuffer) -> Result { Ok(Self { - name: bytebuf.get_string_len(16).unwrap(), - uuid: bytebuf.get_uuid(), + name: bytebuf.get_string_len(16)?, + uuid: bytebuf.get_uuid()?, }) } } diff --git a/pumpkin-protocol/src/server/login/s_plugin_response.rs b/pumpkin-protocol/src/server/login/s_plugin_response.rs index 15fca41f1..44360bde2 100644 --- a/pumpkin-protocol/src/server/login/s_plugin_response.rs +++ b/pumpkin-protocol/src/server/login/s_plugin_response.rs @@ -16,9 +16,9 @@ pub struct SLoginPluginResponse { impl ServerPacket for SLoginPluginResponse { fn read(bytebuf: &mut ByteBuffer) -> Result { Ok(Self { - message_id: bytebuf.get_var_int(), - successful: bytebuf.get_bool(), - data: bytebuf.get_option(|v| v.get_slice()), + message_id: bytebuf.get_var_int()?, + successful: bytebuf.get_bool()?, + data: bytebuf.get_option(|v| Ok(v.get_slice()))?, }) } } diff --git a/pumpkin-protocol/src/server/play/s_chat_message.rs b/pumpkin-protocol/src/server/play/s_chat_message.rs index 9332a28f2..76bd707cc 100644 --- a/pumpkin-protocol/src/server/play/s_chat_message.rs +++ b/pumpkin-protocol/src/server/play/s_chat_message.rs @@ -21,12 +21,12 @@ pub struct SChatMessage { impl ServerPacket for SChatMessage { fn read(bytebuf: &mut ByteBuffer) -> Result { Ok(Self { - message: bytebuf.get_string().unwrap(), - timestamp: bytebuf.get_i64(), - salt: bytebuf.get_i64(), - signature: bytebuf.get_option(|v| v.copy_to_bytes(256)), - message_count: bytebuf.get_var_int(), - acknowledged: bytebuf.get_fixed_bitset(20), + message: bytebuf.get_string()?, + timestamp: bytebuf.get_i64()?, + salt: bytebuf.get_i64()?, + signature: bytebuf.get_option(|v| v.copy_to_bytes(256))?, + message_count: bytebuf.get_var_int()?, + acknowledged: bytebuf.get_fixed_bitset(20)?, }) } } diff --git a/pumpkin-protocol/src/server/play/s_interact.rs b/pumpkin-protocol/src/server/play/s_interact.rs index 18e651163..36c968645 100644 --- a/pumpkin-protocol/src/server/play/s_interact.rs +++ b/pumpkin-protocol/src/server/play/s_interact.rs @@ -18,8 +18,8 @@ impl ServerPacket for SInteract { fn read( bytebuf: &mut crate::bytebuf::ByteBuffer, ) -> Result { - let entity_id = bytebuf.get_var_int(); - let typ = bytebuf.get_var_int(); + let entity_id = bytebuf.get_var_int()?; + let typ = bytebuf.get_var_int()?; let action = ActionType::from_i32(typ.0).ok_or(DeserializerError::Message( "invalid action type".to_string(), ))?; @@ -27,13 +27,13 @@ impl ServerPacket for SInteract { ActionType::Interact => None, ActionType::Attack => None, ActionType::InteractAt => { - Some((bytebuf.get_f32(), bytebuf.get_f32(), bytebuf.get_f32())) + Some((bytebuf.get_f32()?, bytebuf.get_f32()?, bytebuf.get_f32()?)) } }; let hand = match action { - ActionType::Interact => Some(bytebuf.get_var_int()), + ActionType::Interact => Some(bytebuf.get_var_int()?), ActionType::Attack => None, - ActionType::InteractAt => Some(bytebuf.get_var_int()), + ActionType::InteractAt => Some(bytebuf.get_var_int()?), }; Ok(Self { @@ -41,7 +41,7 @@ impl ServerPacket for SInteract { typ, target_position, hand, - sneaking: bytebuf.get_bool(), + sneaking: bytebuf.get_bool()?, }) } } diff --git a/pumpkin-protocol/src/server/play/s_player_command.rs b/pumpkin-protocol/src/server/play/s_player_command.rs index f6365cc02..c042f4769 100644 --- a/pumpkin-protocol/src/server/play/s_player_command.rs +++ b/pumpkin-protocol/src/server/play/s_player_command.rs @@ -25,9 +25,9 @@ pub enum Action { impl ServerPacket for SPlayerCommand { fn read(bytebuf: &mut crate::bytebuf::ByteBuffer) -> Result { Ok(Self { - entity_id: bytebuf.get_var_int(), - action: bytebuf.get_var_int(), - jump_boost: bytebuf.get_var_int(), + entity_id: bytebuf.get_var_int()?, + action: bytebuf.get_var_int()?, + jump_boost: bytebuf.get_var_int()?, }) } } diff --git a/pumpkin/src/proxy/velocity.rs b/pumpkin/src/proxy/velocity.rs index b2f753506..b23e17082 100644 --- a/pumpkin/src/proxy/velocity.rs +++ b/pumpkin/src/proxy/velocity.rs @@ -52,7 +52,7 @@ pub fn receive_plugin_response( buf.put_slice(data_without_signature); // check velocity version - let version = buf.get_var_int(); + let version = buf.get_var_int().unwrap(); let version = version.0; if version > MAX_SUPPORTED_FORWARDING_VERSION { client.kick(&format!( From af878ff549cc8b0201bf9bd9a8935f4b6205e3c5 Mon Sep 17 00:00:00 2001 From: kralverde Date: Thu, 5 Sep 2024 15:29:27 -0400 Subject: [PATCH 3/3] fix check --- pumpkin-protocol/src/bytebuf/mod.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pumpkin-protocol/src/bytebuf/mod.rs b/pumpkin-protocol/src/bytebuf/mod.rs index 7e3c025d5..4ec1dc5a6 100644 --- a/pumpkin-protocol/src/bytebuf/mod.rs +++ b/pumpkin-protocol/src/bytebuf/mod.rs @@ -359,19 +359,23 @@ impl ByteBuffer { } pub fn copy_to_bytes(&mut self, len: usize) -> Result { - if self.buffer.remaining() >= len { + if self.buffer.len() >= len { Ok(self.buffer.copy_to_bytes(len)) } else { - Err(DeserializerError::Message("Unable to copy".to_string())) + Err(DeserializerError::Message( + "Unable to copy bytes".to_string(), + )) } } pub fn copy_to_slice(&mut self, dst: &mut [u8]) -> Result<(), DeserializerError> { - if self.buffer.remaining() > dst.len() { + if self.buffer.remaining() >= dst.len() { self.buffer.copy_to_slice(dst); Ok(()) } else { - Err(DeserializerError::Message("Unable to copy".to_string())) + Err(DeserializerError::Message( + "Unable to copy slice".to_string(), + )) } }