From db29582a1523def67ae783ac39c5493ce1c1e7ff Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Tue, 13 Sep 2022 23:40:59 -0700 Subject: [PATCH] String ref and value APIs (#27) Part of #24. --- src/string_ref.rs | 28 +++++++- src/type.rs | 166 +++++++++++++++++++++++++++++++--------------- src/value.rs | 140 +++++++++++++++++++++++++++++++++++--- 3 files changed, 269 insertions(+), 65 deletions(-) diff --git a/src/string_ref.rs b/src/string_ref.rs index b491bebb80..417af06329 100644 --- a/src/string_ref.rs +++ b/src/string_ref.rs @@ -1,4 +1,4 @@ -use mlir_sys::{mlirStringRefCreateFromCString, MlirStringRef}; +use mlir_sys::{mlirStringRefCreateFromCString, mlirStringRefEqual, MlirStringRef}; use once_cell::sync::Lazy; use std::{collections::HashMap, ffi::CString, marker::PhantomData, slice, str, sync::RwLock}; @@ -10,6 +10,7 @@ static STRING_CACHE: Lazy>> = Lazy::new(Default: // // TODO The documentation says string refs do not have to be null-terminated. // But it looks like some functions do not handle strings not null-terminated? +#[derive(Clone, Copy, Debug)] pub struct StringRef<'a> { raw: MlirStringRef, _parent: PhantomData<&'a ()>, @@ -29,7 +30,7 @@ impl<'a> StringRef<'a> { } } - pub(crate) unsafe fn to_raw(&self) -> MlirStringRef { + pub(crate) unsafe fn to_raw(self) -> MlirStringRef { self.raw } @@ -41,6 +42,14 @@ impl<'a> StringRef<'a> { } } +impl<'a> PartialEq for StringRef<'a> { + fn eq(&self, other: &Self) -> bool { + unsafe { mlirStringRefEqual(self.raw, other.raw) } + } +} + +impl<'a> Eq for StringRef<'a> {} + impl From<&str> for StringRef<'static> { fn from(string: &str) -> Self { if !STRING_CACHE.read().unwrap().contains_key(string) { @@ -56,3 +65,18 @@ impl From<&str> for StringRef<'static> { unsafe { Self::from_raw(mlirStringRefCreateFromCString(string.as_ptr())) } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn equal() { + assert_eq!(StringRef::from("foo"), StringRef::from("foo")); + } + + #[test] + fn not_equal() { + assert_ne!(StringRef::from("foo"), StringRef::from("bar")); + } +} diff --git a/src/type.rs b/src/type.rs index 41a5715320..a047308282 100644 --- a/src/type.rs +++ b/src/type.rs @@ -192,7 +192,7 @@ impl<'c> PartialEq for Type<'c> { impl<'c> Eq for Type<'c> {} -impl<'c> Display for &Type<'c> { +impl<'c> Display for Type<'c> { fn fmt(&self, formatter: &mut Formatter) -> fmt::Result { let mut data = (formatter, Ok(())); @@ -228,6 +228,40 @@ mod tests { Type::parse(&Context::new(), "i8").context(); } + #[test] + fn integer() { + let context = Context::new(); + + assert_eq!(Type::integer(&context, 42), Type::parse(&context, "i42")); + } + + #[test] + fn signed_integer() { + let context = Context::new(); + + assert_eq!( + Type::signed_integer(&context, 42), + Type::parse(&context, "si42") + ); + } + + #[test] + fn unsigned_integer() { + let context = Context::new(); + + assert_eq!( + Type::unsigned_integer(&context, 42), + Type::parse(&context, "ui42") + ); + } + + #[test] + fn display() { + let context = Context::new(); + + assert_eq!(Type::integer(&context, 42).to_string(), "i42"); + } + mod function { use super::*; @@ -330,72 +364,96 @@ mod tests { } } - #[test] - fn integer() { - let context = Context::new(); + mod llvm { + use super::*; - assert_eq!(Type::integer(&context, 42), Type::parse(&context, "i42")); - } + fn create_context() -> Context { + let context = Context::new(); - #[test] - fn signed_integer() { - let context = Context::new(); + DialectHandle::llvm().register_dialect(&context); + context.get_or_load_dialect("llvm"); - assert_eq!( - Type::signed_integer(&context, 42), - Type::parse(&context, "si42") - ); - } + context + } - #[test] - fn unsigned_integer() { - let context = Context::new(); + #[test] + fn pointer() { + let context = create_context(); + let i32 = Type::integer(&context, 32); - assert_eq!( - Type::unsigned_integer(&context, 42), - Type::parse(&context, "ui42") - ); - } + assert_eq!( + Type::llvm_pointer(i32, 0), + Type::parse(&context, "!llvm.ptr") + ); + } - #[test] - fn create_llvm_types() { - let context = Context::new(); + #[test] + fn pointer_with_address_space() { + let context = create_context(); + let i32 = Type::integer(&context, 32); - DialectHandle::llvm().register_dialect(&context); - context.get_or_load_dialect("llvm"); + assert_eq!( + Type::llvm_pointer(i32, 4), + Type::parse(&context, "!llvm.ptr") + ); + } - let i8 = Type::integer(&context, 8); - let i32 = Type::integer(&context, 32); - let i64 = Type::integer(&context, 64); + #[test] + fn void() { + let context = create_context(); - assert_eq!( - Type::llvm_pointer(i32, 0), - Type::parse(&context, "!llvm.ptr") - ); + assert_eq!( + Type::llvm_void(&context), + Type::parse(&context, "!llvm.void") + ); + } - assert_eq!( - Type::llvm_pointer(i32, 4), - Type::parse(&context, "!llvm.ptr") - ); + #[test] + fn array() { + let context = create_context(); + let i32 = Type::integer(&context, 32); - assert_eq!( - Type::llvm_void(&context), - Type::parse(&context, "!llvm.void") - ); + assert_eq!( + Type::llvm_array(i32, 4), + Type::parse(&context, "!llvm.array<4xi32>") + ); + } - assert_eq!( - Type::llvm_array(i32, 4), - Type::parse(&context, "!llvm.array<4xi32>") - ); + #[test] + fn function() { + let context = create_context(); + let i8 = Type::integer(&context, 8); + let i32 = Type::integer(&context, 32); + let i64 = Type::integer(&context, 64); - assert_eq!( - Type::llvm_function(i8, &[i32, i64], false), - Type::parse(&context, "!llvm.func") - ); + assert_eq!( + Type::llvm_function(i8, &[i32, i64], false), + Type::parse(&context, "!llvm.func") + ); + } - assert_eq!( - Type::llvm_struct(&context, &[i32, i64], false), - Type::parse(&context, "!llvm.struct<(i32, i64)>") - ); + #[test] + fn r#struct() { + let context = create_context(); + let i32 = Type::integer(&context, 32); + let i64 = Type::integer(&context, 64); + + assert_eq!( + Type::llvm_struct(&context, &[i32, i64], false), + Type::parse(&context, "!llvm.struct<(i32, i64)>") + ); + } + + #[test] + fn packed_struct() { + let context = create_context(); + let i32 = Type::integer(&context, 32); + let i64 = Type::integer(&context, 64); + + assert_eq!( + Type::llvm_struct(&context, &[i32, i64], true), + Type::parse(&context, "!llvm.struct") + ); + } } } diff --git a/src/value.rs b/src/value.rs index 34200c3c77..51013a9b9f 100644 --- a/src/value.rs +++ b/src/value.rs @@ -1,8 +1,13 @@ -use crate::r#type::Type; +use crate::{r#type::Type, string_ref::StringRef}; use mlir_sys::{ - mlirValueDump, mlirValueGetType, mlirValueIsABlockArgument, mlirValueIsAOpResult, MlirValue, + mlirValueDump, mlirValueEqual, mlirValueGetType, mlirValueIsABlockArgument, + mlirValueIsAOpResult, mlirValuePrint, MlirStringRef, MlirValue, +}; +use std::{ + ffi::c_void, + fmt::{self, Display, Formatter}, + marker::PhantomData, }; -use std::marker::PhantomData; /// A value. // Values are always non-owning references to their parents, such as operations @@ -46,11 +51,41 @@ impl<'a> Value<'a> { } } +impl<'a> PartialEq for Value<'a> { + fn eq(&self, other: &Self) -> bool { + unsafe { mlirValueEqual(self.raw, other.raw) } + } +} + +impl<'a> Eq for Value<'a> {} + +impl<'c> Display for Value<'c> { + fn fmt(&self, formatter: &mut Formatter) -> fmt::Result { + let mut data = (formatter, Ok(())); + + unsafe extern "C" fn callback(string: MlirStringRef, data: *mut c_void) { + let data = &mut *(data as *mut (&mut Formatter, fmt::Result)); + let result = write!(data.0, "{}", StringRef::from_raw(string).as_str()); + + if data.1.is_ok() { + data.1 = result; + } + } + + unsafe { + mlirValuePrint(self.raw, Some(callback), &mut data as *mut _ as *mut c_void); + } + + data.1 + } +} + #[cfg(test)] mod tests { use crate::{ - attribute::Attribute, block::Block, context::Context, identifier::Identifier, - location::Location, operation::Operation, operation_state::OperationState, r#type::Type, + attribute::Attribute, block::Block, context::Context, dialect_registry::DialectRegistry, + identifier::Identifier, location::Location, operation::Operation, + operation_state::OperationState, r#type::Type, utility::register_all_dialects, }; #[test] @@ -59,7 +94,7 @@ mod tests { let location = Location::unknown(&context); let index_type = Type::parse(&context, "index"); - let value = Operation::new( + let operation = Operation::new( OperationState::new("arith.constant", location) .add_results(&[index_type]) .add_attributes(&[( @@ -68,7 +103,7 @@ mod tests { )]), ); - assert_eq!(value.result(0).unwrap().r#type(), index_type); + assert_eq!(operation.result(0).unwrap().r#type(), index_type); } #[test] @@ -77,7 +112,7 @@ mod tests { let location = Location::unknown(&context); let r#type = Type::parse(&context, "index"); - let value = Operation::new( + let operation = Operation::new( OperationState::new("arith.constant", location) .add_results(&[r#type]) .add_attributes(&[( @@ -86,7 +121,7 @@ mod tests { )]), ); - assert!(value.result(0).unwrap().is_operation_result()); + assert!(operation.result(0).unwrap().is_operation_result()); } #[test] @@ -115,4 +150,91 @@ mod tests { value.result(0).unwrap().dump(); } + + #[test] + fn equal() { + let context = Context::new(); + let location = Location::unknown(&context); + let index_type = Type::parse(&context, "index"); + + let operation = Operation::new( + OperationState::new("arith.constant", location) + .add_results(&[index_type]) + .add_attributes(&[( + Identifier::new(&context, "value"), + Attribute::parse(&context, "0 : index"), + )]), + ); + + assert_eq!(operation.result(0), operation.result(0)); + } + + #[test] + fn not_equal() { + let context = Context::new(); + let location = Location::unknown(&context); + let index_type = Type::parse(&context, "index"); + + let operation = || { + Operation::new( + OperationState::new("arith.constant", location) + .add_results(&[index_type]) + .add_attributes(&[( + Identifier::new(&context, "value"), + Attribute::parse(&context, "0 : index"), + )]), + ) + }; + + assert_ne!(operation().result(0), operation().result(0)); + } + + #[test] + fn display() { + let context = Context::new(); + context.load_all_available_dialects(); + let location = Location::unknown(&context); + let index_type = Type::parse(&context, "index"); + + let operation = Operation::new( + OperationState::new("arith.constant", location) + .add_results(&[index_type]) + .add_attributes(&[( + Identifier::new(&context, "value"), + Attribute::parse(&context, "0 : index"), + )]), + ); + + assert_eq!( + operation.result(0).unwrap().to_string(), + "%0 = \"arith.constant\"() {value = 0 : index} : () -> index\n" + ); + } + + #[test] + fn display_with_dialect_loaded() { + let registry = DialectRegistry::new(); + register_all_dialects(®istry); + + let context = Context::new(); + context.append_dialect_registry(®istry); + context.load_all_available_dialects(); + + let location = Location::unknown(&context); + let index_type = Type::parse(&context, "index"); + + let operation = Operation::new( + OperationState::new("arith.constant", location) + .add_results(&[index_type]) + .add_attributes(&[( + Identifier::new(&context, "value"), + Attribute::parse(&context, "0 : index"), + )]), + ); + + assert_eq!( + operation.result(0).unwrap().to_string(), + "%c0 = arith.constant 0 : index\n" + ); + } }