Skip to content

Commit

Permalink
support more Rust built-in types (#156)
Browse files Browse the repository at this point in the history
* support hashmap

* template impl for id()

* hashset

* i128
  • Loading branch information
chenyan-dfinity authored Jan 7, 2021
1 parent d48bc04 commit e54d3e4
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 38 deletions.
68 changes: 62 additions & 6 deletions rust/candid/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,30 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
primitive_impl!(f32, Opcode::Float32, read_f32::<LittleEndian>);
primitive_impl!(f64, Opcode::Float64, read_f64::<LittleEndian>);

fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
use std::convert::TryInto;
self.record_nesting_depth = 0;
self.check_type(Opcode::Int)?;
let v = Int::decode(&mut self.input).map_err(Error::msg)?;
let value: i128 = v.0.try_into().map_err(Error::msg)?;
visitor.visit_i128(value)
}

fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
use std::convert::TryInto;
self.record_nesting_depth = 0;
self.check_type(Opcode::Nat)?;
let v = Nat::decode(&mut self.input).map_err(Error::msg)?;
let value: u128 = v.0.try_into().map_err(Error::msg)?;
visitor.visit_u128(value)
}

fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
Expand Down Expand Up @@ -639,7 +663,6 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
where
V: Visitor<'de>,
{
self.record_nesting_depth = 0;
self.deserialize_unit(visitor)
}
fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
Expand Down Expand Up @@ -668,11 +691,22 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
_ => Err(Error::msg("seq only takes vector or tuple")),
}
}
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
fn deserialize_map<V>(mut self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.record_nesting_depth = 0;
self.check_type(Opcode::Vec)?;
let len = self.leb128_read()?;
let ty = self.peek_current_type()?.clone();
let value = visitor.visit_map(Compound::new(&mut self, Style::Map { len, ty }));
self.pop_current_type()?;
value
}
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_tuple_struct<V>(
Expand All @@ -684,7 +718,6 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
where
V: Visitor<'de>,
{
self.record_nesting_depth = 0;
self.deserialize_seq(visitor)
}
fn deserialize_struct<V>(
Expand Down Expand Up @@ -755,7 +788,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
}

serde::forward_to_deserialize_any! {
char bytes byte_buf map
char bytes byte_buf
}
}

Expand All @@ -776,6 +809,10 @@ enum Style {
len: u32,
fs: BTreeMap<u32, Option<&'static str>>,
},
Map {
len: u64,
ty: RawValue,
},
}

struct Compound<'a, 'de> {
Expand Down Expand Up @@ -855,14 +892,33 @@ impl<'de, 'a> de::MapAccess<'de> for Compound<'a, 'de> {
}
seed.deserialize(&mut *self.de).map(Some)
}
_ => Err(Error::msg("expect struct")),
Style::Map { ref mut len, .. } => {
// This only comes from deserialize_map
if *len == 0 {
return Ok(None);
}
self.de.check_type(Opcode::Record)?;
assert_eq!(2, self.de.pop_current_type()?.get_u32()?);
assert_eq!(0, self.de.pop_current_type()?.get_u32()?);
*len -= 1;
seed.deserialize(&mut *self.de).map(Some)
}
_ => Err(Error::msg("expect struct or map")),
}
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
where
V: de::DeserializeSeed<'de>,
{
seed.deserialize(&mut *self.de)
match self.style {
Style::Map { ref ty, .. } => {
assert_eq!(1, self.de.pop_current_type()?.get_u32()?);
let res = seed.deserialize(&mut *self.de);
self.de.current_type.push_front(ty.clone());
res
}
_ => seed.deserialize(&mut *self.de),
}
}
}

