Skip to content

Commit

Permalink
Fix bug in replace_some_uses_with, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vaivaswatha committed Jul 3, 2024
1 parent 5340695 commit aea315a
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ impl BasicBlock {
pred(ctx, pred_block)
};

self.preds.replace_some_uses_with(ctx, predicate, &other);
DefNode::replace_some_uses_with(ctx, predicate, &self.self_ptr, &other);
}

/// Get all successors of this block.
Expand Down
31 changes: 22 additions & 9 deletions src/use_def_lists.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,36 @@ impl<T: DefUseParticipant> DefNode<T> {

/// Replace some uses of the underlying definition with `other`.
pub(crate) fn replace_some_uses_with<P: Fn(&Context, &Use<T>) -> bool>(
&mut self,
ctx: &Context,
predicate: P,
this: &T,
other: &T,
) where
T: DefTrait + UseTrait,
{
if std::ptr::eq(self, &*other.get_defnode_ref(ctx)) {
if std::ptr::eq(&*this.get_defnode_ref(ctx), &*other.get_defnode_ref(ctx)) {
return;
}
for r#use in self.uses.iter().filter(|r#use| predicate(ctx, r#use)) {
let mut use_mut = T::get_usenode_mut(r#use, ctx);
*use_mut = other.get_defnode_mut(ctx).add_use(*other, *r#use);

// We collect because we don't want to keep the defnode locked up.
let touched_uses: FxHashSet<_> = this
.get_defnode_ref(ctx)
.uses
.iter()
.filter(|r#use| predicate(ctx, r#use))
.cloned()
.collect();

// Add each [Use] as a use of `other`, replacing the [UseNode] subsequently.
for r#use in &touched_uses {
let new_use_node = other.get_defnode_mut(ctx).add_use(*other, *r#use);
*T::get_usenode_mut(r#use, ctx) = new_use_node;
}
// self will no longer have these uses.
self.uses.retain(|r#use| !predicate(ctx, r#use));

// `this` will no longer have the touched uses.
this.get_defnode_mut(ctx)
.uses
.retain(|r#use| !touched_uses.contains(r#use));
}
}

Expand Down Expand Up @@ -150,8 +164,7 @@ impl Value {
pred: P,
other: &Value,
) {
self.get_defnode_mut(ctx)
.replace_some_uses_with(ctx, pred, other);
DefNode::replace_some_uses_with(ctx, pred, self, other);
}
}

Expand Down
83 changes: 81 additions & 2 deletions tests/ir_construct.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
use common::ConstantOp;
use common::{ConstantOp, ReturnOp};
use expect_test::{expect, Expect};
use pliron::{
builtin::op_interfaces::OneResultInterface,
basic_block::BasicBlock,
builtin::{
op_interfaces::OneResultInterface,
types::{IntegerType, Signedness},
},
common_traits::Verify,
context::Context,
debug_info::set_operation_result_name,
impl_canonical_syntax, impl_verify_succ,
irfmt::parsers::spaced,
location,
op::Op,
Expand All @@ -19,6 +24,7 @@ use pliron::{
WALKCONFIG_PREORDER_FORWARD,
},
};
use pliron_derive::def_op;

use crate::common::{const_ret_in_mod, setup_context_dialects};
use combine::parser::Parser;
Expand Down Expand Up @@ -122,6 +128,79 @@ fn replace_c0_with_c1_operand() -> Result<()> {
Ok(())
}

#[def_op("test.dual_def")]
struct DualDefOp {}
impl_verify_succ!(DualDefOp);
impl_canonical_syntax!(DualDefOp);

/// If an Op has multiple results, or a block multiple args,
/// replacing all uses of one with the other should work.
/// (since our RefCell is at the Op or block level, we shouldn't
/// end up with a multiple borrow panic).
#[test]
fn test_replace_within_same_def_site() {
let ctx = &mut setup_context_dialects();
DualDefOp::register(ctx, DualDefOp::parser_fn);

let u64_ty = IntegerType::get(ctx, 64, Signedness::Signed).into();

let dual_def_op = Operation::new(
ctx,
DualDefOp::get_opid_static(),
vec![u64_ty, u64_ty],
vec![],
vec![],
0,
);
let (res1, res2) = (
dual_def_op.deref(ctx).get_result(0).unwrap(),
dual_def_op.deref(ctx).get_result(1).unwrap(),
);
let (module_op, func_op, const_op, ret_op) = const_ret_in_mod(ctx).unwrap();
dual_def_op.insert_before(ctx, ret_op.get_operation());
const_op
.get_result(ctx)
.replace_some_uses_with(ctx, |_, _| true, &res1);
res1.replace_some_uses_with(ctx, |_, _| true, &res2);
let printed = format!("{}", module_op.disp(ctx));
expect![[r#"
builtin.module @bar {
^block_1v1():
builtin.func @foo: builtin.function<() -> (builtin.int<si64>)> {
^entry_block_2v1():
c0_op_4v1_res0 = test.constant builtin.integer <0x0: builtin.int<si64>>;
op_1v1_res0, op_1v1_res1 = test.dual_def () [] []: <() -> (builtin.int<si64>, builtin.int<si64>)>;
test.return op_1v1_res1
}
}"#]]
.assert_eq(&printed);

let dual_arg_block = BasicBlock::new(ctx, None, vec![u64_ty, u64_ty]);
let (arg1, arg2) = (
dual_arg_block.deref(ctx).get_argument(0).unwrap(),
dual_arg_block.deref(ctx).get_argument(1).unwrap(),
);
dual_arg_block.insert_after(ctx, func_op.get_entry_block(ctx));
let ret_op = ReturnOp::new(ctx, arg1);
ret_op.get_operation().insert_at_back(dual_arg_block, ctx);
arg1.replace_some_uses_with(ctx, |_, _| true, &arg2);

let printed = format!("{}", module_op.disp(ctx));
expect![[r#"
builtin.module @bar {
^block_1v1():
builtin.func @foo: builtin.function<() -> (builtin.int<si64>)> {
^entry_block_2v1():
c0_op_4v1_res0 = test.constant builtin.integer <0x0: builtin.int<si64>>;
op_1v1_res0, op_1v1_res1 = test.dual_def () [] []: <() -> (builtin.int<si64>, builtin.int<si64>)>;
test.return op_1v1_res1
^block_3v1(block_3v1_arg0:builtin.int<si64>,block_3v1_arg1:builtin.int<si64>):
test.return block_3v1_arg1
}
}"#]]
.assert_eq(&printed);
}

#[test]
/// A test to just print a constructed IR to stdout.
fn print_simple() -> Result<()> {
Expand Down

0 comments on commit aea315a

Please sign in to comment.