Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#15] Request Parameter 자동 바인딩 구현 #132

Merged
merged 9 commits into from
Aug 26, 2024
4 changes: 4 additions & 0 deletions rupring/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -618,3 +618,7 @@ impl<T: IModule + Clone + Copy + Sync + Send + 'static> RupringFactory<T> {

#[cfg(test)]
mod test_proc_macro;

pub use anyhow;
pub use serde;
pub use serde_json;
222 changes: 222 additions & 0 deletions rupring/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ pub struct Request {
pub(crate) di_context: Arc<crate::DIContext>,
}

pub trait BindFromRequest {
fn bind(&mut self, request: Request) -> anyhow::Result<Self>
where
Self: Sized;
}

impl UnwindSafe for Request {}

impl Request {
Expand All @@ -21,6 +27,222 @@ impl Request {
}
}

#[derive(Debug, Clone)]
pub struct QueryString(pub Vec<String>);

pub trait QueryStringDeserializer<T>: Sized {
type Error;

fn deserialize_query_string(&self) -> Result<T, Self::Error>;
}

impl<T> QueryStringDeserializer<Option<T>> for QueryString
where
QueryString: QueryStringDeserializer<T>,
{
type Error = ();

fn deserialize_query_string(&self) -> Result<Option<T>, Self::Error> {
let result = Self::deserialize_query_string(self);
match result {
Ok(v) => Ok(Some(v)),
Err(_) => Ok(None),
}
}
}

impl QueryStringDeserializer<i8> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<i8, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<i8>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<i16> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<i16, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<i16>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<i32> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<i32, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<i32>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<i64> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<i64, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<i64>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<i128> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<i128, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<i128>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<isize> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<isize, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<isize>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<u8> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<u8, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<u8>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<u16> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<u16, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<u16>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<u32> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<u32, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<u32>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<u64> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<u64, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<u64>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<u128> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<u128, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<u128>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<usize> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<usize, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<usize>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<f32> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<f32, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<f32>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<f64> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<f64, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<f64>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<bool> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<bool, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<bool>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<String> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<String, Self::Error> {
if let Some(e) = self.0.get(0) {
Ok(e.clone())
} else {
Err(())
}
}
}

#[derive(Debug, Clone)]
pub struct ParamString(pub String);

Expand Down
92 changes: 92 additions & 0 deletions rupring_macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,13 +633,24 @@ pub fn derive_rupring_doc(item: TokenStream) -> TokenStream {
code += format!(r#"query_parameters: vec![],"#).as_str();
code += "};";

let mut define_struct_for_json = "".to_string();
define_struct_for_json +=
format!(r#"#[derive(rupring::serde::Serialize, rupring::serde::Deserialize)]"#).as_str();
define_struct_for_json += format!(r#"pub struct {struct_name}__JSON {{"#).as_str();

let mut json_field_names = vec![];
let mut path_field_names = vec![];
let mut query_field_names = vec![];

for field in ast.fields.iter() {
let mut description = "".to_string();
let mut example = r#""""#.to_string();

let mut field_name = field.ident.as_ref().unwrap().to_string();
let mut field_type = field.ty.to_token_stream().to_string().replace(" ", "");

let original_field_name = field_name.clone();

let attributes = field.attrs.clone();

let mut is_required = true;
Expand Down Expand Up @@ -754,6 +765,8 @@ pub fn derive_rupring_doc(item: TokenStream) -> TokenStream {
}

if is_path_parameter {
path_field_names.push((original_field_name.clone(), field_type.clone()));

code += format!(
r#"swagger_definition.path_parameters.push(rupring::swagger::json::SwaggerParameter {{
name: "{field_name}".to_string(),
Expand All @@ -773,6 +786,8 @@ pub fn derive_rupring_doc(item: TokenStream) -> TokenStream {
}

if is_query_parameter {
query_field_names.push((original_field_name.clone(), field_type.clone()));

code += format!(
r#"swagger_definition.query_parameters.push(rupring::swagger::json::SwaggerParameter {{
name: "{field_name}".to_string(),
Expand All @@ -791,6 +806,15 @@ pub fn derive_rupring_doc(item: TokenStream) -> TokenStream {
continue;
}

json_field_names.push((original_field_name.clone(), field_type.clone()));

define_struct_for_json += format!(
r#"
pub {original_field_name}: {field_type},
"#
)
.as_str();

// Body 파라미터 생성 구현
code += format!(r#"let property_of_type = {field_type}::to_swagger_definition(context);"#)
.as_str();
Expand Down Expand Up @@ -831,11 +855,79 @@ pub fn derive_rupring_doc(item: TokenStream) -> TokenStream {
.as_str();
}

define_struct_for_json += format!(r#"}}"#).as_str();

code += "rupring::swagger::macros::SwaggerDefinitionNode::Object(swagger_definition)";

code += "}";

code += "}";

code += define_struct_for_json.as_str();

let mut request_bind_code = "".to_string();
request_bind_code +=
format!(r#"impl rupring::request::BindFromRequest for {struct_name} {{"#).as_str();

request_bind_code +=
"fn bind(&mut self, request: rupring::request::Request) -> rupring::anyhow::Result<Self> {";
request_bind_code += "use rupring::request::ParamStringDeserializer;";
request_bind_code += "use rupring::request::QueryStringDeserializer;";

request_bind_code += format!("let mut json_bound = rupring::serde_json::from_str::<{struct_name}__JSON>(request.body.as_str()).unwrap();").as_str();

request_bind_code += format!("let bound = {struct_name} {{").as_str();

for (field_name, _) in json_field_names {
request_bind_code += format!("{field_name}: json_bound.{field_name},").as_str();
}

for (field_name, field_type) in path_field_names {
request_bind_code += format!(
r#"{field_name}: {{
let param = rupring::request::ParamString(
request.path_parameters["{field_name}"].clone()
);
let deserialized: {field_type} = match param.deserialize() {{
Ok(v) => v,
Err(_) => return Err(rupring::anyhow::anyhow!("invalid parameter: {field_name}")),
}};
deserialized
}}
"#
)
.as_str();
}

for (field_name, field_type) in query_field_names {
request_bind_code += format!(
r#"{field_name}: {{
let query = rupring::request::QueryString(
request.query_parameters["{field_name}"].clone()
);
let deserialized: {field_type} = match query.deserialize_query_string() {{
Ok(v) => v,
Err(_) => return Err(rupring::anyhow::anyhow!("invalid parameter: {field_name}")),
}};
deserialized
}},
"#
)
.as_str();
}

request_bind_code += format!("}};").as_str();

request_bind_code += "Ok(bound)";
request_bind_code += "}";

request_bind_code += format!(r#"}}"#).as_str();

code += request_bind_code.as_str();

return TokenStream::from_str(code.as_str()).unwrap();
}
Loading