diff --git a/src/main/scala/ir/IRCursor.scala b/src/main/scala/ir/IRCursor.scala index 7acd268a3..2b5d3123e 100644 --- a/src/main/scala/ir/IRCursor.scala +++ b/src/main/scala/ir/IRCursor.scala @@ -46,8 +46,8 @@ object IRWalk: def commandBegin(pos: CFGPosition) : Option[Command] = { pos match { - case p: Procedure => p.entryBlock.map(b => b.statements.headOption().getOrElse(b.jump)) - case b: Block => Some(b.statements.headOption().getOrElse(b.jump)) + case p: Procedure => p.entryBlock.map(b => b.statements.headOption.getOrElse(b.jump)) + case b: Block => Some(b.statements.headOption.getOrElse(b.jump)) case c: Command => Some(c) } } @@ -81,7 +81,7 @@ trait IntraProcIRCursor extends IRWalk[CFGPosition, CFGPosition] { def succ(pos: CFGPosition): Set[CFGPosition] = { pos match { case proc: Procedure => proc.entryBlock.toSet - case b: Block => Set(b.statements.headOption().getOrElse(b.jump)) + case b: Block => Set(b.statements.headOption.getOrElse(b.jump)) case s: Statement => Set(s.succ().getOrElse(s.parent.jump)) case n: GoTo => n.targets.asInstanceOf[Set[CFGPosition]] case c: Call => c.parent.fallthrough.toSet diff --git a/src/main/scala/ir/Program.scala b/src/main/scala/ir/Program.scala index 4e8f18b00..a8337d6ed 100644 --- a/src/main/scala/ir/Program.scala +++ b/src/main/scala/ir/Program.scala @@ -358,9 +358,8 @@ class Block private ( statements.onInsert = x => x.setParent(this) statements.onRemove = x => x.deParent() - def this(label: String, address: Option[Int] = None, statements: IterableOnce[Statement] = Set.empty, jump: Jump = GoTo(Set.empty)) = { - this(label, address, IntrusiveList.from(statements), jump, mutable.HashSet.empty, None) + this(label, address, IntrusiveList().addAll(statements), jump, mutable.HashSet.empty, None) } def jump: Jump = _jump @@ -377,7 +376,8 @@ class Block private ( _fallthrough = g } - def jump_=(j: Jump): Unit = { + private def jump_=(j: Jump): Unit = { + require(!j.hasParent) if (j ne _jump) { _jump.deParent() _jump = j @@ -386,6 +386,11 @@ class Block private ( } def replaceJump(j: Jump): Block = { + if (j.hasParent) { + val parent = j.parent + j.deParent() + parent.jump = GoTo(Set.empty) + } jump = j this } diff --git a/src/main/scala/ir/Statement.scala b/src/main/scala/ir/Statement.scala index 8cbd55952..5ad925637 100644 --- a/src/main/scala/ir/Statement.scala +++ b/src/main/scala/ir/Statement.scala @@ -127,21 +127,7 @@ object GoTo: sealed trait Call extends Jump { - private var _returnTarget: Option[Block] = None - - // replacing the return target of a call - def returnTarget_=(b: Block): Unit = { - require(b.hasParent) - - if (hasParent) { - // if we don't have a parent now, delay adding the fallthrough block until linking - parent.fallthrough = Some(GoTo(Set(b))) - } - - _returnTarget = Some(b) - } - - def returnTarget: Option[Block] = _returnTarget + val returnTarget: Option[Block] // moving a call between blocks override def linkParent(p: Block): Unit = { @@ -154,10 +140,9 @@ sealed trait Call extends Jump { } class DirectCall(val target: Procedure, - private val _returnTarget: Option[Block] = None, + override val returnTarget: Option[Block] = None, override val label: Option[String] = None ) extends Call { - _returnTarget.foreach(x => returnTarget = x) /* override def locals: Set[Variable] = condition match { case Some(c) => c.locals case None => Set() @@ -182,10 +167,9 @@ object DirectCall: def unapply(i: DirectCall): Option[(Procedure, Option[Block], Option[String])] = Some(i.target, i.returnTarget, i.label) class IndirectCall(var target: Variable, - private val _returnTarget: Option[Block] = None, + override val returnTarget: Option[Block] = None, override val label: Option[String] = None ) extends Call { - _returnTarget.foreach(x => returnTarget = x) /* override def locals: Set[Variable] = condition match { case Some(c) => c.locals + target case None => Set(target) diff --git a/src/main/scala/ir/dsl/DSL.scala b/src/main/scala/ir/dsl/DSL.scala index 161f0d19e..d19b362e6 100644 --- a/src/main/scala/ir/dsl/DSL.scala +++ b/src/main/scala/ir/dsl/DSL.scala @@ -47,7 +47,10 @@ case class EventuallyIndirectCall(target: Variable, fallthrough: Option[DelayNam case class EventuallyCall(target: DelayNameResolve, fallthrough: Option[DelayNameResolve]) extends EventuallyJump { override def resolve(p: Program): DirectCall = { - val t = target.resolveProc(p).get + val t = target.resolveProc(p) match { + case Some(x) => x + case None => throw Exception("can't resolve proc " + p) + } val ft = fallthrough.flatMap(_.resolveBlock(p)) DirectCall(t, ft) } @@ -70,11 +73,9 @@ def goto(targets: List[String]): EventuallyGoto = { EventuallyGoto(targets.map(p => DelayNameResolve(p))) } -def indirectCall(tgt: String, fallthrough: Option[String]): EventuallyCall = EventuallyCall(DelayNameResolve(tgt), fallthrough.map(x => DelayNameResolve(x))) +def directCall(tgt: String, fallthrough: Option[String]): EventuallyCall = EventuallyCall(DelayNameResolve(tgt), fallthrough.map(x => DelayNameResolve(x))) -def call(tgt: String, fallthrough: Option[String]): EventuallyCall = EventuallyCall(DelayNameResolve(tgt), fallthrough.map(x => DelayNameResolve(x))) - -def call(tgt: Variable, fallthrough: Option[String]): EventuallyIndirectCall = EventuallyIndirectCall(tgt, fallthrough.map(x => DelayNameResolve(x))) +def indirectCall(tgt: Variable, fallthrough: Option[String]): EventuallyIndirectCall = EventuallyIndirectCall(tgt, fallthrough.map(x => DelayNameResolve(x))) // def directcall(tgt: String) = EventuallyCall(DelayNameResolve(tgt), None) diff --git a/src/main/scala/translating/GTIRBToIR.scala b/src/main/scala/translating/GTIRBToIR.scala index b4b6f5e31..fa33bb147 100644 --- a/src/main/scala/translating/GTIRBToIR.scala +++ b/src/main/scala/translating/GTIRBToIR.scala @@ -292,7 +292,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ private def cleanUpIfPCAssign(block: Block, procedure: Procedure): Unit = { var newBlockCount = 0 var currentBlock = block - var currentStatement = currentBlock.statements.head() + var currentStatement = currentBlock.statements.head var breakLoop = false val queue = mutable.Queue[Block]() while (!breakLoop) { @@ -306,7 +306,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ if (queue.nonEmpty) { currentBlock = queue.dequeue() - currentStatement = currentBlock.statements.head() + currentStatement = currentBlock.statements.head } else { breakLoop = true } @@ -326,7 +326,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ if (queue.nonEmpty) { currentBlock = queue.dequeue() - currentStatement = currentBlock.statements.head() + currentStatement = currentBlock.statements.head } else { breakLoop = true } @@ -335,7 +335,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ currentStatement = currentBlock.statements.getNext(currentStatement) } else if (queue.nonEmpty) { currentBlock = queue.dequeue() - currentStatement = currentBlock.statements.head() + currentStatement = currentBlock.statements.head } else { breakLoop = true } @@ -374,7 +374,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ g.targets.head } else { // case where goto has multiple targets: create an extra block and point to that - val afterBlock = Block(parentLabel + "$__" + newBlockCount, None) + val afterBlock = Block(parentLabel + "$__" + newBlockCount) newBlockCount += 1 newBlocks.append(afterBlock) afterBlock.replaceJump(currentBlock.jump) diff --git a/src/main/scala/translating/IRToBoogie.scala b/src/main/scala/translating/IRToBoogie.scala index 3a88a1f64..51a820273 100644 --- a/src/main/scala/translating/IRToBoogie.scala +++ b/src/main/scala/translating/IRToBoogie.scala @@ -637,7 +637,7 @@ class IRToBoogie(var program: Program, var spec: Specification) { case g: GoTo => // collects all targets of the goto with a branch condition that we need to check the security level for // and collects the variables for that - val conditions = g.targets.flatMap(_.statements.headOption()).collect { case a: Assume if a.checkSecurity => a } + val conditions = g.targets.flatMap(_.statements.headOption.collect { case a: Assume if a.checkSecurity => a }) val conditionVariables = conditions.flatMap(_.body.variables) val gammas = conditionVariables.map(_.toGamma).toList.sorted val conditionAssert: List[BCmd] = if (gammas.size > 1) { diff --git a/src/main/scala/util/intrusive_list/IntrusiveList.scala b/src/main/scala/util/intrusive_list/IntrusiveList.scala index 2803cf931..a6c5f0b93 100644 --- a/src/main/scala/util/intrusive_list/IntrusiveList.scala +++ b/src/main/scala/util/intrusive_list/IntrusiveList.scala @@ -1,5 +1,6 @@ package util.intrusive_list import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer // TODO: implement IterableOps // So need iterablefactory https://docs.scala-lang.org/overviews/core/custom-collections.html @@ -90,7 +91,7 @@ final class IntrusiveList[T <: IntrusiveListElement[T]] private ( elem } - class IntrusiveListIterator(var elem: Option[T], forward: Boolean) extends Iterator[T] { + private class IntrusiveListIterator(var elem: Option[T], forward: Boolean) extends Iterator[T] { override def hasNext: Boolean = elem.isDefined override def next: T = { val t = elem.get @@ -133,14 +134,14 @@ final class IntrusiveList[T <: IntrusiveListElement[T]] private ( /** * Unsafely return the first element of the list. */ - override def head(): T = firstElem.get + override def head: T = firstElem.get - override def headOption(): Option[T] = firstElem + override def headOption: Option[T] = firstElem /** * Unsafely return the first element of the list. */ - def begin(): T = firstElem.get + def begin: T = firstElem.get /** * Check whether the list contains the given element (by reference) by linear scan. @@ -169,7 +170,7 @@ final class IntrusiveList[T <: IntrusiveListElement[T]] private ( /** * Unsafely return the last element of the list. */ - def back(): T = lastElem.get + def back: T = lastElem.get /** * Add an element to the beginning of the list. @@ -232,19 +233,19 @@ final class IntrusiveList[T <: IntrusiveListElement[T]] private ( } /** - * Split the list into two lists, the first retains all elements up to to and including the provided element, - * and and returns the second list from the element until the end. + * Removes all elements after the provided element n and returns an ArrayBuffer containing the removed elements, + * maintaining the ordering. * * @param n The element to split on, remains in the first list. - * @return A list containing all elements after n. + * @return An ArrayBuffer containing all elements after n. */ - def splitOn(n: T): IntrusiveList[T] = { + def splitOn(n: T): ArrayBuffer[T] = { require(!lastElem.contains(n)) require(containsRef(n)) val ne = n.next - val newlist = new IntrusiveList[T]() + val newlist = ArrayBuffer[T]() var next = n.next while (next.isDefined) { remove(next.get) @@ -355,12 +356,7 @@ final class IntrusiveList[T <: IntrusiveListElement[T]] private ( } object IntrusiveList { - - def from[T <: IntrusiveListElement[T]](it: IntrusiveList[T]): IntrusiveList[T] = it - - def from[T <: IntrusiveListElement[T]](it: IterableOnce[T]): IntrusiveList[T] = IntrusiveList[T]().addAll(it) - - def empty[T <: IntrusiveListElement[T]]: IntrusiveList[T] = new IntrusiveList[T]() + def empty[T <: IntrusiveListElement[T]]: IntrusiveList[T] = IntrusiveList[T]() } /** @@ -382,8 +378,6 @@ trait IntrusiveListElement[T <: IntrusiveListElement[T]]: elem } - - private[intrusive_list] final def unitary: Boolean = next.isEmpty && prev.isEmpty private[intrusive_list] final def insertAfter(elem: T): T = { diff --git a/src/test/scala/LiveVarsAnalysisTests.scala b/src/test/scala/LiveVarsAnalysisTests.scala index e47c63679..2b3bf84c6 100644 --- a/src/test/scala/LiveVarsAnalysisTests.scala +++ b/src/test/scala/LiveVarsAnalysisTests.scala @@ -31,10 +31,10 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { block("first_call", r0ConstantAssign, r1ConstantAssign, - call("callee1", Some("second_call")) + directCall("callee1", Some("second_call")) ), block("second_call", - call("callee2", Some("returnBlock")) + directCall("callee2", Some("returnBlock")) ), block("returnBlock", ret @@ -69,10 +69,10 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { block("first_call", r0ConstantAssign, r1ConstantAssign, - call("callee1", Some("second_call")) + directCall("callee1", Some("second_call")) ), block("second_call", - call("callee2", Some("returnBlock")) + directCall("callee2", Some("returnBlock")) ), block("returnBlock", ret @@ -104,10 +104,10 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { var program = prog( proc("main", block("main_first_call", - call("wrapper1", Some("main_second_call")) + directCall("wrapper1", Some("main_second_call")) ), block("main_second_call", - call("wrapper2", Some("main_return")) + directCall("wrapper2", Some("main_return")) ), block("main_return", ret) ), @@ -117,19 +117,19 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { proc("wrapper1", block("wrapper1_first_call", LocalAssign(R1, constant1), - call("callee", Some("wrapper1_second_call")) + directCall("callee", Some("wrapper1_second_call")) ), block("wrapper1_second_call", - call("callee2", Some("wrapper1_return"))), + directCall("callee2", Some("wrapper1_return"))), block("wrapper1_return", ret) ), proc("wrapper2", block("wrapper2_first_call", LocalAssign(R2, constant1), - call("callee", Some("wrapper2_second_call")) + directCall("callee", Some("wrapper2_second_call")) ), block("wrapper2_second_call", - call("callee3", Some("wrapper2_return"))), + directCall("callee3", Some("wrapper2_return"))), block("wrapper2_return", ret) ) ) @@ -148,7 +148,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { var program = prog( proc("main", block("lmain", - call("killer", Some("aftercall")) + directCall("killer", Some("aftercall")) ), block("aftercall", LocalAssign(R0, R1), @@ -212,7 +212,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { block( "lmain", LocalAssign(R0, R1), - call("main", Some("return")) + directCall("main", Some("return")) ), block("return", LocalAssign(R0, R2), @@ -240,7 +240,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { ), block( "recursion", - call("main", Some("assign")) + directCall("main", Some("assign")) ), block("assign", LocalAssign(R0, R2), diff --git a/src/test/scala/PointsToTest.scala b/src/test/scala/PointsToTest.scala index afcb2d6bb..38ce1afc3 100644 --- a/src/test/scala/PointsToTest.scala +++ b/src/test/scala/PointsToTest.scala @@ -168,7 +168,7 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft goto("0x1") ), block("0x1", - call("p2", Some("returntarget")) + directCall("p2", Some("returntarget")) ), block("returntarget", ret @@ -217,7 +217,7 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft goto("0x1") ), block("0x1", - call("p2", Some("returntarget")) + directCall("p2", Some("returntarget")) ), block("returntarget", ret @@ -227,7 +227,7 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft block("l_foo", LocalAssign(getRegister("R0"), MemoryLoad(mem, BinaryExpr(BVADD, getRegister("R31"), bv64(6)), LittleEndian, 64)), LocalAssign(getRegister("R1"), BinaryExpr(BVADD, getRegister("R31"), bv64(10))), - call("p2", Some("l_foo_1")) + directCall("p2", Some("l_foo_1")) ), block("l_foo_1", ret, diff --git a/src/test/scala/ir/IRTest.scala b/src/test/scala/ir/IRTest.scala index dd1e21081..5eff69ded 100644 --- a/src/test/scala/ir/IRTest.scala +++ b/src/test/scala/ir/IRTest.scala @@ -142,7 +142,7 @@ class IRTest extends AnyFunSuite { ), block("l_main_1", LocalAssign(R0, bv64(22)), - call("p2", Some("returntarget")) + directCall("p2", Some("returntarget")) ), block("returntarget", ret @@ -246,7 +246,7 @@ class IRTest extends AnyFunSuite { LocalAssign(R0, bv64(22)), LocalAssign(R0, bv64(22)), LocalAssign(R0, bv64(22)), - call("main", None) + directCall("main", None) ).resolve(p) val b2 = block("newblock1", LocalAssign(R0, bv64(22)), @@ -271,7 +271,7 @@ class IRTest extends AnyFunSuite { assert(called.incomingCalls().isEmpty) val b3 = block("newblock3", LocalAssign(R0, bv64(22)), - call("called", None) + directCall("called", None) ).resolve(p) assert(b3.calls.toSet == Set(p.procs("called"))) @@ -333,7 +333,7 @@ class IRTest extends AnyFunSuite { proc("main", block("l_main", LocalAssign(R0, bv64(10)), - call("p1", Some("returntarget")) + directCall("p1", Some("returntarget")) ), block("returntarget", ret @@ -365,5 +365,41 @@ class IRTest extends AnyFunSuite { } + test("replace jump") { + val p = prog( + proc("p1", + block("b1", + LocalAssign(R0, bv64(10)), + ret + ) + ), + proc("main", + block("l_main", + LocalAssign(R0, bv64(10)), + indirectCall(R1, Some("returntarget")) + ), + block("block2", + directCall("p1", Some("returntarget")) + ), + block("returntarget", + ret + ) + ), + ) + + val main = p.blocks("l_main") + val p1 = p.procs("p1") + val block2 = p.blocks("block2") + + val oldJump = main.jump + val newJump = block2.jump + + main.replaceJump(newJump) + + assert(newJump.parent == main) + assert(block2.jump.isInstanceOf[GoTo]) + assert(block2.jump.asInstanceOf[GoTo].targets.isEmpty) + } + } diff --git a/src/test/scala/util/intrusive_list/IntrusiveListPublicInterfaceTest.scala b/src/test/scala/util/intrusive_list/IntrusiveListPublicInterfaceTest.scala index 06c3ffcb2..f9f89867f 100644 --- a/src/test/scala/util/intrusive_list/IntrusiveListPublicInterfaceTest.scala +++ b/src/test/scala/util/intrusive_list/IntrusiveListPublicInterfaceTest.scala @@ -50,9 +50,9 @@ class IntrusiveListPublicInterfaceTest extends AnyFunSuite { // x.foreach(println(_)) - val y = x.head() + val y = x.head assert(y.t == 10) - assert(x.back().t == 14) + assert(x.back.t == 14) } test("Clear") { @@ -183,19 +183,9 @@ class IntrusiveListPublicInterfaceTest extends AnyFunSuite { } test("construct") { - val l = getSequentialList(3) - - - val l2 = IntrusiveList.from(l) - assert(l2.size == 3) - - assert(l.forall(x => l2.contains(x))) - - assert (l eq l2) - val l3 = mutable.ArrayBuffer(Elem(1), Elem(2), Elem(3)) - val l4 = IntrusiveList.from(l3) + val l4 = IntrusiveList().addAll(l3) assert(l3 ne l4) assert(l3.forall(x => l4.contains(x)))