Skip to content

Commit

Permalink
Merge pull request #21 from lukas0008/master
Browse files Browse the repository at this point in the history
Add Deserializer for server bound packets
  • Loading branch information
Snowiiii authored Aug 7, 2024
2 parents 0cec9bf + 96c9a57 commit 4f06496
Show file tree
Hide file tree
Showing 27 changed files with 741 additions and 355 deletions.
321 changes: 321 additions & 0 deletions pumpkin-protocol/src/bytebuf/deserializer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
use std::fmt::Display;

use serde::de::{self, DeserializeSeed, SeqAccess};
use thiserror::Error;

use super::ByteBuffer;

pub struct Deserializer<'a> {
inner: &'a mut ByteBuffer,
}

#[derive(Debug, Error)]
pub enum DeserializerError {
#[error("serializer error {0}")]
Message(String),
#[error("Stdio error")]
Stdio(std::io::Error),
}

impl de::Error for DeserializerError {
fn custom<T: Display>(msg: T) -> Self {
Self::Message(msg.to_string())
}
}

impl<'a> Deserializer<'a> {
pub fn new(buf: &'a mut ByteBuffer) -> Self {
Self { inner: buf }
}
}

impl<'a, 'de> de::Deserializer<'de> for Deserializer<'a> {
type Error = DeserializerError;

fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
unimplemented!(
"This is impossible to do, since you cannot infer the data structure from the packet"
)
}

fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_bool(self.inner.get_bool())
}

fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_i8(self.inner.get_i8())
}

fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_i16(self.inner.get_i16())
}

fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_i32(self.inner.get_i32())
}

fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_i64(self.inner.get_i64())
}

fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_u8(self.inner.get_u8())
}

fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_u16(self.inner.get_u16())
}

fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_u32(self.inner.get_u32())
}

fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_u64(self.inner.get_u64())
}

fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_f32(self.inner.get_f32())
}

fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_f64(self.inner.get_f64())
}

fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
unimplemented!()
}

fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
let string = self.inner.get_string().map_err(DeserializerError::Stdio)?;
visitor.visit_str(&string)
}

fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
let string = self.inner.get_string().map_err(DeserializerError::Stdio)?;
visitor.visit_str(&string)
}

fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
unimplemented!()
}

fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
unimplemented!()
}

fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
unimplemented!()
}

fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
unimplemented!()
}

fn deserialize_unit_struct<V>(
self,
_name: &'static str,
_visitor: V,
) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
unimplemented!()
}

fn deserialize_newtype_struct<V>(
self,
_name: &'static str,
_visitor: V,
) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
unimplemented!()
}

fn deserialize_seq<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
struct Access<'a, 'b> {
deserializer: &'a mut Deserializer<'b>,
}

impl<'de, 'a, 'b: 'a> SeqAccess<'de> for Access<'a, 'b> {
type Error = DeserializerError;

fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: de::DeserializeSeed<'de>,
{
let value = DeserializeSeed::deserialize(
seed,
Deserializer {
inner: self.deserializer.inner,
},
)?;
Ok(Some(value))
}
}

let value = visitor.visit_seq(Access {
deserializer: &mut self,
});

value
}

fn deserialize_tuple<V>(mut self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
struct Access<'a, 'b> {
deserializer: &'a mut Deserializer<'b>,
len: usize,
}

impl<'de, 'a, 'b: 'a> SeqAccess<'de> for Access<'a, 'b> {
type Error = DeserializerError;

fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: de::DeserializeSeed<'de>,
{
if self.len > 0 {
self.len -= 1;
let value = DeserializeSeed::deserialize(
seed,
Deserializer {
inner: self.deserializer.inner,
},
)?;
Ok(Some(value))
} else {
Ok(None)
}
}
}

let value = visitor.visit_seq(Access {
deserializer: &mut self,
len,
});

value
}

fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
_len: usize,
_visitor: V,
) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
unimplemented!()
}

fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
unimplemented!()
}

fn deserialize_struct<V>(
self,
_name: &'static str,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
self.deserialize_tuple(fields.len(), visitor)
}

fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
_visitor: V,
) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
unimplemented!()
}

fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
unimplemented!()
}

fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
unimplemented!()
}
}
46 changes: 46 additions & 0 deletions pumpkin-protocol/src/bytebuf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use bytes::{Buf, BufMut, BytesMut};
use core::str;
use std::io::{self, Error, ErrorKind};

mod deserializer;
pub use deserializer::DeserializerError;
pub mod packet_id;
mod serializer;

Expand Down Expand Up @@ -312,3 +314,47 @@ impl ByteBuffer {
self.buffer.split()
}
}

#[cfg(test)]
mod test {
use serde::{Deserialize, Serialize};

use crate::{
bytebuf::{deserializer, serializer, ByteBuffer},
VarInt,
};

#[test]
fn test_i32_reserialize() {
#[derive(serde::Serialize, serde::Deserialize, PartialEq, Eq, Debug)]
struct Foo {
bar: i32,
}
let foo = Foo { bar: 69 };
let mut serializer = serializer::Serializer::new(ByteBuffer::empty());
foo.serialize(&mut serializer).unwrap();

let mut serialized: ByteBuffer = serializer.into();
let deserialized: Foo =
Foo::deserialize(deserializer::Deserializer::new(&mut serialized)).unwrap();

assert_eq!(foo, deserialized);
}

#[test]
fn test_varint_reserialize() {
#[derive(serde::Serialize, serde::Deserialize, PartialEq, Eq, Debug)]
struct Foo {
bar: VarInt,
}
let foo = Foo { bar: 69.into() };
let mut serializer = serializer::Serializer::new(ByteBuffer::empty());
foo.serialize(&mut serializer).unwrap();

let mut serialized: ByteBuffer = serializer.into();
let deserialized: Foo =
Foo::deserialize(deserializer::Deserializer::new(&mut serialized)).unwrap();

assert_eq!(foo, deserialized);
}
}
Loading

0 comments on commit 4f06496

Please sign in to comment.