Skip to content

Commit

Permalink
Block API (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
raviqqe authored Sep 16, 2022
1 parent 4d7db1d commit 72447ba
Show file tree
Hide file tree
Showing 4 changed files with 331 additions and 35 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "melior"
description = "The rustic MLIR bindings in Rust"
version = "0.1.0"
edition = "2021"
license-file = "LICENSE"
license = "Apache-2.0"
repository = "https://github.com/raviqqe/melior"

[dependencies]
Expand Down
268 changes: 234 additions & 34 deletions src/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,25 @@ use crate::{
operation::{Operation, OperationRef},
r#type::Type,
region::RegionRef,
string_ref::StringRef,
utility::into_raw_array,
value::Value,
};
use mlir_sys::{
mlirBlockAddArgument, mlirBlockAppendOwnedOperation, mlirBlockCreate, mlirBlockDestroy,
mlirBlockEqual, mlirBlockGetArgument, mlirBlockGetFirstOperation, mlirBlockGetNumArguments,
mlirBlockGetParentRegion, mlirBlockInsertOwnedOperation, MlirBlock,
mlirBlockDetach, mlirBlockEqual, mlirBlockGetArgument, mlirBlockGetFirstOperation,
mlirBlockGetNextInRegion, mlirBlockGetNumArguments, mlirBlockGetParentOperation,
mlirBlockGetParentRegion, mlirBlockGetTerminator, mlirBlockInsertOwnedOperation,
mlirBlockInsertOwnedOperationAfter, mlirBlockInsertOwnedOperationBefore, mlirBlockPrint,
MlirBlock, MlirStringRef,
};
use std::{
ffi::c_void,
fmt::{self, Display, Formatter},
marker::PhantomData,
mem::forget,
ops::Deref,
};
use std::{marker::PhantomData, mem::forget, ops::Deref};

/// A block
#[derive(Debug)]
Expand All @@ -25,24 +35,28 @@ impl<'c> Block<'c> {
/// Creates a block.
pub fn new(arguments: &[(Type<'c>, Location<'c>)]) -> Self {
unsafe {
Self {
r#ref: BlockRef::from_raw(mlirBlockCreate(
arguments.len() as isize,
into_raw_array(
arguments
.iter()
.map(|(argument, _)| argument.to_raw())
.collect(),
),
into_raw_array(
arguments
.iter()
.map(|(_, location)| location.to_raw())
.collect(),
),
)),
_context: Default::default(),
}
Self::from_raw(mlirBlockCreate(
arguments.len() as isize,
into_raw_array(
arguments
.iter()
.map(|(argument, _)| argument.to_raw())
.collect(),
),
into_raw_array(
arguments
.iter()
.map(|(_, location)| location.to_raw())
.collect(),
),
))
}
}

pub(crate) unsafe fn from_raw(raw: MlirBlock) -> Self {
Self {
r#ref: BlockRef::from_raw(raw),
_context: Default::default(),
}
}

Expand Down Expand Up @@ -78,8 +92,6 @@ impl<'c> Deref for Block<'c> {
}

/// A reference of a block.
// TODO Should we split context lifetimes? Or, is it transitively proven that 'c
// > 'a?
#[derive(Clone, Copy, Debug)]
pub struct BlockRef<'a> {
raw: MlirBlock,
Expand All @@ -106,11 +118,6 @@ impl<'c> BlockRef<'c> {
unsafe { mlirBlockGetNumArguments(self.raw) as usize }
}

/// Gets a parent region.
pub fn parent_region(&self) -> Option<RegionRef> {
unsafe { RegionRef::from_option_raw(mlirBlockGetParentRegion(self.raw)) }
}

