Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Boogie-Style IR Control Flow #140

Merged
merged 6 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 16 additions & 54 deletions src/main/scala/analysis/Cfg.scala
Original file line number Diff line number Diff line change
Expand Up @@ -689,10 +689,10 @@ class ProgramCfgFactory:
case i if i > 0 =>
// Block contains some statements
val endStmt: CfgCommandNode = visitStmts(block.statements, prevBlockEnd, cond)
visitJumps(block.jumps, endStmt, TrueLiteral, solitary = false)
visitJump(block.jump, endStmt, TrueLiteral, solitary = false)
case _ =>
// Only jumps in this block
visitJumps(block.jumps, prevBlockEnd, cond, solitary = true)
visitJump(block.jump, prevBlockEnd, cond, solitary = true)
}

/** If a block has statements, we add them to the CFG. Blocks in this case are basic blocks, so we know
Expand Down Expand Up @@ -744,9 +744,9 @@ class ProgramCfgFactory:
* @param solitary
* `True` if this block contains no statements, `False` otherwise
*/
def visitJumps(jmps: ArrayBuffer[Jump], prevNode: CfgNode, cond: Expr, solitary: Boolean): Unit = {
def visitJump(jmp: Jump, prevNode: CfgNode, cond: Expr, solitary: Boolean): Unit = {

val jmpNode: CfgJumpNode = CfgJumpNode(data = jmps.head, block = block, parent = funcEntryNode)
val jmpNode: CfgJumpNode = CfgJumpNode(data = jmp, block = block, parent = funcEntryNode)
var precNode: CfgNode = prevNode

if (solitary) {
Expand All @@ -758,7 +758,7 @@ class ProgramCfgFactory:

Currently we display these nodes in the DOT view of the CFG, however these could be hidden if desired.
*/
jmps.head match {
jmp match {
case jmp: GoTo =>
// `GoTo`s are just edges, so introduce a fake `start of block` that can be jmp'd to
val ghostNode = CfgGhostNode(block = block, parent = funcEntryNode, data = NOP(jmp.label))
Expand All @@ -773,40 +773,8 @@ class ProgramCfgFactory:

// TODO this is not a robust approach

jmps.head match {
case goto: GoTo =>
// Process first jump
var targetBlock: Block = goto.target
var targetCond: Expr = goto.condition match {
case Some(c) => c
case None => TrueLiteral
}

// Jump to target block
if (visitedBlocks.contains(targetBlock)) {
val targetBlockEntry: CfgCommandNode = visitedBlocks(targetBlock)
cfg.addEdge(precNode, targetBlockEntry, targetCond)
} else {
visitBlock(targetBlock, precNode, targetCond)
}

/* TODO it is not a safe assumption that there are a maximum of two jumps, or that a GoTo will follow a GoTo
*/
if (targetCond != TrueLiteral) {
val secondGoto: GoTo = jmps.tail.head.asInstanceOf[GoTo]
targetBlock = secondGoto.target
// IR doesn't store negation of condition, so we must do it manually
targetCond = negateConditional(targetCond)

// Jump to target block
if (visitedBlocks.contains(targetBlock)) {
val targetBlockEntry: CfgCommandNode = visitedBlocks(targetBlock)
cfg.addEdge(precNode, targetBlockEntry, targetCond)
} else {
visitBlock(targetBlock, precNode, targetCond)
}
}
case n: NonDetGoTo =>
jmp match {
case n: GoTo =>
for (targetBlock <- n.targets) {
if (visitedBlocks.contains(targetBlock)) {
val targetBlockEntry: CfgCommandNode = visitedBlocks(targetBlock)
Expand All @@ -818,16 +786,14 @@ class ProgramCfgFactory:
case dCall: DirectCall =>
val targetProc: Procedure = dCall.target

// Branch to this call
val calls = jmps.filter(_.isInstanceOf[DirectCall]).map(x => CfgJumpNode(data = x, block = block, parent = funcEntryNode))
val callNode = CfgJumpNode(data = dCall, block = block, parent = funcEntryNode)

calls.foreach(node => {
cfg.addEdge(precNode, node)
// Branch to this call
cfg.addEdge(precNode, callNode)

procToCalls(proc) += node
procToCallers(targetProc) += node
callToNodes(funcEntryNode) += node
})
procToCalls(proc) += callNode
procToCallers(targetProc) += callNode
callToNodes(funcEntryNode) += callNode

// Record call association

Expand All @@ -837,9 +803,7 @@ class ProgramCfgFactory:
// Add intermediary return node (split call into call and return)
val callRet = CfgCallReturnNode()

calls.foreach(node => {
cfg.addEdge(node, callRet)
})
cfg.addEdge(callNode, callRet)
if (visitedBlocks.contains(retBlock)) {
val retBlockEntry: CfgCommandNode = visitedBlocks(retBlock)
cfg.addEdge(callRet, retBlockEntry)
Expand All @@ -848,9 +812,7 @@ class ProgramCfgFactory:
}
case None =>
val noReturn = CfgCallNoReturnNode()
calls.foreach(node => {
cfg.addEdge(node, noReturn)
})
cfg.addEdge(callNode, noReturn)
cfg.addEdge(noReturn, funcExitNode)
}
case iCall: IndirectCall =>
Expand Down Expand Up @@ -889,7 +851,7 @@ class ProgramCfgFactory:
cfg.addEdge(jmpNode, noReturn)
cfg.addEdge(noReturn, funcExitNode)
}
case _ => assert(false, s"unexpected jump encountered, jumps: $jmps")
case _ => assert(false, s"unexpected jump encountered, jump: $jmp")
} // `jmps.head` match
} // `visitJumps` function
} // `visitBlocks` function
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/analysis/SteensgaardAnalysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class SteensgaardAnalysis(program: Program, constantPropResult: Map[CfgNode, Map

case block: Block =>
block.statements.foreach(visit(_, ()))
block.jumps.foreach(visit(_, ()))
visit(block.jump, ())

case _ => // ignore other kinds of nodes

Expand Down
15 changes: 9 additions & 6 deletions src/main/scala/bap/BAPStatement.scala
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
package bap

sealed trait BAPJump
sealed trait BAPJump {
val line: String
val instruction: String
}

case class BAPDirectCall(
target: String,
returnTarget: Option[String],
line: String,
instruction: String
override val line: String,
override val instruction: String
) extends BAPJump

case class BAPIndirectCall(
target: BAPVar,
returnTarget: Option[String],
line: String,
instruction: String
override val line: String,
override val instruction: String
) extends BAPJump

case class BAPGoTo(target: String, condition: BAPExpr, line: String, instruction: String) extends BAPJump
case class BAPGoTo(target: String, condition: BAPExpr, override val line: String, override val instruction: String) extends BAPJump

sealed trait BAPStatement

Expand Down
72 changes: 38 additions & 34 deletions src/main/scala/ir/Interpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -188,44 +188,44 @@ class Interpreter() {

// Block.Jump
breakable {
for ((jump, index) <- b.jumps.zipWithIndex) {
Logger.debug(s"jump[$index]:")
jump match {
case gt: GoTo =>
Logger.debug(s"$gt")
gt.condition match {
case Some(value) =>
eval(value, regs) match {
case TrueLiteral =>
nextBlock = Some(gt.target)
break
case FalseLiteral =>
}
Logger.debug(s"jump:")
b.jump match {
case gt: GoTo =>
Logger.debug(s"$gt")
for (g <- gt.targets) {
val condition: Option[Expr] = g.statements.headOption.collect { case a: Assume => a.body }
condition match {
case Some(e) => eval(e, regs) match {
case TrueLiteral =>
nextBlock = Some(g)
break
case _ =>
}
case None =>
nextBlock = Some(gt.target)
nextBlock = Some(g)
break
}
case dc: DirectCall =>
Logger.debug(s"$dc")
if (dc.returnTarget.isDefined) {
returnBlock.push(dc.returnTarget.get)
}
interpretProcedure(dc.target)
break
case ic: IndirectCall =>
Logger.debug(s"$ic")
if (ic.target == Register("R30", BitVecType(64)) && ic.returnTarget.isEmpty) {
if (returnBlock.nonEmpty) {
nextBlock = Some(returnBlock.pop())
} else {
//Exit Interpreter
nextBlock = None
}
break
}
case dc: DirectCall =>
Logger.debug(s"$dc")
if (dc.returnTarget.isDefined) {
returnBlock.push(dc.returnTarget.get)
}
interpretProcedure(dc.target)
break
case ic: IndirectCall =>
Logger.debug(s"$ic")
if (ic.target == Register("R30", BitVecType(64)) && ic.returnTarget.isEmpty) {
if (returnBlock.nonEmpty) {
nextBlock = Some(returnBlock.pop())
} else {
???
//Exit Interpreter
nextBlock = None
}
}
break
} else {
???
}
}
}
}
Expand Down Expand Up @@ -253,7 +253,11 @@ class Interpreter() {

case assert: Assert =>
Logger.debug(assert)
???
// TODO

case assume: Assume =>
Logger.debug(assume)
// TODO, but already taken into effect if it is a branch condition
}
}

Expand Down
20 changes: 8 additions & 12 deletions src/main/scala/ir/Program.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import analysis.BitVectorEval

class Program(var procedures: ArrayBuffer[Procedure], var mainProcedure: Procedure, var initialMemory: ArrayBuffer[MemorySection], var readOnlyMemory: ArrayBuffer[MemorySection]) {

// This shouldn't be run before indirect calls are resolved?
// This shouldn't be run before indirect calls are resolved
def stripUnreachableFunctions(): Unit = {
val functionToChildren = procedures.map(f => f.name -> f.calls.map(_.name)).toMap

Expand Down Expand Up @@ -170,13 +170,10 @@ class Procedure(
}
}
visitedBlocks.add(b)
for (j <- b.jumps) {
j match {
case g: GoTo => visitBlock(g.target)
case d: DirectCall => d.returnTarget.foreach(visitBlock)
case i: IndirectCall => i.returnTarget.foreach(visitBlock)
case n: NonDetGoTo => n.targets.foreach(visitBlock)
}
b.jump match {
case g: GoTo => g.targets.foreach(visitBlock)
case d: DirectCall => d.returnTarget.foreach(visitBlock)
case i: IndirectCall => i.returnTarget.foreach(visitBlock)
}
}
}
Expand All @@ -187,17 +184,16 @@ class Block(
var label: String,
var address: Option[Int],
var statements: ArrayBuffer[Statement],
var jumps: ArrayBuffer[Jump]
var jump: Jump
) {
def calls: Set[Procedure] = jumps.flatMap(_.calls).toSet
def calls: Set[Procedure] = jump.calls
def modifies: Set[Global] = statements.flatMap(_.modifies).toSet
//def locals: Set[Variable] = statements.flatMap(_.locals).toSet ++ jumps.flatMap(_.locals).toSet

override def toString: String = {
// display all statements and jumps
val statementsString = statements.map(_.toString).mkString("\n")
val jumpsString = jumps.map(_.toString).mkString("\n")
s"Block $label with $statementsString\n$jumpsString"
s"Block $label with $statementsString\n$jump"
}

override def equals(obj: scala.Any): Boolean =
Expand Down
20 changes: 10 additions & 10 deletions src/main/scala/ir/Statement.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ir

import scala.collection.mutable.ArrayBuffer

trait Command {
val label: Option[String]
def labelStr: String = label match {
Expand Down Expand Up @@ -52,13 +54,16 @@ class Assert(var body: Expr, var comment: Option[String] = None, override val la
object Assert:
def unapply(a: Assert): Option[(Expr, Option[String], Option[String])] = Some(a.body, a.comment, a.label)

class Assume(var body: Expr, var comment: Option[String] = None, override val label: Option[String] = None) extends Statement {
/**
* checkSecurity is true if this is a branch condition that we want to assert has a security level of low before branching
* */
class Assume(var body: Expr, var comment: Option[String] = None, override val label: Option[String] = None, var checkSecurity: Boolean = false) extends Statement {
override def toString: String = s"${labelStr}assume $body" + comment.map(" //" + _)
override def acceptVisit(visitor: Visitor): Statement = visitor.visitAssume(this)
}

object Assume:
def unapply(a: Assume): Option[(Expr, Option[String], Option[String])] = Some(a.body, a.comment, a.label)
def unapply(a: Assume): Option[(Expr, Option[String], Option[String], Boolean)] = Some(a.body, a.comment, a.label, a.checkSecurity)

trait Jump extends Command {
def modifies: Set[Global] = Set()
Expand All @@ -67,23 +72,18 @@ trait Jump extends Command {
def acceptVisit(visitor: Visitor): Jump = throw new Exception("visitor " + visitor + " unimplemented for: " + this)
}

class GoTo(var target: Block, var condition: Option[Expr], override val label: Option[String] = None) extends Jump {
class GoTo(var targets: ArrayBuffer[Block], override val label: Option[String] = None) extends Jump {
/* override def locals: Set[Variable] = condition match {
case Some(c) => c.locals
case None => Set()
} */
override def toString: String = s"${labelStr}GoTo(${target.label}, $condition)"
override def toString: String = s"${labelStr}GoTo(${targets.map(_.label).mkString(", ")})"

override def acceptVisit(visitor: Visitor): Jump = visitor.visitGoTo(this)
}

object GoTo:
def unapply(g: GoTo): Option[(Block, Option[Expr], Option[String])] = Some(g.target, g.condition, g.label)

class NonDetGoTo(var targets: Seq[Block], override val label: Option[String] = None) extends Jump {
override def toString: String = s"${labelStr}NonDetGoTo(${targets.map(_.label).mkString(", ")})"
override def acceptVisit(visitor: Visitor): Jump = visitor.visitNonDetGoTo(this)
}
def unapply(g: GoTo): Option[(ArrayBuffer[Block], Option[String])] = Some(g.targets, g.label)

class DirectCall(var target: Procedure, var returnTarget: Option[Block], override val label: Option[String] = None) extends Jump {
/* override def locals: Set[Variable] = condition match {
Expand Down
13 changes: 2 additions & 11 deletions src/main/scala/ir/Visitor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ abstract class Visitor {
def visitJump(node: Jump): Jump = node.acceptVisit(this)

def visitGoTo(node: GoTo): Jump = {
node.condition = node.condition.map(visitExpr)
node
}

def visitNonDetGoTo(node: NonDetGoTo): Jump = {
node
}

Expand All @@ -55,9 +50,7 @@ abstract class Visitor {
for (i <- node.statements.indices) {
node.statements(i) = visitStatement(node.statements(i))
}
for (i <- node.jumps.indices) {
node.jumps(i) = visitJump(node.jumps(i))
}
node.jump = visitJump(node.jump)
node
}

Expand Down Expand Up @@ -231,9 +224,7 @@ abstract class ReadOnlyVisitor extends Visitor {
for (i <- node.statements) {
visitStatement(i)
}
for (i <- node.jumps) {
visitJump(i)
}
visitJump(node.jump)
node
}

Expand Down
Loading