From 30ebc1763db6530cc5ae9b1ba9f864b8059d5276 Mon Sep 17 00:00:00 2001 From: Vaivaswatha N Date: Tue, 25 Jun 2024 11:35:57 +0530 Subject: [PATCH] Introduce interfaces for Types (#36) --- pliron-derive/src/derive_type.rs | 24 +++ src/attribute.rs | 5 +- src/builtin/attributes.rs | 19 +- src/type.rs | 318 ++++++++++++++++++++++++++++++- tests/interfaces.rs | 84 +++++++- 5 files changed, 424 insertions(+), 26 deletions(-) diff --git a/pliron-derive/src/derive_type.rs b/pliron-derive/src/derive_type.rs index d831bf5..46253fb 100644 --- a/pliron-derive/src/derive_type.rs +++ b/pliron-derive/src/derive_type.rs @@ -110,6 +110,17 @@ impl ToTokens for ImplType { dialect: ::pliron::dialect::DialectName::new(#dialect), } } + + fn verify_interfaces(&self, ctx: &::pliron::context::Context) -> ::pliron::result::Result<()> { + if let Some(interface_verifiers) = + ::pliron::r#type::TYPE_INTERFACE_VERIFIERS_MAP.get(&Self::get_type_id_static()) + { + for (_, verifier) in interface_verifiers { + verifier(self, ctx)?; + } + } + Ok(()) + } } }); } @@ -152,6 +163,19 @@ mod tests { dialect: ::pliron::dialect::DialectName::new("testing"), } } + fn verify_interfaces( + &self, + ctx: &::pliron::context::Context, + ) -> ::pliron::result::Result<()> { + if let Some(interface_verifiers) = ::pliron::r#type::TYPE_INTERFACE_VERIFIERS_MAP + .get(&Self::get_type_id_static()) + { + for (_, verifier) in interface_verifiers { + verifier(self, ctx)?; + } + } + Ok(()) + } } "##]] .assert_eq(&got); diff --git a/src/attribute.rs b/src/attribute.rs index ea48374..04eb4d1 100644 --- a/src/attribute.rs +++ b/src/attribute.rs @@ -16,7 +16,7 @@ //! and [Interfaces](https://mlir.llvm.org/docs/Interfaces/). //! Interfaces must all implement an associated function named `verify` with //! the type [AttrInterfaceVerifier]. -//! New attributes must be specified via [decl_attr_interface](pliron::decl_attr_interface) +//! New interfaces must be specified via [decl_attr_interface](pliron::decl_attr_interface) //! for proper verification. //! //! [Attribute]s that implement an interface must do so using the @@ -27,7 +27,8 @@ //! is [implemented](attr_impls)) with ease. //! //! [AttrObj]s can be downcasted to their concrete types using -/// [downcast_rs](https://docs.rs/downcast-rs/1.2.0/downcast_rs/index.html#example-without-generics). +//! [downcast_rs](https://docs.rs/downcast-rs/1.2.0/downcast_rs/index.html#example-without-generics). + use std::{ fmt::{Debug, Display}, hash::Hash, diff --git a/src/builtin/attributes.rs b/src/builtin/attributes.rs index cb7992c..26672c7 100644 --- a/src/builtin/attributes.rs +++ b/src/builtin/attributes.rs @@ -118,11 +118,7 @@ impl Printable for IntegerAttr { } } -impl Verify for IntegerAttr { - fn verify(&self, _ctx: &Context) -> Result<()> { - Ok(()) - } -} +impl_verify_succ!(IntegerAttr); impl IntegerAttr { /// Create a new [IntegerAttr]. @@ -371,12 +367,7 @@ impl Printable for UnitAttr { write!(f, "()") } } - -impl Verify for UnitAttr { - fn verify(&self, _ctx: &Context) -> Result<()> { - Ok(()) - } -} +impl_verify_succ!(UnitAttr); impl Parsable for UnitAttr { type Arg = (); @@ -428,11 +419,7 @@ impl Parsable for TypeAttr { } } -impl Verify for TypeAttr { - fn verify(&self, _ctx: &Context) -> Result<()> { - Ok(()) - } -} +impl_verify_succ!(TypeAttr); impl Typed for TypeAttr { fn get_type(&self, _ctx: &Context) -> Ptr { diff --git a/src/type.rs b/src/type.rs index 5f6c661..ac09570 100644 --- a/src/type.rs +++ b/src/type.rs @@ -4,10 +4,29 @@ //! The type system is open, with no fixed list of types, //! and there are no restrictions on the abstractions they represent. //! -//! See [MLIR Types](https://mlir.llvm.org/docs/DefiningDialects/AttributesAndTypes/#types) +//! See [MLIR Types](https://mlir.llvm.org/docs/DefiningDialects/TypesAndTypes/#types) //! //! The [def_type](pliron_derive::def_type) proc macro from [pliron-derive] //! can be used to implement [Type] for a rust type. +//! +//! Common semantics, API and behaviour of [Type]s are +//! abstracted into interfaces. Interfaces in pliron capture MLIR +//! functionality of both [Traits](https://mlir.llvm.org/docs/Traits/) +//! and [Interfaces](https://mlir.llvm.org/docs/Interfaces/). +//! Interfaces must all implement an associated function named `verify` with +//! the type [TypeInterfaceVerifier]. +//! New interfaces must be specified via [decl_type_interface](pliron::decl_type_interface) +//! for proper verification. +//! +//! [Type]s that implement an interface must do so using the +//! [impl_type_interface](crate::impl_type_interface) macro. +//! This ensures that the interface verifier is automatically called, +//! and that a `&dyn Type` object can be [cast](type_cast) into an +//! interface object, (or that it can be checked if the interface +//! is [implemented](type_impls)) with ease. +//! +//! [TypeObj]s can be downcasted to their concrete types using +//! [downcast_rs](https://docs.rs/downcast-rs/1.2.0/downcast_rs/index.html#example-without-generics). use crate::common_traits::Verify; use crate::context::{private::ArenaObj, ArenaCell, Context, Ptr}; @@ -23,8 +42,11 @@ use crate::{arg_err_noloc, input_err}; use combine::{parser, Parser}; use downcast_rs::{impl_downcast, Downcast}; +use linkme::distributed_slice; +use rustc_hash::FxHashMap; use std::cell::Ref; use std::fmt::Debug; +use std::fmt::Display; use std::hash::{Hash, Hasher}; use std::marker::PhantomData; use std::ops::Deref; @@ -130,6 +152,9 @@ pub trait Type: Printable + Verify + Downcast + Sync + Send + Debug { where Self: Sized; + /// Verify all interfaces implemented by this Type. + fn verify_interfaces(&self, ctx: &Context) -> Result<()>; + /// Register this Type's [TypeId] in the dialect it belongs to. fn register_type_in_dialect(ctx: &mut Context, parser: ParserFn<(), TypePtr>) where @@ -237,6 +262,12 @@ impl Printable for TypeName { _state: &printable::State, f: &mut core::fmt::Formatter<'_>, ) -> core::fmt::Result { + ::fmt(self, f) + } +} + +impl Display for TypeName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } @@ -289,11 +320,17 @@ impl Parsable for TypeId { impl Printable for TypeId { fn fmt( &self, - ctx: &Context, + _ctx: &Context, _state: &printable::State, f: &mut core::fmt::Formatter<'_>, ) -> core::fmt::Result { - write!(f, "{}.{}", self.dialect.disp(ctx), self.name.disp(ctx)) + ::fmt(self, f) + } +} + +impl Display for TypeId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}", self.dialect, self.name) } } @@ -340,7 +377,7 @@ impl Printable for TypeObj { state: &printable::State, f: &mut std::fmt::Formatter<'_>, ) -> std::fmt::Result { - self.get_type_id().fmt(ctx, state, f)?; + write!(f, "{}", self.get_type_id())?; Printable::fmt(self.deref(), ctx, state, f) } } @@ -496,3 +533,276 @@ impl Verify for TypePtr { self.0.verify(ctx) } } + +/// Cast reference to a [Type] object to an interface reference. +pub fn type_cast(ty: &dyn Type) -> Option<&T> { + crate::trait_cast::any_to_trait::(ty.as_any()) +} + +/// Does this [Type] object implement interface T? +pub fn type_impls(ty: &dyn Type) -> bool { + type_cast::(ty).is_some() +} + +/// Every type interface must have a function named `verify` with this type. +pub type TypeInterfaceVerifier = fn(&dyn Type, &Context) -> Result<()>; + +/// Implement a Type Interface for a Type. +/// The interface trait must define a `verify` function with type [TypeInterfaceVerifier]. +/// +/// Usage: +/// ``` +/// #[def_type("dialect.name")] +/// #[derive(PartialEq, Eq, Clone, Debug, Hash)] +/// struct MyType { } +/// +/// decl_type_interface! { +/// /// My first type interface. +/// MyTypeInterface { +/// fn monu(&self); +/// fn verify(r#type: &dyn Type, ctx: &Context) -> Result<()> +/// where Self: Sized, +/// { +/// Ok(()) +/// } +/// } +/// } +/// impl_type_interface!( +/// MyTypeInterface for MyType +/// { +/// fn monu(&self) { println!("monu"); } +/// } +/// ); +/// # use pliron::{ +/// # decl_type_interface, +/// # printable::{self, Printable}, +/// # context::Context, result::Result, common_traits::Verify, +/// # r#type::Type, impl_type_interface +/// # }; +/// # use pliron_derive::def_type; +/// # +/// # impl Printable for MyType { +/// # fn fmt(&self, _ctx: &Context, _state: &printable::State, _f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { +/// # unimplemented!() +/// # } +/// # } +/// # pliron::impl_verify_succ!(MyType); +#[macro_export] +macro_rules! impl_type_interface { + ($intr_name:ident for $type_name:ident { $($tt:tt)* }) => { + $crate::type_to_trait!($type_name, $intr_name); + impl $intr_name for $type_name { + $($tt)* + } + const _: () = { + #[linkme::distributed_slice(pliron::r#type::TYPE_INTERFACE_VERIFIERS)] + static INTERFACE_VERIFIER: $crate::Lazy< + (pliron::r#type::TypeId, (std::any::TypeId, pliron::r#type::TypeInterfaceVerifier)) + > = + $crate::Lazy::new(|| + ($type_name::get_type_id_static(), (std::any::TypeId::of::(), + <$type_name as $intr_name>::verify)) + ); + }; + }; +} + +/// [Type]s paired with every interface it implements (and the verifier for that interface). +#[distributed_slice] +pub static TYPE_INTERFACE_VERIFIERS: [crate::Lazy<( + TypeId, + (std::any::TypeId, TypeInterfaceVerifier), +)>]; + +/// All interfaces mapped to their super-interfaces +#[distributed_slice] +pub static TYPE_INTERFACE_DEPS: [crate::Lazy<(std::any::TypeId, Vec)>]; + +/// A map from every [Type] to its ordered (as per interface deps) list of interface verifiers. +/// An interface's super-interfaces are to be verified before it itself is. +pub static TYPE_INTERFACE_VERIFIERS_MAP: crate::Lazy< + FxHashMap>, +> = crate::Lazy::new(|| { + use std::any::TypeId; + // Collect TYPE_INTERFACE_VERIFIERS into a [TypeId] indexed map. + let mut type_intr_verifiers = FxHashMap::default(); + for lazy in TYPE_INTERFACE_VERIFIERS { + let (ty_id, (type_id, verifier)) = (**lazy).clone(); + type_intr_verifiers + .entry(ty_id) + .and_modify(|verifiers: &mut Vec<(TypeId, TypeInterfaceVerifier)>| { + verifiers.push((type_id, verifier)) + }) + .or_insert(vec![(type_id, verifier)]); + } + + // Collect interface deps into a map. + let interface_deps: FxHashMap<_, _> = TYPE_INTERFACE_DEPS + .iter() + .map(|lazy| (**lazy).clone()) + .collect(); + + // Assign an integer to each interface, such that if y depends on x + // i.e., x is a super-interface of y, then dep_sort_idx[x] < dep_sort_idx[y] + let mut dep_sort_idx = FxHashMap::::default(); + let mut sort_idx = 0; + fn assign_idx_to_intr( + interface_deps: &FxHashMap>, + dep_sort_idx: &mut FxHashMap, + sort_idx: &mut u32, + intr: &TypeId, + ) { + if dep_sort_idx.contains_key(intr) { + return; + } + + // Assign index to every dependent first. We don't bother to check for cyclic + // dependences since super interfaces are also super traits in Rust. + let deps = interface_deps + .get(intr) + .expect("Expect every interface to have a (possibly empty) list of dependences"); + for dep in deps { + assign_idx_to_intr(interface_deps, dep_sort_idx, sort_idx, dep); + } + + // Assign an index to the current interface. + dep_sort_idx.insert(*intr, *sort_idx); + *sort_idx += 1; + } + + // Assign dep_sort_idx to every interface. + for lazy in TYPE_INTERFACE_DEPS.iter() { + let (intr, _deps) = &**lazy; + assign_idx_to_intr(&interface_deps, &mut dep_sort_idx, &mut sort_idx, intr); + } + + for verifiers in type_intr_verifiers.values_mut() { + // sort verifiers based on its dep_sort_idx. + verifiers.sort_by(|(a, _), (b, _)| dep_sort_idx[a].cmp(&dep_sort_idx[b])); + } + + type_intr_verifiers +}); + +/// Declare a [Type] interface, which can be implemented by any [Type]. +/// +/// If the interface requires any other interface to be already implemented, +/// they can be specified. The trait to which this interface is expanded will +/// have the dependent interfaces as super-traits, in addition to the [Type] trait +/// itself, which is always automatically added as a super-trait. +/// +/// When a [Type] is verified, its interfaces are also automatically verified, +/// with guarantee that a super-interface is verified before an interface itself is. +/// +/// Example: Here `Super1` and `Super2` are super interfaces for the interface `MyTypeIntr`. +/// ``` +/// # use pliron::{decl_type_interface, r#type::Type, context::Context, result::Result}; +/// decl_type_interface!( +/// Super1 { +/// fn verify(_type: &dyn Type, _ctx: &Context) -> Result<()> +/// where +/// Self: Sized, +/// { +/// Ok(()) +/// } +/// } +/// ); +/// decl_type_interface!( +/// Super2 { +/// fn verify(_type: &dyn Type, _ctx: &Context) -> Result<()> +/// where +/// Self: Sized, +/// { +/// Ok(()) +/// } +/// } +/// ); +/// decl_type_interface!( +/// /// MyTypeIntr is my best type interface. +/// MyTypeIntr: Super1, Super2 { +/// fn verify(_type: &dyn Type, _ctx: &Context) -> Result<()> +/// where +/// Self: Sized, +/// { +/// Ok(()) +/// } +/// } +/// ); +/// ``` +#[macro_export] +macro_rules! decl_type_interface { + // No deps case + ($(#[$docs:meta])* + $intr_name:ident { $($tt:tt)* }) => { + decl_type_interface!( + $(#[$docs])* + $intr_name: { $($tt)* } + ); + }; + // Zero or more deps + ($(#[$docs:meta])* + $intr_name:ident: $($dep:path),* { $($tt:tt)* }) => { + $(#[$docs])* + pub trait $intr_name: pliron::r#type::Type $( + $dep )* { + $($tt)* + } + const _: () = { + #[linkme::distributed_slice(pliron::r#type::TYPE_INTERFACE_DEPS)] + static TYPE_INTERFACE_DEP: $crate::Lazy<(std::any::TypeId, Vec)> + = $crate::Lazy::new(|| { + (std::any::TypeId::of::(), vec![$(std::any::TypeId::of::(),)*]) + }); + }; + }; +} + +#[cfg(test)] +mod tests { + + use pliron::result::Result; + use rustc_hash::{FxHashMap, FxHashSet}; + use std::any::TypeId; + + use crate::verify_err_noloc; + + use super::{TYPE_INTERFACE_DEPS, TYPE_INTERFACE_VERIFIERS_MAP}; + + #[test] + /// For every interface that a [Type] implements, ensure that the interface verifiers + /// get called in the right order, with super-interface verifiers called before their + /// sub-interface verifier. + fn check_verifiers_deps() -> Result<()> { + // Collect interface deps into a map. + let interface_deps: FxHashMap<_, _> = TYPE_INTERFACE_DEPS + .iter() + .map(|lazy| (**lazy).clone()) + .collect(); + + for (ty, intrs) in TYPE_INTERFACE_VERIFIERS_MAP.iter() { + let mut satisfied_deps = FxHashSet::::default(); + for (intr, _) in intrs { + let deps = interface_deps.get(intr).ok_or_else(|| { + let err: Result<()> = verify_err_noloc!( + "Missing deps list for TypeId {:?} when checking verifier dependences for {}", + intr, + ty + ); + err.unwrap_err() + })?; + for dep in deps { + if !satisfied_deps.contains(dep) { + return verify_err_noloc!( + "For {}, depencence {:?} not satisfied for {:?}", + ty, + dep, + intr + ); + } + } + satisfied_deps.insert(*intr); + } + } + + Ok(()) + } +} diff --git a/tests/interfaces.rs b/tests/interfaces.rs index 3fb45a1..bf670c7 100644 --- a/tests/interfaces.rs +++ b/tests/interfaces.rs @@ -11,22 +11,24 @@ use pliron::{ attributes::{IntegerAttr, StringAttr}, op_interfaces::{OneResultInterface, OneResultVerifyErr}, ops::ModuleOp, + types::{IntegerType, UnitType}, }, common_traits::Verify, context::{Context, Ptr}, - decl_attr_interface, decl_op_interface, + decl_attr_interface, decl_op_interface, decl_type_interface, identifier::Identifier, - impl_attr_interface, impl_canonical_syntax, impl_op_interface, impl_verify_succ, + impl_attr_interface, impl_canonical_syntax, impl_op_interface, impl_type_interface, + impl_verify_succ, location::Location, op::{op_cast, Op, OpObj}, operation::Operation, parsable::{Parsable, ParseResult, StateStream}, printable::{self, Printable}, - r#type::TypeObj, + r#type::{Type, TypeObj}, result::{Error, ErrorKind, Result}, trait_cast::any_to_trait, }; -use pliron_derive::{def_attribute, def_op}; +use pliron_derive::{def_attribute, def_op, def_type}; use crate::common::{const_ret_in_mod, setup_context_dialects}; @@ -262,3 +264,77 @@ fn test_attr_intr_verify_order() -> Result<()> { .assert_eq(&TEST_ATTR_VERIFIERS_OUTPUT.lock().unwrap()); Ok(()) } + +decl_type_interface! { + TestTypeInterfaceX { + fn verify(_op: &dyn Type, _ctx: &Context) -> Result<()> + where + Self: Sized, + { + Ok(()) + } + } +} + +impl_type_interface!(TestTypeInterfaceX for UnitType {}); +impl_type_interface!(TestTypeInterfaceX for IntegerType {}); + +static TEST_TYPE_VERIFIERS_OUTPUT: Lazy> = Lazy::new(|| Mutex::new("".into())); + +#[def_type("test.verify_intr_type")] +#[derive(PartialEq, Clone, Debug, Hash)] +struct VerifyIntrType {} +impl_verify_succ!(VerifyIntrType); +impl_type_interface!(TestTypeInterface for VerifyIntrType {}); +impl_type_interface!(TestTypeInterface2 for VerifyIntrType {}); + +impl Printable for VerifyIntrType { + fn fmt( + &self, + _ctx: &Context, + _state: &printable::State, + f: &mut core::fmt::Formatter<'_>, + ) -> core::fmt::Result { + write!(f, "VerifyIntType") + } +} + +decl_type_interface! { + TestTypeInterface { + fn verify(_op: &dyn Type, _ctx: &Context) -> Result<()> + where + Self: Sized, + { + *TEST_TYPE_VERIFIERS_OUTPUT.lock().unwrap() += "TestTypeInterface verified\n"; + Ok(()) + } + } +} + +decl_type_interface! { + TestTypeInterface2: TestTypeInterface { + fn verify(_op: &dyn Type, _ctx: &Context) -> Result<()> + where + Self: Sized, + { + *TEST_TYPE_VERIFIERS_OUTPUT.lock().unwrap() += "TestTypeInterface2 verified\n"; + Ok(()) + } + } +} + +#[test] +fn test_type_intr_verify_order() -> Result<()> { + let ctx = &mut setup_context_dialects(); + VerifyIntrOp::register(ctx, VerifyIntrOp::parser_fn); + + let vio = VerifyIntrType {}; + vio.verify_interfaces(ctx)?; + + expect![[r#" + TestTypeInterface verified + TestTypeInterface2 verified + "#]] + .assert_eq(&TEST_TYPE_VERIFIERS_OUTPUT.lock().unwrap()); + Ok(()) +}