Skip to content

Commit

Permalink
fix: replace group_by with into_group_map_by and impl message for Box (
Browse files Browse the repository at this point in the history
  • Loading branch information
LYF1999 authored Aug 29, 2022
1 parent 94af82b commit cf8220c
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 33 deletions.
4 changes: 2 additions & 2 deletions pilota-build/src/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ where

if let Some(adjust) = adjust {
if adjust.boxed() {
ty = quote::quote! { Box<#ty> }
ty = quote::quote! { ::std::boxed::Box<#ty> }
}
}

if f.is_optional() {
ty = quote::quote! { Option<#ty> }
ty = quote::quote! { ::std::option::Option<#ty> }
}

let attrs = adjust.iter().flat_map(|a| a.attrs());
Expand Down
43 changes: 19 additions & 24 deletions pilota-build/src/codegen/pkg_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,27 @@ pub struct PkgNode {
}

fn from_pkgs(base_path: &[Symbol], pkgs: &[ItemPath]) -> Arc<[PkgNode]> {
let groups = pkgs.iter().group_by(|p| p.first().unwrap());
let groups = groups.into_iter();
let groups = pkgs.iter().into_group_map_by(|p| p.first().unwrap());

Arc::from(
groups
Arc::from_iter(groups.into_iter().map(|(k, v)| {
let path = base_path
.iter()
.chain(Some(k).into_iter())
.cloned()
.collect::<Vec<_>>();

let pkgs = v
.into_iter()
.map(|(k, v)| {
let path = base_path
.iter()
.chain(Some(k).into_iter())
.cloned()
.collect::<Vec<_>>();

let pkgs = v
.filter(|p| p.len() > 1)
.map(|p| ItemPath::from(&p[1..]))
.collect::<Vec<_>>();

let children = from_pkgs(&path, &pkgs);
PkgNode {
path: ItemPath::from(path),
children,
}
})
.collect::<Vec<_>>(),
)
.filter(|p| p.len() > 1)
.map(|p| ItemPath::from(&p[1..]))
.collect::<Vec<_>>();

let children = from_pkgs(&path, &pkgs);
PkgNode {
path: ItemPath::from(path),
children,
}
}))
}

impl PkgNode {
Expand Down
3 changes: 2 additions & 1 deletion pilota-build/src/codegen/thrift/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ impl ThriftBackend {
}
}

#[inline]
fn field_is_box(&self, f: &Field) -> bool {
match self.adjust(f.did) {
Some(a) => a.boxed(),
Expand Down Expand Up @@ -241,7 +242,7 @@ impl ThriftBackend {
let mut read_field = self.codegen_decode_ty(helper, &f.ty);
let field_id = f.id as i16;
if self.field_is_box(f) {
read_field = quote! {::std::boxed::Box::new(read_field) };
read_field = quote! {::std::boxed::Box::new(#read_field) };
};
let skip = helper.codegen_skip_ttype(quote! { ttype });

Expand Down
5 changes: 5 additions & 0 deletions pilota-build/src/middle/adjust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,27 @@ pub struct Adjust {
}

impl Adjust {
#[inline]
pub fn set_boxed(&mut self) {
self.boxed = true
}

#[inline]
pub fn boxed(&self) -> bool {
self.boxed
}

#[inline]
pub fn attrs(&self) -> &Vec<syn::Attribute> {
&self.attrs
}

#[inline]
pub fn add_attrs(&mut self, attrs: &[syn::Attribute]) {
self.attrs.extend_from_slice(attrs)
}

#[inline]
pub fn add_lifetime(&mut self, lifetime: syn::Lifetime) {
self.lifetimes.push(lifetime)
}
Expand Down
2 changes: 1 addition & 1 deletion pilota-build/src/plugin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl Plugin for BoxedPlugin {
s.fields.iter().for_each(|f| {
if let ty::Path(p) = &f.ty.kind {
if cx.type_graph().is_nested(p.did, def_id) {
cx.with_adjust(def_id, |adj| adj.set_boxed())
cx.with_adjust(f.did, |adj| adj.set_boxed())
}
}
})
Expand Down
4 changes: 2 additions & 2 deletions pilota-build/test_data/protobuf/nested_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ pub mod nested_message {
#[derive(PartialOrd, Hash, Eq, Ord, :: prost :: Message, Clone, PartialEq)]
pub struct T2 {
#[prost(message, tag = "1", optional)]
pub t3: Option<Tt3>,
pub t3: ::std::option::Option<Tt3>,
}
}
#[derive(PartialOrd, Hash, Eq, Ord, :: prost :: Message, Clone, PartialEq)]
pub struct Tt1 {
#[prost(message, tag = "1", optional)]
pub t2: Option<T2::T2>,
pub t2: ::std::option::Option<T2::T2>,
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions pilota-build/test_data/thrift/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub mod normal {
pub mod normal {
#[derive(PartialOrd, Hash, Eq, Ord, Debug, Default, Clone, PartialEq)]
pub struct A {
pub a: Option<i32>,
pub a: ::std::option::Option<i32>,
}
#[::async_trait::async_trait]
impl ::pilota::thrift::Message for A {
Expand Down Expand Up @@ -114,7 +114,7 @@ pub mod normal {
}
#[derive(PartialOrd, Hash, Eq, Ord, Debug, Default, Clone, PartialEq)]
pub struct B {
pub a: Option<A>,
pub a: ::std::option::Option<A>,
}
#[::async_trait::async_trait]
impl ::pilota::thrift::Message for B {
Expand Down
120 changes: 120 additions & 0 deletions pilota-build/test_data/thrift/recursive_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
pub mod recursive_type {
#![allow(
unused_variables,
dead_code,
missing_docs,
clippy::unused_unit,
clippy::needless_borrow,
unused_mut
)]
pub mod recursive_type {
#[derive(PartialOrd, Hash, Eq, Ord, Debug, Default, Clone, PartialEq)]
pub struct A {
pub a: ::std::option::Option<::std::boxed::Box<A>>,
}
#[::async_trait::async_trait]
impl ::pilota::thrift::Message for A {
fn encode<T: ::pilota::thrift::TOutputProtocol>(
&self,
protocol: &mut T,
) -> ::std::result::Result<(), ::pilota::thrift::Error> {
let struct_ident = ::pilota::thrift::TStructIdentifier { name: "A" };
protocol.write_struct_begin(&struct_ident)?;
if let Some(value) = self.a.as_ref() {
let field = ::pilota::thrift::TFieldIdentifier {
name: Some("a"),
field_type: ::pilota::thrift::TType::Struct,
id: Some(1i16),
};
protocol.write_field_begin(&field)?;
::pilota::thrift::Message::encode(value, protocol)?;
protocol.write_field_end()?;
};
protocol.write_field_stop()?;
protocol.write_struct_end()?;
Ok(())
}
fn decode<T: ::pilota::thrift::TInputProtocol>(
protocol: &mut T,
) -> ::std::result::Result<Self, ::pilota::thrift::Error> {
let mut a = None;
protocol.read_struct_begin()?;
loop {
let field_ident = protocol.read_field_begin()?;
let ttype = field_ident.field_type;
if ttype == ::pilota::thrift::TType::Stop {
break;
}
let field_id = field_ident.id;
match field_id {
Some(1i16) => {
if ttype == ::pilota::thrift::TType::Struct {
a = Some(::std::boxed::Box::new(
::pilota::thrift::Message::decode(protocol)?,
));
} else {
protocol.skip(ttype)?;
}
}
_ => {
protocol.skip(ttype)?;
}
}
protocol.read_field_end()?;
}
protocol.read_struct_end()?;
let data = Self { a };
Ok(data)
}
async fn decode_async<C: ::tokio::io::AsyncRead + Unpin + Send>(
protocol: &mut ::pilota::thrift::TAsyncBinaryProtocol<C>,
) -> ::std::result::Result<Self, ::pilota::thrift::Error> {
let mut a = None;
protocol.read_struct_begin().await?;
loop {
let field_ident = protocol.read_field_begin().await?;
let ttype = field_ident.field_type;
if ttype == ::pilota::thrift::TType::Stop {
break;
}
let field_id = field_ident.id;
match field_id {
Some(1i16) => {
if ttype == ::pilota::thrift::TType::Struct {
a = Some(::std::boxed::Box::new(
::pilota::thrift::Message::decode_async(protocol).await?,
));
} else {
protocol.skip(ttype).await?;
}
}
_ => {
protocol.skip(ttype).await?;
}
}
protocol.read_field_end().await?;
}
protocol.read_struct_end().await?;
let data = Self { a };
Ok(data)
}
}
impl ::pilota::thrift::Size for A {
fn size<T: ::pilota::thrift::TLengthProtocol>(&self, protocol: &T) -> usize {
protocol.write_struct_begin_len(&::pilota::thrift::TStructIdentifier { name: "A" })
+ if let Some(value) = self.a.as_ref() {
protocol.write_field_begin_len(&::pilota::thrift::TFieldIdentifier {
name: Some("a"),
field_type: ::pilota::thrift::TType::Struct,
id: Some(1i16),
}) + ::pilota::thrift::Size::size(value, protocol)
+ protocol.write_field_end_len()
} else {
0
}
+ protocol.write_field_stop_len()
+ protocol.write_struct_end_len()
}
}
}
}
4 changes: 4 additions & 0 deletions pilota-build/test_data/thrift/recursive_type.thrift
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

struct A {
1: optional A a,
}
26 changes: 25 additions & 1 deletion pilota/src/thrift/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub mod binary;
pub mod error;
pub mod rw_ext;

use std::sync::Arc;
use std::{ops::Deref, sync::Arc};

use bytes::{Buf, BufMut};
pub use error::*;
Expand All @@ -27,6 +27,24 @@ pub trait Message: Sized + Send {
R: AsyncRead + Unpin + Send;
}

#[async_trait::async_trait]
impl<M: Message> Message for Box<M> {
fn encode<T: TOutputProtocol>(&self, protocol: &mut T) -> Result<(), Error> {
self.deref().encode(protocol)
}

fn decode<T: TInputProtocol>(protocol: &mut T) -> Result<Self, Error> {
Ok(Box::new(M::decode(protocol)?))
}

async fn decode_async<R>(protocol: &mut TAsyncBinaryProtocol<R>) -> Result<Self, Error>
where
R: AsyncRead + Unpin + Send,
{
Ok(Box::new(M::decode_async(protocol).await?))
}
}

#[async_trait::async_trait]
pub trait EntryMessage: Sized + Send {
fn encode<T: TOutputProtocol>(&self, protocol: &mut T) -> Result<(), Error>;
Expand Down Expand Up @@ -498,6 +516,12 @@ pub trait Size {
fn size<T: TLengthProtocol>(&self, protocol: &T) -> usize;
}

impl<M: Size> Size for Box<M> {
fn size<T: TLengthProtocol>(&self, protocol: &T) -> usize {
self.deref().size(protocol)
}
}

#[async_trait::async_trait]
impl<Message> EntryMessage for Arc<Message>
where
Expand Down

0 comments on commit cf8220c

Please sign in to comment.