From dbae825c22f5df38a0b7983bc177e2c5e011f1b5 Mon Sep 17 00:00:00 2001 From: Vaivaswatha Nagaraj Date: Mon, 23 Oct 2023 21:28:09 +0530 Subject: [PATCH] Attribute parsers --- src/attribute.rs | 95 ++++++++++++- src/dialect.rs | 20 +-- src/dialects/builtin/attributes.rs | 221 ++++++++++++++++++++++++----- src/dialects/builtin/ops.rs | 38 ++++- src/dialects/llvm/ops.rs | 16 ++- src/op.rs | 10 +- tests/interfaces.rs | 15 +- 7 files changed, 350 insertions(+), 65 deletions(-) diff --git a/src/attribute.rs b/src/attribute.rs index 38db3eb..99bd162 100644 --- a/src/attribute.rs +++ b/src/attribute.rs @@ -24,8 +24,9 @@ //! //! [AttrObj]s can be downcasted to their concrete types using /// [downcast_rs](https://docs.rs/downcast-rs/1.2.0/downcast_rs/index.html#example-without-generics). -use std::{hash::Hash, ops::Deref}; +use std::{fmt::Display, hash::Hash, ops::Deref}; +use combine::{easy, parser, ParseResult, Parser, Positioned}; use downcast_rs::{impl_downcast, Downcast}; use intertrait::{cast::CastRef, CastFrom}; @@ -34,6 +35,8 @@ use crate::{ context::Context, dialect::{Dialect, DialectName}, error::Result, + input_err, + parsable::{identifier, spaced, to_parse_result, Parsable, ParserFn, StateStream}, printable::{self, Printable}, }; @@ -59,11 +62,11 @@ pub trait Attribute: Printable + Verify + Downcast + CastFrom + Sync { /// Register this attribute's [AttrId] in the dialect it belongs to. /// **Warning**: No check is made as to whether this attr is already registered /// in `dialect`. - fn register_attr_in_dialect(dialect: &mut Dialect) + fn register_attr_in_dialect(dialect: &mut Dialect, attr_parser: ParserFn) where Self: Sized, { - dialect.add_attr(Self::get_attr_id_static()); + dialect.add_attr(Self::get_attr_id_static(), attr_parser); } } impl_downcast!(Attribute); @@ -133,10 +136,31 @@ impl Printable for AttrName { _state: &printable::State, f: &mut core::fmt::Formatter<'_>, ) -> core::fmt::Result { + ::fmt(self, f) + } +} + +impl Display for AttrName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } +impl Parsable for AttrName { + type Parsed = AttrName; + + fn parse<'a>( + state_stream: &mut crate::parsable::StateStream<'a>, + ) -> combine::ParseResult>> + where + Self: Sized, + { + identifier() + .map(|name| AttrName::new(&name)) + .parse_stream(&mut state_stream.stream) + } +} + impl Deref for AttrName { type Target = String; @@ -154,14 +178,75 @@ pub struct AttrId { impl Printable for AttrId { fn fmt( &self, - ctx: &Context, + _ctx: &Context, _state: &printable::State, f: &mut core::fmt::Formatter<'_>, ) -> core::fmt::Result { - write!(f, "{}.{}", self.dialect.disp(ctx), self.name.disp(ctx)) + ::fmt(self, f) + } +} + +impl Display for AttrId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}", self.dialect, self.name) + } +} + +impl Parsable for AttrId { + type Parsed = AttrId; + + // Parses (but does not validate) a TypeId. + fn parse<'a>( + state_stream: &mut StateStream<'a>, + ) -> ParseResult>> + where + Self: Sized, + { + let parser = DialectName::parser() + .skip(parser::char::char('.')) + .and(AttrName::parser()) + .map(|(dialect, name)| AttrId { dialect, name }); + spaced(parser).parse_stream(state_stream) } } +/// Parse an identified attribute, which is [AttrId] followed by its contents. +pub fn attr_parse<'a>( + state_stream: &mut StateStream<'a>, +) -> ParseResult>> { + let position = state_stream.stream.position(); + let attr_id_parser = AttrId::parser(); + + let attr_parser = attr_id_parser.then(|attr_id: AttrId| { + combine::parser(move |parsable_state: &mut StateStream<'a>| { + let state = &parsable_state.state; + let dialect = state + .ctx + .dialects + .get(&attr_id.dialect) + .expect("Dialect name parsed but dialect isn't registered"); + let Some(attr_parser) = dialect.attributes.get(&attr_id) else { + return to_parse_result( + input_err!("Unregistered attribute {}", attr_id.disp(state.ctx)), + position, + ) + .into_result(); + }; + attr_parser(&()).parse_stream(parsable_state).into_result() + }) + }); + + let mut attr_parser = spaced(attr_parser); + attr_parser.parse_stream(state_stream) +} + +/// A parser combinator to parse [AttrId] followed by the attribute's contents. +pub fn attr_parser<'a>( +) -> Box, Output = AttrObj, PartialState = ()> + 'a> { + combine::parser(|parsable_state: &mut StateStream<'a>| attr_parse(parsable_state).into_result()) + .boxed() +} + /// Every attribute interface must have a function named `verify` with this type. pub type AttrInterfaceVerifier = fn(&dyn Attribute, &Context) -> Result<()>; diff --git a/src/dialect.rs b/src/dialect.rs index 60ba638..6124490 100644 --- a/src/dialect.rs +++ b/src/dialect.rs @@ -6,9 +6,9 @@ use combine::{easy, ParseResult, Parser}; use rustc_hash::FxHashMap; use crate::{ - attribute::AttrId, + attribute::{AttrId, AttrObj}, context::{Context, Ptr}, - op::OpId, + op::{OpId, OpObj}, parsable::{self, Parsable, ParserFn, StateStream}, printable::{self, Printable}, r#type::{TypeId, TypeObj}, @@ -80,11 +80,11 @@ pub struct Dialect { /// Name of this dialect. pub name: DialectName, /// Ops that are part of this dialect. - pub ops: Vec, + pub ops: FxHashMap>, /// Types that are part of this dialect. pub types: FxHashMap>>, /// Attributes that are part of this dialect. - pub attributes: Vec, + pub attributes: FxHashMap>, } impl Printable for Dialect { @@ -103,9 +103,9 @@ impl Dialect { pub fn new(name: DialectName) -> Dialect { Dialect { name, - ops: vec![], + ops: FxHashMap::default(), types: FxHashMap::default(), - attributes: vec![], + attributes: FxHashMap::default(), } } @@ -115,9 +115,9 @@ impl Dialect { } /// Add an [Op](crate::op::Op) to this dialect. - pub fn add_op(&mut self, op: OpId) { + pub fn add_op(&mut self, op: OpId, op_parser: ParserFn) { assert!(op.dialect == self.name); - self.ops.push(op); + self.ops.insert(op, op_parser); } /// Add a [Type](crate::type::Type) to this dialect. @@ -127,9 +127,9 @@ impl Dialect { } /// Add an [Attribute](crate::attribute::Attribute) to this dialect. - pub fn add_attr(&mut self, attr: AttrId) { + pub fn add_attr(&mut self, attr: AttrId, attr_parser: ParserFn) { assert!(attr.dialect == self.name); - self.attributes.push(attr); + self.attributes.insert(attr, attr_parser); } /// This Dialect's name. diff --git a/src/dialects/builtin/attributes.rs b/src/dialects/builtin/attributes.rs index ce774c2..e90af69 100644 --- a/src/dialects/builtin/attributes.rs +++ b/src/dialects/builtin/attributes.rs @@ -1,5 +1,5 @@ use apint::ApInt; -use intertrait::cast_to; +use combine::{any, between, easy, many, none_of, token, ParseResult, Parser, Positioned}; use sorted_vector_map::SortedVectorMap; use thiserror::Error; @@ -9,9 +9,10 @@ use crate::{ context::{Context, Ptr}, dialect::Dialect, error::Result, - impl_attr, impl_attr_interface, + impl_attr, impl_attr_interface, input_err, + parsable::{spaced, to_parse_result, Parsable, StateStream}, printable::{self, Printable}, - r#type::TypeObj, + r#type::{type_parser, TypeObj}, verify_err, }; @@ -21,7 +22,7 @@ use super::{attr_interfaces::TypedAttrInterface, types::IntegerType}; /// Similar to MLIR's [StringAttr](https://mlir.llvm.org/docs/Dialects/Builtin/#stringattr). #[derive(PartialEq, Eq, Clone)] pub struct StringAttr(String); -impl_attr!(StringAttr, "String", "builtin"); +impl_attr!(StringAttr, "string", "builtin"); impl StringAttr { /// Create a new [StringAttr]. @@ -43,7 +44,7 @@ impl Printable for StringAttr { _state: &printable::State, f: &mut core::fmt::Formatter<'_>, ) -> core::fmt::Result { - write!(f, "{}", self.0) + write!(f, "{} {:?}", Self::get_attr_id_static(), self.0) } } @@ -53,6 +54,46 @@ impl Verify for StringAttr { } } +impl Parsable for StringAttr { + type Parsed = AttrObj; + + fn parse<'a>( + state_stream: &mut StateStream<'a>, + ) -> ParseResult>> { + // An escaped charater is one that is preceded by a backslash. + let escaped_char = combine::parser(move |parsable_state: &mut StateStream<'a>| { + // This combine::parser() is so that we can get a position before the parsing begins. + let position = parsable_state.position(); + let mut escaped_char = token('\\').with(any()).then(|c: char| { + // This combine::parser() is so that we can return an error of the right type. + // I can't get the right error type with `and_then` + combine::parser(move |_parsable_state: &mut StateStream<'a>| { + // Filter out the escaped characters that we handle. + let result = match c { + '\\' => Ok('\\'), + '\"' => Ok('\"'), + _ => input_err!("Unexpected escaped character \\{}", c), + }; + to_parse_result(result, position).into_result() + }) + }); + escaped_char.parse_stream(parsable_state).into_result() + }); + + // We want to scan a double quote deliminted string with possibly escaped characters in between. + let mut quoted_string = between( + token('"'), + token('"'), + many(escaped_char.or(none_of("\"".chars()))), + ) + .map(|str: Vec<_>| -> Box { + Box::new(StringAttr(str.into_iter().collect())) + }); + + quoted_string.parse_stream(state_stream) + } +} + /// An attribute containing an integer. /// Similar to MLIR's [IntegerAttr](https://mlir.llvm.org/docs/Dialects/Builtin/#integerattr). #[derive(PartialEq, Eq, Clone)] @@ -60,7 +101,7 @@ pub struct IntegerAttr { ty: Ptr, val: ApInt, } -impl_attr!(IntegerAttr, "Integer", "builtin"); +impl_attr!(IntegerAttr, "integer", "builtin"); impl Printable for IntegerAttr { fn fmt( @@ -99,6 +140,16 @@ impl From for ApInt { } } +impl Parsable for IntegerAttr { + type Parsed = AttrObj; + + fn parse<'a>( + _state_stream: &mut StateStream<'a>, + ) -> ParseResult>> { + todo!() + } +} + impl_attr_interface!(TypedAttrInterface for IntegerAttr { fn get_type(&self) -> Ptr { self.ty @@ -114,7 +165,7 @@ pub struct APFloat(); /// TODO: Use rustc's APFloat. #[derive(PartialEq, Clone)] pub struct FloatAttr(APFloat); -impl_attr!(FloatAttr, "Float", "builtin"); +impl_attr!(FloatAttr, "float", "builtin"); impl Printable for FloatAttr { fn fmt( @@ -146,9 +197,20 @@ impl From for APFloat { } } -#[cast_to] -impl TypedAttrInterface for FloatAttr { - fn get_type(&self) -> Ptr { +impl_attr_interface!( + TypedAttrInterface for FloatAttr { + fn get_type(&self) -> Ptr { + todo!() + } + } +); + +impl Parsable for FloatAttr { + type Parsed = AttrObj; + + fn parse<'a>( + _state_stream: &mut StateStream<'a>, + ) -> ParseResult>> { todo!() } } @@ -159,7 +221,7 @@ impl TypedAttrInterface for FloatAttr { /// Similar to MLIR's [DictionaryAttr](https://mlir.llvm.org/docs/Dialects/Builtin/#dictionaryattr), #[derive(PartialEq, Eq)] pub struct SmallDictAttr(SortedVectorMap<&'static str, AttrObj>); -impl_attr!(SmallDictAttr, "SmallDict", "builtin"); +impl_attr!(SmallDictAttr, "small_dict", "builtin"); impl Printable for SmallDictAttr { fn fmt( @@ -191,6 +253,16 @@ impl Verify for SmallDictAttr { } } +impl Parsable for SmallDictAttr { + type Parsed = AttrObj; + + fn parse<'a>( + _state_stream: &mut StateStream<'a>, + ) -> ParseResult>> { + todo!() + } +} + impl SmallDictAttr { /// Create a new [SmallDictAttr]. pub fn create(value: Vec<(&'static str, AttrObj)>) -> AttrObj { @@ -224,7 +296,7 @@ impl SmallDictAttr { #[derive(PartialEq, Eq)] pub struct VecAttr(pub Vec); -impl_attr!(VecAttr, "Vec", "builtin"); +impl_attr!(VecAttr, "vec", "builtin"); impl VecAttr { pub fn create(value: Vec) -> AttrObj { @@ -249,11 +321,21 @@ impl Verify for VecAttr { } } +impl Parsable for VecAttr { + type Parsed = AttrObj; + + fn parse<'a>( + _state_stream: &mut StateStream<'a>, + ) -> ParseResult>> { + todo!() + } +} + /// Represent attributes that only have meaning from their existence. /// See [UnitAttr](https://mlir.llvm.org/docs/Dialects/Builtin/#unitattr) in MLIR. #[derive(PartialEq, Eq, Clone, Copy)] pub struct UnitAttr(); -impl_attr!(UnitAttr, "Unit", "builtin"); +impl_attr!(UnitAttr, "unit", "builtin"); impl UnitAttr { pub fn create() -> AttrObj { @@ -278,11 +360,21 @@ impl Verify for UnitAttr { } } +impl Parsable for UnitAttr { + type Parsed = AttrObj; + + fn parse<'a>( + state_stream: &mut StateStream<'a>, + ) -> ParseResult>> { + to_parse_result(Ok(UnitAttr::create()), state_stream.position()) + } +} + /// An attribute that does nothing but hold a Type. /// Same as MLIR's [TypeAttr](https://mlir.llvm.org/docs/Dialects/Builtin/#typeattr). #[derive(PartialEq, Eq, Clone)] pub struct TypeAttr(Ptr); -impl_attr!(TypeAttr, "Type", "builtin"); +impl_attr!(TypeAttr, "type", "builtin"); impl TypeAttr { pub fn create(ty: Ptr) -> AttrObj { @@ -297,7 +389,19 @@ impl Printable for TypeAttr { _state: &printable::State, f: &mut core::fmt::Formatter<'_>, ) -> core::fmt::Result { - write!(f, "{}<{}>", self.get_attr_id().disp(ctx), self.0.disp(ctx)) + write!(f, "{} <{}>", self.get_attr_id().disp(ctx), self.0.disp(ctx)) + } +} + +impl Parsable for TypeAttr { + type Parsed = AttrObj; + + fn parse<'a>( + state_stream: &mut StateStream<'a>, + ) -> ParseResult>> { + between(spaced(token('<')), spaced(token('>')), type_parser()) + .map(TypeAttr::create) + .parse_stream(state_stream) } } @@ -307,34 +411,40 @@ impl Verify for TypeAttr { } } -#[cast_to] -impl TypedAttrInterface for TypeAttr { - fn get_type(&self) -> Ptr { - self.0 +impl_attr_interface!( + TypedAttrInterface for TypeAttr { + fn get_type(&self) -> Ptr { + self.0 + } } -} +); pub fn register(dialect: &mut Dialect) { - StringAttr::register_attr_in_dialect(dialect); - IntegerAttr::register_attr_in_dialect(dialect); - SmallDictAttr::register_attr_in_dialect(dialect); - VecAttr::register_attr_in_dialect(dialect); - UnitAttr::register_attr_in_dialect(dialect); - TypeAttr::register_attr_in_dialect(dialect); + StringAttr::register_attr_in_dialect(dialect, StringAttr::parser_fn); + IntegerAttr::register_attr_in_dialect(dialect, IntegerAttr::parser_fn); + SmallDictAttr::register_attr_in_dialect(dialect, SmallDictAttr::parser_fn); + VecAttr::register_attr_in_dialect(dialect, VecAttr::parser_fn); + UnitAttr::register_attr_in_dialect(dialect, UnitAttr::parser_fn); + TypeAttr::register_attr_in_dialect(dialect, TypeAttr::parser_fn); } #[cfg(test)] mod tests { use apint::ApInt; + use expect_test::expect; use crate::{ - attribute::{self, attr_cast}, + attribute::{self, attr_cast, attr_parser}, context::Context, - dialects::builtin::{ - attr_interfaces::TypedAttrInterface, - attributes::{IntegerAttr, StringAttr}, - types::{IntegerType, Signedness}, + dialects::{ + self, + builtin::{ + attr_interfaces::TypedAttrInterface, + attributes::{IntegerAttr, StringAttr}, + types::{IntegerType, Signedness}, + }, }, + parsable::{self, state_stream_from_iterator}, printable::Printable, }; @@ -367,21 +477,49 @@ mod tests { #[test] fn test_string_attributes() { - let ctx = Context::new(); + let mut ctx = Context::new(); + dialects::builtin::register(&mut ctx); let str_0_ptr = StringAttr::create("hello".to_string()); let str_1_ptr = StringAttr::create("world".to_string()); assert!(str_0_ptr.is::() && &str_0_ptr != &str_1_ptr); let str_0_ptr2 = StringAttr::create("hello".to_string()); assert!(str_0_ptr == str_0_ptr2); - assert!( - str_0_ptr.disp(&ctx).to_string() == "hello" - && str_1_ptr.disp(&ctx).to_string() == "world" + assert_eq!(str_0_ptr.disp(&ctx).to_string(), "builtin.string \"hello\""); + assert_eq!(str_1_ptr.disp(&ctx).to_string(), "builtin.string \"world\""); + assert_eq!( + String::from(str_0_ptr.downcast_ref::().unwrap().clone()), + "hello", ); - assert!( - String::from(str_0_ptr.downcast_ref::().unwrap().clone()) == "hello" - && String::from(str_1_ptr.downcast_ref::().unwrap().clone()) == "world", + assert_eq!( + String::from(str_1_ptr.downcast_ref::().unwrap().clone()), + "world" ); + + let attr_input = "builtin.string \"hello\""; + let state_stream = + state_stream_from_iterator(attr_input.chars(), parsable::State { ctx: &mut ctx }); + let attr = attr_parser().parse(state_stream).unwrap().0; + assert_eq!(attr.disp(&ctx).to_string(), attr_input); + + let attr_input = "builtin.string \"hello \\\"world\\\"\""; + let state_stream = + state_stream_from_iterator(attr_input.chars(), parsable::State { ctx: &mut ctx }); + let attr_parsed = attr_parser().parse(state_stream).unwrap().0; + assert_eq!(attr_parsed.disp(&ctx).to_string(), attr_input,); + + // Unsupported escaped character. + let state_stream = state_stream_from_iterator( + "builtin.string \"hello \\k \"".chars(), + parsable::State { ctx: &mut ctx }, + ); + let res = attr_parser().parse(state_stream); + let err_msg = format!("{}", res.err().unwrap()); + let expected_err_msg = expect![[r#" + Parse error at line: 1, column: 23 + Unexpected escaped character \k + "#]]; + expected_err_msg.assert_eq(&err_msg); } #[test] @@ -431,11 +569,18 @@ mod tests { #[test] fn test_type_attributes() { let mut ctx = Context::new(); + dialects::builtin::register(&mut ctx); let ty = IntegerType::get(&mut ctx, 64, Signedness::Signed); let ty_attr = TypeAttr::create(ty); let ty_interface = attr_cast::(&*ty_attr).unwrap(); assert!(ty_interface.get_type() == ty); + + let ty_attr = ty_attr.disp(&ctx).to_string(); + let state_stream = + state_stream_from_iterator(ty_attr.chars(), parsable::State { ctx: &mut ctx }); + let ty_attr_parsed = attr_parser().parse(state_stream).unwrap().0; + assert_eq!(ty_attr_parsed.disp(&ctx).to_string(), ty_attr); } } diff --git a/src/dialects/builtin/ops.rs b/src/dialects/builtin/ops.rs index 275ff64..02538e0 100644 --- a/src/dialects/builtin/ops.rs +++ b/src/dialects/builtin/ops.rs @@ -1,3 +1,4 @@ +use combine::{easy::ParseError, ParseResult}; use thiserror::Error; use crate::{ @@ -10,8 +11,9 @@ use crate::{ error::Result, impl_op_interface, linked_list::ContainsLinkedList, - op::Op, + op::{Op, OpObj}, operation::Operation, + parsable::{Parsable, StateStream}, printable::{self, Printable}, r#type::TypeObj, verify_err, @@ -92,6 +94,15 @@ impl ModuleOp { } } +impl Parsable for ModuleOp { + type Parsed = OpObj; + fn parse<'a>( + _state_stream: &mut StateStream<'a>, + ) -> ParseResult>> { + todo!() + } +} + impl_op_interface!(OneRegionInterface for ModuleOp {}); impl_op_interface!(SingleBlockRegionInterface for ModuleOp {}); impl_op_interface!(SymbolOpInterface for ModuleOp {}); @@ -204,6 +215,15 @@ impl Verify for FuncOp { } } +impl Parsable for FuncOp { + type Parsed = OpObj; + fn parse<'a>( + _state_stream: &mut StateStream<'a>, + ) -> ParseResult>> { + todo!() + } +} + declare_op!( /// Numeric constant. /// See MLIR's [arith.constant](https://mlir.llvm.org/docs/Dialects/ArithOps/#arithconstant-mlirarithconstantop). @@ -268,6 +288,16 @@ impl Printable for ConstantOp { } } +impl Parsable for ConstantOp { + type Parsed = OpObj; + + fn parse<'a>( + _state_stream: &mut StateStream<'a>, + ) -> ParseResult>> { + todo!() + } +} + #[derive(Error, Debug)] #[error("{}: Unexpected type", ConstantOp::get_opid_static())] pub struct ConstantOpVerifyErr; @@ -286,7 +316,7 @@ impl_op_interface! (ZeroOpdInterface for ConstantOp {}); impl_op_interface! (OneResultInterface for ConstantOp {}); pub fn register(ctx: &mut Context, dialect: &mut Dialect) { - ModuleOp::register(ctx, dialect); - FuncOp::register(ctx, dialect); - ConstantOp::register(ctx, dialect); + ModuleOp::register(ctx, dialect, ModuleOp::parser_fn); + FuncOp::register(ctx, dialect, FuncOp::parser_fn); + ConstantOp::register(ctx, dialect, ConstantOp::parser_fn); } diff --git a/src/dialects/llvm/ops.rs b/src/dialects/llvm/ops.rs index 0febf23..b6d94aa 100644 --- a/src/dialects/llvm/ops.rs +++ b/src/dialects/llvm/ops.rs @@ -1,3 +1,5 @@ +use combine::{easy::ParseError, ParseResult}; + use crate::{ common_traits::Verify, context::Context, @@ -6,8 +8,9 @@ use crate::{ dialects::builtin::op_interfaces::IsTerminatorInterface, error::Result, impl_op_interface, - op::Op, + op::{Op, OpObj}, operation::Operation, + parsable::{Parsable, StateStream}, printable::{self, Printable}, use_def_lists::Value, }; @@ -58,8 +61,17 @@ impl Verify for ReturnOp { } } +impl Parsable for ReturnOp { + type Parsed = OpObj; + fn parse<'a>( + _state_stream: &mut crate::parsable::StateStream<'a>, + ) -> ParseResult>> { + todo!() + } +} + impl_op_interface!(IsTerminatorInterface for ReturnOp {}); pub fn register(ctx: &mut Context, dialect: &mut Dialect) { - ReturnOp::register(ctx, dialect); + ReturnOp::register(ctx, dialect, ReturnOp::parser_fn); } diff --git a/src/op.rs b/src/op.rs index 07129a8..d7c86b4 100644 --- a/src/op.rs +++ b/src/op.rs @@ -33,6 +33,7 @@ use crate::{ dialect::{Dialect, DialectName}, error::Result, operation::Operation, + parsable::ParserFn, printable::{self, Printable}, }; @@ -99,7 +100,8 @@ impl Display for OpId { pub(crate) type OpCreator = fn(Ptr) -> OpObj; /// A wrapper around [Operation] for Op(code) specific work. -/// All per-instance data must be in the underyling Operation. +/// All per-instance data must be in the underyling Operation, +/// which means that [OpObj]s are light-weight. /// /// See [module](crate::op) documentation for more information. pub trait Op: Downcast + Verify + Printable + CastFrom { @@ -120,7 +122,7 @@ pub trait Op: Downcast + Verify + Printable + CastFrom { fn verify_interfaces(&self, ctx: &Context) -> Result<()>; /// Register Op in Context and add it to dialect. - fn register(ctx: &mut Context, dialect: &mut Dialect) + fn register(ctx: &mut Context, dialect: &mut Dialect, op_parser: ParserFn) where Self: Sized, { @@ -128,7 +130,7 @@ pub trait Op: Downcast + Verify + Printable + CastFrom { std::collections::hash_map::Entry::Occupied(_) => (), std::collections::hash_map::Entry::Vacant(v) => { v.insert(Self::wrap_operation); - dialect.add_op(Self::get_opid_static()); + dialect.add_op(Self::get_opid_static(), op_parser); } } } @@ -136,7 +138,7 @@ pub trait Op: Downcast + Verify + Printable + CastFrom { impl_downcast!(Op); /// Create [OpObj] from [`Ptr`](Operation) -pub fn from_operation(ctx: &Context, op: Ptr) -> OpObj { +pub(crate) fn from_operation(ctx: &Context, op: Ptr) -> OpObj { let opid = op.deref(ctx).get_opid(); (ctx.ops .get(&opid) diff --git a/tests/interfaces.rs b/tests/interfaces.rs index 21ed1db..2871615 100644 --- a/tests/interfaces.rs +++ b/tests/interfaces.rs @@ -1,5 +1,6 @@ mod common; +use combine::{easy::ParseError, ParseResult}; use pliron::{ attribute::Attribute, common_traits::Verify, @@ -16,8 +17,9 @@ use pliron::{ }, error::{Error, ErrorKind, Result}, impl_attr, impl_attr_interface, impl_op_interface, - op::Op, + op::{Op, OpObj}, operation::Operation, + parsable::{Parsable, StateStream}, printable::{self, Printable}, r#type::TypeObj, }; @@ -45,6 +47,15 @@ impl Verify for ZeroResultOp { } } +impl Parsable for ZeroResultOp { + type Parsed = OpObj; + fn parse<'a>( + _state_stream: &mut StateStream<'a>, + ) -> ParseResult>> { + todo!() + } +} + impl ZeroResultOp { fn new(ctx: &mut Context) -> ZeroResultOp { *Operation::new(ctx, Self::get_opid_static(), vec![], vec![], 1) @@ -60,7 +71,7 @@ impl ZeroResultOp { fn check_intrf_verfiy_errs() { let ctx = &mut setup_context_dialects(); let mut dialect = Dialect::new(DialectName::new("test")); - ZeroResultOp::register(ctx, &mut dialect); + ZeroResultOp::register(ctx, &mut dialect, ZeroResultOp::parser_fn); let zero_res_op = ZeroResultOp::new(ctx).get_operation(); let (module_op, _, _, ret_op) = const_ret_in_mod(ctx).unwrap();