From 72447ba0524f5acdd53c5b3b15ff3c4589ecc831 Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Thu, 15 Sep 2022 19:18:11 -0700 Subject: [PATCH] Block API (#34) --- Cargo.toml | 2 +- src/block.rs | 268 +++++++++++++++++++++++++++++++++++++++++------ src/lib.rs | 84 +++++++++++++++ src/operation.rs | 12 +++ 4 files changed, 331 insertions(+), 35 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4e818312d1..93d559de2e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "melior" description = "The rustic MLIR bindings in Rust" version = "0.1.0" edition = "2021" -license-file = "LICENSE" +license = "Apache-2.0" repository = "https://github.com/raviqqe/melior" [dependencies] diff --git a/src/block.rs b/src/block.rs index 14ae77f35f..36f0ccf95e 100644 --- a/src/block.rs +++ b/src/block.rs @@ -4,15 +4,25 @@ use crate::{ operation::{Operation, OperationRef}, r#type::Type, region::RegionRef, + string_ref::StringRef, utility::into_raw_array, value::Value, }; use mlir_sys::{ mlirBlockAddArgument, mlirBlockAppendOwnedOperation, mlirBlockCreate, mlirBlockDestroy, - mlirBlockEqual, mlirBlockGetArgument, mlirBlockGetFirstOperation, mlirBlockGetNumArguments, - mlirBlockGetParentRegion, mlirBlockInsertOwnedOperation, MlirBlock, + mlirBlockDetach, mlirBlockEqual, mlirBlockGetArgument, mlirBlockGetFirstOperation, + mlirBlockGetNextInRegion, mlirBlockGetNumArguments, mlirBlockGetParentOperation, + mlirBlockGetParentRegion, mlirBlockGetTerminator, mlirBlockInsertOwnedOperation, + mlirBlockInsertOwnedOperationAfter, mlirBlockInsertOwnedOperationBefore, mlirBlockPrint, + MlirBlock, MlirStringRef, +}; +use std::{ + ffi::c_void, + fmt::{self, Display, Formatter}, + marker::PhantomData, + mem::forget, + ops::Deref, }; -use std::{marker::PhantomData, mem::forget, ops::Deref}; /// A block #[derive(Debug)] @@ -25,24 +35,28 @@ impl<'c> Block<'c> { /// Creates a block. pub fn new(arguments: &[(Type<'c>, Location<'c>)]) -> Self { unsafe { - Self { - r#ref: BlockRef::from_raw(mlirBlockCreate( - arguments.len() as isize, - into_raw_array( - arguments - .iter() - .map(|(argument, _)| argument.to_raw()) - .collect(), - ), - into_raw_array( - arguments - .iter() - .map(|(_, location)| location.to_raw()) - .collect(), - ), - )), - _context: Default::default(), - } + Self::from_raw(mlirBlockCreate( + arguments.len() as isize, + into_raw_array( + arguments + .iter() + .map(|(argument, _)| argument.to_raw()) + .collect(), + ), + into_raw_array( + arguments + .iter() + .map(|(_, location)| location.to_raw()) + .collect(), + ), + )) + } + } + + pub(crate) unsafe fn from_raw(raw: MlirBlock) -> Self { + Self { + r#ref: BlockRef::from_raw(raw), + _context: Default::default(), } } @@ -78,8 +92,6 @@ impl<'c> Deref for Block<'c> { } /// A reference of a block. -// TODO Should we split context lifetimes? Or, is it transitively proven that 'c -// > 'a? #[derive(Clone, Copy, Debug)] pub struct BlockRef<'a> { raw: MlirBlock, @@ -106,11 +118,6 @@ impl<'c> BlockRef<'c> { unsafe { mlirBlockGetNumArguments(self.raw) as usize } } - /// Gets a parent region. - pub fn parent_region(&self) -> Option { - unsafe { RegionRef::from_option_raw(mlirBlockGetParentRegion(self.raw)) } - } - /// Gets the first operation. pub fn first_operation(&self) -> Option { unsafe { @@ -124,6 +131,21 @@ impl<'c> BlockRef<'c> { } } + /// Gets a terminator operation. + pub fn terminator(&self) -> Option { + unsafe { OperationRef::from_option_raw(mlirBlockGetTerminator(self.raw)) } + } + + /// Gets a parent region. + pub fn parent_region(&self) -> Option { + unsafe { RegionRef::from_option_raw(mlirBlockGetParentRegion(self.raw)) } + } + + /// Gets a parent operation. + pub fn parent_operation(&self) -> Option { + unsafe { OperationRef::from_option_raw(mlirBlockGetParentOperation(self.raw)) } + } + /// Adds an argument. pub fn add_argument(&self, r#type: Type<'c>, location: Location<'c>) -> Value { unsafe { @@ -135,6 +157,17 @@ impl<'c> BlockRef<'c> { } } + /// Appends an operation. + pub fn append_operation(&self, operation: Operation) -> OperationRef { + unsafe { + let operation = operation.into_raw(); + + mlirBlockAppendOwnedOperation(self.raw, operation); + + OperationRef::from_raw(operation) + } + } + /// Inserts an operation. // TODO How can we make those update functions take `&mut self`? // TODO Use cells? @@ -148,17 +181,46 @@ impl<'c> BlockRef<'c> { } } - /// Appends an operation. - pub fn append_operation(&self, operation: Operation) -> OperationRef { + /// Inserts an operation after another. + pub fn insert_operation_after(&self, one: OperationRef, other: Operation) -> OperationRef { unsafe { - let operation = operation.into_raw(); + let other = other.into_raw(); - mlirBlockAppendOwnedOperation(self.raw, operation); + mlirBlockInsertOwnedOperationAfter(self.raw, one.to_raw(), other); - OperationRef::from_raw(operation) + OperationRef::from_raw(other) + } + } + + /// Inserts an operation before another. + pub fn insert_operation_before(&self, one: OperationRef, other: Operation) -> OperationRef { + unsafe { + let other = other.into_raw(); + + mlirBlockInsertOwnedOperationBefore(self.raw, one.to_raw(), other); + + OperationRef::from_raw(other) } } + /// Detaches a block from a region and assumes its ownership. + pub fn detach(&self) -> Option { + if self.parent_region().is_some() { + unsafe { + mlirBlockDetach(self.raw); + + Some(Block::from_raw(self.raw)) + } + } else { + None + } + } + + /// Gets a next block in a region. + pub fn next_in_region(&self) -> Option { + unsafe { BlockRef::from_option_raw(mlirBlockGetNextInRegion(self.raw)) } + } + pub(crate) unsafe fn from_raw(raw: MlirBlock) -> Self { Self { raw, @@ -187,10 +249,34 @@ impl<'a> PartialEq for BlockRef<'a> { impl<'a> Eq for BlockRef<'a> {} +impl<'a> Display for BlockRef<'a> { + fn fmt(&self, formatter: &mut Formatter) -> fmt::Result { + let mut data = (formatter, Ok(())); + + unsafe extern "C" fn callback(string: MlirStringRef, data: *mut c_void) { + let data = &mut *(data as *mut (&mut Formatter, fmt::Result)); + let result = write!(data.0, "{}", StringRef::from_raw(string).as_str()); + + if data.1.is_ok() { + data.1 = result; + } + } + + unsafe { + mlirBlockPrint(self.raw, Some(callback), &mut data as *mut _ as *mut c_void); + } + + data.1 + } +} + #[cfg(test)] mod tests { use super::*; - use crate::{operation_state::OperationState, region::Region}; + use crate::{ + dialect_registry::DialectRegistry, module::Module, operation_state::OperationState, + region::Region, utility::register_all_dialects, + }; #[test] fn new() { @@ -236,6 +322,48 @@ mod tests { assert_eq!(block.parent_region(), None); } + #[test] + fn parent_operation() { + let context = Context::new(); + let module = Module::new(Location::unknown(&context)); + + assert_eq!( + module.body().parent_operation(), + Some(module.as_operation()) + ); + } + + #[test] + fn parent_operation_none() { + let block = Block::new(&[]); + + assert_eq!(block.parent_operation(), None); + } + + #[test] + fn terminator() { + let registry = DialectRegistry::new(); + register_all_dialects(®istry); + + let context = Context::new(); + context.append_dialect_registry(®istry); + context.load_all_available_dialects(); + + let block = Block::new(&[]); + + let operation = block.append_operation(Operation::new(OperationState::new( + "func.return", + Location::unknown(&context), + ))); + + assert_eq!(block.terminator(), Some(operation)); + } + + #[test] + fn terminator_none() { + assert_eq!(Block::new(&[]).terminator(), None); + } + #[test] fn first_operation() { let context = Context::new(); @@ -277,4 +405,76 @@ mod tests { Operation::new(OperationState::new("foo", Location::unknown(&context))), ); } + + #[test] + fn insert_operation_after() { + let context = Context::new(); + let block = Block::new(&[]); + + let first_operation = block.append_operation(Operation::new(OperationState::new( + "foo", + Location::unknown(&context), + ))); + let second_operation = block.insert_operation_after( + first_operation, + Operation::new(OperationState::new("foo", Location::unknown(&context))), + ); + + assert_eq!(block.first_operation(), Some(first_operation)); + assert_eq!( + block.first_operation().unwrap().next_in_block(), + Some(second_operation) + ); + } + + #[test] + fn insert_operation_before() { + let context = Context::new(); + let block = Block::new(&[]); + + let second_operation = block.append_operation(Operation::new(OperationState::new( + "foo", + Location::unknown(&context), + ))); + let first_operation = block.insert_operation_before( + second_operation, + Operation::new(OperationState::new("foo", Location::unknown(&context))), + ); + + assert_eq!(block.first_operation(), Some(first_operation)); + assert_eq!( + block.first_operation().unwrap().next_in_block(), + Some(second_operation) + ); + } + + #[test] + fn next_in_region() { + let region = Region::new(); + + let first_block = region.append_block(Block::new(&[])); + let second_block = region.append_block(Block::new(&[])); + + assert_eq!(first_block.next_in_region(), Some(second_block)); + } + + #[test] + fn detach() { + let region = Region::new(); + let block = region.append_block(Block::new(&[])); + + assert_eq!(block.detach().unwrap().to_string(), "<>\n"); + } + + #[test] + fn detach_detached() { + let block = Block::new(&[]); + + assert!(block.detach().is_none()); + } + + #[test] + fn display() { + assert_eq!(Block::new(&[]).to_string(), "<>\n"); + } } diff --git a/src/lib.rs b/src/lib.rs index 86abc1a8e2..531fabc62e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,87 @@ +//! Melior is the rustic MLIR bindings for Rust. It aims to provide a simple, +//! safe, and complete API for MLIR with a reasonably sane ownership model +//! represented by the type system in Rust. +//! +//! This crate is a wrapper of [the MLIR C API](https://mlir.llvm.org/docs/CAPI/). +//! +//! # Dependencies +//! +//! [LLVM/MLIR 15](https://llvm.org/) needs to be installed on your system. On Linux and macOS, you can install it via [Homebrew](https://brew.sh). +//! +//! ```sh +//! brew install llvm@15 +//! ``` +//! +//! # Examples +//! +//! ## Building a function to add integers +//! +//! ```rust +//! use melior::{ +//! attribute::Attribute, +//! block::Block, +//! context::Context, +//! dialect_registry::DialectRegistry, +//! identifier::Identifier, +//! location::Location, +//! module::Module, +//! operation::Operation, +//! operation_state::OperationState, +//! region::Region, +//! r#type::Type, +//! utility::register_all_dialects, +//! }; +//! +//! let registry = DialectRegistry::new(); +//! register_all_dialects(®istry); +//! +//! let context = Context::new(); +//! context.append_dialect_registry(®istry); +//! context.get_or_load_dialect("func"); +//! +//! let location = Location::unknown(&context); +//! let module = Module::new(location); +//! +//! let integer_type = Type::integer(&context, 64); +//! +//! let function = { +//! let region = Region::new(); +//! let block = Block::new(&[(integer_type, location), (integer_type, location)]); +//! +//! let sum = block.append_operation(Operation::new( +//! OperationState::new("arith.addi", location) +//! .add_operands(&[block.argument(0).unwrap(), block.argument(1).unwrap()]) +//! .add_results(&[integer_type]), +//! )); +//! +//! block.append_operation(Operation::new( +//! OperationState::new("func.return", Location::unknown(&context)) +//! .add_operands(&[sum.result(0).unwrap()]), +//! )); +//! +//! region.append_block(block); +//! +//! Operation::new( +//! OperationState::new("func.func", Location::unknown(&context)) +//! .add_attributes(&[ +//! ( +//! Identifier::new(&context, "function_type"), +//! Attribute::parse(&context, "(i64, i64) -> i64").unwrap(), +//! ), +//! ( +//! Identifier::new(&context, "sym_name"), +//! Attribute::parse(&context, "\"add\"").unwrap(), +//! ), +//! ]) +//! .add_regions(vec![region]), +//! ) +//! }; +//! +//! module.body().append_operation(function); +//! +//! assert!(module.as_operation().verify()); +//! ``` + pub mod attribute; pub mod block; pub mod context; diff --git a/src/operation.rs b/src/operation.rs index a7f8e61a94..961ad12148 100644 --- a/src/operation.rs +++ b/src/operation.rs @@ -156,12 +156,24 @@ impl<'a> OperationRef<'a> { unsafe { mlirOperationDump(self.raw) } } + pub(crate) unsafe fn to_raw(self) -> MlirOperation { + self.raw + } + pub(crate) unsafe fn from_raw(raw: MlirOperation) -> Self { Self { raw, _reference: Default::default(), } } + + pub(crate) unsafe fn from_option_raw(raw: MlirOperation) -> Option { + if raw.ptr.is_null() { + None + } else { + Some(Self::from_raw(raw)) + } + } } impl<'a> PartialEq for OperationRef<'a> {