From c04a1f19e19042a3256c8f33129bfd1f9141e91a Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Tue, 28 Nov 2023 18:48:21 +1000 Subject: [PATCH] fix broken tests & cleanup --- src/main/scala/ir/IRCursor.scala | 2 +- src/main/scala/ir/Program.scala | 159 +++++++++++--------- src/main/scala/ir/Statement.scala | 22 +-- src/main/scala/translating/BAPToIR.scala | 10 +- src/main/scala/translating/IRToBoogie.scala | 6 +- src/main/scala/util/RunUtils.scala | 2 +- 6 files changed, 104 insertions(+), 97 deletions(-) diff --git a/src/main/scala/ir/IRCursor.scala b/src/main/scala/ir/IRCursor.scala index f0352dbd5..6dcd3dd0d 100644 --- a/src/main/scala/ir/IRCursor.scala +++ b/src/main/scala/ir/IRCursor.scala @@ -51,7 +51,7 @@ object IntraProcIRCursor { if proc.entryBlock.isEmpty then Set(proc.returnBlock) else Set(proc.entryBlock.get) case b: Block => if b.statements.isEmpty - then Set.from(b.jumpSet) + then Set(b.jump) else Set[CFGPosition](b.statements.head()) case s: Statement => if (s.parent.statements.hasNext(s)) { diff --git a/src/main/scala/ir/Program.scala b/src/main/scala/ir/Program.scala index b9843bbfc..8dc8a027c 100644 --- a/src/main/scala/ir/Program.scala +++ b/src/main/scala/ir/Program.scala @@ -16,11 +16,51 @@ trait HasParent[T]: private var _parent: Option[T] = None def parent: T = _parent.get - protected final def setParentValue(p: T): Unit = _parent = Some(p) - def deParent(): Unit = {} + /** + * Update any IL control-flow links implied by this relation. + * NOT necessarily idempotent. + * For example; + * - Registering calls with their target procedure + * - Registering jumps with their target block + * + * TODO: consider making abstract to force implementers to consider the linking. + * @param p The new parent + */ + protected[this] def linkParent(p: T): Unit = () - def setParent(p: T): Unit = setParentValue(p) + /** + * Update any IL control-flow links implied by this relation. + * NOT necessarily idempotent. + */ + protected[this] def unlinkParent(): Unit = () + + +/** + * Remove this element's parent and update any IL control-flow links implied by this relation. + * Is idempotent. + */ + final def deParent(): Unit = if _parent.isDefined then { + unlinkParent() + _parent = None + } + + /** + * Set this element's parent and update any IL control-flow links implied by this relation. + * If another parent is already set then it will be de-parented and unlinked from that first. + * Is idempotent. + */ + final def setParent(p: T): Unit = { + _parent match + case Some(existing) if existing == p => () + case None => + _parent = Some(p) + linkParent(p) + case Some(_) => + deParent() + _parent = Some(p) + linkParent(p) + } class Program(var procedures: ArrayBuffer[Procedure], var mainProcedure: Procedure, var initialMemory: ArrayBuffer[MemorySection], @@ -133,39 +173,41 @@ class Program(var procedures: ArrayBuffer[Procedure], var mainProcedure: Procedu class Procedure private ( var name: String, var address: Option[Int], + var entryBlock: Option[Block], private val _blocks: mutable.HashSet[Block], var in: ArrayBuffer[Parameter], var out: ArrayBuffer[Parameter], ) { - - def this(name: String, address: Option[Int] = None , blocks: IterableOnce[Block] = ArrayBuffer(), in: IterableOnce[Parameter] = ArrayBuffer(), out: IterableOnce[Parameter] = ArrayBuffer()) = { - this(name, address, mutable.HashSet.from(blocks), ArrayBuffer.from(in), ArrayBuffer.from(out)) - } - - private var _entryBlock: Option[Block] = None + /* First block executed when the procedure begins */ + /* returnBlock: Single return point for all returns from the procedure to its caller; always defined but only included + * in `blocks` if the procedure contains any real blocks.*/ val returnBlock: Block = new Block(name + "_return", None, List(), new IndirectCall(Register("R30", BitVecType(64)), None, Some(name))) - addBlocks(returnBlock) + returnBlock.setParent(this) - private var _callers = new mutable.HashMap[Procedure, mutable.Set[Call]] with mutable.MultiMap[Procedure, Call] + private var _callers = new mutable.HashSet[Call] - def entryBlock: Option[Block] = { - _entryBlock match { - case None => - _entryBlock = blocks.find(p => p.address == address) - _entryBlock - case Some(b) => Some(b) - } + def this(name: String, address: Option[Int] = None , entryBlock: Option[Block] = None, blocks: IterableOnce[Block] = ArrayBuffer(), in: IterableOnce[Parameter] = ArrayBuffer(), out: IterableOnce[Parameter] = ArrayBuffer()) = { + this(name, address, entryBlock, mutable.HashSet.from(blocks), ArrayBuffer.from(in), ArrayBuffer.from(out)) } - def calls: Set[Procedure] = blocks.flatMap(_.calls).toSet override def toString: String = { s"Procedure $name at ${address.getOrElse("None")} with ${blocks.size} blocks and ${in.size} in and ${out.size} out parameters" } - def blocks: Iterator[Block] = _blocks.iterator + def calls: Set[Procedure] = blocks.iterator.flatMap(_.calls).toSet + + /** + * Horrible, compensating for not storing the blocks in-order and storing initialBlock and returnBlock separately. + * @return + */ + def blocks: Seq[Block] = + (entryBlock match + case Some(b) if _blocks.nonEmpty => Seq(b) ++ _blocks.filter(x => x ne b).toSeq + case _ => _blocks.toSeq) + ++ (if _blocks.nonEmpty then Seq(returnBlock) else Seq()) def removeCaller(c: Call): Unit = { - _callers.removeBinding(c.parent.parent, c) + _callers.add(c) } def addBlocks(block: Block): Block = { @@ -182,8 +224,8 @@ class Procedure private ( } def replaceBlock(oldBlock: Block, block: Block): Block = { - require(_blocks.contains(oldBlock) || block == returnBlock) if (oldBlock ne block) { + require(_blocks.contains(oldBlock) || block == returnBlock) removeBlocks(oldBlock) addBlocks(block) } @@ -196,13 +238,11 @@ class Procedure private ( * @param newBlocks the new set of blocks * @return an iterator to the new block set */ - def replaceBlocks(newBlocks: Iterable[Block]): Iterator[Block] = { - _blocks.clear + def replaceBlocks(newBlocks: Iterable[Block]): Unit = { + removeBlocks(_blocks) addBlocks(newBlocks) - blocks } - def removeBlocks(block: Block): Block = { block.deParent() _blocks.remove(block) @@ -215,10 +255,10 @@ class Procedure private ( } def addCaller(c: Call): Unit = { - _callers.addBinding(c.parent.parent, c) + _callers.remove(c) } - def callers(): Iterable[Procedure] = _callers.keySet + def callers(): Iterable[Procedure] = _callers.map(_.parent.parent).toSet[Procedure] var modifies: mutable.Set[Global] = mutable.Set() @@ -287,81 +327,47 @@ class Parameter(var name: String, var size: Int, var value: Register) { class Block private (var label: String, var address: Option[Int], val statements: IntrusiveList[Statement], - // invariant: all Goto targets are disjoint - private var _jump: Option[Jump], - //private val _calls: IntrusiveList[Call], + private var _jump: Jump, val incomingJumps: mutable.HashSet[Block], ) extends IntrusiveListElement, HasParent[Procedure] { statements.foreach(_.setParent(this)) - _jump.foreach(_.setParent(this)) + _jump.setParent(this) + statements.onInsert = x => x.setParent(this) statements.onRemove = x => x.deParent() - statements.onInsert = parenter - - def parenter(c: Statement): Statement = { - c.setParent(this) - c - } def this(label: String, address: Option[Int], statements: IterableOnce[Statement], jump: Jump) = { - this(label, address, IntrusiveList.from(statements), Some(jump), mutable.HashSet.empty) - _jump.foreach(_.setParent(this)) + this(label, address, IntrusiveList.from(statements), jump, mutable.HashSet.empty) } def this(label: String, address: Option[Int], statements: IterableOnce[Statement]) = { - this(label, address, IntrusiveList.from(statements), None, mutable.HashSet.empty) + this(label, address, IntrusiveList.from(statements), GoTo(Seq(), Some(label + "_unknown")), mutable.HashSet.empty) } def this(label: String, address: Option[Int] = None) = { - this(label, address, IntrusiveList(), None, mutable.HashSet.empty) + this(label, address, IntrusiveList(), GoTo(Seq(), Some(label + "_unknown")), mutable.HashSet.empty) } - def jump: Jump = _jump.get - def jumpSet: Set[Jump] = _jump.toSet + def jump: Jump = _jump def addGoToTargets(targets: mutable.Set[Block]): this.type = { - require(_jump.isDefined && _jump.get.isInstanceOf[GoTo]) - _jump.foreach(_.asInstanceOf[GoTo].addAllTargets(targets)) + require(_jump.parent == this && _jump.isInstanceOf[GoTo]) + _jump.asInstanceOf[GoTo].addAllTargets(targets) this } def replaceJump(j: Jump): this.type = { - _jump.foreach(_.deParent()) + _jump.deParent() j.setParent(this) - _jump = Some(j) + _jump = j this } def isReturn: Boolean = this == parent.returnBlock -// def replaceGoTo(targets: Iterable[Block], label: Option[String] = None): this.type = { -// _jump.foreach(_.deParent) -// _jump = Some(GoTo(targets, label).setParent(this)) -// this -// } -// -// def replaceDirectCall(target: Procedure, returnTarget: Option[Block] = None, label: Option[String] = None): this.type = { -// _jump.foreach(_.deParent) -// _jump = Some(DirectCall(target, returnTarget, label).setParent(this)) -// this -// } -// -// def addIndirectCall(target: Variable, returnTarget: Option[Block] = None, label: Option[String] = None): this.type = { -// _jump.foreach(_.deParent) -// _jump = Some(IndirectCall(target, returnTarget, label).setParent(this)) -// this -// } - - def removeJump(): this.type = { - _jump.foreach(_.deParent()) - _jump = None - this - } - - def predecessors: immutable.Set[Block] = incomingJumps to immutable.Set - def calls: Set[Procedure] = _jump.toSet.flatMap(_.calls) + def calls: Set[Procedure] = _jump.calls def modifies: Set[Global] = statements.flatMap(_.modifies).toSet //def locals: Set[Variable] = statements.flatMap(_.locals).toSet ++ jumps.flatMap(_.locals).toSet @@ -382,7 +388,10 @@ class Block private (var label: String, case _ => false override def hashCode(): Int = label.hashCode() -} + + override def linkParent(p: Procedure): Unit = () // TODO; cannot support moving blocks around + override def unlinkParent(): Unit = () // TODO + } /** diff --git a/src/main/scala/ir/Statement.scala b/src/main/scala/ir/Statement.scala index a3b5d546e..9a29d963b 100644 --- a/src/main/scala/ir/Statement.scala +++ b/src/main/scala/ir/Statement.scala @@ -1,7 +1,6 @@ package ir import intrusiveList.IntrusiveListElement -import scala.collection.mutable.ArrayBuffer import collection.mutable /* @@ -99,12 +98,13 @@ class GoTo private (private var _targets: mutable.Set[Block], override val label } } - override def setParent(b: Block): Unit = { - setParentValue(b) + override def linkParent(b: Block): Unit = { _targets.foreach(_.incomingJumps.add(parent)) } - override def deParent(): Unit = targets.foreach(_.incomingJumps.remove(parent)) + override def unlinkParent(): Unit = { + targets.foreach(_.incomingJumps.remove(parent)) + } def removeTarget(t: Block): Unit = { @@ -136,13 +136,11 @@ class DirectCall(val target: Procedure, var returnTarget: Option[Block], overri override def toString: String = s"${labelStr}DirectCall(${target.name}, ${returnTarget.map(_.label)})" override def acceptVisit(visitor: Visitor): Jump = visitor.visitDirectCall(this) - override def setParent(p: Block): Unit = { - super.setParent(p) + override def linkParent(p: Block): Unit = { target.addCaller(this) } - override def deParent(): Unit = target.removeCaller(this) - + override def unlinkParent(): Unit = target.removeCaller(this) } object DirectCall: @@ -155,14 +153,6 @@ class IndirectCall(var target: Variable, var returnTarget: Option[Block], overri } */ override def toString: String = s"${labelStr}IndirectCall($target, ${returnTarget.map(_.label)})" override def acceptVisit(visitor: Visitor): Jump = visitor.visitIndirectCall(this) - - override def equals(obj: Any): Boolean = { - obj match - case c: IndirectCall => c.parent == parent && c.target == target && c.returnTarget == c.returnTarget && c.label == label - case o: Any => false - } - - override def hashCode(): Int = toString.hashCode } object IndirectCall: diff --git a/src/main/scala/translating/BAPToIR.scala b/src/main/scala/translating/BAPToIR.scala index ff25e2ec3..b27824d56 100644 --- a/src/main/scala/translating/BAPToIR.scala +++ b/src/main/scala/translating/BAPToIR.scala @@ -20,9 +20,12 @@ class BAPToIR(var program: BAPProgram, mainAddress: Int) { val procedures: ArrayBuffer[Procedure] = ArrayBuffer() for (s <- program.subroutines) { val procedure = Procedure(s.name, Some(s.address)) + val t = Block("terminate", None) + t.replaceJump(GoTo(Seq(t))) + procedure.addBlocks(t) for (b <- s.blocks) { - val block = Block(b.label, b.address, ArrayBuffer()) + val block = Block(b.label, b.address, ArrayBuffer(), GoTo(Seq(t))) procedure.addBlocks(block) labelToBlock.addOne(b.label, block) } @@ -50,6 +53,11 @@ class BAPToIR(var program: BAPProgram, mainAddress: Int) { procedure.addBlocks(newBlocks) block.replaceJump(jump) } + + // Set entry block to the block with the same address as the procedure or the first in sequence + procedure.entryBlock = procedure.blocks.find(b => b.address == procedure.address) + if procedure.entryBlock.isEmpty then procedure.entryBlock = procedure.blocks.headOption + } val memorySections: ArrayBuffer[MemorySection] = ArrayBuffer() diff --git a/src/main/scala/translating/IRToBoogie.scala b/src/main/scala/translating/IRToBoogie.scala index 7bc7c0109..94b945af9 100644 --- a/src/main/scala/translating/IRToBoogie.scala +++ b/src/main/scala/translating/IRToBoogie.scala @@ -323,9 +323,9 @@ class IRToBoogie(var program: Program, var spec: Specification) { def translateProcedure(p: Procedure, readOnlyMemory: List[BExpr]): BProcedure = { - val body = p.blocks.map(b => translateBlock(b)) + val body = p.blocks.map(translateBlock).toList - val callsRely: Boolean = body.flatten(_.body).exists(_ match + val callsRely: Boolean = body.flatMap(_.body).exists(_ match case BProcedureCall("rely", lhs, params, comment) => true case _ => false) @@ -387,7 +387,7 @@ class IRToBoogie(var program: Program, var spec: Specification) { def translateBlock(b: Block): BBlock = { val captureState = captureStateStatement(s"${b.label}") - val cmds = List(captureState) ++ (b.statements.flatMap(s => translate(s)) ++ translate(b.jump)) + val cmds = List(captureState) ++ b.statements.flatMap(s => translate(s)) ++ translate(b.jump) BBlock(b.label, cmds) } diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index 8784893b9..0ad3f9968 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -456,7 +456,7 @@ object RunUtils { for (t <- targets) { val assume = Assume(BinaryExpr(BVEQ, indirectCall.target, BitVecLiteral(t.address.get, 64)), null) val newLabel: String = block.label + t.name - val bl = Block(newLabel, None, ArrayBuffer(assume)).replaceJump(DirectCall(t, indirectCall.returnTarget, None)) + val bl = Block(newLabel, None, ArrayBuffer(assume), DirectCall(t, indirectCall.returnTarget, None)) //val directCall = DirectCall(t, indirectCall.returnTarget, null) newBlocks.append(bl) }