diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index f8d341445..46131e361 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -16,6 +16,7 @@ use prost_types::{ use crate::ast::{Comments, Method, Service}; use crate::extern_paths::ExternPaths; +use crate::fully_qualified_name::FullyQualifiedName; use crate::ident::{strip_enum_prefix, to_snake, to_upper_camel}; use crate::message_graph::MessageGraph; use crate::Config; @@ -159,7 +160,8 @@ impl CodeGenerator<'_> { debug!(" message: {:?}", message.name()); let message_name = message.name().to_string(); - let fq_message_name = self.fq_name(&message_name); + let fq_message_name = + FullyQualifiedName::new(&self.package, &self.type_path, &message_name); // Skip external types. if self.extern_paths.resolve_ident(&fq_message_name).is_some() { @@ -170,7 +172,7 @@ impl CodeGenerator<'_> { // of the map field entry types. The path index of the nested message types is preserved so // that comments can be retrieved. type NestedTypes = Vec<(DescriptorProto, usize)>; - type MapTypes = HashMap; + type MapTypes = HashMap; let (nested_types, map_types): (NestedTypes, MapTypes) = message .nested_type .into_iter() @@ -187,7 +189,7 @@ impl CodeGenerator<'_> { assert_eq!("key", key.name()); assert_eq!("value", value.name()); - let name = format!("{}.{}", &fq_message_name, nested_type.name()); + let name = fq_message_name.join(nested_type.name()); Either::Right((name, (key, value))) } else { Either::Left((nested_type, idx)) @@ -246,12 +248,10 @@ impl CodeGenerator<'_> { self.depth += 1; self.path.push(2); for field in &fields { + let type_name = field.descriptor.type_name.as_ref(); self.path.push(field.path_index); - match field - .descriptor - .type_name - .as_ref() - .and_then(|type_name| map_types.get(type_name)) + match type_name + .and_then(|type_name| map_types.get(&FullyQualifiedName::from_type_name(type_name))) { Some((key, value)) => self.append_map_field(&fq_message_name, field, key, value), None => self.append_field(&fq_message_name, field), @@ -302,7 +302,7 @@ impl CodeGenerator<'_> { } } - fn append_type_name(&mut self, message_name: &str, fq_message_name: &str) { + fn append_type_name(&mut self, message_name: &str, fq_message_name: &FullyQualifiedName) { self.buf.push_str(&format!( "impl {}::Name for {} {{\n", self.config.prost_path.as_deref().unwrap_or("::prost"), @@ -322,33 +322,29 @@ impl CodeGenerator<'_> { let prost_path = self.config.prost_path.as_deref().unwrap_or("::prost"); let string_path = format!("{prost_path}::alloc::string::String"); - let full_name = format!( - "{}{}{}{}{message_name}", - self.package.trim_matches('.'), - if self.package.is_empty() { "" } else { "." }, - self.type_path.join("."), - if self.type_path.is_empty() { "" } else { "." }, - ); + let full_name = FullyQualifiedName::new(&self.package, &self.type_path, message_name); let domain_name = self .config .type_name_domains .get_first(fq_message_name) .map_or("", |name| name.as_str()); + let full_name_str = full_name.as_ref().trim_start_matches('.'); self.buf.push_str(&format!( - r#"fn full_name() -> {string_path} {{ "{full_name}".into() }}"#, + r#"fn full_name() -> {string_path} {{ "{}".into() }}"#, + full_name_str, )); self.buf.push_str(&format!( - r#"fn type_url() -> {string_path} {{ "{domain_name}/{full_name}".into() }}"#, + r#"fn type_url() -> {string_path} {{ "{domain_name}/{}".into() }}"#, + full_name_str )); self.depth -= 1; self.buf.push_str("}\n"); } - fn append_type_attributes(&mut self, fq_message_name: &str) { - assert_eq!(b'.', fq_message_name.as_bytes()[0]); + fn append_type_attributes(&mut self, fq_message_name: &FullyQualifiedName) { for attribute in self.config.type_attributes.get(fq_message_name) { push_indent(self.buf, self.depth); self.buf.push_str(attribute); @@ -356,8 +352,7 @@ impl CodeGenerator<'_> { } } - fn append_message_attributes(&mut self, fq_message_name: &str) { - assert_eq!(b'.', fq_message_name.as_bytes()[0]); + fn append_message_attributes(&mut self, fq_message_name: &FullyQualifiedName) { for attribute in self.config.message_attributes.get(fq_message_name) { push_indent(self.buf, self.depth); self.buf.push_str(attribute); @@ -365,12 +360,11 @@ impl CodeGenerator<'_> { } } - fn should_skip_debug(&self, fq_message_name: &str) -> bool { - assert_eq!(b'.', fq_message_name.as_bytes()[0]); + fn should_skip_debug(&self, fq_message_name: &FullyQualifiedName) -> bool { self.config.skip_debug.get(fq_message_name).next().is_some() } - fn append_skip_debug(&mut self, fq_message_name: &str) { + fn append_skip_debug(&mut self, fq_message_name: &FullyQualifiedName) { if self.should_skip_debug(fq_message_name) { push_indent(self.buf, self.depth); self.buf.push_str("#[prost(skip_debug)]"); @@ -378,8 +372,7 @@ impl CodeGenerator<'_> { } } - fn append_enum_attributes(&mut self, fq_message_name: &str) { - assert_eq!(b'.', fq_message_name.as_bytes()[0]); + fn append_enum_attributes(&mut self, fq_message_name: &FullyQualifiedName) { for attribute in self.config.enum_attributes.get(fq_message_name) { push_indent(self.buf, self.depth); self.buf.push_str(attribute); @@ -387,8 +380,7 @@ impl CodeGenerator<'_> { } } - fn append_field_attributes(&mut self, fq_message_name: &str, field_name: &str) { - assert_eq!(b'.', fq_message_name.as_bytes()[0]); + fn append_field_attributes(&mut self, fq_message_name: &FullyQualifiedName, field_name: &str) { for attribute in self .config .field_attributes @@ -400,7 +392,7 @@ impl CodeGenerator<'_> { } } - fn append_field(&mut self, fq_message_name: &str, field: &Field) { + fn append_field(&mut self, fq_message_name: &FullyQualifiedName, field: &Field) { let type_ = field.descriptor.r#type(); let repeated = field.descriptor.label == Some(Label::Repeated as i32); let deprecated = self.deprecated(&field.descriptor); @@ -527,7 +519,7 @@ impl CodeGenerator<'_> { fn append_map_field( &mut self, - fq_message_name: &str, + fq_message_name: &FullyQualifiedName, field: &Field, key: &FieldDescriptorProto, value: &FieldDescriptorProto, @@ -575,7 +567,7 @@ impl CodeGenerator<'_> { fn append_oneof_field( &mut self, message_name: &str, - fq_message_name: &str, + fq_message_name: &FullyQualifiedName, oneof: &OneofField, ) { let type_name = format!( @@ -603,14 +595,14 @@ impl CodeGenerator<'_> { )); } - fn append_oneof(&mut self, fq_message_name: &str, oneof: &OneofField) { + fn append_oneof(&mut self, fq_message_name: &FullyQualifiedName, oneof: &OneofField) { self.path.push(8); self.path.push(oneof.path_index); self.append_doc(fq_message_name, None); self.path.pop(); self.path.pop(); - let oneof_name = format!("{}.{}", fq_message_name, oneof.descriptor.name()); + let oneof_name = fq_message_name.join(oneof.descriptor.name()); self.append_type_attributes(&oneof_name); self.append_enum_attributes(&oneof_name); self.push_indent(); @@ -692,7 +684,7 @@ impl CodeGenerator<'_> { Some(&source_info.location[idx]) } - fn append_doc(&mut self, fq_name: &str, field_name: Option<&str>) { + fn append_doc(&mut self, fq_name: &FullyQualifiedName, field_name: Option<&str>) { let append_doc = if let Some(field_name) = field_name { self.config .disable_comments @@ -715,7 +707,8 @@ impl CodeGenerator<'_> { let enum_name = to_upper_camel(proto_enum_name); let enum_values = &desc.value; - let fq_proto_enum_name = self.fq_name(proto_enum_name); + let fq_proto_enum_name = + FullyQualifiedName::new(&self.package, &self.type_path, proto_enum_name); if self .extern_paths @@ -883,8 +876,10 @@ impl CodeGenerator<'_> { let name = method.name.take().unwrap(); let input_proto_type = method.input_type.take().unwrap(); let output_proto_type = method.output_type.take().unwrap(); - let input_type = self.resolve_ident(&input_proto_type); - let output_type = self.resolve_ident(&output_proto_type); + let input_type = + self.resolve_ident(&FullyQualifiedName::from_type_name(&input_proto_type)); + let output_type = + self.resolve_ident(&FullyQualifiedName::from_type_name(&output_proto_type)); let client_streaming = method.client_streaming(); let server_streaming = method.server_streaming(); @@ -947,7 +942,11 @@ impl CodeGenerator<'_> { self.buf.push_str("}\n"); } - fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String { + fn resolve_type( + &self, + field: &FieldDescriptorProto, + fq_message_name: &FullyQualifiedName, + ) -> String { match field.r#type() { Type::Float => String::from("f32"), Type::Double => String::from("f64"), @@ -965,14 +964,13 @@ impl CodeGenerator<'_> { .unwrap_or_default() .rust_type() .to_owned(), - Type::Group | Type::Message => self.resolve_ident(field.type_name()), + Type::Group | Type::Message => { + self.resolve_ident(&FullyQualifiedName::from_type_name(field.type_name())) + } } } - fn resolve_ident(&self, pb_ident: &str) -> String { - // protoc should always give fully qualified identifiers. - assert_eq!(".", &pb_ident[..1]); - + fn resolve_ident(&self, pb_ident: &FullyQualifiedName) -> String { if let Some(proto_ident) = self.extern_paths.resolve_ident(pb_ident) { return proto_ident; } @@ -990,7 +988,7 @@ impl CodeGenerator<'_> { local_path.next(); } - let mut ident_path = pb_ident[1..].split('.'); + let mut ident_path = pb_ident.path_iterator(); let ident_type = ident_path.next_back().unwrap(); let mut ident_path = ident_path.peekable(); @@ -1028,7 +1026,7 @@ impl CodeGenerator<'_> { Type::Message => Cow::Borrowed("message"), Type::Enum => Cow::Owned(format!( "enumeration={:?}", - self.resolve_ident(field.type_name()) + self.resolve_ident(&FullyQualifiedName::from_type_name(field.type_name())) )), } } @@ -1037,7 +1035,7 @@ impl CodeGenerator<'_> { match field.r#type() { Type::Enum => Cow::Owned(format!( "enumeration({})", - self.resolve_ident(field.type_name()) + self.resolve_ident(&FullyQualifiedName::from_type_name(field.type_name())) )), _ => self.field_type_tag(field), } @@ -1066,7 +1064,7 @@ impl CodeGenerator<'_> { fn boxed( &self, field: &FieldDescriptorProto, - fq_message_name: &str, + fq_message_name: &FullyQualifiedName, oneof: Option<&str>, ) -> bool { let repeated = field.label == Some(Label::Repeated as i32); @@ -1075,13 +1073,13 @@ impl CodeGenerator<'_> { && (fd_type == Type::Message || fd_type == Type::Group) && self .message_graph - .is_nested(field.type_name(), fq_message_name) + .is_nested(field.type_name(), fq_message_name.as_ref()) { return true; } let config_path = match oneof { None => Cow::Borrowed(fq_message_name), - Some(ooname) => Cow::Owned(format!("{fq_message_name}.{ooname}")), + Some(ooname) => Cow::Owned(fq_message_name.join(ooname)), }; if self .config @@ -1108,18 +1106,6 @@ impl CodeGenerator<'_> { .as_ref() .map_or(false, FieldOptions::deprecated) } - - /// Returns the fully-qualified name, starting with a dot - fn fq_name(&self, message_name: &str) -> String { - format!( - "{}{}{}{}.{}", - if self.package.is_empty() { "" } else { "." }, - self.package.trim_matches('.'), - if self.type_path.is_empty() { "" } else { "." }, - self.type_path.join("."), - message_name, - ) - } } /// Returns `true` if the repeated field type can be packed. diff --git a/prost-build/src/config.rs b/prost-build/src/config.rs index 896726b16..098534fa4 100644 --- a/prost-build/src/config.rs +++ b/prost-build/src/config.rs @@ -1079,8 +1079,13 @@ impl Config { let mut packages = HashMap::new(); let message_graph = MessageGraph::new(requests.iter().map(|x| &x.1), self.boxed.clone()); - let extern_paths = ExternPaths::new(&self.extern_paths, self.prost_types) - .map_err(|error| Error::new(ErrorKind::InvalidInput, error))?; + let extern_paths = ExternPaths::new( + self.extern_paths + .iter() + .map(|(a, b)| (a.as_str(), b.as_str())), + self.prost_types, + ) + .map_err(|error| Error::new(ErrorKind::InvalidInput, error))?; for (request_module, request_fd) in requests { // Only record packages that have services diff --git a/prost-build/src/extern_paths.rs b/prost-build/src/extern_paths.rs index 8f6bee784..14dc1de49 100644 --- a/prost-build/src/extern_paths.rs +++ b/prost-build/src/extern_paths.rs @@ -2,7 +2,10 @@ use std::collections::{hash_map, HashMap}; use itertools::Itertools; -use crate::ident::{to_snake, to_upper_camel}; +use crate::{ + fully_qualified_name::FullyQualifiedName, + ident::{to_snake, to_upper_camel}, +}; fn validate_proto_path(path: &str) -> Result<(), String> { if path.chars().next().map(|c| c != '.').unwrap_or(true) { @@ -19,52 +22,54 @@ fn validate_proto_path(path: &str) -> Result<(), String> { #[derive(Debug)] pub struct ExternPaths { + // IMPROVEMENT: store as FullyQualifiedName and syn::Path extern_paths: HashMap, } impl ExternPaths { - pub fn new(paths: &[(String, String)], prost_types: bool) -> Result { + pub fn new<'a>( + paths: impl IntoIterator + 'a, + prost_types: bool, + ) -> Result { let mut extern_paths = ExternPaths { extern_paths: HashMap::new(), }; for (proto_path, rust_path) in paths { - extern_paths.insert(proto_path.clone(), rust_path.clone())?; + extern_paths.insert(proto_path, rust_path)?; } if prost_types { - extern_paths.insert(".google.protobuf".to_string(), "::prost_types".to_string())?; - extern_paths.insert(".google.protobuf.BoolValue".to_string(), "bool".to_string())?; + extern_paths.insert(".google.protobuf", "::prost_types")?; + extern_paths.insert(".google.protobuf.BoolValue", "bool")?; extern_paths.insert( - ".google.protobuf.BytesValue".to_string(), - "::prost::alloc::vec::Vec".to_string(), + ".google.protobuf.BytesValue", + "::prost::alloc::vec::Vec", )?; + extern_paths.insert(".google.protobuf.DoubleValue", "f64")?; + extern_paths.insert(".google.protobuf.Empty", "()")?; + extern_paths.insert(".google.protobuf.FloatValue", "f32")?; + extern_paths.insert(".google.protobuf.Int32Value", "i32")?; + extern_paths.insert(".google.protobuf.Int64Value", "i64")?; extern_paths.insert( - ".google.protobuf.DoubleValue".to_string(), - "f64".to_string(), - )?; - extern_paths.insert(".google.protobuf.Empty".to_string(), "()".to_string())?; - extern_paths.insert(".google.protobuf.FloatValue".to_string(), "f32".to_string())?; - extern_paths.insert(".google.protobuf.Int32Value".to_string(), "i32".to_string())?; - extern_paths.insert(".google.protobuf.Int64Value".to_string(), "i64".to_string())?; - extern_paths.insert( - ".google.protobuf.StringValue".to_string(), - "::prost::alloc::string::String".to_string(), - )?; - extern_paths.insert( - ".google.protobuf.UInt32Value".to_string(), - "u32".to_string(), - )?; - extern_paths.insert( - ".google.protobuf.UInt64Value".to_string(), - "u64".to_string(), + ".google.protobuf.StringValue", + "::prost::alloc::string::String", )?; + extern_paths.insert(".google.protobuf.UInt32Value", "u32")?; + extern_paths.insert(".google.protobuf.UInt64Value", "u64")?; } Ok(extern_paths) } - fn insert(&mut self, proto_path: String, rust_path: String) -> Result<(), String> { + fn insert( + &mut self, + proto_path: impl Into, + rust_path: impl Into, + ) -> Result<(), String> { + let proto_path = proto_path.into(); + let rust_path = rust_path.into(); + validate_proto_path(&proto_path)?; match self.extern_paths.entry(proto_path) { hash_map::Entry::Occupied(occupied) => { @@ -78,10 +83,8 @@ impl ExternPaths { Ok(()) } - pub fn resolve_ident(&self, pb_ident: &str) -> Option { - // protoc should always give fully qualified identifiers. - assert_eq!(".", &pb_ident[..1]); - + pub fn resolve_ident(&self, pb_ident: &FullyQualifiedName) -> Option { + let pb_ident = pb_ident.as_ref(); if let Some(rust_path) = self.extern_paths.get(pb_ident) { return Some(rust_path.clone()); } @@ -124,19 +127,22 @@ mod tests { #[test] fn test_extern_paths() { let paths = ExternPaths::new( - &[ - (".foo".to_string(), "::foo1".to_string()), - (".foo.bar".to_string(), "::foo2".to_string()), - (".foo.baz".to_string(), "::foo3".to_string()), - (".foo.Fuzz".to_string(), "::foo4::Fuzz".to_string()), - (".a.b.c.d.e.f".to_string(), "::abc::def".to_string()), + [ + (".foo", "::foo1"), + (".foo.bar", "::foo2"), + (".foo.baz", "::foo3"), + (".foo.Fuzz", "::foo4::Fuzz"), + (".a.b.c.d.e.f", "::abc::def"), ], false, ) .unwrap(); let case = |proto_ident: &str, resolved_ident: &str| { - assert_eq!(paths.resolve_ident(proto_ident).unwrap(), resolved_ident); + assert_eq!( + paths.resolve_ident(&proto_ident.into()).unwrap(), + resolved_ident + ); }; case(".foo", "::foo1"); @@ -150,17 +156,20 @@ mod tests { case(".a.b.c.d.e.f", "::abc::def"); case(".a.b.c.d.e.f.g.FooBar.Baz", "::abc::def::g::foo_bar::Baz"); - assert!(paths.resolve_ident(".a").is_none()); - assert!(paths.resolve_ident(".a.b").is_none()); - assert!(paths.resolve_ident(".a.c").is_none()); + assert!(paths.resolve_ident(&".a".into()).is_none()); + assert!(paths.resolve_ident(&".a.b".into()).is_none()); + assert!(paths.resolve_ident(&".a.c".into()).is_none()); } #[test] fn test_well_known_types() { - let paths = ExternPaths::new(&[], true).unwrap(); + let paths = ExternPaths::new([], true).unwrap(); let case = |proto_ident: &str, resolved_ident: &str| { - assert_eq!(paths.resolve_ident(proto_ident).unwrap(), resolved_ident); + assert_eq!( + paths.resolve_ident(&proto_ident.into()).unwrap(), + resolved_ident + ); }; case(".google.protobuf.Value", "::prost_types::Value"); @@ -170,8 +179,8 @@ mod tests { #[test] fn test_error_fully_qualified() { - let paths = [("foo".to_string(), "bar".to_string())]; - let err = ExternPaths::new(&paths, false).unwrap_err(); + let paths = [("foo", "bar")]; + let err = ExternPaths::new(paths, false).unwrap_err(); assert_eq!( err.to_string(), "Protobuf paths must be fully qualified (begin with a leading '.'): foo" @@ -180,8 +189,8 @@ mod tests { #[test] fn test_error_invalid_path() { - let paths = [(".foo.".to_string(), "bar".to_string())]; - let err = ExternPaths::new(&paths, false).unwrap_err(); + let paths = [(".foo.", "bar")]; + let err = ExternPaths::new(paths, false).unwrap_err(); assert_eq!( err.to_string(), "invalid fully-qualified Protobuf path: .foo." @@ -190,11 +199,8 @@ mod tests { #[test] fn test_error_duplicate() { - let paths = [ - (".foo".to_string(), "bar".to_string()), - (".foo".to_string(), "bar".to_string()), - ]; - let err = ExternPaths::new(&paths, false).unwrap_err(); + let paths = [(".foo", "bar"), (".foo", "bar")]; + let err = ExternPaths::new(paths, false).unwrap_err(); assert_eq!(err.to_string(), "duplicate extern Protobuf path: .foo") } } diff --git a/prost-build/src/fully_qualified_name.rs b/prost-build/src/fully_qualified_name.rs new file mode 100644 index 000000000..4f636122a --- /dev/null +++ b/prost-build/src/fully_qualified_name.rs @@ -0,0 +1,52 @@ +use itertools::Itertools; + +// Invariant: should always begin with a '.' (dot) +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub struct FullyQualifiedName(String); + +impl FullyQualifiedName { + pub fn new(package_string: &str, type_path: &[impl AsRef], message_name: &str) -> Self { + Self(format!( + "{}{}{}{}{}{}", + if package_string.is_empty() { "" } else { "." }, + package_string.trim_matches('.'), + if type_path.is_empty() { "" } else { "." }, + type_path + .iter() + .map(AsRef::as_ref) + .map(|type_path_str| type_path_str.trim_start_matches('.')) + .join("."), + if message_name.is_empty() { "" } else { "." }, + message_name, + )) + } + + pub fn from_type_name(type_name: &str) -> Self { + Self::new("", &[type_name], "") + } + + pub fn path_iterator(&self) -> impl DoubleEndedIterator { + self.0[1..].split('.') + } + + pub fn join(&self, path: &str) -> Self { + Self(format!("{}.{}", self.0, path)) + } +} + +impl AsRef for FullyQualifiedName { + fn as_ref(&self) -> &str { + &self.0 + } +} + +#[cfg(test)] +mod test_helpers { + use super::*; + + impl From<&str> for FullyQualifiedName { + fn from(str: &str) -> Self { + Self(str.to_string()) + } + } +} diff --git a/prost-build/src/lib.rs b/prost-build/src/lib.rs index 14324f9cb..c31052746 100644 --- a/prost-build/src/lib.rs +++ b/prost-build/src/lib.rs @@ -148,6 +148,7 @@ pub(crate) use collections::{BytesType, MapType}; mod code_generator; mod extern_paths; +mod fully_qualified_name; mod ident; mod message_graph; mod path; diff --git a/prost-build/src/message_graph.rs b/prost-build/src/message_graph.rs index e2bcad918..a05c83e33 100644 --- a/prost-build/src/message_graph.rs +++ b/prost-build/src/message_graph.rs @@ -9,7 +9,7 @@ use prost_types::{ DescriptorProto, FieldDescriptorProto, FileDescriptorProto, }; -use crate::path::PathMap; +use crate::{fully_qualified_name::FullyQualifiedName, path::PathMap}; /// `MessageGraph` builds a graph of messages whose edges correspond to nesting. /// The goal is to recognize when message types are recursively nested, so @@ -101,9 +101,8 @@ impl MessageGraph { } /// Returns `true` if this message can automatically derive Copy trait. - pub fn can_message_derive_copy(&self, fq_message_name: &str) -> bool { - assert_eq!(".", &fq_message_name[..1]); - self.get_message(fq_message_name) + pub fn can_message_derive_copy(&self, fq_message_name: &FullyQualifiedName) -> bool { + self.get_message(fq_message_name.as_ref()) .unwrap() .field .iter() @@ -113,17 +112,15 @@ impl MessageGraph { /// Returns `true` if the type of this field allows deriving the Copy trait. pub fn can_field_derive_copy( &self, - fq_message_name: &str, + fq_message_name: &FullyQualifiedName, field: &FieldDescriptorProto, ) -> bool { - assert_eq!(".", &fq_message_name[..1]); - // repeated field cannot derive Copy if field.label() == Label::Repeated { false } else if field.r#type() == Type::Message { // nested and boxed messages cannot derive Copy - if self.is_nested(field.type_name(), fq_message_name) + if self.is_nested(field.type_name(), fq_message_name.as_ref()) || self .boxed .get_first_field(fq_message_name, field.name()) @@ -131,7 +128,7 @@ impl MessageGraph { { false } else { - self.can_message_derive_copy(field.type_name()) + self.can_message_derive_copy(&FullyQualifiedName::from_type_name(field.type_name())) } } else { matches!( diff --git a/prost-build/src/path.rs b/prost-build/src/path.rs index bc33fc440..22fd4343b 100644 --- a/prost-build/src/path.rs +++ b/prost-build/src/path.rs @@ -2,6 +2,8 @@ use std::iter; +use crate::fully_qualified_name::FullyQualifiedName; + /// Maps a fully-qualified Protobuf path to a value using path matchers. #[derive(Clone, Debug, Default)] pub(crate) struct PathMap { @@ -17,26 +19,29 @@ impl PathMap { } /// Returns a iterator over all the value matching the given fd_path and associated suffix/prefix path - pub(crate) fn get(&self, fq_path: &str) -> Iter<'_, T> { - Iter::new(self, fq_path.to_string()) + pub(crate) fn get(&self, fq_path: &FullyQualifiedName) -> Iter<'_, T> { + Iter::new(self, fq_path.clone()) } /// Returns a iterator over all the value matching the path `fq_path.field` and associated suffix/prefix path - pub(crate) fn get_field(&self, fq_path: &str, field: &str) -> Iter<'_, T> { - Iter::new(self, format!("{}.{}", fq_path, field)) + pub(crate) fn get_field(&self, fq_path: &FullyQualifiedName, field: &str) -> Iter<'_, T> { + Iter::new(self, fq_path.join(field)) } /// Returns the first value found matching the given path /// If nothing matches the path, suffix paths will be tried, then prefix paths, then the global path - #[allow(unused)] - pub(crate) fn get_first<'a>(&'a self, fq_path: &'_ str) -> Option<&'a T> { + pub(crate) fn get_first<'a>(&'a self, fq_path: &'_ FullyQualifiedName) -> Option<&'a T> { self.find_best_matching(fq_path) } /// Returns the first value found matching the path `fq_path.field` /// If nothing matches the path, suffix paths will be tried, then prefix paths, then the global path - pub(crate) fn get_first_field<'a>(&'a self, fq_path: &'_ str, field: &'_ str) -> Option<&'a T> { - self.find_best_matching(&format!("{}.{}", fq_path, field)) + pub(crate) fn get_first_field<'a>( + &'a self, + fq_path: &'_ FullyQualifiedName, + field: &'_ str, + ) -> Option<&'a T> { + self.find_best_matching(&fq_path.join(field)) } /// Removes all matchers from the path map. @@ -46,7 +51,7 @@ impl PathMap { /// Returns the first value found best matching the path /// See [sub_path_iter()] for paths test order - fn find_best_matching(&self, full_path: &str) -> Option<&T> { + fn find_best_matching(&self, full_path: &FullyQualifiedName) -> Option<&T> { sub_path_iter(full_path).find_map(|path| { self.matchers .iter() @@ -59,11 +64,11 @@ impl PathMap { /// Iterator inside a PathMap that only returns values that matches a given path pub(crate) struct Iter<'a, T> { iter: std::slice::Iter<'a, (String, T)>, - path: String, + path: FullyQualifiedName, } impl<'a, T> Iter<'a, T> { - fn new(map: &'a PathMap, path: String) -> Self { + fn new(map: &'a PathMap, path: FullyQualifiedName) -> Self { Self { iter: map.matchers.iter(), path, @@ -71,7 +76,7 @@ impl<'a, T> Iter<'a, T> { } fn is_match(&self, path: &str) -> bool { - sub_path_iter(self.path.as_str()).any(|p| p == path) + sub_path_iter(&self.path).any(|p| p == path) } } @@ -101,7 +106,8 @@ impl std::iter::FusedIterator for Iter<'_, T> {} /// - the global path /// /// Example: sub_path_iter(".a.b.c") -> [".a.b.c", "a.b.c", "b.c", "c", ".a.b", ".a", "."] -fn sub_path_iter(full_path: &str) -> impl Iterator { +fn sub_path_iter(full_path: &FullyQualifiedName) -> impl Iterator { + let full_path = full_path.as_ref(); // First, try matching the path. iter::once(full_path) // Then, try matching path suffixes. @@ -165,67 +171,77 @@ mod tests { fn test_get_matches_sub_path() { let mut path_map = PathMap::default(); + let abcd_fqn = FullyQualifiedName::from(".a.b.c.d"); + let abc_fqn = FullyQualifiedName::from(".a.b.c"); + let bcd_fqn = FullyQualifiedName::from("b.c.d"); + // full path path_map.insert(".a.b.c.d".to_owned(), 1); - assert_eq!(Some(&1), path_map.get(".a.b.c.d").next()); - assert_eq!(Some(&1), path_map.get_field(".a.b.c", "d").next()); + assert_eq!(Some(&1), path_map.get(&abcd_fqn).next()); + assert_eq!(Some(&1), path_map.get_field(&abc_fqn, "d").next()); // suffix path_map.clear(); path_map.insert("c.d".to_owned(), 1); - assert_eq!(Some(&1), path_map.get(".a.b.c.d").next()); - assert_eq!(Some(&1), path_map.get("b.c.d").next()); - assert_eq!(Some(&1), path_map.get_field(".a.b.c", "d").next()); + assert_eq!(Some(&1), path_map.get(&abcd_fqn).next()); + assert_eq!(Some(&1), path_map.get(&bcd_fqn).next()); + assert_eq!(Some(&1), path_map.get_field(&abc_fqn, "d").next()); // prefix path_map.clear(); path_map.insert(".a.b".to_owned(), 1); - assert_eq!(Some(&1), path_map.get(".a.b.c.d").next()); - assert_eq!(Some(&1), path_map.get_field(".a.b.c", "d").next()); + assert_eq!(Some(&1), path_map.get(&abcd_fqn).next()); + assert_eq!(Some(&1), path_map.get_field(&abc_fqn, "d").next()); // global path_map.clear(); path_map.insert(".".to_owned(), 1); - assert_eq!(Some(&1), path_map.get(".a.b.c.d").next()); - assert_eq!(Some(&1), path_map.get("b.c.d").next()); - assert_eq!(Some(&1), path_map.get_field(".a.b.c", "d").next()); + assert_eq!(Some(&1), path_map.get(&abcd_fqn).next()); + assert_eq!(Some(&1), path_map.get(&bcd_fqn).next()); + assert_eq!(Some(&1), path_map.get_field(&abc_fqn, "d").next()); } #[test] fn test_get_best() { let mut path_map = PathMap::default(); + let abcd_fqn = FullyQualifiedName::from(".a.b.c.d"); + let abc_fqn = FullyQualifiedName::from(".a.b.c"); + let bcd_fqn = FullyQualifiedName::from("b.c.d"); + // worst is global path_map.insert(".".to_owned(), 1); - assert_eq!(Some(&1), path_map.get_first(".a.b.c.d")); - assert_eq!(Some(&1), path_map.get_first("b.c.d")); - assert_eq!(Some(&1), path_map.get_first_field(".a.b.c", "d")); + assert_eq!(Some(&1), path_map.get_first(&abcd_fqn)); + assert_eq!(Some(&1), path_map.get_first(&bcd_fqn)); + assert_eq!(Some(&1), path_map.get_first_field(&abc_fqn, "d")); // then prefix path_map.insert(".a.b".to_owned(), 2); - assert_eq!(Some(&2), path_map.get_first(".a.b.c.d")); - assert_eq!(Some(&2), path_map.get_first_field(".a.b.c", "d")); + assert_eq!(Some(&2), path_map.get_first(&abcd_fqn)); + assert_eq!(Some(&2), path_map.get_first_field(&abc_fqn, "d")); // then suffix path_map.insert("c.d".to_owned(), 3); - assert_eq!(Some(&3), path_map.get_first(".a.b.c.d")); - assert_eq!(Some(&3), path_map.get_first("b.c.d")); - assert_eq!(Some(&3), path_map.get_first_field(".a.b.c", "d")); + assert_eq!(Some(&3), path_map.get_first(&abcd_fqn)); + assert_eq!(Some(&3), path_map.get_first(&bcd_fqn)); + assert_eq!(Some(&3), path_map.get_first_field(&abc_fqn, "d")); // best is full path path_map.insert(".a.b.c.d".to_owned(), 4); - assert_eq!(Some(&4), path_map.get_first(".a.b.c.d")); - assert_eq!(Some(&4), path_map.get_first_field(".a.b.c", "d")); + assert_eq!(Some(&4), path_map.get_first(&abcd_fqn)); + assert_eq!(Some(&4), path_map.get_first_field(&abc_fqn, "d")); } #[test] fn test_get_keep_order() { + let abcd_fqn = FullyQualifiedName::from(".a.b.c.d"); + let mut path_map = PathMap::default(); path_map.insert(".".to_owned(), 1); path_map.insert(".a.b".to_owned(), 2); path_map.insert(".a.b.c.d".to_owned(), 3); - let mut iter = path_map.get(".a.b.c.d"); + let mut iter = path_map.get(&abcd_fqn); assert_eq!(Some(&1), iter.next()); assert_eq!(Some(&2), iter.next()); assert_eq!(Some(&3), iter.next()); @@ -237,7 +253,7 @@ mod tests { path_map.insert(".a.b".to_owned(), 2); path_map.insert(".".to_owned(), 3); - let mut iter = path_map.get(".a.b.c.d"); + let mut iter = path_map.get(&abcd_fqn); assert_eq!(Some(&1), iter.next()); assert_eq!(Some(&2), iter.next()); assert_eq!(Some(&3), iter.next());