diff --git a/pilota-build/src/codegen/mod.rs b/pilota-build/src/codegen/mod.rs index abf30ff4..17e959e5 100644 --- a/pilota-build/src/codegen/mod.rs +++ b/pilota-build/src/codegen/mod.rs @@ -65,12 +65,12 @@ where if let Some(adjust) = adjust { if adjust.boxed() { - ty = quote::quote! { Box<#ty> } + ty = quote::quote! { ::std::boxed::Box<#ty> } } } if f.is_optional() { - ty = quote::quote! { Option<#ty> } + ty = quote::quote! { ::std::option::Option<#ty> } } let attrs = adjust.iter().flat_map(|a| a.attrs()); diff --git a/pilota-build/src/codegen/pkg_tree.rs b/pilota-build/src/codegen/pkg_tree.rs index a24a4116..f3343d65 100644 --- a/pilota-build/src/codegen/pkg_tree.rs +++ b/pilota-build/src/codegen/pkg_tree.rs @@ -11,32 +11,27 @@ pub struct PkgNode { } fn from_pkgs(base_path: &[Symbol], pkgs: &[ItemPath]) -> Arc<[PkgNode]> { - let groups = pkgs.iter().group_by(|p| p.first().unwrap()); - let groups = groups.into_iter(); + let groups = pkgs.iter().into_group_map_by(|p| p.first().unwrap()); - Arc::from( - groups + Arc::from_iter(groups.into_iter().map(|(k, v)| { + let path = base_path + .iter() + .chain(Some(k).into_iter()) + .cloned() + .collect::>(); + + let pkgs = v .into_iter() - .map(|(k, v)| { - let path = base_path - .iter() - .chain(Some(k).into_iter()) - .cloned() - .collect::>(); - - let pkgs = v - .filter(|p| p.len() > 1) - .map(|p| ItemPath::from(&p[1..])) - .collect::>(); - - let children = from_pkgs(&path, &pkgs); - PkgNode { - path: ItemPath::from(path), - children, - } - }) - .collect::>(), - ) + .filter(|p| p.len() > 1) + .map(|p| ItemPath::from(&p[1..])) + .collect::>(); + + let children = from_pkgs(&path, &pkgs); + PkgNode { + path: ItemPath::from(path), + children, + } + })) } impl PkgNode { diff --git a/pilota-build/src/codegen/thrift/mod.rs b/pilota-build/src/codegen/thrift/mod.rs index 81217dd9..b34b43d3 100644 --- a/pilota-build/src/codegen/thrift/mod.rs +++ b/pilota-build/src/codegen/thrift/mod.rs @@ -213,6 +213,7 @@ impl ThriftBackend { } } + #[inline] fn field_is_box(&self, f: &Field) -> bool { match self.adjust(f.did) { Some(a) => a.boxed(), @@ -241,7 +242,7 @@ impl ThriftBackend { let mut read_field = self.codegen_decode_ty(helper, &f.ty); let field_id = f.id as i16; if self.field_is_box(f) { - read_field = quote! {::std::boxed::Box::new(read_field) }; + read_field = quote! {::std::boxed::Box::new(#read_field) }; }; let skip = helper.codegen_skip_ttype(quote! { ttype }); diff --git a/pilota-build/src/middle/adjust.rs b/pilota-build/src/middle/adjust.rs index 0a09a757..b08873e1 100644 --- a/pilota-build/src/middle/adjust.rs +++ b/pilota-build/src/middle/adjust.rs @@ -6,22 +6,27 @@ pub struct Adjust { } impl Adjust { + #[inline] pub fn set_boxed(&mut self) { self.boxed = true } + #[inline] pub fn boxed(&self) -> bool { self.boxed } + #[inline] pub fn attrs(&self) -> &Vec { &self.attrs } + #[inline] pub fn add_attrs(&mut self, attrs: &[syn::Attribute]) { self.attrs.extend_from_slice(attrs) } + #[inline] pub fn add_lifetime(&mut self, lifetime: syn::Lifetime) { self.lifetimes.push(lifetime) } diff --git a/pilota-build/src/plugin/mod.rs b/pilota-build/src/plugin/mod.rs index c5e8b6cb..74a33f45 100644 --- a/pilota-build/src/plugin/mod.rs +++ b/pilota-build/src/plugin/mod.rs @@ -112,7 +112,7 @@ impl Plugin for BoxedPlugin { s.fields.iter().for_each(|f| { if let ty::Path(p) = &f.ty.kind { if cx.type_graph().is_nested(p.did, def_id) { - cx.with_adjust(def_id, |adj| adj.set_boxed()) + cx.with_adjust(f.did, |adj| adj.set_boxed()) } } }) diff --git a/pilota-build/test_data/protobuf/nested_message.rs b/pilota-build/test_data/protobuf/nested_message.rs index 0d35a199..424d142a 100644 --- a/pilota-build/test_data/protobuf/nested_message.rs +++ b/pilota-build/test_data/protobuf/nested_message.rs @@ -18,13 +18,13 @@ pub mod nested_message { #[derive(PartialOrd, Hash, Eq, Ord, :: prost :: Message, Clone, PartialEq)] pub struct T2 { #[prost(message, tag = "1", optional)] - pub t3: Option, + pub t3: ::std::option::Option, } } #[derive(PartialOrd, Hash, Eq, Ord, :: prost :: Message, Clone, PartialEq)] pub struct Tt1 { #[prost(message, tag = "1", optional)] - pub t2: Option, + pub t2: ::std::option::Option, } } } diff --git a/pilota-build/test_data/thrift/normal.rs b/pilota-build/test_data/thrift/normal.rs index 892a6f25..fe48834f 100644 --- a/pilota-build/test_data/thrift/normal.rs +++ b/pilota-build/test_data/thrift/normal.rs @@ -10,7 +10,7 @@ pub mod normal { pub mod normal { #[derive(PartialOrd, Hash, Eq, Ord, Debug, Default, Clone, PartialEq)] pub struct A { - pub a: Option, + pub a: ::std::option::Option, } #[::async_trait::async_trait] impl ::pilota::thrift::Message for A { @@ -114,7 +114,7 @@ pub mod normal { } #[derive(PartialOrd, Hash, Eq, Ord, Debug, Default, Clone, PartialEq)] pub struct B { - pub a: Option, + pub a: ::std::option::Option, } #[::async_trait::async_trait] impl ::pilota::thrift::Message for B { diff --git a/pilota-build/test_data/thrift/recursive_type.rs b/pilota-build/test_data/thrift/recursive_type.rs new file mode 100644 index 00000000..16e5185c --- /dev/null +++ b/pilota-build/test_data/thrift/recursive_type.rs @@ -0,0 +1,120 @@ +pub mod recursive_type { + #![allow( + unused_variables, + dead_code, + missing_docs, + clippy::unused_unit, + clippy::needless_borrow, + unused_mut + )] + pub mod recursive_type { + #[derive(PartialOrd, Hash, Eq, Ord, Debug, Default, Clone, PartialEq)] + pub struct A { + pub a: ::std::option::Option<::std::boxed::Box>, + } + #[::async_trait::async_trait] + impl ::pilota::thrift::Message for A { + fn encode( + &self, + protocol: &mut T, + ) -> ::std::result::Result<(), ::pilota::thrift::Error> { + let struct_ident = ::pilota::thrift::TStructIdentifier { name: "A" }; + protocol.write_struct_begin(&struct_ident)?; + if let Some(value) = self.a.as_ref() { + let field = ::pilota::thrift::TFieldIdentifier { + name: Some("a"), + field_type: ::pilota::thrift::TType::Struct, + id: Some(1i16), + }; + protocol.write_field_begin(&field)?; + ::pilota::thrift::Message::encode(value, protocol)?; + protocol.write_field_end()?; + }; + protocol.write_field_stop()?; + protocol.write_struct_end()?; + Ok(()) + } + fn decode( + protocol: &mut T, + ) -> ::std::result::Result { + let mut a = None; + protocol.read_struct_begin()?; + loop { + let field_ident = protocol.read_field_begin()?; + let ttype = field_ident.field_type; + if ttype == ::pilota::thrift::TType::Stop { + break; + } + let field_id = field_ident.id; + match field_id { + Some(1i16) => { + if ttype == ::pilota::thrift::TType::Struct { + a = Some(::std::boxed::Box::new( + ::pilota::thrift::Message::decode(protocol)?, + )); + } else { + protocol.skip(ttype)?; + } + } + _ => { + protocol.skip(ttype)?; + } + } + protocol.read_field_end()?; + } + protocol.read_struct_end()?; + let data = Self { a }; + Ok(data) + } + async fn decode_async( + protocol: &mut ::pilota::thrift::TAsyncBinaryProtocol, + ) -> ::std::result::Result { + let mut a = None; + protocol.read_struct_begin().await?; + loop { + let field_ident = protocol.read_field_begin().await?; + let ttype = field_ident.field_type; + if ttype == ::pilota::thrift::TType::Stop { + break; + } + let field_id = field_ident.id; + match field_id { + Some(1i16) => { + if ttype == ::pilota::thrift::TType::Struct { + a = Some(::std::boxed::Box::new( + ::pilota::thrift::Message::decode_async(protocol).await?, + )); + } else { + protocol.skip(ttype).await?; + } + } + _ => { + protocol.skip(ttype).await?; + } + } + protocol.read_field_end().await?; + } + protocol.read_struct_end().await?; + let data = Self { a }; + Ok(data) + } + } + impl ::pilota::thrift::Size for A { + fn size(&self, protocol: &T) -> usize { + protocol.write_struct_begin_len(&::pilota::thrift::TStructIdentifier { name: "A" }) + + if let Some(value) = self.a.as_ref() { + protocol.write_field_begin_len(&::pilota::thrift::TFieldIdentifier { + name: Some("a"), + field_type: ::pilota::thrift::TType::Struct, + id: Some(1i16), + }) + ::pilota::thrift::Size::size(value, protocol) + + protocol.write_field_end_len() + } else { + 0 + } + + protocol.write_field_stop_len() + + protocol.write_struct_end_len() + } + } + } +} diff --git a/pilota-build/test_data/thrift/recursive_type.thrift b/pilota-build/test_data/thrift/recursive_type.thrift new file mode 100644 index 00000000..10d74531 --- /dev/null +++ b/pilota-build/test_data/thrift/recursive_type.thrift @@ -0,0 +1,4 @@ + +struct A { + 1: optional A a, +} diff --git a/pilota/src/thrift/mod.rs b/pilota/src/thrift/mod.rs index 91f1937f..53f9d632 100644 --- a/pilota/src/thrift/mod.rs +++ b/pilota/src/thrift/mod.rs @@ -2,7 +2,7 @@ pub mod binary; pub mod error; pub mod rw_ext; -use std::sync::Arc; +use std::{ops::Deref, sync::Arc}; use bytes::{Buf, BufMut}; pub use error::*; @@ -27,6 +27,24 @@ pub trait Message: Sized + Send { R: AsyncRead + Unpin + Send; } +#[async_trait::async_trait] +impl Message for Box { + fn encode(&self, protocol: &mut T) -> Result<(), Error> { + self.deref().encode(protocol) + } + + fn decode(protocol: &mut T) -> Result { + Ok(Box::new(M::decode(protocol)?)) + } + + async fn decode_async(protocol: &mut TAsyncBinaryProtocol) -> Result + where + R: AsyncRead + Unpin + Send, + { + Ok(Box::new(M::decode_async(protocol).await?)) + } +} + #[async_trait::async_trait] pub trait EntryMessage: Sized + Send { fn encode(&self, protocol: &mut T) -> Result<(), Error>; @@ -498,6 +516,12 @@ pub trait Size { fn size(&self, protocol: &T) -> usize; } +impl Size for Box { + fn size(&self, protocol: &T) -> usize { + self.deref().size(protocol) + } +} + #[async_trait::async_trait] impl EntryMessage for Arc where