/// Gets the first operation.
pub fn first_operation(&self) -> Option<OperationRef> {
unsafe {
Expand All @@ -124,6 +131,21 @@ impl<'c> BlockRef<'c> {
}
}

/// Gets a terminator operation.
pub fn terminator(&self) -> Option<OperationRef> {
unsafe { OperationRef::from_option_raw(mlirBlockGetTerminator(self.raw)) }
}

/// Gets a parent region.
pub fn parent_region(&self) -> Option<RegionRef> {
unsafe { RegionRef::from_option_raw(mlirBlockGetParentRegion(self.raw)) }
}

/// Gets a parent operation.
pub fn parent_operation(&self) -> Option<OperationRef> {
unsafe { OperationRef::from_option_raw(mlirBlockGetParentOperation(self.raw)) }
}

/// Adds an argument.
pub fn add_argument(&self, r#type: Type<'c>, location: Location<'c>) -> Value {
unsafe {
Expand All @@ -135,6 +157,17 @@ impl<'c> BlockRef<'c> {
}
}

/// Appends an operation.
pub fn append_operation(&self, operation: Operation) -> OperationRef {
unsafe {
let operation = operation.into_raw();

mlirBlockAppendOwnedOperation(self.raw, operation);

OperationRef::from_raw(operation)
}
}

/// Inserts an operation.
// TODO How can we make those update functions take `&mut self`?
// TODO Use cells?
Expand All @@ -148,17 +181,46 @@ impl<'c> BlockRef<'c> {
}
}

/// Appends an operation.
pub fn append_operation(&self, operation: Operation) -> OperationRef {
/// Inserts an operation after another.
pub fn insert_operation_after(&self, one: OperationRef, other: Operation) -> OperationRef {
unsafe {
let operation = operation.into_raw();
let other = other.into_raw();

mlirBlockAppendOwnedOperation(self.raw, operation);
mlirBlockInsertOwnedOperationAfter(self.raw, one.to_raw(), other);

OperationRef::from_raw(operation)
OperationRef::from_raw(other)
}
}

/// Inserts an operation before another.
pub fn insert_operation_before(&self, one: OperationRef, other: Operation) -> OperationRef {
unsafe {
let other = other.into_raw();

mlirBlockInsertOwnedOperationBefore(self.raw, one.to_raw(), other);

OperationRef::from_raw(other)
}
}

/// Detaches a block from a region and assumes its ownership.
pub fn detach(&self) -> Option<Block> {
if self.parent_region().is_some() {
unsafe {
mlirBlockDetach(self.raw);

Some(Block::from_raw(self.raw))
}
} else {
None
}
}

/// Gets a next block in a region.
pub fn next_in_region(&self) -> Option<BlockRef> {
unsafe { BlockRef::from_option_raw(mlirBlockGetNextInRegion(self.raw)) }
}

pub(crate) unsafe fn from_raw(raw: MlirBlock) -> Self {
Self {
raw,
Expand Down Expand Up @@ -187,10 +249,34 @@ impl<'a> PartialEq for BlockRef<'a> {

impl<'a> Eq for BlockRef<'a> {}

impl<'a> Display for BlockRef<'a> {
fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
let mut data = (formatter, Ok(()));

unsafe extern "C" fn callback(string: MlirStringRef, data: *mut c_void) {
let data = &mut *(data as *mut (&mut Formatter, fmt::Result));
let result = write!(data.0, "{}", StringRef::from_raw(string).as_str());

if data.1.is_ok() {
data.1 = result;
}
}

unsafe {
mlirBlockPrint(self.raw, Some(callback), &mut data as *mut _ as *mut c_void);
}

data.1
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{operation_state::OperationState, region::Region};
use crate::{
dialect_registry::DialectRegistry, module::Module, operation_state::OperationState,
region::Region, utility::register_all_dialects,
};

#[test]
fn new() {
Expand Down Expand Up @@ -236,6 +322,48 @@ mod tests {
assert_eq!(block.parent_region(), None);
}

#[test]
fn parent_operation() {
let context = Context::new();
let module = Module::new(Location::unknown(&context));

assert_eq!(
module.body().parent_operation(),
Some(module.as_operation())
);
}

#[test]
fn parent_operation_none() {
let block = Block::new(&[]);

assert_eq!(block.parent_operation(), None);
}

#[test]
fn terminator() {
let registry = DialectRegistry::new();
register_all_dialects(&registry);

let context = Context::new();
context.append_dialect_registry(&registry);
context.load_all_available_dialects();

let block = Block::new(&[]);

let operation = block.append_operation(Operation::new(OperationState::new(
"func.return",
Location::unknown(&context),
)));

assert_eq!(block.terminator(), Some(operation));
}

#[test]
fn terminator_none() {
assert_eq!(Block::new(&[]).terminator(), None);
}

#[test]
fn first_operation() {
let context = Context::new();
Expand Down Expand Up @@ -277,4 +405,76 @@ mod tests {
Operation::new(OperationState::new("foo", Location::unknown(&context))),
);
}

#[test]
fn insert_operation_after() {
let context = Context::new();
let block = Block::new(&[]);

let first_operation = block.append_operation(Operation::new(OperationState::new(
"foo",
Location::unknown(&context),
)));
let second_operation = block.insert_operation_after(
first_operation,
Operation::new(OperationState::new("foo", Location::unknown(&context))),
);

assert_eq!(block.first_operation(), Some(first_operation));
assert_eq!(
block.first_operation().unwrap().next_in_block(),
Some(second_operation)
);
}

#[test]
fn insert_operation_before() {
let context = Context::new();
let block = Block::new(&[]);

let second_operation = block.append_operation(Operation::new(OperationState::new(
"foo",
Location::unknown(&context),
)));
let first_operation = block.insert_operation_before(
second_operation,
Operation::new(OperationState::new("foo", Location::unknown(&context))),
);

assert_eq!(block.first_operation(), Some(first_operation));
assert_eq!(
block.first_operation().unwrap().next_in_block(),
Some(second_operation)
);
}

#[test]
fn next_in_region() {
let region = Region::new();

let first_block = region.append_block(Block::new(&[]));
let second_block = region.append_block(Block::new(&[]));

assert_eq!(first_block.next_in_region(), Some(second_block));
}

#[test]
fn detach() {
let region = Region::new();
let block = region.append_block(Block::new(&[]));

assert_eq!(block.detach().unwrap().to_string(), "<<UNLINKED BLOCK>>\n");
}

#[test]
fn detach_detached() {
let block = Block::new(&[]);

assert!(block.detach().is_none());
}

#[test]
fn display() {
assert_eq!(Block::new(&[]).to_string(), "<<UNLINKED BLOCK>>\n");
}
}
Loading

0 comments on commit 72447ba

Please sign in to comment.