Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
raviqqe committed Oct 18, 2023
1 parent 32f26be commit 37e8db3
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 56 deletions.
22 changes: 10 additions & 12 deletions melior/src/dialect/memref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ fn allocate<'c>(
alignment: Option<IntegerAttribute<'c>>,
location: Location<'c>,
) -> Operation<'c> {
let mut builder = OperationBuilder::new( name, location);
let mut builder = OperationBuilder::new(name, location);

builder = builder.add_attributes(&[(
Identifier::new(context, "operand_segment_sizes"),
Expand All @@ -84,12 +84,11 @@ fn allocate<'c>(

/// Create a `memref.cast` operation.
pub fn cast<'c>(
context: &'c Context,
value: Value<'c, '_>,
r#type: MemRefType<'c>,
location: Location<'c>,
) -> Operation<'c> {
OperationBuilder::new( "memref.cast", location)
OperationBuilder::new("memref.cast", location)
.add_operands(&[value])
.add_results(&[r#type.into()])
.build()
Expand All @@ -102,7 +101,7 @@ pub fn dealloc<'c>(
value: Value<'c, '_>,
location: Location<'c>,
) -> Operation<'c> {
OperationBuilder::new( "memref.dealloc", location)
OperationBuilder::new("memref.dealloc", location)
.add_operands(&[value])
.build()
.expect("valid operation")
Expand All @@ -115,7 +114,7 @@ pub fn dim<'c>(
index: Value<'c, '_>,
location: Location<'c>,
) -> Operation<'c> {
OperationBuilder::new( "memref.dim", location)
OperationBuilder::new("memref.dim", location)
.add_operands(&[value, index])
.enable_result_type_inference()
.build()
Expand All @@ -129,7 +128,7 @@ pub fn get_global<'c>(
r#type: MemRefType<'c>,
location: Location<'c>,
) -> Operation<'c> {
OperationBuilder::new( "memref.get_global", location)
OperationBuilder::new("memref.get_global", location)
.add_attributes(&[(
Identifier::new(context, "name"),
FlatSymbolRefAttribute::new(context, name).into(),
Expand All @@ -151,7 +150,7 @@ pub fn global<'c>(
alignment: Option<IntegerAttribute<'c>>,
location: Location<'c>,
) -> Operation<'c> {
let mut builder = OperationBuilder::new( "memref.global", location).add_attributes(&[
let mut builder = OperationBuilder::new("memref.global", location).add_attributes(&[
(
Identifier::new(context, "sym_name"),
StringAttribute::new(context, name).into(),
Expand Down Expand Up @@ -195,7 +194,7 @@ pub fn load<'c>(
indices: &[Value<'c, '_>],
location: Location<'c>,
) -> Operation<'c> {
OperationBuilder::new( "memref.load", location)
OperationBuilder::new("memref.load", location)
.add_operands(&[memref])
.add_operands(indices)
.enable_result_type_inference()
Expand All @@ -209,7 +208,7 @@ pub fn rank<'c>(
value: Value<'c, '_>,
location: Location<'c>,
) -> Operation<'c> {
OperationBuilder::new( "memref.rank", location)
OperationBuilder::new("memref.rank", location)
.add_operands(&[value])
.enable_result_type_inference()
.build()
Expand All @@ -224,7 +223,7 @@ pub fn store<'c>(
indices: &[Value<'c, '_>],
location: Location<'c>,
) -> Operation<'c> {
OperationBuilder::new( "memref.store", location)
OperationBuilder::new("memref.store", location)
.add_operands(&[value, memref])
.add_operands(indices)
.build()
Expand All @@ -240,7 +239,7 @@ pub fn realloc<'c>(
alignment: Option<IntegerAttribute<'c>>,
location: Location<'c>,
) -> Operation<'c> {
let mut builder = OperationBuilder::new( "memref.realloc", location)
let mut builder = OperationBuilder::new("memref.realloc", location)
.add_operands(&[value])
.add_results(&[r#type.into()]);

Expand Down Expand Up @@ -378,7 +377,6 @@ mod tests {
));

block.append_operation(cast(
&context,
memref.result(0).unwrap().into(),
Type::parse(&context, "memref<?xf64>")
.unwrap()
Expand Down
51 changes: 13 additions & 38 deletions melior/src/dialect/scf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@ use crate::{

/// Creates a `scf.condition` operation.
pub fn condition<'c>(
context: &'c Context,
condition: Value<'c, '_>,
values: &[Value<'c, '_>],
location: Location<'c>,
) -> Operation<'c> {
OperationBuilder::new( "scf.condition", location)
OperationBuilder::new("scf.condition", location)
.add_operands(&[condition])
.add_operands(values)
.build()
Expand All @@ -24,12 +23,11 @@ pub fn condition<'c>(

/// Creates a `scf.execute_region` operation.
pub fn execute_region<'c>(
context: &'c Context,
result_types: &[Type<'c>],
region: Region<'c>,
location: Location<'c>,
) -> Operation<'c> {
OperationBuilder::new( "scf.execute_region", location)
OperationBuilder::new("scf.execute_region", location)
.add_results(result_types)
.add_regions(vec![region])
.build()
Expand All @@ -38,14 +36,13 @@ pub fn execute_region<'c>(

/// Creates a `scf.for` operation.
pub fn r#for<'c>(
context: &'c Context,
start: Value<'c, '_>,
end: Value<'c, '_>,
step: Value<'c, '_>,
region: Region<'c>,
location: Location<'c>,
) -> Operation<'c> {
OperationBuilder::new( "scf.for", location)
OperationBuilder::new("scf.for", location)
.add_operands(&[start, end, step])
.add_regions(vec![region])
.build()
Expand All @@ -54,14 +51,13 @@ pub fn r#for<'c>(

/// Creates a `scf.if` operation.
pub fn r#if<'c>(
context: &'c Context,
condition: Value<'c, '_>,
result_types: &[Type<'c>],
then_region: Region<'c>,
else_region: Region<'c>,
location: Location<'c>,
) -> Operation<'c> {
OperationBuilder::new( "scf.if", location)
OperationBuilder::new("scf.if", location)
.add_operands(&[condition])
.add_results(result_types)
.add_regions(vec![then_region, else_region])
Expand All @@ -78,7 +74,7 @@ pub fn index_switch<'c>(
regions: Vec<Region<'c>>,
location: Location<'c>,
) -> Operation<'c> {
OperationBuilder::new( "scf.index_switch", location)
OperationBuilder::new("scf.index_switch", location)
.add_operands(&[condition])
.add_results(result_types)
.add_attributes(&[(Identifier::new(context, "cases"), cases.into())])
Expand All @@ -89,14 +85,13 @@ pub fn index_switch<'c>(

/// Creates a `scf.while` operation.
pub fn r#while<'c>(
context: &'c Context,
initial_values: &[Value<'c, '_>],
result_types: &[Type<'c>],
before_region: Region<'c>,
after_region: Region<'c>,
location: Location<'c>,
) -> Operation<'c> {
OperationBuilder::new( "scf.while", location)
OperationBuilder::new("scf.while", location)
.add_operands(initial_values)
.add_results(result_types)
.add_regions(vec![before_region, after_region])
Expand All @@ -105,12 +100,8 @@ pub fn r#while<'c>(
}

/// Creates a `scf.yield` operation.
pub fn r#yield<'c>(
context: &'c Context,
values: &[Value<'c, '_>],
location: Location<'c>,
) -> Operation<'c> {
OperationBuilder::new( "scf.yield", location)
pub fn r#yield<'c>(values: &[Value<'c, '_>], location: Location<'c>) -> Operation<'c> {
OperationBuilder::new("scf.yield", location)
.add_operands(values)
.build()
.expect("valid operation")
Expand Down Expand Up @@ -147,7 +138,6 @@ mod tests {
let block = Block::new(&[]);

block.append_operation(execute_region(
&context,
&[index_type],
{
let block = Block::new(&[]);
Expand All @@ -159,7 +149,6 @@ mod tests {
));

block.append_operation(r#yield(
&context,
&[value.result(0).unwrap().into()],
location,
));
Expand Down Expand Up @@ -219,13 +208,12 @@ mod tests {
));

block.append_operation(r#for(
&context,
start.result(0).unwrap().into(),
end.result(0).unwrap().into(),
step.result(0).unwrap().into(),
{
let block = Block::new(&[(Type::index(&context), location)]);
block.append_operation(r#yield(&context, &[], location));
block.append_operation(r#yield(&[], location));

let region = Region::new();
region.append_block(block);
Expand Down Expand Up @@ -274,7 +262,6 @@ mod tests {
));

let result = block.append_operation(r#if(
&context,
condition.result(0).unwrap().into(),
&[index_type],
{
Expand All @@ -287,7 +274,6 @@ mod tests {
));

block.append_operation(r#yield(
&context,
&[result.result(0).unwrap().into()],
location,
));
Expand All @@ -306,7 +292,6 @@ mod tests {
));

block.append_operation(r#yield(
&context,
&[result.result(0).unwrap().into()],
location,
));
Expand Down Expand Up @@ -358,13 +343,12 @@ mod tests {
));

block.append_operation(r#if(
&context,
condition.result(0).unwrap().into(),
&[],
{
let block = Block::new(&[]);

block.append_operation(r#yield(&context, &[], location));
block.append_operation(r#yield(&[], location));

let region = Region::new();
region.append_block(block);
Expand Down Expand Up @@ -419,7 +403,7 @@ mod tests {
{
let block = Block::new(&[]);

block.append_operation(r#yield(&context, &[], location));
block.append_operation(r#yield(&[], location));

let region = Region::new();
region.append_block(block);
Expand All @@ -428,7 +412,7 @@ mod tests {
{
let block = Block::new(&[]);

block.append_operation(r#yield(&context, &[], location));
block.append_operation(r#yield(&[], location));

let region = Region::new();
region.append_block(block);
Expand All @@ -437,7 +421,7 @@ mod tests {
{
let block = Block::new(&[]);

block.append_operation(r#yield(&context, &[], location));
block.append_operation(r#yield(&[], location));

let region = Region::new();
region.append_block(block);
Expand Down Expand Up @@ -487,7 +471,6 @@ mod tests {
));

block.append_operation(r#while(
&context,
&[initial.result(0).unwrap().into()],
&[index_type],
{
Expand All @@ -507,7 +490,6 @@ mod tests {
));

block.append_operation(super::condition(
&context,
condition.result(0).unwrap().into(),
&[result.result(0).unwrap().into()],
location,
Expand All @@ -527,7 +509,6 @@ mod tests {
));

block.append_operation(r#yield(
&context,
&[result.result(0).unwrap().into()],
location,
));
Expand Down Expand Up @@ -577,7 +558,6 @@ mod tests {
));

block.append_operation(r#while(
&context,
&[initial.result(0).unwrap().into()],
&[float_type],
{
Expand All @@ -597,7 +577,6 @@ mod tests {
));

block.append_operation(super::condition(
&context,
condition.result(0).unwrap().into(),
&[result.result(0).unwrap().into()],
location,
Expand All @@ -617,7 +596,6 @@ mod tests {
));

block.append_operation(r#yield(
&context,
&[result.result(0).unwrap().into()],
location,
));
Expand Down Expand Up @@ -666,7 +644,6 @@ mod tests {
));

block.append_operation(r#while(
&context,
&[
initial.result(0).unwrap().into(),
initial.result(0).unwrap().into(),
Expand All @@ -690,7 +667,6 @@ mod tests {
));

block.append_operation(super::condition(
&context,
condition.result(0).unwrap().into(),
&[
result.result(0).unwrap().into(),
Expand All @@ -714,7 +690,6 @@ mod tests {
));

block.append_operation(r#yield(
&context,
&[
result.result(0).unwrap().into(),
result.result(0).unwrap().into(),
Expand Down
Loading

0 comments on commit 37e8db3

Please sign in to comment.