From 4c7e3b488b185c92131600dfe524c498bbb101ae Mon Sep 17 00:00:00 2001 From: Vaivaswatha Nagaraj Date: Sat, 7 Oct 2023 22:07:23 +0530 Subject: [PATCH] impl Parsable for StructType --- Cargo.toml | 1 + src/dialect.rs | 2 +- src/dialects/llvm/types.rs | 315 +++++++++++++++++++++++++++---------- src/parsable.rs | 19 ++- src/printable.rs | 7 + src/type.rs | 22 +-- 6 files changed, 271 insertions(+), 95 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1ff2431..9160ebd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ keywords = ["pliron", "llvm", "mlir", "compiler"] generational-arena = "0.2" downcast-rs = "1.2.0" rustc-hash = "1.1.0" +# anyerror = "1.0.75" # thiserror = "1.0.24" # clap = "4.1.6" apint = "0.2.0" diff --git a/src/dialect.rs b/src/dialect.rs index 85eb8c9..5e11547 100644 --- a/src/dialect.rs +++ b/src/dialect.rs @@ -45,7 +45,7 @@ impl Parsable for DialectName { where Self: Sized, { - let id = parsable::parse_id(); + let id = parsable::identifier(); let mut parser = id.and_then(|dialect_name| { let dialect_name = DialectName::new(&dialect_name); if state_stream.state.ctx.dialects.contains_key(&dialect_name) { diff --git a/src/dialects/llvm/types.rs b/src/dialects/llvm/types.rs index ade65bf..7ff103d 100644 --- a/src/dialects/llvm/types.rs +++ b/src/dialects/llvm/types.rs @@ -1,19 +1,57 @@ -use combine::{easy, token, ParseResult, Parser}; - use crate::{ common_traits::Verify, context::{Context, Ptr}, dialect::Dialect, error::CompilerError, impl_type, - parsable::{spaced, Parsable, StateStream}, - printable::{self, Printable}, + parsable::{identifier, spaced, to_parse_result, Parsable, StateStream}, + printable::{self, Printable, PrintableIter}, r#type::{type_parser, Type, TypeObj}, storage_uniquer::TypeValueHash, }; +use combine::{between, easy, optional, sep_by, token, ParseResult, Parser}; use std::hash::Hash; +/// A field in a [StructType]. +#[derive(Clone, PartialEq, Eq)] +pub struct StructField { + pub field_name: String, + pub field_type: Ptr, +} + +impl Printable for StructField { + fn fmt( + &self, + ctx: &Context, + state: &printable::State, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + write!( + f, + "{}: {}", + self.field_name, + self.field_type.print(ctx, state) + ) + } +} + +impl Parsable for StructField { + type Parsed = StructField; + + fn parse<'a>( + state_stream: &mut StateStream<'a>, + ) -> ParseResult>> { + // Parse a single type annotated field. + (spaced(identifier()), token(':'), spaced(type_parser())) + .parse_stream(state_stream) + .map(|(field_name, _, field_type)| StructField { + field_name, + field_type, + }) + } +} + /// Represents a c-like struct type. /// Limitations and warnings on its usage are similar to that in MLIR. /// `` @@ -24,21 +62,22 @@ use std::hash::Hash; /// named structs as identified structs. pub struct StructType { name: Option, - fields: Vec<(String, Ptr)>, + fields: Vec, finalized: bool, } impl_type!(StructType, "struct", "llvm"); impl StructType { - /// Create a new named StructType. + /// Get or create a new named StructType. /// If fields is None, it indicates an opaque (i.e., not finalized) struct. - /// Opaque structs must be finalized for verify() to succeed. + /// Opaque structs must be finalized (by passing non-none `fields`) for verify() to succeed. /// Opaque structs are an intermediary in creating recursive types. - pub fn create_named( + /// Returns an error when the name is already registered but the fields don't match. + pub fn get_named( ctx: &mut Context, name: &str, - fields: Option)>>, - ) -> Ptr { + fields: Option>, + ) -> Result, CompilerError> { let self_ptr = Type::register_instance( StructType { name: Some(name.to_string()), @@ -48,25 +87,25 @@ impl StructType { ctx, ); // Verify that we created a new or equivalent existing type. - let self_ref = self_ptr.deref(ctx); - let self_ref = self_ref.downcast_ref::().unwrap(); + let mut self_ref = self_ptr.deref_mut(ctx); + let self_ref = self_ref.downcast_mut::().unwrap(); assert!(self_ref.name.as_ref().unwrap() == name); - assert!( - self_ref.finalized == fields.is_some(), - "Struct already exists, or is being finalized via new creation" - ); if let Some(fields) = fields { - assert!( - self_ref.fields == fields, - "Struct {name} already exists and is different" - ); - }; - self_ptr + if !self_ref.finalized { + self_ref.fields = fields; + self_ref.finalized = true; + } else if self_ref.fields != fields { + return Err(CompilerError::BadInput { + msg: format!("Struct {name} already exists and is different"), + }); + } + } + Ok(self_ptr) } /// Get or create a new unnamed (anonymous) struct. /// These are finalized upon creation, and uniqued based on the fields. - pub fn get_unnamed(ctx: &mut Context, fields: Vec<(String, Ptr)>) -> Ptr { + pub fn get_unnamed(ctx: &mut Context, fields: Vec) -> Ptr { Type::register_instance( StructType { name: None, @@ -77,22 +116,20 @@ impl StructType { ) } - /// Finalize this structure. It is an error to call if already finalized. - pub fn finalize(&mut self, fields: Vec<(String, Ptr)>) { - assert!( - !self.finalized, - "Attempt to finalize an already finalized struct" - ); - self.fields = fields; - self.finalized = true; + /// Is this struct finalized? Returns false for non [StructType]s. + pub fn is_finalized(ctx: &Context, ty: Ptr) -> bool { + ty.deref(ctx) + .downcast_ref::() + .filter(|s| s.finalized) + .is_some() } /// If a named struct already exists, get a pointer to it. - /// Note that named structs are uniqued only on the name. pub fn get_existing_named(ctx: &Context, name: &str) -> Option> { Type::get_instance( StructType { name: Some(name.to_string()), + /// Named structs are uniqued only on the name. fields: vec![], finalized: false, }, @@ -100,11 +137,8 @@ impl StructType { ) } - /// Get, if it already exists, a [Ptr] to an unnamed struct. - pub fn get_existing_unnamed( - ctx: &Context, - fields: Vec<(String, Ptr)>, - ) -> Option> { + /// If an unnamed struct already exists, get a pointer to it. + pub fn get_existing_unnamed(ctx: &Context, fields: Vec) -> Option> { Type::get_instance( StructType { name: None, @@ -131,7 +165,7 @@ impl Printable for StructType { fn fmt( &self, ctx: &Context, - _state: &printable::State, + state: &printable::State, f: &mut core::fmt::Formatter<'_>, ) -> core::fmt::Result { write!(f, "{} <", Self::get_type_id_static().disp(ctx))?; @@ -143,35 +177,29 @@ impl Printable for StructType { // going to be large, in which case vec would be faster. static IN_PRINTING: RefCell> = RefCell::new(vec![]); } - let mut s; if let Some(name) = &self.name { let in_printing = IN_PRINTING.with(|f| f.borrow().contains(name)); if in_printing { return write!(f, "{}>", name.clone()); } IN_PRINTING.with(|f| f.borrow_mut().push(name.clone())); - s = format!("{name} {{ "); - } else { - s = "{{ ".to_string(); + write!(f, "{name} ")?; } - for field in &self.fields { - s += [ - field.0.clone(), - ": ".to_string(), - field.1.deref(ctx).disp(ctx).to_string(), - ", ".to_string(), - ] - .concat() - .as_str(); - } - s += "}"; + write!( + f, + "{{ {} }}", + self.fields + .iter() + .iprint(ctx, state, printable::ListSeparator::SpacedChar(',')) + )?; + // Done processing this struct. Remove it from the stack. if let Some(name) = &self.name { debug_assert!(IN_PRINTING.with(|f| f.borrow().last().unwrap() == name)); IN_PRINTING.with(|f| f.borrow_mut().pop()); } - write!(f, "{s}>") + write!(f, ">") } } @@ -179,10 +207,15 @@ impl Hash for StructType { fn hash(&self, state: &mut H) { match &self.name { Some(name) => name.hash(state), - None => self.fields.iter().for_each(|(name, ty)| { - name.hash(state); - ty.hash(state); - }), + None => self.fields.iter().for_each( + |StructField { + field_name, + field_type, + }| { + field_name.hash(state); + field_type.hash(state); + }, + ), } } } @@ -193,11 +226,9 @@ impl PartialEq for StructType { (Some(name), Some(other_name)) => name == other_name, (None, None) => { self.fields.len() == other.fields.len() - && self - .fields - .iter() - .zip(other.fields.iter()) - .all(|(f1, f2)| f1.0 == f2.0 && f1.1 == f2.1) + && self.fields.iter().zip(other.fields.iter()).all(|(f1, f2)| { + f1.field_name == f2.field_name && f1.field_type == f2.field_type + }) } _ => false, } @@ -208,12 +239,52 @@ impl Parsable for StructType { type Parsed = Ptr; fn parse<'a>( - _state_stream: &mut StateStream<'a>, + state_stream: &mut StateStream<'a>, ) -> ParseResult>> where Self: Sized, { - todo!() + let body_parser = || { + combine::parser(|parsable_state: &mut StateStream<'a>| { + // Parse multiple type annotated fields separated by ','. + let fields_parser = sep_by::, _, _, _>(StructField::parser(), token(',')); + + // The body is multiple type annotated fields surrounded by '{' and '}'. + let mut body = between(spaced(token('{')), spaced(token('}')), fields_parser); + + // Finally parse the whole thing. + body.parse_stream(parsable_state).into_result() + }) + }; + + let named = spaced(identifier()) + .and(optional(body_parser())) + .map(|(name, body_opt)| (Some(name), body_opt)); + let anonymous = body_parser().map(|body| (None::, Some(body))); + + // A struct type is named or anonymous. + let mut struct_parser = between( + spaced(token('<')), + spaced(token('>')), + (combine::position(), named.or(anonymous)), + ); + + struct_parser + .parse_stream(state_stream) + .and_then(|(position, (name_opt, body_opt))| { + let ctx = &mut state_stream.state.ctx; + if let Some(name) = name_opt { + to_parse_result(StructType::get_named(ctx, &name, body_opt), position) + } else { + to_parse_result( + Ok(StructType::get_unnamed( + ctx, + body_opt.expect("Without a name, a struct type must have a body."), + )), + position, + ) + } + }) } } @@ -266,7 +337,9 @@ impl Parsable for PointerType { where Self: Sized, { - spaced(combine::between(token('<'), token('>'), type_parser())).parse_stream(state_stream) + spaced(combine::between(token('<'), token('>'), type_parser())) + .parse_stream(state_stream) + .map(|pointee_ty| PointerType::get(state_stream.state.ctx, pointee_ty)) } } @@ -284,36 +357,44 @@ pub fn register(dialect: &mut Dialect) { #[cfg(test)] mod tests { + use expect_test::expect; + use crate::{ + common_traits::Verify, context::Context, dialects::{ self, builtin::types::{IntegerType, Signedness}, - llvm::types::{PointerType, StructType}, + llvm::types::{PointerType, StructField, StructType}, }, + error::CompilerError, parsable::{self, state_stream_from_iterator}, printable::Printable, r#type::{type_parser, Type}, }; #[test] - fn test_struct() { + fn test_struct() -> Result<(), CompilerError> { let mut ctx = Context::new(); let int64_ptr = IntegerType::get(&mut ctx, 64, Signedness::Signless); // Create an opaque struct since we want a recursive type. - let list_struct = StructType::create_named(&mut ctx, "LinkedList", None); + let list_struct = StructType::get_named(&mut ctx, "LinkedList", None)?; + assert!(!StructType::is_finalized(&ctx, list_struct)); let list_struct_ptr = PointerType::get(&mut ctx, list_struct); let fields = vec![ - ("data".to_string(), int64_ptr), - ("next".to_string(), list_struct_ptr), + StructField { + field_name: "data".to_string(), + field_type: int64_ptr, + }, + StructField { + field_name: "next".to_string(), + field_type: list_struct_ptr, + }, ]; // Finalize the type now. - list_struct - .deref_mut(&mut ctx) - .downcast_mut::() - .unwrap() - .finalize(fields); + StructType::get_named(&mut ctx, "LinkedList", Some(fields))?; + assert!(StructType::is_finalized(&ctx, list_struct)); let list_struct_2 = StructType::get_existing_named(&ctx, "LinkedList").unwrap(); assert!(list_struct == list_struct_2); @@ -326,12 +407,18 @@ mod tests { .unwrap() .disp(&ctx) .to_string(), - "llvm.struct , next: llvm.ptr >, }>" + "llvm.struct , next: llvm.ptr > }>" ); let head_fields = vec![ - ("len".to_string(), int64_ptr), - ("first".to_string(), list_struct_ptr), + StructField { + field_name: "len".to_string(), + field_type: int64_ptr, + }, + StructField { + field_name: "first".to_string(), + field_type: list_struct_ptr, + }, ]; let head_struct = StructType::get_unnamed(&mut ctx, head_fields.clone()); let head_struct2 = StructType::get_existing_unnamed(&ctx, head_fields).unwrap(); @@ -339,12 +426,20 @@ mod tests { assert!(StructType::get_existing_unnamed( &ctx, vec![ - ("len".to_string(), int64_ptr), + StructField { + field_name: "len".to_string(), + field_type: int64_ptr + }, // The actual field is a LinkedList here, rather than a pointer type to it. - ("first".to_string(), list_struct), + StructField { + field_name: "first".to_string(), + field_type: list_struct + }, ] ) .is_none()); + + Ok(()) } #[test] @@ -394,6 +489,62 @@ mod tests { ); let res = type_parser().parse(state_stream).unwrap().0; - assert_eq!(&res.disp(&ctx).to_string(), "builtin.integer "); + assert_eq!( + &res.disp(&ctx).to_string(), + "llvm.ptr >" + ); + } + + #[test] + fn test_struct_type_parsing() { + let mut ctx = Context::new(); + dialects::builtin::register(&mut ctx); + dialects::llvm::register(&mut ctx); + + let state_stream = state_stream_from_iterator( + "llvm.struct , next: llvm.ptr > }>".chars(), + parsable::State { ctx: &mut ctx }, + ); + + let res = type_parser().parse(state_stream).unwrap().0; + assert_eq!(&res.disp(&ctx).to_string(), "llvm.struct , next: llvm.ptr > }>"); + } + + #[test] + fn test_struct_type_errs() { + let mut ctx = Context::new(); + dialects::builtin::register(&mut ctx); + dialects::llvm::register(&mut ctx); + + let state_stream = state_stream_from_iterator( + "llvm.struct < My1 { f1: builtin.integer } >".chars(), + parsable::State { ctx: &mut ctx }, + ); + let _ = type_parser().parse(state_stream).unwrap().0; + + let state_stream = state_stream_from_iterator( + "llvm.struct < My1 { f1: builtin.integer } >".chars(), + parsable::State { ctx: &mut ctx }, + ); + + let res = type_parser().parse(state_stream); + let err_msg = format!("{}", res.err().unwrap()); + + let expected_err_msg = expect![[r#" + Parse error at line: 1, column: 15 + Compilation failed. + Struct My1 already exists and is different + "#]]; + expected_err_msg.assert_eq(&err_msg); + + let state_stream = state_stream_from_iterator( + "llvm.struct < My2 >".chars(), + parsable::State { ctx: &mut ctx }, + ); + let res = type_parser().parse(state_stream).unwrap().0; + let expected_err_msg = expect![[r#" + Internal compiler error. Verification failed. + Struct not finalized"#]]; + expected_err_msg.assert_eq(&res.verify(&ctx).unwrap_err().to_string()) } } diff --git a/src/parsable.rs b/src/parsable.rs index 759cd54..827339f 100644 --- a/src/parsable.rs +++ b/src/parsable.rs @@ -1,6 +1,6 @@ //! IR objects that can be parsed from their text representation. -use crate::context::Context; +use crate::{context::Context, error::CompilerError}; use combine::{ easy, parser::char::spaces, @@ -129,7 +129,7 @@ pub type ParserFn = ) -> Box, Output = Parsed, PartialState = ()> + 'a>; /// Parse an identifier. -pub fn parse_id>() -> impl Parser { +pub fn identifier>() -> impl Parser { use combine::{many, parser::char}; char::letter() .and(many::(char::alpha_num().or(char::char('_')))) @@ -142,3 +142,18 @@ pub fn spaced, Output>( ) -> impl Parser { combine::between(spaces(), spaces(), parser) } + +/// Convert `Result<_, CompilerError>` into [ParseResult], +/// Helps in returning errors when writing a parser. +pub fn to_parse_result<'a, T>( + result: Result, + position: SourcePosition, +) -> ParseResult>> { + match result { + Ok(t) => ParseResult::CommitOk(t), + Err(e) => ParseResult::CommitErr(easy::Errors::from_errors( + position, + vec![easy::Error::Message(e.to_string().into())], + )), + } +} diff --git a/src/printable.rs b/src/printable.rs index 1370365..14f0e7c 100644 --- a/src/printable.rs +++ b/src/printable.rs @@ -165,8 +165,12 @@ impl<'a, T: Printable> Printable for &'a T { #[derive(Clone, Copy)] /// When printing lists, how must they be separated pub enum ListSeparator { + /// Newline Newline, + /// Single character Char(char), + /// Single character followed by a space + SpacedChar(char), } /// Iterate over [Item](Iterator::Item)s in an [Iterator] and print them. @@ -192,6 +196,9 @@ where ListSeparator::Char(c) => { write!(f, "{}", c)?; } + ListSeparator::SpacedChar(c) => { + write!(f, "{} ", c)?; + } } item.fmt(ctx, state, f)?; } diff --git a/src/type.rs b/src/type.rs index c1ebfb1..4886a8a 100644 --- a/src/type.rs +++ b/src/type.rs @@ -11,7 +11,8 @@ use crate::common_traits::Verify; use crate::context::{private::ArenaObj, ArenaCell, Context, Ptr}; use crate::dialect::{Dialect, DialectName}; -use crate::parsable::{parse_id, spaced, Parsable, ParserFn, StateStream}; +use crate::error::CompilerError; +use crate::parsable::{identifier, spaced, to_parse_result, Parsable, ParserFn, StateStream}; use crate::printable::{self, Printable}; use crate::storage_uniquer::TypeValueHash; @@ -155,7 +156,7 @@ impl Parsable for TypeName { where Self: Sized, { - parse_id() + identifier() .map(|name| TypeName::new(&name)) .parse_stream(&mut state_stream.stream) } @@ -332,12 +333,12 @@ pub fn type_parse<'a>( .get(&type_id.dialect) .expect("Dialect name parsed but dialect isn't registered"); let Some(type_parser) = dialect.types.get(&type_id) else { - return ParseResult::CommitErr(easy::Errors::from_errors( + return to_parse_result( + Err(CompilerError::BadInput { + msg: format!("Unregistered type {}.", type_id.disp(state.ctx)), + }), position, - vec![easy::Error::Message( - format!("Unregistered type {}.", type_id.disp(state.ctx)).into(), - )], - )) + ) .into_result(); }; type_parser(&()).parse_stream(parsable_state).into_result() @@ -379,9 +380,10 @@ mod test { let err_msg = format!("{}", res.err().unwrap()); let expected_err_msg = expect![[r#" - Parse error at line: 1, column: 1 - Unregistered type builtin.some. - "#]]; + Parse error at line: 1, column: 1 + Compilation failed. + Unregistered type builtin.some. + "#]]; expected_err_msg.assert_eq(&err_msg); let state_stream = state_stream_from_iterator(