diff --git a/codegen/src/main.rs b/codegen/src/main.rs index 4dfd218..d8987fa 100644 --- a/codegen/src/main.rs +++ b/codegen/src/main.rs @@ -15,13 +15,13 @@ struct Args { long, help = "root folder for all spec, will scan all specs in the folder recursively" )] - input: Option, + input: Option, #[arg( long, help = "root folder for all spec, will scan all specs in the folder recursively, deprecated, use --input instead" )] - spec_folder: Option, + spec_folder: Option, #[arg(short, long, default_value = "rs_serde")] codegen: String, @@ -32,22 +32,23 @@ struct Args { default_value = "examples/spec/", help = "output path, if input is folder, then output must be folder" )] - output: std::path::PathBuf, + output: PathBuf, } fn main() -> anyhow::Result<()> { let args = Args::parse(); + let input = args.input.or(args.spec_folder).unwrap(); + let codegen: Box = match args.codegen.as_str() { - "rs_serde" => Box::new(RsSerde::default()), - "java_jackson" => Box::new(JavaJackson::default()), - "swift_codable" => Box::new(SwiftCodable::default()), - "py_dataclass" => Box::new(PyDataclass::default()), - "swagger" => Box::new(Swagger::default()), + "rs_serde" => Box::new(RsSerde::load_from_folder(&input)?), + "java_jackson" => Box::new(JavaJackson::load_from_folder(&input)?), + "swift_codable" => Box::new(SwiftCodable::load_from_folder(&input)?), + "py_dataclass" => Box::new(PyDataclass::load_from_folder(&input)?), + "swagger" => Box::new(Swagger::load_from_folder(&input)?), _ => anyhow::bail!("unknown codegen name"), }; - let input = args.input.or(args.spec_folder).unwrap(); let output = absolute(&args.output); // create output folder diff --git a/tot_spec/src/codegen/context.rs b/tot_spec/src/codegen/context.rs index 3a30e9b..c333c38 100644 --- a/tot_spec/src/codegen/context.rs +++ b/tot_spec/src/codegen/context.rs @@ -101,6 +101,35 @@ impl Context { &self.folder_tree } + /// helper to load codegen specific config from spec_config.yaml + pub fn load_codegen_config( + &self, + key_name: &str, + ) -> anyhow::Result> { + let config_file = self.root_folder().join("spec_config.yaml"); + if !config_file.exists() { + return Ok(None); + } + + let config_content = std::fs::read_to_string(config_file) + .map_err(|_| anyhow::anyhow!("not able to read spec_config.yaml from folder"))?; + let config_value = + serde_yaml::from_str::>(&config_content)?; + let Some(codegen_value) = config_value.get("codegen") else { + return Ok(None) + }; + + assert!(codegen_value.is_object()); + + let Some(config_value) = codegen_value.as_object().unwrap().get(key_name) else { + return Ok(None) + }; + + let config = serde_json::from_value::(config_value.to_owned())?; + + Ok(Some(config)) + } + /// get a ref to definition for spec path, the spec should already loaded /// panic if path not loaded pub fn get_definition(&self, path: impl AsRef) -> anyhow::Result<&Definition> { @@ -261,7 +290,7 @@ impl Context { // skip example validate for virtual model } crate::ModelType::NewType { inner_type } => { - return self.validate_value_for_type(&value, &inner_type.0, true, spec) + return self.validate_value_for_type(&value, &inner_type.as_ref().0, true, spec) } crate::ModelType::Const { .. } => { // skip validate for const diff --git a/tot_spec/src/codegen/fixtures/specs/spec_config.yaml b/tot_spec/src/codegen/fixtures/specs/spec_config.yaml index 2d0d46a..601b5db 100644 --- a/tot_spec/src/codegen/fixtures/specs/spec_config.yaml +++ b/tot_spec/src/codegen/fixtures/specs/spec_config.yaml @@ -5,6 +5,11 @@ style: - ignore_styles/**/*.yaml codegen: + # globally overwrite type + rs_serde: + type_overwrites: + bigint: tot_spec_util::big_int::BigInt + swagger: title: "swagger test" description: "testing" diff --git a/tot_spec/src/codegen/java_jackson.rs b/tot_spec/src/codegen/java_jackson.rs index a66b960..1975b21 100644 --- a/tot_spec/src/codegen/java_jackson.rs +++ b/tot_spec/src/codegen/java_jackson.rs @@ -10,6 +10,10 @@ use std::{borrow::Cow, fmt::Write, path::PathBuf}; pub struct JavaJackson {} impl super::Codegen for JavaJackson { + fn load_from_folder(_folder: &PathBuf) -> anyhow::Result { + Ok(Self::default()) + } + fn generate_for_folder(&self, folder: &PathBuf, output: &PathBuf) -> anyhow::Result<()> { let context = Context::new_from_folder(folder)?; diff --git a/tot_spec/src/codegen/mod.rs b/tot_spec/src/codegen/mod.rs index 2f7e992..857a34b 100644 --- a/tot_spec/src/codegen/mod.rs +++ b/tot_spec/src/codegen/mod.rs @@ -10,5 +10,9 @@ pub mod swift_codable; pub mod utils; pub trait Codegen { + fn load_from_folder(folder: &PathBuf) -> anyhow::Result + where + Self: Sized; + fn generate_for_folder(&self, folder: &PathBuf, output: &PathBuf) -> anyhow::Result<()>; } diff --git a/tot_spec/src/codegen/py_dataclass.rs b/tot_spec/src/codegen/py_dataclass.rs index f27c5a8..3ad4e3f 100644 --- a/tot_spec/src/codegen/py_dataclass.rs +++ b/tot_spec/src/codegen/py_dataclass.rs @@ -9,6 +9,10 @@ use super::utils::{indent, multiline_prefix_with}; pub struct PyDataclass {} impl super::Codegen for PyDataclass { + fn load_from_folder(_folder: &PathBuf) -> anyhow::Result { + Ok(Self::default()) + } + fn generate_for_folder(&self, folder: &PathBuf, output: &PathBuf) -> anyhow::Result<()> { let context = Context::new_from_folder(folder)?; diff --git a/tot_spec/src/codegen/rs_serde.rs b/tot_spec/src/codegen/rs_serde.rs index e499351..d858c0a 100644 --- a/tot_spec/src/codegen/rs_serde.rs +++ b/tot_spec/src/codegen/rs_serde.rs @@ -1,22 +1,50 @@ +use crate::codegen::style::Case; use crate::{ codegen::utils::indent, models::Definition, ConstType, ConstValueDef, FieldDef, ModelDef, - StringOrInteger, StructDef, Type, VariantDef, + StringOrInteger, StructDef, Type, TypeReference, VariantDef, }; +use serde::{Deserialize, Serialize}; use std::path::Path; use std::{borrow::Cow, fmt::Write, path::PathBuf}; use super::context::Context; use super::{utils::folder_tree::Entry, utils::multiline_prefix_with}; -#[derive(Default)] -pub struct RsSerde {} +pub struct RsSerde { + context: Context, + config: CodegenConfig, +} + +#[derive(Default, Debug, Serialize, Deserialize)] +pub struct CodegenConfig { + type_overwrites: Option, +} + +/// overwrite types +#[derive(Debug, Serialize, Deserialize)] +pub struct TypeOverwrites { + bigint: Option, + decimal: Option, +} impl super::Codegen for RsSerde { - fn generate_for_folder(&self, folder: &PathBuf, output: &PathBuf) -> anyhow::Result<()> { + fn load_from_folder(folder: &PathBuf) -> anyhow::Result { let context = Context::new_from_folder(folder)?; + // load codegen config from spec_config.yaml file + let config = context.load_codegen_config::("rs_serde")?; + + Ok(Self { + context, + config: config.unwrap_or_default(), + }) + } + + fn generate_for_folder(&self, _folder: &PathBuf, output: &PathBuf) -> anyhow::Result<()> { + let context = &self.context; + context.folder_tree().foreach_entry_recursively(|entry| { - let outputs = render_folder(entry).unwrap(); + let outputs = self.render_folder(entry).unwrap(); for (file_relative_path, content) in outputs { let file_path = output.join(file_relative_path); println!("write output to {:?}", file_path); @@ -40,7 +68,7 @@ impl super::Codegen for RsSerde { let parent_folder = output.parent().unwrap(); std::fs::create_dir_all(parent_folder)?; - let code = render(&spec, &context)?; + let code = self.render(&spec)?; std::fs::write(&output, code).unwrap(); println!("write output to {:?}", output); @@ -50,198 +78,247 @@ impl super::Codegen for RsSerde { } } -fn render_folder(entry: &Entry) -> anyhow::Result> { - if entry.is_empty() { - // this is a leaf node, continue - return Ok(vec![]); - } +impl RsSerde { + fn render_folder(&self, entry: &Entry) -> anyhow::Result> { + if entry.is_empty() { + // this is a leaf node, continue + return Ok(vec![]); + } - let mut outputs = vec![]; - let children = entry.iter_child().collect::>(); + let mut outputs = vec![]; + let children = entry.iter_child().collect::>(); - let mut code = "".to_string(); - for child in children { - writeln!( - code, - "pub mod {};", - child.path().file_stem().unwrap().to_str().unwrap() - ) - .unwrap(); - } + let mut code = "".to_string(); + for child in children { + writeln!( + code, + "pub mod {};", + child.path().file_stem().unwrap().to_str().unwrap() + ) + .unwrap(); + } - outputs.push((entry.path().join("mod.rs"), code)); + outputs.push((entry.path().join("mod.rs"), code)); - Ok(outputs) -} + Ok(outputs) + } -fn render(spec_path: &Path, context: &Context) -> anyhow::Result { - let def = context.get_definition(spec_path)?; - let def = &def; - let mut result = String::new(); + fn render(&self, spec_path: &Path) -> anyhow::Result { + let context = &self.context; + let def = context.get_definition(spec_path)?; + let def = &def; + let mut result = String::new(); - for include in def.includes.iter() { - let include_path = context.get_include_path(&include.namespace, spec_path)?; - let relative_path = pathdiff::diff_paths(&include_path, spec_path).unwrap(); + for include in def.includes.iter() { + let include_path = context.get_include_path(&include.namespace, spec_path)?; + let relative_path = pathdiff::diff_paths(&include_path, spec_path).unwrap(); - let include_name = relative_path.file_stem().unwrap().to_str().unwrap(); + let include_name = relative_path.file_stem().unwrap().to_str().unwrap(); - let mut mod_path = "".to_string(); + let mut mod_path = "".to_string(); - let relative_path_components = relative_path.components().collect::>(); - for (idx, component) in relative_path_components.iter().enumerate() { - match component { - std::path::Component::ParentDir => { - mod_path.push_str("super::"); - } - std::path::Component::CurDir => { - // do nothing - } - std::path::Component::Normal(name) => { - if idx + 1 < relative_path_components.len() { - mod_path.push_str(&format!("{}::", name.to_str().unwrap())); - } else { - break; + let relative_path_components = relative_path.components().collect::>(); + for (idx, component) in relative_path_components.iter().enumerate() { + match component { + std::path::Component::ParentDir => { + mod_path.push_str("super::"); + } + std::path::Component::CurDir => { + // do nothing + } + std::path::Component::Normal(name) => { + if idx + 1 < relative_path_components.len() { + mod_path.push_str(&format!("{}::", name.to_str().unwrap())); + } else { + break; + } + } + std::path::Component::Prefix(_) | std::path::Component::RootDir => { + unimplemented!() } } - std::path::Component::Prefix(_) | std::path::Component::RootDir => unimplemented!(), } - } - mod_path.push_str(include_name); + mod_path.push_str(include_name); + + if let Some(rs_mod) = include.attributes.get("rs_mod") { + // rs-mod overwrite + mod_path = rs_mod.to_string() + } - if let Some(rs_mod) = include.attributes.get("rs_mod") { - // rs-mod overwrite - mod_path = rs_mod.to_string() + if mod_path.eq(&include.namespace) { + writeln!(result, "use {};", mod_path)?; + } else { + writeln!(result, "use {} as {};", mod_path, include.namespace)?; + } } - if mod_path.eq(&include.namespace) { - writeln!(result, "use {};", mod_path)?; - } else { - writeln!(result, "use {} as {};", mod_path, include.namespace)?; + if !def.includes.is_empty() { + writeln!(result, "")?; } - } - if !def.includes.is_empty() { - writeln!(result, "")?; - } + let mut model_codes = vec![]; - let mut model_codes = vec![]; + for model in def.models.iter() { + model_codes.push("".to_string()); + let model_code = model_codes.last_mut().unwrap(); - for model in def.models.iter() { - model_codes.push("".to_string()); - let model_code = model_codes.last_mut().unwrap(); + let model_name = &model.name; - let model_name = &model.name; + writeln!(model_code, "")?; + writeln!( + model_code, + "{}", + multiline_prefix_with( + model + .desc + .as_ref() + .map(|s| s.as_str()) + .unwrap_or(model_name.as_str()), + "/// " + ) + )?; - writeln!(model_code, "")?; - writeln!( - model_code, - "{}", - multiline_prefix_with( - model - .desc - .as_ref() - .map(|s| s.as_str()) - .unwrap_or(model_name.as_str()), - "/// " - ) - )?; + let mut derived = vec!["Debug", "serde::Serialize", "serde::Deserialize"]; - let mut derived = vec!["Debug", "serde::Serialize", "serde::Deserialize"]; + if let Some(extra_derived) = model.attribute("rs_extra_derive") { + derived.extend(extra_derived.split(",").map(|d| d.trim())); + } - if let Some(extra_derived) = model.attribute("rs_extra_derive") { - derived.extend(extra_derived.split(",").map(|d| d.trim())); - } + match &model.type_ { + crate::ModelType::Enum { variants } => { + let code = self.render_enum(model, &derived, variants, def)?; + writeln!(model_code, "{}", code.trim())?; + } + crate::ModelType::Struct(struct_def) => { + let code = self.render_struct(&model_name, &derived, struct_def, def)?; + writeln!(model_code, "{}", code.trim())?; + } - match &model.type_ { - crate::ModelType::Enum { variants } => { - let code = render_enum(model, &derived, variants, def)?; - writeln!(model_code, "{}", code.trim())?; - } - crate::ModelType::Struct(struct_def) => { - let code = render_struct(&model_name, &derived, struct_def, &def)?; - writeln!(model_code, "{}", code.trim())?; - } + crate::ModelType::Virtual(struct_def) => { + writeln!(model_code, "pub trait {} {{", &model.name)?; + for field in struct_def.fields.iter() { + let field_name = &field.name; + let (field_name_rs, _) = to_identifier(field_name); - crate::ModelType::Virtual(struct_def) => { - writeln!(model_code, "pub trait {} {{", &model.name)?; - for field in struct_def.fields.iter() { - let field_name = &field.name; - let (field_name_rs, _) = to_identifier(field_name); + if let Some(desc) = &field.desc { + let comment = indent(multiline_prefix_with(desc, "/// "), 1); + writeln!(model_code, "{comment}",)?; + } - if let Some(desc) = &field.desc { - let comment = indent(multiline_prefix_with(desc, "/// "), 1); - writeln!(model_code, "{comment}",)?; - } + let field_type = self.rs_type_for_field(field); - let field_type = field.rs_type(); + writeln!( + model_code, + " fn {field_name_rs}(&self) -> &{field_type};", + )?; - writeln!( - model_code, - " fn {field_name_rs}(&self) -> &{field_type};", - )?; + writeln!( + model_code, + " fn set_{field_name_rs}(&mut self, value: {field_type}) -> {field_type};", + )?; + } + writeln!(model_code, "}}")?; + } - writeln!( - model_code, - " fn set_{field_name_rs}(&mut self, value: {field_type}) -> {field_type};", - )?; + crate::ModelType::NewType { inner_type } => { + let code = self.render_new_type(model_name, &derived, inner_type)?; + writeln!(model_code, "{}", code.trim())?; + } + crate::ModelType::Const { value_type, values } => { + let code = self.render_const(&model_name, &derived, value_type, &values)?; + writeln!(model_code, "{}", code.trim())?; } - writeln!(model_code, "}}")?; } - crate::ModelType::NewType { inner_type } => { - let code = render_new_type(model_name, &derived, inner_type)?; - writeln!(model_code, "{}", code.trim())?; - } - crate::ModelType::Const { value_type, values } => { - let code = render_const(&model_name, &derived, value_type, &values)?; - writeln!(model_code, "{}", code.trim())?; + *model_code = super::utils::format_rust_code(model_code.as_str())?; + } + + for (idx, model_code) in model_codes.into_iter().enumerate() { + // prepend a new line + if idx != 0 { + writeln!(result, "")?; } + writeln!(result, "{}", model_code.trim())?; } - *model_code = super::utils::format_rust_code(model_code.as_str())?; + Ok(result) } - for (idx, model_code) in model_codes.into_iter().enumerate() { - // prepend a new line - if idx != 0 { - writeln!(result, "")?; - } - writeln!(result, "{}", model_code.trim())?; + fn render_derived(&self, derived: &[&str]) -> String { + format!( + "#[derive({})]", + derived + .iter() + .map(|d| format!("{},", d)) + .collect::>() + .join("") + ) } - Ok(result) -} + fn render_struct( + &self, + model_name: &str, + derived: &[&str], + struct_def: &StructDef, + def: &Definition, + ) -> anyhow::Result { + let mut result = "".to_string(); + let model_code = &mut result; + + { + writeln!(model_code, "{}", self.render_derived(&derived))?; + writeln!(model_code, "pub struct {model_name} {{")?; + + let mut fields = vec![]; + if let Some(virtual_name) = &struct_def.extend { + match def.get_model(&virtual_name) { + Some(model) => match &model.type_ { + crate::ModelType::Virtual(struct_def) => { + fields.extend(struct_def.fields.clone()); + } + _ => { + anyhow::bail!("model is not virtual: {}", virtual_name); + } + }, + None => anyhow::bail!("not able to find virtual model: {}", virtual_name), + } + } -fn render_derived(derived: &[&str]) -> String { - format!( - "#[derive({})]", - derived - .iter() - .map(|d| format!("{},", d)) - .collect::>() - .join("") - ) -} + fields.extend(struct_def.fields.clone()); -fn render_struct( - model_name: &str, - derived: &[&str], - struct_def: &StructDef, - def: &Definition, -) -> anyhow::Result { - let mut result = "".to_string(); - let model_code = &mut result; + let fields_def_code = self.render_fields_def(&fields)?; + writeln!(model_code, "{}", indent(fields_def_code, 1))?; - { - writeln!(model_code, "{}", render_derived(&derived))?; - writeln!(model_code, "pub struct {model_name} {{")?; + writeln!(model_code, "}}")?; + } - let mut fields = vec![]; if let Some(virtual_name) = &struct_def.extend { + writeln!(model_code, "")?; + writeln!(model_code, "impl {virtual_name} for {model_name} {{")?; match def.get_model(&virtual_name) { Some(model) => match &model.type_ { crate::ModelType::Virtual(struct_def) => { - fields.extend(struct_def.fields.clone()); + for field in struct_def.fields.iter() { + let field_name = &field.name; + let (field_name_rs, _) = to_identifier(field_name); + let field_type = self.rs_type_for_field(&field); + writeln!( + model_code, + " fn {field_name_rs}(&self) -> &{field_type} {{", + )?; + writeln!(model_code, " &self.{field_name_rs}")?; + writeln!(model_code, " }}",)?; + + writeln!( + model_code, + " fn set_{field_name_rs}(&mut self, value: {field_type}) -> {field_type} {{", + )?; + writeln!( + model_code, + " std::mem::replace(&mut self.{field_name_rs}, value)" + )?; + writeln!(model_code, " }}",)?; + } } _ => { anyhow::bail!("model is not virtual: {}", virtual_name); @@ -249,197 +326,324 @@ fn render_struct( }, None => anyhow::bail!("not able to find virtual model: {}", virtual_name), } + writeln!(model_code, "}}")?; } - fields.extend(struct_def.fields.clone()); + Ok(result) + } - let fields_def_code = render_fields_def(&fields)?; - writeln!(model_code, "{}", indent(fields_def_code, 1))?; + fn render_fields_def(&self, fields: &[FieldDef]) -> anyhow::Result { + let mut result = "".to_string(); + let code = &mut result; + for field in fields.iter() { + if let Some(desc) = &field.desc { + let comment = multiline_prefix_with(desc, "/// "); + writeln!(code, "{}", comment)?; + } - writeln!(model_code, "}}")?; - } + for attr in field.rs_attributes() { + writeln!(code, "#[{attr}]")?; + } - if let Some(virtual_name) = &struct_def.extend { - writeln!(model_code, "")?; - writeln!(model_code, "impl {virtual_name} for {model_name} {{")?; - match def.get_model(&virtual_name) { - Some(model) => match &model.type_ { - crate::ModelType::Virtual(struct_def) => { - for field in struct_def.fields.iter() { - let field_name = &field.name; - let (field_name_rs, _) = to_identifier(field_name); - let field_type = field.rs_type(); - writeln!( - model_code, - " fn {field_name_rs}(&self) -> &{field_type} {{", - )?; - writeln!(model_code, " &self.{field_name_rs}")?; - writeln!(model_code, " }}",)?; + let field_name = &field.name; + let (field_name_rs, modified) = to_identifier(field_name); - writeln!( - model_code, - " fn set_{field_name_rs}(&mut self, value: {field_type}) -> {field_type} {{", - )?; - writeln!( - model_code, - " std::mem::replace(&mut self.{field_name_rs}, value)" - )?; - writeln!(model_code, " }}",)?; - } - } - _ => { - anyhow::bail!("model is not virtual: {}", virtual_name); - } - }, - None => anyhow::bail!("not able to find virtual model: {}", virtual_name), + if modified { + writeln!(code, "#[serde(rename = \"{field_name}\")]")?; + } + writeln!( + code, + "pub {}: {},", + field_name_rs, + self.rs_type_for_field(&field) + )?; } - writeln!(model_code, "}}")?; + Ok(result) } - Ok(result) -} - -fn render_fields_def(fields: &[FieldDef]) -> anyhow::Result { - let mut result = "".to_string(); - let code = &mut result; - for field in fields.iter() { - if let Some(desc) = &field.desc { - let comment = multiline_prefix_with(desc, "/// "); - writeln!(code, "{}", comment)?; - } + fn render_enum( + &self, + model: &ModelDef, + derived: &[&str], + variants: &[VariantDef], + def: &Definition, + ) -> anyhow::Result { + let model_name = &model.name; - for attr in field.rs_attributes() { - writeln!(code, "#[{attr}]")?; - } + let mut result = "".to_string(); + let model_code = &mut result; + match model.attribute("rs_enum_variant_type").map(String::as_str) { + Some("true") => { + // create separate type for each variant + writeln!(model_code, "{}", self.render_derived(&derived))?; + writeln!( + model_code, + "#[serde(tag = \"type\", content = \"payload\")]" + )?; + writeln!(model_code, "pub enum {} {{", &model.name)?; + + for variant in variants { + let variant_name = &variant.name; + let variant_type_name = format!("{model_name}{variant_name}"); + + if let Some(desc) = &variant.desc { + let comment = multiline_prefix_with(desc, "/// "); + writeln!(model_code, "{}", indent(&comment, 1))?; + } + writeln!(model_code, " {variant_name}({variant_type_name}),",)?; + } + writeln!(model_code, "}}")?; - let field_name = &field.name; - let (field_name_rs, modified) = to_identifier(field_name); + for variant in variants { + let variant_name = &variant.name; + let variant_type_name = format!("{model_name}{variant_name}"); + + let mut code = "".to_string(); + let code = &mut code; + + if let Some(payload_type) = &variant.payload_type { + let payload_type = self.rs_type(&payload_type); + writeln!(code, "{}", self.render_derived(&derived))?; + writeln!(code, "pub struct {variant_type_name}({payload_type});")?; + } else if let Some(fields) = &variant.payload_fields { + let struct_def = StructDef { + extend: None, + fields: fields.clone(), + }; + let struct_code = + self.render_struct(&variant_type_name, &derived, &struct_def, def)?; + writeln!(code, "{}", struct_code)?; + } else { + writeln!(code, "{}", self.render_derived(&derived))?; + writeln!(code, "pub struct {variant_type_name};")?; + } - if modified { - writeln!(code, "#[serde(rename = \"{field_name}\")]")?; - } - writeln!(code, "pub {}: {},", field_name_rs, field.rs_type())?; - } - Ok(result) -} + writeln!(code, "impl Into<{model_name}> for {variant_type_name} {{")?; + writeln!(code, " fn into(self) -> {model_name} {{")?; + writeln!(code, " {model_name}::{variant_name}(self)")?; + writeln!(code, " }}")?; + writeln!(code, "}}")?; -fn render_enum( - model: &ModelDef, - derived: &[&str], - variants: &[VariantDef], - def: &Definition, -) -> anyhow::Result { - let model_name = &model.name; - - let mut result = "".to_string(); - let model_code = &mut result; - match model.attribute("rs_enum_variant_type").map(String::as_str) { - Some("true") => { - // create separate type for each variant - writeln!(model_code, "{}", render_derived(&derived))?; - writeln!( - model_code, - "#[serde(tag = \"type\", content = \"payload\")]" - )?; - writeln!(model_code, "pub enum {} {{", &model.name)?; + writeln!(model_code, "{}", code)?; + } + } + _ => { + // create separate type for each variant + writeln!(model_code, "{}", self.render_derived(&derived))?; + writeln!( + model_code, + "#[serde(tag = \"type\", content = \"payload\")]" + )?; + writeln!(model_code, "pub enum {} {{", &model.name)?; + + for variant in variants { + if let Some(desc) = &variant.desc { + let comment = multiline_prefix_with(desc, "/// "); + writeln!(model_code, "{}", indent(&comment, 1))?; + } - for variant in variants { - let variant_name = &variant.name; - let variant_type_name = format!("{model_name}{variant_name}"); + if let Some(payload_type) = &variant.payload_type { + writeln!( + model_code, + " {}({}),", + variant.name, + self.rs_type(&payload_type) + )?; + } else if let Some(fields) = &variant.payload_fields { + let fields_def_code = self.render_fields_def(&fields)?; - if let Some(desc) = &variant.desc { - let comment = multiline_prefix_with(desc, "/// "); - writeln!(model_code, "{}", indent(&comment, 1))?; + writeln!(model_code, " {} {{", variant.name,)?; + writeln!(model_code, "{}", indent(&fields_def_code, 2))?; + writeln!(model_code, " }},")?; + } else { + writeln!(model_code, " {},", variant.name,)?; + } } - writeln!(model_code, " {variant_name}({variant_type_name}),",)?; + + writeln!(model_code, "}}")?; } - writeln!(model_code, "}}")?; + } + Ok(result) + } + + fn render_new_type( + &self, + model_name: &str, + derived: &[&str], + inner_type: &Type, + ) -> anyhow::Result { + let mut result = "".to_string(); + writeln!(result, "{}", self.render_derived(derived))?; + writeln!( + result, + "pub struct {model_name}(pub {});", + self.rs_type(inner_type) + )?; + Ok(result) + } - for variant in variants { - let variant_name = &variant.name; - let variant_type_name = format!("{model_name}{variant_name}"); + fn render_const( + &self, + model_name: &str, + derived: &[&str], + value_type: &ConstType, + values: &[ConstValueDef], + ) -> anyhow::Result { + let mut code = "".to_string(); + let value_type_in_struct = rs_type_for_const(value_type); + let value_type_in_to_value = value_type_in_struct; + let value_type_in_from_value = match value_type { + ConstType::I8 => "i8", + ConstType::I16 => "i16", + ConstType::I32 => "i32", + ConstType::I64 => "i64", + // from_value able to accept &str for all lifetime + ConstType::String => "&str", + }; + // for const, we should always derive, "Copy", "Clone", "Hash", "Ord" like + let derived = extend_derived( + derived, + &[ + "Copy", + "Clone", + "Hash", + "PartialEq", + "Eq", + "PartialOrd", + "Ord", + ], + ); + + writeln!(code, "{}", self.render_derived(&derived))?; + writeln!(code, "pub struct {model_name}(pub {value_type_in_struct});")?; + + { + // generate from_value and to_value + writeln!(code, "")?; + writeln!(code, "impl {model_name} {{")?; + + let from_value = { + // from_value let mut code = "".to_string(); - let code = &mut code; - - if let Some(payload_type) = &variant.payload_type { - let payload_type = payload_type.rs_type(); - writeln!(code, "{}", render_derived(&derived))?; - writeln!(code, "pub struct {variant_type_name}({payload_type});")?; - } else if let Some(fields) = &variant.payload_fields { - let struct_def = StructDef { - extend: None, - fields: fields.clone(), - }; - let struct_code = - render_struct(&variant_type_name, &derived, &struct_def, def)?; - writeln!(code, "{}", struct_code)?; - } else { - writeln!(code, "{}", render_derived(&derived))?; - writeln!(code, "pub struct {variant_type_name};")?; + writeln!( + code, + "pub fn from_value(val: {value_type_in_from_value}) -> Option {{" + )?; + writeln!(code, " match val {{")?; + for value in values.iter() { + let value_name = rs_const_name(&value.name); + let value_literal = rs_const_literal(&value.value); + writeln!(code, " {value_literal} => Some(Self::{value_name}),")?; } + writeln!(code, " _ => None,")?; - writeln!(code, "impl Into<{model_name}> for {variant_type_name} {{")?; - writeln!(code, " fn into(self) -> {model_name} {{")?; - writeln!(code, " {model_name}::{variant_name}(self)")?; writeln!(code, " }}")?; writeln!(code, "}}")?; + code + }; - writeln!(model_code, "{}", code)?; - } + writeln!(code, "{}", indent(&from_value.trim(), 1))?; + + let to_value = { + // from_value + let mut code = "".to_string(); + writeln!(code, "pub fn to_value(self) -> {value_type_in_to_value} {{")?; + writeln!(code, " self.0")?; + writeln!(code, "}}")?; + code + }; + + writeln!(code, "{}", indent(&to_value.trim(), 1))?; + + writeln!(code, "}}")?; } - _ => { - // create separate type for each variant - writeln!(model_code, "{}", render_derived(&derived))?; + + writeln!(code, "")?; + writeln!(code, "impl {model_name} {{")?; + + for value in values.iter() { + let value_name = rs_const_name(&value.name); + let value_literal = rs_const_literal(&value.value); + if let Some(desc) = &value.desc { + let comment = indent(multiline_prefix_with(desc, "/// "), 1); + writeln!(code, "{comment}")?; + } + writeln!( - model_code, - "#[serde(tag = \"type\", content = \"payload\")]" + code, + " pub const {value_name}: {model_name} = {model_name}({value_literal});" )?; - writeln!(model_code, "pub enum {} {{", &model.name)?; + } - for variant in variants { - if let Some(desc) = &variant.desc { - let comment = multiline_prefix_with(desc, "/// "); - writeln!(model_code, "{}", indent(&comment, 1))?; - } + writeln!(code, "}}")?; + Ok(code) + } - if let Some(payload_type) = &variant.payload_type { - writeln!( - model_code, - " {}({}),", - variant.name, - payload_type.rs_type() - )?; - } else if let Some(fields) = &variant.payload_fields { - let fields_def_code = render_fields_def(&fields)?; - - writeln!(model_code, " {} {{", variant.name,)?; - writeln!(model_code, "{}", indent(&fields_def_code, 2))?; - writeln!(model_code, " }},")?; - } else { - writeln!(model_code, " {},", variant.name,)?; - } - } + fn rs_type_for_field(&self, field: &FieldDef) -> String { + let ty = field + .attribute("rs_type") + .map(|s| s.to_string()) + .unwrap_or(self.rs_type(&field.type_)); + if field.required { + ty + } else { + format!("std::option::Option<{}>", ty) + } + } - writeln!(model_code, "}}")?; + fn rs_type(&self, ty_: &Type) -> String { + match ty_ { + Type::Bool => "bool".into(), + Type::I8 => "i8".into(), + Type::I16 => "i16".into(), + Type::I32 => "i32".into(), + Type::I64 => "i64".into(), + Type::F64 => "f64".into(), + Type::Bytes => "std::vec::Vec".into(), + Type::String => "std::string::String".into(), + Type::List { item_type } => { + format!("std::vec::Vec<{}>", self.rs_type(item_type)) + } + Type::Map { value_type } => { + format!( + "std::collections::HashMap", + self.rs_type(value_type) + ) + } + Type::Reference(TypeReference { + namespace: None, + target, + }) => target.clone(), + Type::Reference(TypeReference { + namespace: Some(namespace), + target, + }) => format!("{namespace}::{target}"), + Type::Json => "serde_json::Value".to_string(), + Type::Decimal => self.decimal_type(), + Type::BigInt => self.bigint_type(), } } - Ok(result) -} -fn render_new_type( - model_name: &str, - derived: &[&str], - inner_type: &Type, -) -> anyhow::Result { - let mut result = "".to_string(); - writeln!(result, "{}", render_derived(derived))?; - writeln!( - result, - "pub struct {model_name}(pub {});", - inner_type.rs_type() - )?; - Ok(result) + fn decimal_type(&self) -> String { + self.config + .type_overwrites + .as_ref() + .map(|t| t.decimal.as_ref()) + .flatten() + .cloned() + .unwrap_or_else(|| "rust_decimal::Decimal".to_string()) + } + + fn bigint_type(&self) -> String { + self.config + .type_overwrites + .as_ref() + .map(|t| t.bigint.as_ref()) + .flatten() + .cloned() + .unwrap_or_else(|| "tot_spec_util::big_int::BigInt".to_string()) + } } fn extend_derived<'a>(derived: &[&'a str], more: &[&'a str]) -> Vec<&'a str> { @@ -454,106 +658,19 @@ fn extend_derived<'a>(derived: &[&'a str], more: &[&'a str]) -> Vec<&'a str> { derived } -fn render_const( - model_name: &str, - derived: &[&str], - value_type: &ConstType, - values: &[ConstValueDef], -) -> anyhow::Result { - let mut code = "".to_string(); - let value_type_in_struct = value_type.rs_type(); - let value_type_in_to_value = value_type_in_struct; - let value_type_in_from_value = match value_type { +fn rs_const_name(name: &str) -> String { + use convert_case::{Case, Casing}; + name.to_case(Case::UpperSnake) +} + +fn rs_type_for_const(const_type: &ConstType) -> &'static str { + match const_type { ConstType::I8 => "i8", ConstType::I16 => "i16", ConstType::I32 => "i32", ConstType::I64 => "i64", - // from_value able to accept &str for all lifetime - ConstType::String => "&str", - }; - - // for const, we should always derive, "Copy", "Clone", "Hash", "Ord" like - let derived = extend_derived( - derived, - &[ - "Copy", - "Clone", - "Hash", - "PartialEq", - "Eq", - "PartialOrd", - "Ord", - ], - ); - - writeln!(code, "{}", render_derived(&derived))?; - writeln!(code, "pub struct {model_name}(pub {value_type_in_struct});")?; - - { - // generate from_value and to_value - writeln!(code, "")?; - writeln!(code, "impl {model_name} {{")?; - - let from_value = { - // from_value - let mut code = "".to_string(); - writeln!( - code, - "pub fn from_value(val: {value_type_in_from_value}) -> Option {{" - )?; - writeln!(code, " match val {{")?; - for value in values.iter() { - let value_name = rs_const_name(&value.name); - let value_literal = rs_const_literal(&value.value); - writeln!(code, " {value_literal} => Some(Self::{value_name}),")?; - } - writeln!(code, " _ => None,")?; - - writeln!(code, " }}")?; - writeln!(code, "}}")?; - code - }; - - writeln!(code, "{}", indent(&from_value.trim(), 1))?; - - let to_value = { - // from_value - let mut code = "".to_string(); - writeln!(code, "pub fn to_value(self) -> {value_type_in_to_value} {{")?; - writeln!(code, " self.0")?; - writeln!(code, "}}")?; - code - }; - - writeln!(code, "{}", indent(&to_value.trim(), 1))?; - - writeln!(code, "}}")?; - } - - writeln!(code, "")?; - writeln!(code, "impl {model_name} {{")?; - - for value in values.iter() { - let value_name = rs_const_name(&value.name); - let value_literal = rs_const_literal(&value.value); - if let Some(desc) = &value.desc { - let comment = indent(multiline_prefix_with(desc, "/// "), 1); - writeln!(code, "{comment}")?; - } - - writeln!( - code, - " pub const {value_name}: {model_name} = {model_name}({value_literal});" - )?; + ConstType::String => "&'static str", } - - writeln!(code, "}}")?; - Ok(code) -} - -fn rs_const_name(name: &str) -> String { - use convert_case::{Case, Casing}; - name.to_case(Case::UpperSnake) } fn rs_const_literal(val: &StringOrInteger) -> String { @@ -564,30 +681,37 @@ fn rs_const_literal(val: &StringOrInteger) -> String { } fn to_identifier(name: &str) -> (Cow, bool) { - match name { + let name_snake_case = Case::Snake.convert(name); + + let result: Cow = match name_snake_case.as_ref() { "as" | "use" | "extern crate" | "break" | "const" | "continue" | "crate" | "else" | "if" | "if let" | "enum" | "extern" | "false" | "fn" | "for" | "impl" | "in" | "let" | "loop" | "match" | "mod" | "move" | "mut" | "pub" | "ref" | "return" | "Self" | "self" | "static" | "struct" | "super" | "trait" | "true" | "type" | "unsafe" | "where" | "while" | "abstract" | "alignof" | "become" | "box" | "do" | "final" | "macro" | "offsetof" | "override" | "priv" | "proc" | "pure" | "sizeof" | "typeof" - | "unsized" | "virtual" | "yield" => (format!("{name}_").into(), true), - _ => (name.into(), false), - } + | "unsized" | "virtual" | "yield" => format!("{name}_").into(), + _ => name_snake_case, + }; + + let modified = result.ne(name); + (result, modified) } #[cfg(test)] mod tests { use super::*; + use crate::codegen::Codegen; #[test] fn test_render() { fn test_def(spec: &Path, code_path: &str) { let spec = spec.strip_prefix("src/codegen/fixtures/specs/").unwrap(); - let context = - Context::new_from_folder(&PathBuf::from("src/codegen/fixtures/specs/")).unwrap(); - let rendered = super::render(spec, &context).unwrap(); + let codegen = + RsSerde::load_from_folder(&PathBuf::from("src/codegen/fixtures/specs/")).unwrap(); + + let rendered = codegen.render(spec).unwrap(); let rendered_ast = syn::parse_file(&mut rendered.clone()).unwrap(); let code = std::fs::read_to_string(code_path).unwrap(); diff --git a/tot_spec/src/codegen/style.rs b/tot_spec/src/codegen/style.rs index f8fdd13..d833461 100644 --- a/tot_spec/src/codegen/style.rs +++ b/tot_spec/src/codegen/style.rs @@ -1,6 +1,7 @@ use crate::{FieldDef, ModelDef, ModelType}; use convert_case::{Boundary, Case as ConvertCase, Casing}; use serde::{Deserialize, Serialize}; +use std::borrow::Cow; use std::path::Path; #[derive(Debug, Serialize, Deserialize)] @@ -23,15 +24,15 @@ pub enum Case { } impl Case { - fn is_case(&self, name: &str) -> bool { + pub fn is_case(&self, name: &str) -> bool { self.convert(name).eq(name) } - fn convert(&self, name: &str) -> String { + pub fn convert<'a>(&self, name: &'a str) -> Cow<'a, str> { let convert_case = match self { Case::Snake => ConvertCase::Snake, Case::Camel => ConvertCase::Camel, - Case::Unspecified => return name.to_string(), + Case::Unspecified => return Cow::Borrowed(name), }; name.with_boundaries(&[ @@ -43,6 +44,7 @@ impl Case { Boundary::LowerUpper, ]) .to_case(convert_case) + .into() } } diff --git a/tot_spec/src/codegen/swagger.rs b/tot_spec/src/codegen/swagger.rs index 98a8a3b..3c71cd8 100644 --- a/tot_spec/src/codegen/swagger.rs +++ b/tot_spec/src/codegen/swagger.rs @@ -121,6 +121,10 @@ impl Swagger { } impl Codegen for Swagger { + fn load_from_folder(_folder: &PathBuf) -> anyhow::Result { + Ok(Self::default()) + } + fn generate_for_folder(&self, folder: &PathBuf, output: &PathBuf) -> anyhow::Result<()> { // load codegen config from spec_config.yaml file let config = Swagger::load_config(&folder.join("spec_config.yaml"))?.unwrap_or_default(); diff --git a/tot_spec/src/codegen/swift_codable.rs b/tot_spec/src/codegen/swift_codable.rs index 71e2294..cd8a580 100644 --- a/tot_spec/src/codegen/swift_codable.rs +++ b/tot_spec/src/codegen/swift_codable.rs @@ -11,6 +11,10 @@ use super::utils::{indent, multiline_prefix_with}; pub struct SwiftCodable {} impl super::Codegen for SwiftCodable { + fn load_from_folder(_folder: &PathBuf) -> anyhow::Result { + Ok(Self::default()) + } + fn generate_for_folder(&self, folder: &PathBuf, output: &PathBuf) -> anyhow::Result<()> { use walkdir::WalkDir; diff --git a/tot_spec/src/models.rs b/tot_spec/src/models.rs index 30fd059..293ec5c 100644 --- a/tot_spec/src/models.rs +++ b/tot_spec/src/models.rs @@ -227,18 +227,6 @@ impl FieldDef { self.attributes.get(name) } - pub fn rs_type(&self) -> String { - let ty = self - .attribute("rs_type") - .map(|s| s.to_string()) - .unwrap_or(self.type_.rs_type()); - if self.required { - ty - } else { - format!("std::option::Option<{}>", ty) - } - } - /// returns attributes for this field pub fn rs_attributes(&self) -> Vec { vec![] @@ -262,18 +250,6 @@ pub enum ConstType { String, } -impl ConstType { - pub fn rs_type(&self) -> &'static str { - match self { - ConstType::I8 => "i8", - ConstType::I16 => "i16", - ConstType::I32 => "i32", - ConstType::I64 => "i64", - ConstType::String => "&'static str", - } - } -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TypeReference { pub namespace: Option, @@ -352,39 +328,6 @@ impl Type { } } -impl Type { - pub fn rs_type(&self) -> String { - match self { - Type::Bool => "bool".into(), - Type::I8 => "i8".into(), - Type::I16 => "i16".into(), - Type::I32 => "i32".into(), - Type::I64 => "i64".into(), - Type::F64 => "f64".into(), - Type::Bytes => "std::vec::Vec".into(), - Type::String => "std::string::String".into(), - Type::List { item_type } => format!("std::vec::Vec<{}>", item_type.rs_type()), - Type::Map { value_type } => { - format!( - "std::collections::HashMap", - value_type.rs_type() - ) - } - Type::Reference(TypeReference { - namespace: None, - target, - }) => target.clone(), - Type::Reference(TypeReference { - namespace: Some(namespace), - target, - }) => format!("{namespace}::{target}"), - Type::Json => "serde_json::Value".to_string(), - Type::Decimal => "rust_decimal::Decimal".into(), - Type::BigInt => "tot_spec_util::big_int::BigInt".into(), - } - } -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct VariantDef { pub name: String,