Skip to content

Commit

Permalink
Merge pull request #140 from UQ-PAC/indirect-calls-nondet
Browse files Browse the repository at this point in the history
Boogie-Style IR Control Flow
  • Loading branch information
l-kent authored Nov 14, 2023
2 parents ccb7669 + 818f24a commit 87b911e
Show file tree
Hide file tree
Showing 116 changed files with 2,822 additions and 1,402 deletions.
70 changes: 16 additions & 54 deletions src/main/scala/analysis/Cfg.scala
Original file line number Diff line number Diff line change
Expand Up @@ -689,10 +689,10 @@ class ProgramCfgFactory:
case i if i > 0 =>
// Block contains some statements
val endStmt: CfgCommandNode = visitStmts(block.statements, prevBlockEnd, cond)
visitJumps(block.jumps, endStmt, TrueLiteral, solitary = false)
visitJump(block.jump, endStmt, TrueLiteral, solitary = false)
case _ =>
// Only jumps in this block
visitJumps(block.jumps, prevBlockEnd, cond, solitary = true)
visitJump(block.jump, prevBlockEnd, cond, solitary = true)
}

/** If a block has statements, we add them to the CFG. Blocks in this case are basic blocks, so we know
Expand Down Expand Up @@ -744,9 +744,9 @@ class ProgramCfgFactory:
* @param solitary
* `True` if this block contains no statements, `False` otherwise
*/
def visitJumps(jmps: ArrayBuffer[Jump], prevNode: CfgNode, cond: Expr, solitary: Boolean): Unit = {
def visitJump(jmp: Jump, prevNode: CfgNode, cond: Expr, solitary: Boolean): Unit = {

val jmpNode: CfgJumpNode = CfgJumpNode(data = jmps.head, block = block, parent = funcEntryNode)
val jmpNode: CfgJumpNode = CfgJumpNode(data = jmp, block = block, parent = funcEntryNode)
var precNode: CfgNode = prevNode

if (solitary) {
Expand All @@ -758,7 +758,7 @@ class ProgramCfgFactory:
Currently we display these nodes in the DOT view of the CFG, however these could be hidden if desired.
*/
jmps.head match {
jmp match {
case jmp: GoTo =>
// `GoTo`s are just edges, so introduce a fake `start of block` that can be jmp'd to
val ghostNode = CfgGhostNode(block = block, parent = funcEntryNode, data = NOP(jmp.label))
Expand All @@ -773,40 +773,8 @@ class ProgramCfgFactory:

// TODO this is not a robust approach

jmps.head match {
case goto: GoTo =>
// Process first jump
var targetBlock: Block = goto.target
var targetCond: Expr = goto.condition match {
case Some(c) => c
case None => TrueLiteral
}

// Jump to target block
if (visitedBlocks.contains(targetBlock)) {
val targetBlockEntry: CfgCommandNode = visitedBlocks(targetBlock)
cfg.addEdge(precNode, targetBlockEntry, targetCond)
} else {
visitBlock(targetBlock, precNode, targetCond)
}

/* TODO it is not a safe assumption that there are a maximum of two jumps, or that a GoTo will follow a GoTo
*/
if (targetCond != TrueLiteral) {
val secondGoto: GoTo = jmps.tail.head.asInstanceOf[GoTo]
targetBlock = secondGoto.target
// IR doesn't store negation of condition, so we must do it manually
targetCond = negateConditional(targetCond)

// Jump to target block
if (visitedBlocks.contains(targetBlock)) {
val targetBlockEntry: CfgCommandNode = visitedBlocks(targetBlock)
cfg.addEdge(precNode, targetBlockEntry, targetCond)
} else {
visitBlock(targetBlock, precNode, targetCond)
}
}
case n: NonDetGoTo =>
jmp match {
case n: GoTo =>
for (targetBlock <- n.targets) {
if (visitedBlocks.contains(targetBlock)) {
val targetBlockEntry: CfgCommandNode = visitedBlocks(targetBlock)
Expand All @@ -818,16 +786,14 @@ class ProgramCfgFactory:
case dCall: DirectCall =>
val targetProc: Procedure = dCall.target

// Branch to this call
val calls = jmps.filter(_.isInstanceOf[DirectCall]).map(x => CfgJumpNode(data = x, block = block, parent = funcEntryNode))
val callNode = CfgJumpNode(data = dCall, block = block, parent = funcEntryNode)

calls.foreach(node => {
cfg.addEdge(precNode, node)
// Branch to this call
cfg.addEdge(precNode, callNode)

procToCalls(proc) += node
procToCallers(targetProc) += node
callToNodes(funcEntryNode) += node
})
procToCalls(proc) += callNode
procToCallers(targetProc) += callNode
callToNodes(funcEntryNode) += callNode

// Record call association

Expand All @@ -837,9 +803,7 @@ class ProgramCfgFactory:
// Add intermediary return node (split call into call and return)
val callRet = CfgCallReturnNode()

calls.foreach(node => {
cfg.addEdge(node, callRet)
})
cfg.addEdge(callNode, callRet)
if (visitedBlocks.contains(retBlock)) {
val retBlockEntry: CfgCommandNode = visitedBlocks(retBlock)
cfg.addEdge(callRet, retBlockEntry)
Expand All @@ -848,9 +812,7 @@ class ProgramCfgFactory:
}
case None =>
val noReturn = CfgCallNoReturnNode()
calls.foreach(node => {
cfg.addEdge(node, noReturn)
})
cfg.addEdge(callNode, noReturn)
cfg.addEdge(noReturn, funcExitNode)
}
case iCall: IndirectCall =>
Expand Down Expand Up @@ -889,7 +851,7 @@ class ProgramCfgFactory:
cfg.addEdge(jmpNode, noReturn)
cfg.addEdge(noReturn, funcExitNode)
}
case _ => assert(false, s"unexpected jump encountered, jumps: $jmps")
case _ => assert(false, s"unexpected jump encountered, jump: $jmp")
} // `jmps.head` match
} // `visitJumps` function
} // `visitBlocks` function
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/analysis/SteensgaardAnalysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class SteensgaardAnalysis(program: Program, constantPropResult: Map[CfgNode, Map

case block: Block =>
block.statements.foreach(visit(_, ()))
block.jumps.foreach(visit(_, ()))
visit(block.jump, ())

case _ => // ignore other kinds of nodes

Expand Down
15 changes: 9 additions & 6 deletions src/main/scala/bap/BAPStatement.scala
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
package bap

sealed trait BAPJump
sealed trait BAPJump {
val line: String
val instruction: String
}

case class BAPDirectCall(
target: String,
returnTarget: Option[String],
line: String,
instruction: String
override val line: String,
override val instruction: String
) extends BAPJump

case class BAPIndirectCall(
target: BAPVar,
returnTarget: Option[String],
line: String,
instruction: String
override val line: String,
override val instruction: String
) extends BAPJump

case class BAPGoTo(target: String, condition: BAPExpr, line: String, instruction: String) extends BAPJump
case class BAPGoTo(target: String, condition: BAPExpr, override val line: String, override val instruction: String) extends BAPJump

sealed trait BAPStatement

Expand Down
72 changes: 38 additions & 34 deletions src/main/scala/ir/Interpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -188,44 +188,44 @@ class Interpreter() {

// Block.Jump
breakable {
for ((jump, index) <- b.jumps.zipWithIndex) {
Logger.debug(s"jump[$index]:")
jump match {
case gt: GoTo =>
Logger.debug(s"$gt")
gt.condition match {
case Some(value) =>
eval(value, regs) match {
case TrueLiteral =>
nextBlock = Some(gt.target)
break
case FalseLiteral =>
}
Logger.debug(s"jump:")
b.jump match {
case gt: GoTo =>
Logger.debug(s"$gt")
for (g <- gt.targets) {
val condition: Option[Expr] = g.statements.headOption.collect { case a: Assume => a.body }
condition match {
case Some(e) => eval(e, regs) match {
case TrueLiteral =>
nextBlock = Some(g)
break
case _ =>
}
case None =>
nextBlock = Some(gt.target)
nextBlock = Some(g)
break
}
case dc: DirectCall =>
Logger.debug(s"$dc")
if (dc.returnTarget.isDefined) {
returnBlock.push(dc.returnTarget.get)
}
interpretProcedure(dc.target)
break
case ic: IndirectCall =>
Logger.debug(s"$ic")
if (ic.target == Register("R30", BitVecType(64)) && ic.returnTarget.isEmpty) {
if (returnBlock.nonEmpty) {
nextBlock = Some(returnBlock.pop())
} else {
//Exit Interpreter
nextBlock = None
}
break
}
case dc: DirectCall =>
Logger.debug(s"$dc")
if (dc.returnTarget.isDefined) {
returnBlock.push(dc.returnTarget.get)
}
interpretProcedure(dc.target)
break
case ic: IndirectCall =>
Logger.debug(s"$ic")
if (ic.target == Register("R30", BitVecType(64)) && ic.returnTarget.isEmpty) {
if (returnBlock.nonEmpty) {
nextBlock = Some(returnBlock.pop())
} else {
???
//Exit Interpreter
nextBlock = None
}
}
break
} else {
???
}
}
}
}
Expand Down Expand Up @@ -253,7 +253,11 @@ class Interpreter() {

case assert: Assert =>
Logger.debug(assert)
???
// TODO

case assume: Assume =>
Logger.debug(assume)
// TODO, but already taken into effect if it is a branch condition
}
}

Expand Down
20 changes: 8 additions & 12 deletions src/main/scala/ir/Program.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import analysis.BitVectorEval

class Program(var procedures: ArrayBuffer[Procedure], var mainProcedure: Procedure, var initialMemory: ArrayBuffer[MemorySection], var readOnlyMemory: ArrayBuffer[MemorySection]) {

// This shouldn't be run before indirect calls are resolved?
// This shouldn't be run before indirect calls are resolved
def stripUnreachableFunctions(): Unit = {
val functionToChildren = procedures.map(f => f.name -> f.calls.map(_.name)).toMap

Expand Down Expand Up @@ -170,13 +170,10 @@ class Procedure(
}
}
visitedBlocks.add(b)
for (j <- b.jumps) {
j match {
case g: GoTo => visitBlock(g.target)
case d: DirectCall => d.returnTarget.foreach(visitBlock)
case i: IndirectCall => i.returnTarget.foreach(visitBlock)
case n: NonDetGoTo => n.targets.foreach(visitBlock)
}
b.jump match {
case g: GoTo => g.targets.foreach(visitBlock)
case d: DirectCall => d.returnTarget.foreach(visitBlock)
case i: IndirectCall => i.returnTarget.foreach(visitBlock)
}
}
}
Expand All @@ -187,17 +184,16 @@ class Block(
var label: String,
var address: Option[Int],
var statements: ArrayBuffer[Statement],
var jumps: ArrayBuffer[Jump]
var jump: Jump
) {
def calls: Set[Procedure] = jumps.flatMap(_.calls).toSet
def calls: Set[Procedure] = jump.calls
def modifies: Set[Global] = statements.flatMap(_.modifies).toSet
//def locals: Set[Variable] = statements.flatMap(_.locals).toSet ++ jumps.flatMap(_.locals).toSet

override def toString: String = {
// display all statements and jumps
val statementsString = statements.map(_.toString).mkString("\n")
val jumpsString = jumps.map(_.toString).mkString("\n")
s"Block $label with $statementsString\n$jumpsString"
s"Block $label with $statementsString\n$jump"
}

override def equals(obj: scala.Any): Boolean =
Expand Down
20 changes: 10 additions & 10 deletions src/main/scala/ir/Statement.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ir

import scala.collection.mutable.ArrayBuffer

trait Command {
val label: Option[String]
def labelStr: String = label match {
Expand Down Expand Up @@ -52,13 +54,16 @@ class Assert(var body: Expr, var comment: Option[String] = None, override val la
object Assert:
def unapply(a: Assert): Option[(Expr, Option[String], Option[String])] = Some(a.body, a.comment, a.label)

class Assume(var body: Expr, var comment: Option[String] = None, override val label: Option[String] = None) extends Statement {
/**
* checkSecurity is true if this is a branch condition that we want to assert has a security level of low before branching
* */
class Assume(var body: Expr, var comment: Option[String] = None, override val label: Option[String] = None, var checkSecurity: Boolean = false) extends Statement {
override def toString: String = s"${labelStr}assume $body" + comment.map(" //" + _)
override def acceptVisit(visitor: Visitor): Statement = visitor.visitAssume(this)
}

object Assume:
def unapply(a: Assume): Option[(Expr, Option[String], Option[String])] = Some(a.body, a.comment, a.label)
def unapply(a: Assume): Option[(Expr, Option[String], Option[String], Boolean)] = Some(a.body, a.comment, a.label, a.checkSecurity)

trait Jump extends Command {
def modifies: Set[Global] = Set()
Expand All @@ -67,23 +72,18 @@ trait Jump extends Command {
def acceptVisit(visitor: Visitor): Jump = throw new Exception("visitor " + visitor + " unimplemented for: " + this)
}

class GoTo(var target: Block, var condition: Option[Expr], override val label: Option[String] = None) extends Jump {
class GoTo(var targets: ArrayBuffer[Block], override val label: Option[String] = None) extends Jump {
/* override def locals: Set[Variable] = condition match {
case Some(c) => c.locals
case None => Set()
} */
override def toString: String = s"${labelStr}GoTo(${target.label}, $condition)"
override def toString: String = s"${labelStr}GoTo(${targets.map(_.label).mkString(", ")})"

override def acceptVisit(visitor: Visitor): Jump = visitor.visitGoTo(this)
}

object GoTo:
def unapply(g: GoTo): Option[(Block, Option[Expr], Option[String])] = Some(g.target, g.condition, g.label)

class NonDetGoTo(var targets: Seq[Block], override val label: Option[String] = None) extends Jump {
override def toString: String = s"${labelStr}NonDetGoTo(${targets.map(_.label).mkString(", ")})"
override def acceptVisit(visitor: Visitor): Jump = visitor.visitNonDetGoTo(this)
}
def unapply(g: GoTo): Option[(ArrayBuffer[Block], Option[String])] = Some(g.targets, g.label)

class DirectCall(var target: Procedure, var returnTarget: Option[Block], override val label: Option[String] = None) extends Jump {
/* override def locals: Set[Variable] = condition match {
Expand Down
13 changes: 2 additions & 11 deletions src/main/scala/ir/Visitor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ abstract class Visitor {
def visitJump(node: Jump): Jump = node.acceptVisit(this)

def visitGoTo(node: GoTo): Jump = {
node.condition = node.condition.map(visitExpr)
node
}

def visitNonDetGoTo(node: NonDetGoTo): Jump = {
node
}

Expand All @@ -55,9 +50,7 @@ abstract class Visitor {
for (i <- node.statements.indices) {
node.statements(i) = visitStatement(node.statements(i))
}
for (i <- node.jumps.indices) {
node.jumps(i) = visitJump(node.jumps(i))
}
node.jump = visitJump(node.jump)
node
}

Expand Down Expand Up @@ -231,9 +224,7 @@ abstract class ReadOnlyVisitor extends Visitor {
for (i <- node.statements) {
visitStatement(i)
}
for (i <- node.jumps) {
visitJump(i)
}
visitJump(node.jump)
node
}

Expand Down
Loading

0 comments on commit 87b911e

Please sign in to comment.