Expand Down
118 changes: 87 additions & 31 deletions rust/candid/src/types/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use super::{CandidType, Compound, Serializer};
macro_rules! primitive_impl {
($t:ty, $id:tt, $method:ident $($cast:tt)*) => {
impl CandidType for $t {
fn id() -> TypeId { TypeId::of::<$t>() }
fn _ty() -> Type { Type::$id }
fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error> where S: Serializer {
serializer.$method(*self $($cast)*)
Expand Down Expand Up @@ -35,10 +34,30 @@ primitive_impl!(f64, Float64, serialize_float64);
primitive_impl!(isize, Int64, serialize_int64 as i64);
primitive_impl!(usize, Nat64, serialize_nat64 as u64);

impl CandidType for String {
fn id() -> TypeId {
TypeId::of::<String>()
impl CandidType for i128 {
fn _ty() -> Type {
Type::Int
}
fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
where
S: Serializer,
{
serializer.serialize_int(&crate::Int::from(*self))
}
}
impl CandidType for u128 {
fn _ty() -> Type {
Type::Nat
}
fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
where
S: Serializer,
{
serializer.serialize_nat(&crate::Nat::from(*self))
}
}

impl CandidType for String {
fn _ty() -> Type {
Type::Text
}
Expand All @@ -54,9 +73,6 @@ impl<T: Sized> CandidType for Option<T>
where
T: CandidType,
{
fn id() -> TypeId {
TypeId::of::<Option<T>>()
}
fn _ty() -> Type {
Type::Opt(Box::new(T::ty()))
}
Expand All @@ -72,9 +88,6 @@ impl<T> CandidType for Vec<T>
where
T: CandidType,
{
fn id() -> TypeId {
TypeId::of::<Vec<T>>()
}
fn _ty() -> Type {
Type::Vec(Box::new(T::ty()))
}
Expand All @@ -94,9 +107,6 @@ impl<T> CandidType for [T]
where
T: CandidType,
{
fn id() -> TypeId {
TypeId::of::<[T]>()
}
fn _ty() -> Type {
Type::Vec(Box::new(T::ty()))
}
Expand All @@ -112,13 +122,76 @@ where
}
}

macro_rules! map_impl {
($ty:ident < K $(: $kbound1:ident $(+ $kbound2:ident)*)*, V $(, $typaram:ident : $bound:ident)* >) => {
impl<K, V $(, $typaram)*> CandidType for $ty<K, V $(, $typaram)*>
where
K: CandidType $(+ $kbound1 $(+ $kbound2)*)*,
V: CandidType,
$($typaram: $bound,)*
{
fn _ty() -> Type {
let tuple = Type::Record(vec![
Field {
id: Label::Id(0),
ty: K::ty(),
},
Field {
id: Label::Id(1),
ty: V::ty(),
},
]);
Type::Vec(Box::new(tuple))
}
fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
where
S: Serializer,
{
let mut ser = serializer.serialize_vec(self.len())?;
for e in self.iter() {
Compound::serialize_element(&mut ser, &e)?;
}
Ok(())
}
}
}
}
macro_rules! set_impl {
($ty:ident < K $(: $kbound1:ident $(+ $kbound2:ident)*)* $(, $typaram:ident : $bound:ident)* >) => {
impl<K $(, $typaram)*> CandidType for $ty<K $(, $typaram)*>
where
K: CandidType $(+ $kbound1 $(+ $kbound2)*)*,
$($typaram: $bound,)*
{
fn _ty() -> Type {
Type::Vec(Box::new(K::ty()))
}
fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
where
S: Serializer,
{
let mut ser = serializer.serialize_vec(self.len())?;
for e in self.iter() {
Compound::serialize_element(&mut ser, &e)?;
}
Ok(())
}
}
}
}
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::hash::{BuildHasher, Hash};
map_impl!(BTreeMap<K: Ord, V>);
set_impl!(BTreeSet<K: Ord>);
map_impl!(HashMap<K: Eq + Hash, V, H: BuildHasher>);
set_impl!(HashSet<K: Eq + Hash, H: BuildHasher>);

macro_rules! array_impls {
($($len:tt)+) => {
$(
impl<T> CandidType for [T; $len]
where T: CandidType,
{
fn id() -> TypeId { TypeId::of::<[T; $len]>() }
fn _ty() -> Type { Type::Vec(Box::new(T::ty())) }
fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
where S: Serializer,
Expand Down Expand Up @@ -146,9 +219,6 @@ where
T: CandidType,
E: CandidType,
{
fn id() -> TypeId {
TypeId::of::<Result<T, E>>()
}
fn _ty() -> Type {
Type::Variant(vec![
// Make sure the field id is sorted by idl_hash
Expand Down Expand Up @@ -183,9 +253,6 @@ impl<T> CandidType for Box<T>
where
T: ?Sized + CandidType,
{
fn id() -> TypeId {
TypeId::of::<Box<T>>()
}
fn _ty() -> Type {
T::ty()
}
Expand Down Expand Up @@ -222,7 +289,6 @@ macro_rules! tuple_impls {
where
$($name: CandidType,)+
{
fn id() -> TypeId { TypeId::of::<($($name,)+)>() }
fn _ty() -> Type {
Type::Record(vec![
$(Field{ id: Label::Id($n), ty: $name::ty() },)+
Expand Down Expand Up @@ -262,10 +328,6 @@ tuple_impls! {
}

impl CandidType for std::time::SystemTime {
fn id() -> TypeId {
TypeId::of::<std::time::SystemTime>()
}

fn _ty() -> Type {
Type::Record(vec![
Field {
Expand All @@ -278,7 +340,6 @@ impl CandidType for std::time::SystemTime {
},
])
}

fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
where
S: Serializer,
Expand All @@ -301,10 +362,6 @@ impl CandidType for std::time::SystemTime {
}

impl CandidType for std::time::Duration {
fn id() -> TypeId {
TypeId::of::<std::time::Duration>()
}

fn _ty() -> Type {
Type::Record(vec![
Field {
Expand All @@ -317,7 +374,6 @@ impl CandidType for std::time::Duration {
},
])
}

fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
where
S: Serializer,
Expand Down
4 changes: 3 additions & 1 deletion rust/candid/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ pub trait CandidType {
t
}
}
fn id() -> TypeId;
fn id() -> TypeId {
TypeId::of::<Self>()
}
fn _ty() -> Type;
// only serialize the value encoding
fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
Expand Down
16 changes: 16 additions & 0 deletions rust/candid/tests/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@ fn test_fixed_number() {
all_check(42u32, "4449444c0001792a000000");
all_check(42u64, "4449444c0001782a00000000000000");
all_check(42usize, "4449444c0001782a00000000000000");
all_check(42u128, "4449444c00017d2a");
all_check(42i8, "4449444c0001772a");
all_check(42i16, "4449444c0001762a00");
all_check(42i32, "4449444c0001752a000000");
all_check(42i64, "4449444c0001742a00000000000000");
all_check(-42i64, "4449444c000174d6ffffffffffffff");
all_check(-42isize, "4449444c000174d6ffffffffffffff");
all_check(42i128, "4449444c00017c2a");
}

#[test]
Expand Down Expand Up @@ -349,6 +351,20 @@ fn test_vector() {
all_check(vec![(); 1000], "4449444c016d7f0100e807");
}

#[test]
fn test_collection() {
use std::collections::{BTreeMap, BTreeSet, HashMap};
let map: HashMap<_, _> = vec![("a".to_string(), 1)].into_iter().collect();
all_check(map, "4449444c026d016c0200710175010001016101000000");
let bmap: BTreeMap<_, _> = vec![(1, 101), (2, 102), (3, 103)].into_iter().collect();
all_check(
bmap,
"4449444c026d016c0200750175010003010000006500000002000000660000000300000067000000",
);
let bset: BTreeSet<_> = vec![1, 2, 3].into_iter().collect();
all_check(bset, "4449444c016d75010003010000000200000003000000");
}

#[test]
fn test_tuple() {
all_check(
Expand Down

0 comments on commit e54d3e4

Please sign in to comment.