diff --git a/.cspell.json b/.cspell.json index a52cf452c0..3b74c596ed 100644 --- a/.cspell.json +++ b/.cspell.json @@ -3,6 +3,7 @@ "addf", "addi", "femtomc", + "hasher", "indoc", "insta", "libm", diff --git a/src/attribute.rs b/src/attribute.rs index d190db58f2..dbd88ce3db 100644 --- a/src/attribute.rs +++ b/src/attribute.rs @@ -1,9 +1,23 @@ use crate::{ context::{Context, ContextRef}, + r#type::{self, Type}, string_ref::StringRef, }; -use mlir_sys::{mlirAttributeGetContext, mlirAttributeParseGet, MlirAttribute}; -use std::marker::PhantomData; +use mlir_sys::{ + mlirAttributeDump, mlirAttributeEqual, mlirAttributeGetContext, mlirAttributeGetNull, + mlirAttributeGetType, mlirAttributeGetTypeID, mlirAttributeIsAAffineMap, mlirAttributeIsAArray, + mlirAttributeIsABool, mlirAttributeIsADenseElements, mlirAttributeIsADenseFPElements, + mlirAttributeIsADenseIntElements, mlirAttributeIsADictionary, mlirAttributeIsAElements, + mlirAttributeIsAFloat, mlirAttributeIsAInteger, mlirAttributeIsAIntegerSet, + mlirAttributeIsAOpaque, mlirAttributeIsAOpaqueElements, mlirAttributeIsASparseElements, + mlirAttributeIsAString, mlirAttributeIsASymbolRef, mlirAttributeIsAType, mlirAttributeIsAUnit, + mlirAttributeParseGet, mlirAttributePrint, MlirAttribute, MlirStringRef, +}; +use std::{ + ffi::c_void, + fmt::{self, Display, Formatter}, + marker::PhantomData, +}; /// An attribute. // Attributes are always values but their internal storage is owned by contexts. @@ -15,7 +29,7 @@ pub struct Attribute<'c> { impl<'c> Attribute<'c> { /// Parses an attribute. - pub fn parse(context: &Context, source: &str) -> Option { + pub fn parse(context: &'c Context, source: &str) -> Option { unsafe { Self::from_option_raw(mlirAttributeParseGet( context.to_raw(), @@ -24,11 +38,134 @@ impl<'c> Attribute<'c> { } } + /// Creates a null attribute. + pub fn null() -> Self { + unsafe { Self::from_raw(mlirAttributeGetNull()) } + } + + /// Gets a type. + pub fn r#type(&self) -> Option> { + if self.is_null() { + None + } else { + unsafe { Some(Type::from_raw(mlirAttributeGetType(self.raw))) } + } + } + + /// Gets a type ID. + pub fn type_id(&self) -> Option { + if self.is_null() { + None + } else { + unsafe { Some(r#type::Id::from_raw(mlirAttributeGetTypeID(self.raw))) } + } + } + + /// Returns `true` if an attribute is null. + pub fn is_null(&self) -> bool { + self.raw.ptr.is_null() + } + + /// Returns `true` if an attribute is a affine map. + pub fn is_affine_map(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsAAffineMap(self.raw) } + } + + /// Returns `true` if an attribute is a array. + pub fn is_array(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsAArray(self.raw) } + } + + /// Returns `true` if an attribute is a bool. + pub fn is_bool(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsABool(self.raw) } + } + + /// Returns `true` if an attribute is dense elements. + pub fn is_dense_elements(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsADenseElements(self.raw) } + } + + /// Returns `true` if an attribute is dense integer elements. + pub fn is_dense_integer_elements(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsADenseIntElements(self.raw) } + } + + /// Returns `true` if an attribute is dense float elements. + pub fn is_dense_float_elements(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsADenseFPElements(self.raw) } + } + + /// Returns `true` if an attribute is a dictionary. + pub fn is_dictionary(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsADictionary(self.raw) } + } + + /// Returns `true` if an attribute is elements. + pub fn is_elements(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsAElements(self.raw) } + } + + /// Returns `true` if an attribute is a float. + pub fn is_float(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsAFloat(self.raw) } + } + + /// Returns `true` if an attribute is an integer. + pub fn is_integer(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsAInteger(self.raw) } + } + + /// Returns `true` if an attribute is an integer set. + pub fn is_integer_set(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsAIntegerSet(self.raw) } + } + + /// Returns `true` if an attribute is opaque. + pub fn is_opaque(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsAOpaque(self.raw) } + } + + /// Returns `true` if an attribute is opaque elements. + pub fn is_opaque_elements(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsAOpaqueElements(self.raw) } + } + + /// Returns `true` if an attribute is sparse elements. + pub fn is_sparse_elements(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsASparseElements(self.raw) } + } + + /// Returns `true` if an attribute is a string. + pub fn is_string(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsAString(self.raw) } + } + + /// Returns `true` if an attribute is a symbol. + pub fn is_symbol(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsASymbolRef(self.raw) } + } + + /// Returns `true` if an attribute is a type. + pub fn is_type(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsAType(self.raw) } + } + + /// Returns `true` if an attribute is a unit. + pub fn is_unit(&self) -> bool { + !self.is_null() && unsafe { mlirAttributeIsAUnit(self.raw) } + } + /// Gets a context. pub fn context(&self) -> ContextRef<'c> { unsafe { ContextRef::from_raw(mlirAttributeGetContext(self.raw)) } } + /// Dumps an attribute. + pub fn dump(&self) { + unsafe { mlirAttributeDump(self.raw) } + } + unsafe fn from_raw(raw: MlirAttribute) -> Self { Self { raw, @@ -49,6 +186,35 @@ impl<'c> Attribute<'c> { } } +impl<'c> PartialEq for Attribute<'c> { + fn eq(&self, other: &Self) -> bool { + unsafe { mlirAttributeEqual(self.raw, other.raw) } + } +} + +impl<'c> Eq for Attribute<'c> {} + +impl<'c> Display for Attribute<'c> { + fn fmt(&self, formatter: &mut Formatter) -> fmt::Result { + let mut data = (formatter, Ok(())); + + unsafe extern "C" fn callback(string: MlirStringRef, data: *mut c_void) { + let data = &mut *(data as *mut (&mut Formatter, fmt::Result)); + let result = write!(data.0, "{}", StringRef::from_raw(string).as_str()); + + if data.1.is_ok() { + data.1 = result; + } + } + + unsafe { + mlirAttributePrint(self.raw, Some(callback), &mut data as *mut _ as *mut c_void); + } + + data.1 + } +} + #[cfg(test)] mod tests { use super::*; @@ -65,8 +231,188 @@ mod tests { assert!(Attribute::parse(&Context::new(), "z").is_none()); } + #[test] + fn null() { + assert_eq!(Attribute::null().to_string(), "<>"); + } + #[test] fn context() { Attribute::parse(&Context::new(), "unit").unwrap().context(); } + + #[test] + fn r#type() { + let context = Context::new(); + + assert_eq!( + Attribute::parse(&context, "unit").unwrap().r#type(), + Some(Type::none(&context)) + ); + } + + #[test] + fn type_none() { + assert_eq!(Attribute::null().r#type(), None); + } + + // TODO Fix this. + #[ignore] + #[test] + fn type_id() { + let context = Context::new(); + + assert_eq!( + Attribute::parse(&context, "42 : index").unwrap().type_id(), + Some(Type::index(&context).id()) + ); + } + + #[test] + fn is_null() { + assert!(Attribute::null().is_null()); + } + + #[test] + fn is_array() { + assert!(Attribute::parse(&Context::new(), "[]").unwrap().is_array()); + } + + #[test] + fn is_bool() { + assert!(Attribute::parse(&Context::new(), "false") + .unwrap() + .is_bool()); + } + + #[test] + fn is_dense_elements() { + assert!( + Attribute::parse(&Context::new(), "dense<10> : tensor<2xi8>") + .unwrap() + .is_dense_elements() + ); + } + + #[test] + fn is_dense_integer_elements() { + assert!( + Attribute::parse(&Context::new(), "dense<42> : tensor<42xi8>") + .unwrap() + .is_dense_integer_elements() + ); + } + + #[test] + fn is_dense_float_elements() { + assert!( + Attribute::parse(&Context::new(), "dense<42.0> : tensor<42xf32>") + .unwrap() + .is_dense_float_elements() + ); + } + + #[test] + fn is_elements() { + assert!(Attribute::parse( + &Context::new(), + "sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>" + ) + .unwrap() + .is_elements()); + } + + #[test] + fn is_integer() { + assert!(Attribute::parse(&Context::new(), "42") + .unwrap() + .is_integer()); + } + + #[test] + fn is_integer_set() { + assert!( + Attribute::parse(&Context::new(), "affine_set<(d0) : (d0 - 2 >= 0)>") + .unwrap() + .is_integer_set() + ); + } + + // TODO Fix this. + #[ignore] + #[test] + fn is_opaque() { + assert!(Attribute::parse(&Context::new(), "#foo<\"bar\">") + .unwrap() + .is_opaque()); + } + + #[test] + fn is_sparse_elements() { + assert!(Attribute::parse( + &Context::new(), + "sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>" + ) + .unwrap() + .is_sparse_elements()); + } + + #[test] + fn is_string() { + assert!(Attribute::parse(&Context::new(), "\"foo\"") + .unwrap() + .is_string()); + } + + #[test] + fn is_type() { + assert!(Attribute::parse(&Context::new(), "index") + .unwrap() + .is_type()); + } + + #[test] + fn is_unit() { + assert!(Attribute::parse(&Context::new(), "unit").unwrap().is_unit()); + } + + #[test] + fn is_not_unit() { + assert!(!Attribute::null().is_unit()); + } + + #[test] + fn is_symbol() { + assert!(Attribute::parse(&Context::new(), "@foo") + .unwrap() + .is_symbol()); + } + + #[test] + fn equal() { + let context = Context::new(); + let attribute = Attribute::parse(&context, "unit").unwrap(); + + assert_eq!(attribute, attribute); + } + + #[test] + fn not_equal() { + let context = Context::new(); + + assert_ne!( + Attribute::parse(&context, "unit").unwrap(), + Attribute::parse(&context, "42").unwrap() + ); + } + + #[test] + fn display() { + assert_eq!( + Attribute::parse(&Context::new(), "unit") + .unwrap() + .to_string(), + "unit" + ); + } } diff --git a/src/type.rs b/src/type.rs index f2e7bcfdcd..7f93a8f3b7 100644 --- a/src/type.rs +++ b/src/type.rs @@ -1,3 +1,6 @@ +pub mod id; + +pub use self::id::Id; use crate::{ context::{Context, ContextRef}, error::Error, @@ -10,9 +13,9 @@ use mlir_sys::{ mlirFunctionTypeGetNumResults, mlirFunctionTypeGetResult, mlirIndexTypeGet, mlirIntegerTypeGet, mlirIntegerTypeSignedGet, mlirIntegerTypeUnsignedGet, mlirLLVMArrayTypeGet, mlirLLVMFunctionTypeGet, mlirLLVMPointerTypeGet, mlirLLVMStructTypeLiteralGet, - mlirLLVMVoidTypeGet, mlirTypeDump, mlirTypeEqual, mlirTypeGetContext, mlirTypeIsAFunction, - mlirTypeParseGet, mlirTypePrint, mlirVectorTypeGet, mlirVectorTypeGetChecked, MlirStringRef, - MlirType, + mlirLLVMVoidTypeGet, mlirNoneTypeGet, mlirTypeDump, mlirTypeEqual, mlirTypeGetContext, + mlirTypeGetTypeID, mlirTypeIsAFunction, mlirTypeParseGet, mlirTypePrint, mlirVectorTypeGet, + mlirVectorTypeGetChecked, MlirStringRef, MlirType, }; use std::{ ffi::c_void, @@ -72,6 +75,11 @@ impl<'c> Type<'c> { unsafe { Self::from_raw(mlirIntegerTypeUnsignedGet(context.to_raw(), bits)) } } + /// Creates a none type. + pub fn none(context: &'c Context) -> Self { + unsafe { Self::from_raw(mlirNoneTypeGet(context.to_raw())) } + } + /// Creates a vector type. pub fn vector(dimensions: &[u64], r#type: Self) -> Self { unsafe { @@ -148,6 +156,11 @@ impl<'c> Type<'c> { unsafe { ContextRef::from_raw(mlirTypeGetContext(self.raw)) } } + /// Gets an ID. + pub fn id(&self) -> Id { + unsafe { Id::from_raw(mlirTypeGetTypeID(self.raw)) } + } + /// Gets an input of a function type. pub fn input(&self, position: usize) -> Result, Error> { unsafe { @@ -262,11 +275,6 @@ mod tests { Type::parse(&Context::new(), "f32"); } - #[test] - fn context() { - Type::parse(&Context::new(), "i8").unwrap().context(); - } - #[test] fn integer() { let context = Context::new(); @@ -351,6 +359,32 @@ mod tests { ); } + #[test] + fn context() { + Type::parse(&Context::new(), "i8").unwrap().context(); + } + + #[test] + fn id() { + let context = Context::new(); + + assert_eq!(Type::index(&context).id(), Type::index(&context).id()); + } + + #[test] + fn equal() { + let context = Context::new(); + + assert_eq!(Type::index(&context), Type::index(&context)); + } + + #[test] + fn not_equal() { + let context = Context::new(); + + assert_ne!(Type::index(&context), Type::integer(&context, 1)); + } + #[test] fn display() { let context = Context::new(); diff --git a/src/type/id.rs b/src/type/id.rs new file mode 100644 index 0000000000..12eadae7b5 --- /dev/null +++ b/src/type/id.rs @@ -0,0 +1,33 @@ +mod allocator; + +pub use allocator::Allocator; +use mlir_sys::{mlirTypeIDEqual, mlirTypeIDHashValue, MlirTypeID}; +use std::hash::{Hash, Hasher}; + +/// A type ID. +#[derive(Clone, Copy, Debug)] +pub struct Id { + raw: MlirTypeID, +} + +impl Id { + pub(crate) unsafe fn from_raw(raw: MlirTypeID) -> Self { + Self { raw } + } +} + +impl PartialEq for Id { + fn eq(&self, other: &Self) -> bool { + unsafe { mlirTypeIDEqual(self.raw, other.raw) } + } +} + +impl Eq for Id {} + +impl Hash for Id { + fn hash(&self, hasher: &mut H) { + unsafe { + mlirTypeIDHashValue(self.raw).hash(hasher); + } + } +} diff --git a/src/type/id/allocator.rs b/src/type/id/allocator.rs new file mode 100644 index 0000000000..b7f58a2c16 --- /dev/null +++ b/src/type/id/allocator.rs @@ -0,0 +1,50 @@ +use super::Id; +use mlir_sys::{ + mlirTypeIDAllocatorAllocateTypeID, mlirTypeIDAllocatorCreate, mlirTypeIDAllocatorDestroy, + MlirTypeIDAllocator, +}; + +/// A type ID allocator. +#[derive(Debug)] +pub struct Allocator { + raw: MlirTypeIDAllocator, +} + +impl Allocator { + pub fn new() -> Self { + Self { + raw: unsafe { mlirTypeIDAllocatorCreate() }, + } + } + + pub fn allocate_type_id(&mut self) -> Id { + unsafe { Id::from_raw(mlirTypeIDAllocatorAllocateTypeID(self.raw)) } + } +} + +impl Drop for Allocator { + fn drop(&mut self) { + unsafe { mlirTypeIDAllocatorDestroy(self.raw) } + } +} + +impl Default for Allocator { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new() { + Allocator::new(); + } + + #[test] + fn allocate_type_id() { + Allocator::new().allocate_type_id(); + } +}