diff --git a/src/main/scala/ir/dsl/DSL.scala b/src/main/scala/ir/dsl/DSL.scala index 161f0d19e..d19b362e6 100644 --- a/src/main/scala/ir/dsl/DSL.scala +++ b/src/main/scala/ir/dsl/DSL.scala @@ -47,7 +47,10 @@ case class EventuallyIndirectCall(target: Variable, fallthrough: Option[DelayNam case class EventuallyCall(target: DelayNameResolve, fallthrough: Option[DelayNameResolve]) extends EventuallyJump { override def resolve(p: Program): DirectCall = { - val t = target.resolveProc(p).get + val t = target.resolveProc(p) match { + case Some(x) => x + case None => throw Exception("can't resolve proc " + p) + } val ft = fallthrough.flatMap(_.resolveBlock(p)) DirectCall(t, ft) } @@ -70,11 +73,9 @@ def goto(targets: List[String]): EventuallyGoto = { EventuallyGoto(targets.map(p => DelayNameResolve(p))) } -def indirectCall(tgt: String, fallthrough: Option[String]): EventuallyCall = EventuallyCall(DelayNameResolve(tgt), fallthrough.map(x => DelayNameResolve(x))) +def directCall(tgt: String, fallthrough: Option[String]): EventuallyCall = EventuallyCall(DelayNameResolve(tgt), fallthrough.map(x => DelayNameResolve(x))) -def call(tgt: String, fallthrough: Option[String]): EventuallyCall = EventuallyCall(DelayNameResolve(tgt), fallthrough.map(x => DelayNameResolve(x))) - -def call(tgt: Variable, fallthrough: Option[String]): EventuallyIndirectCall = EventuallyIndirectCall(tgt, fallthrough.map(x => DelayNameResolve(x))) +def indirectCall(tgt: Variable, fallthrough: Option[String]): EventuallyIndirectCall = EventuallyIndirectCall(tgt, fallthrough.map(x => DelayNameResolve(x))) // def directcall(tgt: String) = EventuallyCall(DelayNameResolve(tgt), None) diff --git a/src/test/scala/LiveVarsAnalysisTests.scala b/src/test/scala/LiveVarsAnalysisTests.scala index e47c63679..2b3bf84c6 100644 --- a/src/test/scala/LiveVarsAnalysisTests.scala +++ b/src/test/scala/LiveVarsAnalysisTests.scala @@ -31,10 +31,10 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { block("first_call", r0ConstantAssign, r1ConstantAssign, - call("callee1", Some("second_call")) + directCall("callee1", Some("second_call")) ), block("second_call", - call("callee2", Some("returnBlock")) + directCall("callee2", Some("returnBlock")) ), block("returnBlock", ret @@ -69,10 +69,10 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { block("first_call", r0ConstantAssign, r1ConstantAssign, - call("callee1", Some("second_call")) + directCall("callee1", Some("second_call")) ), block("second_call", - call("callee2", Some("returnBlock")) + directCall("callee2", Some("returnBlock")) ), block("returnBlock", ret @@ -104,10 +104,10 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { var program = prog( proc("main", block("main_first_call", - call("wrapper1", Some("main_second_call")) + directCall("wrapper1", Some("main_second_call")) ), block("main_second_call", - call("wrapper2", Some("main_return")) + directCall("wrapper2", Some("main_return")) ), block("main_return", ret) ), @@ -117,19 +117,19 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { proc("wrapper1", block("wrapper1_first_call", LocalAssign(R1, constant1), - call("callee", Some("wrapper1_second_call")) + directCall("callee", Some("wrapper1_second_call")) ), block("wrapper1_second_call", - call("callee2", Some("wrapper1_return"))), + directCall("callee2", Some("wrapper1_return"))), block("wrapper1_return", ret) ), proc("wrapper2", block("wrapper2_first_call", LocalAssign(R2, constant1), - call("callee", Some("wrapper2_second_call")) + directCall("callee", Some("wrapper2_second_call")) ), block("wrapper2_second_call", - call("callee3", Some("wrapper2_return"))), + directCall("callee3", Some("wrapper2_return"))), block("wrapper2_return", ret) ) ) @@ -148,7 +148,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { var program = prog( proc("main", block("lmain", - call("killer", Some("aftercall")) + directCall("killer", Some("aftercall")) ), block("aftercall", LocalAssign(R0, R1), @@ -212,7 +212,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { block( "lmain", LocalAssign(R0, R1), - call("main", Some("return")) + directCall("main", Some("return")) ), block("return", LocalAssign(R0, R2), @@ -240,7 +240,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { ), block( "recursion", - call("main", Some("assign")) + directCall("main", Some("assign")) ), block("assign", LocalAssign(R0, R2), diff --git a/src/test/scala/PointsToTest.scala b/src/test/scala/PointsToTest.scala index afcb2d6bb..38ce1afc3 100644 --- a/src/test/scala/PointsToTest.scala +++ b/src/test/scala/PointsToTest.scala @@ -168,7 +168,7 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft goto("0x1") ), block("0x1", - call("p2", Some("returntarget")) + directCall("p2", Some("returntarget")) ), block("returntarget", ret @@ -217,7 +217,7 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft goto("0x1") ), block("0x1", - call("p2", Some("returntarget")) + directCall("p2", Some("returntarget")) ), block("returntarget", ret @@ -227,7 +227,7 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft block("l_foo", LocalAssign(getRegister("R0"), MemoryLoad(mem, BinaryExpr(BVADD, getRegister("R31"), bv64(6)), LittleEndian, 64)), LocalAssign(getRegister("R1"), BinaryExpr(BVADD, getRegister("R31"), bv64(10))), - call("p2", Some("l_foo_1")) + directCall("p2", Some("l_foo_1")) ), block("l_foo_1", ret, diff --git a/src/test/scala/ir/IRTest.scala b/src/test/scala/ir/IRTest.scala index dd1e21081..5eff69ded 100644 --- a/src/test/scala/ir/IRTest.scala +++ b/src/test/scala/ir/IRTest.scala @@ -142,7 +142,7 @@ class IRTest extends AnyFunSuite { ), block("l_main_1", LocalAssign(R0, bv64(22)), - call("p2", Some("returntarget")) + directCall("p2", Some("returntarget")) ), block("returntarget", ret @@ -246,7 +246,7 @@ class IRTest extends AnyFunSuite { LocalAssign(R0, bv64(22)), LocalAssign(R0, bv64(22)), LocalAssign(R0, bv64(22)), - call("main", None) + directCall("main", None) ).resolve(p) val b2 = block("newblock1", LocalAssign(R0, bv64(22)), @@ -271,7 +271,7 @@ class IRTest extends AnyFunSuite { assert(called.incomingCalls().isEmpty) val b3 = block("newblock3", LocalAssign(R0, bv64(22)), - call("called", None) + directCall("called", None) ).resolve(p) assert(b3.calls.toSet == Set(p.procs("called"))) @@ -333,7 +333,7 @@ class IRTest extends AnyFunSuite { proc("main", block("l_main", LocalAssign(R0, bv64(10)), - call("p1", Some("returntarget")) + directCall("p1", Some("returntarget")) ), block("returntarget", ret @@ -365,5 +365,41 @@ class IRTest extends AnyFunSuite { } + test("replace jump") { + val p = prog( + proc("p1", + block("b1", + LocalAssign(R0, bv64(10)), + ret + ) + ), + proc("main", + block("l_main", + LocalAssign(R0, bv64(10)), + indirectCall(R1, Some("returntarget")) + ), + block("block2", + directCall("p1", Some("returntarget")) + ), + block("returntarget", + ret + ) + ), + ) + + val main = p.blocks("l_main") + val p1 = p.procs("p1") + val block2 = p.blocks("block2") + + val oldJump = main.jump + val newJump = block2.jump + + main.replaceJump(newJump) + + assert(newJump.parent == main) + assert(block2.jump.isInstanceOf[GoTo]) + assert(block2.jump.asInstanceOf[GoTo].targets.isEmpty) + } + }