From aea315a69316b9590746570b619cbe1c2f4c61c2 Mon Sep 17 00:00:00 2001 From: Vaivaswatha Nagaraj Date: Wed, 3 Jul 2024 13:05:45 +0530 Subject: [PATCH] Fix bug in `replace_some_uses_with`, add tests --- src/basic_block.rs | 2 +- src/use_def_lists.rs | 31 +++++++++++----- tests/ir_construct.rs | 83 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 104 insertions(+), 12 deletions(-) diff --git a/src/basic_block.rs b/src/basic_block.rs index 9778dcd..739267d 100644 --- a/src/basic_block.rs +++ b/src/basic_block.rs @@ -241,7 +241,7 @@ impl BasicBlock { pred(ctx, pred_block) }; - self.preds.replace_some_uses_with(ctx, predicate, &other); + DefNode::replace_some_uses_with(ctx, predicate, &self.self_ptr, &other); } /// Get all successors of this block. diff --git a/src/use_def_lists.rs b/src/use_def_lists.rs index 7bb7b2d..ecc5bd5 100644 --- a/src/use_def_lists.rs +++ b/src/use_def_lists.rs @@ -81,22 +81,36 @@ impl DefNode { /// Replace some uses of the underlying definition with `other`. pub(crate) fn replace_some_uses_with) -> bool>( - &mut self, ctx: &Context, predicate: P, + this: &T, other: &T, ) where T: DefTrait + UseTrait, { - if std::ptr::eq(self, &*other.get_defnode_ref(ctx)) { + if std::ptr::eq(&*this.get_defnode_ref(ctx), &*other.get_defnode_ref(ctx)) { return; } - for r#use in self.uses.iter().filter(|r#use| predicate(ctx, r#use)) { - let mut use_mut = T::get_usenode_mut(r#use, ctx); - *use_mut = other.get_defnode_mut(ctx).add_use(*other, *r#use); + + // We collect because we don't want to keep the defnode locked up. + let touched_uses: FxHashSet<_> = this + .get_defnode_ref(ctx) + .uses + .iter() + .filter(|r#use| predicate(ctx, r#use)) + .cloned() + .collect(); + + // Add each [Use] as a use of `other`, replacing the [UseNode] subsequently. + for r#use in &touched_uses { + let new_use_node = other.get_defnode_mut(ctx).add_use(*other, *r#use); + *T::get_usenode_mut(r#use, ctx) = new_use_node; } - // self will no longer have these uses. - self.uses.retain(|r#use| !predicate(ctx, r#use)); + + // `this` will no longer have the touched uses. + this.get_defnode_mut(ctx) + .uses + .retain(|r#use| !touched_uses.contains(r#use)); } } @@ -150,8 +164,7 @@ impl Value { pred: P, other: &Value, ) { - self.get_defnode_mut(ctx) - .replace_some_uses_with(ctx, pred, other); + DefNode::replace_some_uses_with(ctx, pred, self, other); } } diff --git a/tests/ir_construct.rs b/tests/ir_construct.rs index 83d8339..f33ee08 100644 --- a/tests/ir_construct.rs +++ b/tests/ir_construct.rs @@ -1,10 +1,15 @@ -use common::ConstantOp; +use common::{ConstantOp, ReturnOp}; use expect_test::{expect, Expect}; use pliron::{ - builtin::op_interfaces::OneResultInterface, + basic_block::BasicBlock, + builtin::{ + op_interfaces::OneResultInterface, + types::{IntegerType, Signedness}, + }, common_traits::Verify, context::Context, debug_info::set_operation_result_name, + impl_canonical_syntax, impl_verify_succ, irfmt::parsers::spaced, location, op::Op, @@ -19,6 +24,7 @@ use pliron::{ WALKCONFIG_PREORDER_FORWARD, }, }; +use pliron_derive::def_op; use crate::common::{const_ret_in_mod, setup_context_dialects}; use combine::parser::Parser; @@ -122,6 +128,79 @@ fn replace_c0_with_c1_operand() -> Result<()> { Ok(()) } +#[def_op("test.dual_def")] +struct DualDefOp {} +impl_verify_succ!(DualDefOp); +impl_canonical_syntax!(DualDefOp); + +/// If an Op has multiple results, or a block multiple args, +/// replacing all uses of one with the other should work. +/// (since our RefCell is at the Op or block level, we shouldn't +/// end up with a multiple borrow panic). +#[test] +fn test_replace_within_same_def_site() { + let ctx = &mut setup_context_dialects(); + DualDefOp::register(ctx, DualDefOp::parser_fn); + + let u64_ty = IntegerType::get(ctx, 64, Signedness::Signed).into(); + + let dual_def_op = Operation::new( + ctx, + DualDefOp::get_opid_static(), + vec![u64_ty, u64_ty], + vec![], + vec![], + 0, + ); + let (res1, res2) = ( + dual_def_op.deref(ctx).get_result(0).unwrap(), + dual_def_op.deref(ctx).get_result(1).unwrap(), + ); + let (module_op, func_op, const_op, ret_op) = const_ret_in_mod(ctx).unwrap(); + dual_def_op.insert_before(ctx, ret_op.get_operation()); + const_op + .get_result(ctx) + .replace_some_uses_with(ctx, |_, _| true, &res1); + res1.replace_some_uses_with(ctx, |_, _| true, &res2); + let printed = format!("{}", module_op.disp(ctx)); + expect![[r#" + builtin.module @bar { + ^block_1v1(): + builtin.func @foo: builtin.function<() -> (builtin.int)> { + ^entry_block_2v1(): + c0_op_4v1_res0 = test.constant builtin.integer <0x0: builtin.int>; + op_1v1_res0, op_1v1_res1 = test.dual_def () [] []: <() -> (builtin.int, builtin.int)>; + test.return op_1v1_res1 + } + }"#]] + .assert_eq(&printed); + + let dual_arg_block = BasicBlock::new(ctx, None, vec![u64_ty, u64_ty]); + let (arg1, arg2) = ( + dual_arg_block.deref(ctx).get_argument(0).unwrap(), + dual_arg_block.deref(ctx).get_argument(1).unwrap(), + ); + dual_arg_block.insert_after(ctx, func_op.get_entry_block(ctx)); + let ret_op = ReturnOp::new(ctx, arg1); + ret_op.get_operation().insert_at_back(dual_arg_block, ctx); + arg1.replace_some_uses_with(ctx, |_, _| true, &arg2); + + let printed = format!("{}", module_op.disp(ctx)); + expect![[r#" + builtin.module @bar { + ^block_1v1(): + builtin.func @foo: builtin.function<() -> (builtin.int)> { + ^entry_block_2v1(): + c0_op_4v1_res0 = test.constant builtin.integer <0x0: builtin.int>; + op_1v1_res0, op_1v1_res1 = test.dual_def () [] []: <() -> (builtin.int, builtin.int)>; + test.return op_1v1_res1 + ^block_3v1(block_3v1_arg0:builtin.int,block_3v1_arg1:builtin.int): + test.return block_3v1_arg1 + } + }"#]] + .assert_eq(&printed); +} + #[test] /// A test to just print a constructed IR to stdout. fn print_simple() -> Result<()> {