From 182c9e8a82e963cd5bf3e1ca6b72f11b3d5df12c Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Mon, 19 Aug 2024 11:55:23 +1000 Subject: [PATCH 01/62] refactor call to a statement & add unreachable and return jumps --- src/main/scala/analysis/Cfg.scala | 137 ++++++++++-------- src/main/scala/analysis/IDEAnalysis.scala | 4 +- .../analysis/InterLiveVarsAnalysis.scala | 6 +- .../analysis/IntraLiveVarsAnalysis.scala | 4 +- src/main/scala/analysis/VSA.scala | 10 +- .../scala/analysis/solvers/IDESolver.scala | 36 ++--- src/main/scala/ir/IRCursor.scala | 63 ++++---- src/main/scala/ir/Interpreter.scala | 99 +++++++------ src/main/scala/ir/Program.scala | 74 ++++------ src/main/scala/ir/Statement.scala | 42 +++--- src/main/scala/ir/Visitor.scala | 31 +--- src/main/scala/ir/cilvisitor/CILVisitor.scala | 9 +- src/main/scala/ir/dsl/DSL.scala | 29 ++-- .../scala/ir/transforms/ReplaceReturn.scala | 52 +++++++ src/main/scala/translating/BAPToIR.scala | 21 +-- src/main/scala/translating/GTIRBToIR.scala | 54 ++++--- src/main/scala/translating/ILtoIL.scala | 4 +- src/main/scala/translating/IRToBoogie.scala | 55 +++---- src/main/scala/util/RunUtils.scala | 98 +++++++------ 19 files changed, 441 insertions(+), 387 deletions(-) create mode 100644 src/main/scala/ir/transforms/ReplaceReturn.scala diff --git a/src/main/scala/analysis/Cfg.scala b/src/main/scala/analysis/Cfg.scala index d508353c1..59d512cd2 100644 --- a/src/main/scala/analysis/Cfg.scala +++ b/src/main/scala/analysis/Cfg.scala @@ -174,7 +174,7 @@ class CfgStatementNode( /** CFG's representation of a jump. This is used as a general jump node, for both indirect and direct calls. */ class CfgJumpNode( - val data: Jump, + val data: Jump | DirectCall | IndirectCall, val block: Block, val parent: CfgFunctionEntryNode ) extends CfgCommandNode: @@ -486,6 +486,72 @@ class ProgramCfgFactory: cfg.addEdge(prevNode, firstNode) visitedBlocks += (block -> firstNode) // This is guaranteed to be entrance to block if we are here + val statements = List.from(stmts).map(s => s match { + case d: DirectCall => CfgJumpNode(d, block, funcEntryNode) + case d: IndirectCall => CfgJumpNode(d, block, funcEntryNode) + case o => CfgStatementNode(o, block, funcEntryNode) + }) + val succs = if (statements.nonEmpty) then statements.zip(statements.tail ++ List(CfgJumpNode(statements.head.data.parent.jump, block, funcEntryNode))) else List() + + for ((s,nexts) <- succs) { + s.data match { + case dCall: DirectCall => + + var precNode = prevNode + + val targetProc: Procedure = dCall.target + funcEntryNode.callers.add(procToCfg(targetProc)._1) + + val callNode = CfgJumpNode(dCall, block, funcEntryNode) + + // Branch to this call + cfg.addEdge(precNode, callNode) + + procToCalls(proc) += callNode + procToCallers(targetProc) += callNode + callToNodes(funcEntryNode) += callNode + + // Record call association + + // Jump to return location + val returnTarget = nexts + // Add intermediary return node (split call into call and return) + val callRet = CfgCallReturnNode() + cfg.addEdge(callNode, callRet) + cfg.addEdge(callRet, returnTarget) + case iCall: IndirectCall => + Logger.debug(s"Indirect call found: $iCall in ${proc.name}") + var precNode = prevNode + + val jmpNode = CfgJumpNode(iCall, block, funcEntryNode) + // Branch to this call + cfg.addEdge(precNode, jmpNode) + + // Record call association + procToCalls(proc) += jmpNode + callToNodes(funcEntryNode) += jmpNode + + // R30 is the link register - this stores the address to return to. + // For now just add a node expressing that we are to return to the previous context. + if (iCall.target == Register("R30", 64)) { + val returnNode = CfgProcedureReturnNode() + cfg.addEdge(jmpNode, returnNode) + cfg.addEdge(returnNode, funcExitNode) + } + + val callRet = CfgCallReturnNode() + cfg.addEdge(jmpNode, callRet) + val returnTarget = nexts + cfg.addEdge(callRet, jmpNode) + case h: Halt => { + assert(false); + // not possible since s is only Statement. + } + case _ => () + } + } + + if (stmts.size == 1) { return firstNode } @@ -548,42 +614,10 @@ class ProgramCfgFactory: visitBlock(targetBlock, precNode) } } - case dCall: DirectCall => - val targetProc: Procedure = dCall.target - funcEntryNode.callers.add(procToCfg(targetProc)._1) - - val callNode = CfgJumpNode(dCall, block, funcEntryNode) - - // Branch to this call - cfg.addEdge(precNode, callNode) - - procToCalls(proc) += callNode - procToCallers(targetProc) += callNode - callToNodes(funcEntryNode) += callNode - - // Record call association - - // Jump to return location - dCall.returnTarget match { - case Some(retBlock) => - // Add intermediary return node (split call into call and return) - val callRet = CfgCallReturnNode() - - cfg.addEdge(callNode, callRet) - if (visitedBlocks.contains(retBlock)) { - val retBlockEntry: CfgCommandNode = visitedBlocks(retBlock) - cfg.addEdge(callRet, retBlockEntry) - } else { - visitBlock(retBlock, callRet) - } - case None => - val noReturn = CfgCallNoReturnNode() - cfg.addEdge(callNode, noReturn) - cfg.addEdge(noReturn, funcExitNode) - } - case iCall: IndirectCall => - Logger.debug(s"Indirect call found: $iCall in ${proc.name}") - + case h: Halt => { + cfg.addEdge(jmpNode, funcExitNode) + } + case r: Return => // Branch to this call cfg.addEdge(precNode, jmpNode) @@ -591,32 +625,9 @@ class ProgramCfgFactory: procToCalls(proc) += jmpNode callToNodes(funcEntryNode) += jmpNode - // R30 is the link register - this stores the address to return to. - // For now just add a node expressing that we are to return to the previous context. - if (iCall.target == Register("R30", 64)) { - val returnNode = CfgProcedureReturnNode() - cfg.addEdge(jmpNode, returnNode) - cfg.addEdge(returnNode, funcExitNode) - return - } - - // Jump to return location - iCall.returnTarget match { - case Some(retBlock) => // Add intermediary return node (split call into call and return) - val callRet = CfgCallReturnNode() - cfg.addEdge(jmpNode, callRet) - - if (visitedBlocks.contains(retBlock)) { - val retBlockEntry = visitedBlocks(retBlock) - cfg.addEdge(callRet, retBlockEntry) - } else { - visitBlock(retBlock, callRet) - } - case None => - val noReturn = CfgCallNoReturnNode() - cfg.addEdge(jmpNode, noReturn) - cfg.addEdge(noReturn, funcExitNode) - } + val returnNode = CfgProcedureReturnNode() + cfg.addEdge(jmpNode, returnNode) + cfg.addEdge(returnNode, funcExitNode) } // `jmps.head` match } // `visitJumps` function } // `visitBlocks` function diff --git a/src/main/scala/analysis/IDEAnalysis.scala b/src/main/scala/analysis/IDEAnalysis.scala index 8f2a7001c..e77627717 100644 --- a/src/main/scala/analysis/IDEAnalysis.scala +++ b/src/main/scala/analysis/IDEAnalysis.scala @@ -55,6 +55,6 @@ trait IDEAnalysis[E, EE, C, R, D, T, L <: Lattice[T]] { } // IndirectCall in these is because they are returns so that can be further tightened in future -trait ForwardIDEAnalysis[D, T, L <: Lattice[T]] extends IDEAnalysis[Procedure, IndirectCall, DirectCall, GoTo, D, T, L] +trait ForwardIDEAnalysis[D, T, L <: Lattice[T]] extends IDEAnalysis[Procedure, IndirectCall, DirectCall, Command, D, T, L] -trait BackwardIDEAnalysis[D, T, L <: Lattice[T]] extends IDEAnalysis[IndirectCall, Procedure, GoTo, DirectCall, D, T, L] +trait BackwardIDEAnalysis[D, T, L <: Lattice[T]] extends IDEAnalysis[IndirectCall, Procedure, Command, DirectCall, D, T, L] diff --git a/src/main/scala/analysis/InterLiveVarsAnalysis.scala b/src/main/scala/analysis/InterLiveVarsAnalysis.scala index 93a34d076..b20edf478 100644 --- a/src/main/scala/analysis/InterLiveVarsAnalysis.scala +++ b/src/main/scala/analysis/InterLiveVarsAnalysis.scala @@ -19,7 +19,7 @@ trait LiveVarsAnalysisFunctions extends BackwardIDEAnalysis[Variable, TwoElement val edgelattice: EdgeFunctionLattice[TwoElement, TwoElementLattice] = EdgeFunctionLattice(valuelattice) import edgelattice.{IdEdge, ConstEdge} - def edgesCallToEntry(call: GoTo, entry: IndirectCall)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { + def edgesCallToEntry(call: Command, entry: IndirectCall)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { Map(d -> IdEdge()) } @@ -27,7 +27,7 @@ trait LiveVarsAnalysisFunctions extends BackwardIDEAnalysis[Variable, TwoElement Map(d -> IdEdge()) } - def edgesCallToAfterCall(call: GoTo, aftercall: DirectCall)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { + def edgesCallToAfterCall(call: Command, aftercall: DirectCall)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { d match case Left(value) => Map() // maps all variables before the call to bottom case Right(_) => Map(d -> IdEdge()) @@ -70,7 +70,7 @@ trait LiveVarsAnalysisFunctions extends BackwardIDEAnalysis[Variable, TwoElement expr.variables.foldLeft(Map[DL, EdgeFunction[TwoElement]](d -> IdEdge())) { (mp, expVar) => mp + (Left(expVar) -> ConstEdge(TwoElementTop)) } - case IndirectCall(variable, _, _) => + case IndirectCall(variable, _) => d match case Left(value) => if value != variable then Map(d -> IdEdge()) else Map() case Right(_) => Map(d -> IdEdge(), Left(variable) -> ConstEdge(TwoElementTop)) diff --git a/src/main/scala/analysis/IntraLiveVarsAnalysis.scala b/src/main/scala/analysis/IntraLiveVarsAnalysis.scala index 1624dcdfb..f3f322321 100644 --- a/src/main/scala/analysis/IntraLiveVarsAnalysis.scala +++ b/src/main/scala/analysis/IntraLiveVarsAnalysis.scala @@ -15,7 +15,7 @@ abstract class LivenessAnalysis(program: Program) extends Analysis[Any]: case MemoryAssign(_, index, value, _, _, _) => s ++ index.variables ++ value.variables case Assume(expr, _, _, _) => s ++ expr.variables case Assert(expr, _, _) => s ++ expr.variables - case IndirectCall(variable, _, _) => s + variable + case IndirectCall(variable, _) => s + variable case c: DirectCall => s case g: GoTo => s case _ => ??? @@ -25,4 +25,4 @@ abstract class LivenessAnalysis(program: Program) extends Analysis[Any]: class IntraLiveVarsAnalysis(program: Program) extends LivenessAnalysis(program) with SimpleWorklistFixpointSolver[CFGPosition, Set[Variable], PowersetLattice[Variable]] - with IRIntraproceduralBackwardDependencies \ No newline at end of file + with IRIntraproceduralBackwardDependencies diff --git a/src/main/scala/analysis/VSA.scala b/src/main/scala/analysis/VSA.scala index 03ef8ff60..f7d9a55e7 100644 --- a/src/main/scala/analysis/VSA.scala +++ b/src/main/scala/analysis/VSA.scala @@ -121,7 +121,7 @@ trait ValueSetAnalysis(program: Program, m = m + (localAssign.lhs -> m(r)) m case None => - Logger.warn("could not find region for " + localAssign) + Logger.debug("could not find region for " + localAssign) m case e: Expr => evaluateExpression(e, constantProp(n)) match { @@ -129,7 +129,7 @@ trait ValueSetAnalysis(program: Program, m = m + (localAssign.lhs -> Set(getValueType(bv))) m case None => - Logger.warn("could not evaluate expression" + e) + Logger.debug("could not evaluate expression" + e) m } case memAssign: MemoryAssign => @@ -154,11 +154,11 @@ trait ValueSetAnalysis(program: Program, m = m + (r -> m(v)) m case _ => - Logger.warn(s"Too Complex: $storeValue") // do nothing + Logger.debug(s"Too Complex: $storeValue") // do nothing m } case None => - Logger.warn("could not find region for " + memAssign) + Logger.debug("could not find region for " + memAssign) m case _ => m @@ -207,4 +207,4 @@ class ValueSetAnalysisSolver( case _ => super.funsub(n, x) } } -} \ No newline at end of file +} diff --git a/src/main/scala/analysis/solvers/IDESolver.scala b/src/main/scala/analysis/solvers/IDESolver.scala index 5773ba047..0dd981b30 100644 --- a/src/main/scala/analysis/solvers/IDESolver.scala +++ b/src/main/scala/analysis/solvers/IDESolver.scala @@ -1,7 +1,7 @@ package analysis.solvers import analysis.{BackwardIDEAnalysis, Dependencies, EdgeFunction, EdgeFunctionLattice, ForwardIDEAnalysis, IDEAnalysis, IRInterproceduralBackwardDependencies, IRInterproceduralForwardDependencies, Lambda, Lattice, MapLattice} -import ir.{CFGPosition, Command, DirectCall, GoTo, IRWalk, IndirectCall, InterProcIRCursor, Procedure, Program, end, isAfterCall} +import ir.{CFGPosition, Command, DirectCall, GoTo, IRWalk, IndirectCall, InterProcIRCursor, Procedure, Program, end, isAfterCall, Halt, Statement, Jump} import util.Logger import scala.collection.immutable.Map @@ -12,7 +12,7 @@ import scala.collection.mutable * Adapted from Tip * https://github.com/cs-au-dk/TIP/blob/master/src/tip/solvers/IDESolver.scala */ -abstract class IDESolver[E <: Procedure | Command, EE <: Procedure | Command, C <: DirectCall | GoTo, R <: DirectCall | GoTo, D, T, L <: Lattice[T]](val program: Program, val startNode: CFGPosition) +abstract class IDESolver[E <: Procedure | Command, EE <: Procedure | Command, C <: Command, R <: Command, D, T, L <: Lattice[T]](val program: Program, val startNode: CFGPosition) extends IDEAnalysis[E, EE, C, R, D, T, L], Dependencies[CFGPosition] { protected def entryToExit(entry: E): EE @@ -204,22 +204,25 @@ abstract class IDESolver[E <: Procedure | Command, EE <: Procedure | Command, C abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) - extends IDESolver[Procedure, IndirectCall, DirectCall, GoTo, D, T, L](program, program.mainProcedure), + extends IDESolver[Procedure, IndirectCall, DirectCall, Command, D, T, L](program, program.mainProcedure), ForwardIDEAnalysis[D, T, L], IRInterproceduralForwardDependencies { protected def entryToExit(entry: Procedure): IndirectCall = entry.end.asInstanceOf[IndirectCall] protected def exitToEntry(exit: IndirectCall): Procedure = IRWalk.procedure(exit) - protected def callToReturn(call: DirectCall): GoTo = call.parent.fallthrough.get + protected def callToReturn(call: DirectCall): Command = call.successor - protected def returnToCall(ret: GoTo): DirectCall = ret.parent.jump.asInstanceOf[DirectCall] + protected def returnToCall(ret: Command): DirectCall = ret match { + case ret: Statement => ret.parent.statements.getPrev(ret).asInstanceOf[DirectCall] + case r: Jump => ret.parent.statements.last.asInstanceOf[DirectCall] + } protected def getCallee(call: DirectCall): Procedure = call.target protected def isCall(call: CFGPosition): Boolean = call match - case directCall: DirectCall if directCall.returnTarget.isDefined && directCall.target.returnBlock.isDefined => true + case directCall: DirectCall if (!directCall.successor.isInstanceOf[Halt]) => true case _ => false protected def isExit(exit: CFGPosition): Boolean = @@ -228,33 +231,32 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) case command: Command => IRWalk.procedure(command).end == command case _ => false - protected def getAfterCalls(exit: IndirectCall): Set[GoTo] = - InterProcIRCursor.succ(exit).foreach(s => assert(s.isInstanceOf[GoTo])) - InterProcIRCursor.succ(exit).filter(_.isInstanceOf[GoTo]).map(_.asInstanceOf[GoTo]) + protected def getAfterCalls(exit: IndirectCall): Set[Command] = + InterProcIRCursor.succ(exit).filter(_.isInstanceOf[Command]).map(_.asInstanceOf[Command]) } abstract class BackwardIDESolver[D, T, L <: Lattice[T]](program: Program) - extends IDESolver[IndirectCall, Procedure, GoTo, DirectCall, D, T, L](program, program.mainProcedure.end), + extends IDESolver[IndirectCall, Procedure, Command, DirectCall, D, T, L](program, program.mainProcedure.end), BackwardIDEAnalysis[D, T, L], IRInterproceduralBackwardDependencies { protected def entryToExit(entry: IndirectCall): Procedure = IRWalk.procedure(entry) protected def exitToEntry(exit: Procedure): IndirectCall = exit.end.asInstanceOf[IndirectCall] - protected def callToReturn(call: GoTo): DirectCall = call.parent.jump.asInstanceOf[DirectCall] + protected def callToReturn(call: Command): DirectCall = call match { + case ret: Statement => ret.parent.statements.getPrev(ret).asInstanceOf[DirectCall] + case r: Jump => r.parent.statements.last.asInstanceOf[DirectCall] + } - protected def returnToCall(ret: DirectCall): GoTo = ret.parent.fallthrough.get + protected def returnToCall(ret: DirectCall): Command = ret.successor - protected def getCallee(call: GoTo): IndirectCall = callToReturn(call).target.end.asInstanceOf[IndirectCall] + protected def getCallee(call: Command): IndirectCall = callToReturn(call: Command).target.end.asInstanceOf[IndirectCall] protected def isCall(call: CFGPosition): Boolean = call match - case goto: GoTo if goto.isAfterCall => - goto.parent.jump match - case directCall: DirectCall => directCall.returnTarget.isDefined && directCall.target.returnBlock.isDefined - case _ => false + case directCall: DirectCall => (!directCall.successor.isInstanceOf[Halt]) case _ => false protected def isExit(exit: CFGPosition): Boolean = diff --git a/src/main/scala/ir/IRCursor.scala b/src/main/scala/ir/IRCursor.scala index 2b5d3123e..0e1dc9d93 100644 --- a/src/main/scala/ir/IRCursor.scala +++ b/src/main/scala/ir/IRCursor.scala @@ -52,11 +52,11 @@ object IRWalk: } } -extension (p: Jump) +extension (p: Command) def isAfterCall : Boolean = { p match { - case g: GoTo => g.parent.fallthrough.contains(g) - case _ => false + case g: Jump => g.parent.statements.lastOption.map(_.isInstanceOf[Call]).getOrElse(false) + case g: Statement => g.parent.statements.prevOption(g).map(_.isInstanceOf[Call]).getOrElse(false) } } @@ -82,9 +82,10 @@ trait IntraProcIRCursor extends IRWalk[CFGPosition, CFGPosition] { pos match { case proc: Procedure => proc.entryBlock.toSet case b: Block => Set(b.statements.headOption.getOrElse(b.jump)) - case s: Statement => Set(s.succ().getOrElse(s.parent.jump)) + case s: Statement => Set(s.successor) case n: GoTo => n.targets.asInstanceOf[Set[CFGPosition]] - case c: Call => c.parent.fallthrough.toSet + case h: Halt => Set() + case h: Return => Set() } } @@ -143,43 +144,43 @@ trait InterProcIRCursor extends IRWalk[CFGPosition, CFGPosition] { IntraProcIRCursor.succ(pos) ++ (pos match case c: DirectCall if c.target.blocks.nonEmpty => Set(c.target) - case c: IndirectCall if c.parent.isProcReturn => c.parent.parent.incomingCalls().flatMap(_.parent.fallthrough.toSet).toSet + case c: IndirectCall if c.parent.isProcReturn => c.parent.parent.incomingCalls().map(_.successor).toSet case _ => Set.empty) } final def pred(pos: CFGPosition): Set[CFGPosition] = { IntraProcIRCursor.pred(pos) ++ (pos match + case d: DirectCall if d.target.blocks.nonEmpty => d.target.returnBlock.toSet case c: Procedure => c.incomingCalls().toSet.asInstanceOf[Set[CFGPosition]] - case b: GoTo if b.isAfterCall => b.parent.jump match { - case DirectCall(t,_, _) if t.blocks.nonEmpty => t.returnBlock.toSet - case _ => Set(b) - } case _ => Set.empty) } } -trait InterProcBlockIRCursor extends IRWalk[CFGPosition, Block] { +// less meaningful with call statements + +// trait InterProcBlockIRCursor extends IRWalk[CFGPosition, Block] { +// +// final def succ(pos: CFGPosition): Set[Block] = { +// IntraProcBlockIRCursor.succ(pos) ++ +// (pos match { +// case s: DirectCall if s.target.blocks.nonEmpty => s.target.entryBlock.toSet +// case b: Block if b.isProcReturn => b.parent.incomingCalls().map(_.parent).toSet +// case _ => Set.empty +// }) +// } +// +// final def pred(pos: CFGPosition): Set[Block] = { +// IntraProcBlockIRCursor.pred(pos) ++ +// (pos match { +// case b: Block if b.isAfterCall => b.incomingJumps.collect {_.parent.jump match +// case d: DirectCall => d.target }.flatMap(_.returnBlock).toSet +// case b: Block if b.isProcEntry => b.parent.incomingCalls().map(_.parent).toSet +// case _ => Set.empty +// }) +// } +// } - final def succ(pos: CFGPosition): Set[Block] = { - IntraProcBlockIRCursor.succ(pos) ++ - (pos match { - case s: DirectCall if s.target.blocks.nonEmpty => s.target.entryBlock.toSet - case b: Block if b.isProcReturn => b.parent.incomingCalls().map(_.parent).toSet - case _ => Set.empty - }) - } - - final def pred(pos: CFGPosition): Set[Block] = { - IntraProcBlockIRCursor.pred(pos) ++ - (pos match { - case b: Block if b.isAfterCall => b.incomingJumps.collect {_.parent.jump match - case d: DirectCall => d.target }.flatMap(_.returnBlock).toSet - case b: Block if b.isProcEntry => b.parent.incomingCalls().map(_.parent).toSet - case _ => Set.empty - }) - } -} object InterProcIRCursor extends InterProcIRCursor trait CallGraph extends IRWalk[Procedure, Procedure] { @@ -190,7 +191,7 @@ trait CallGraph extends IRWalk[Procedure, Procedure] { object CallGraph extends CallGraph -object InterProcBlockIRCursor extends InterProcBlockIRCursor +// object InterProcBlockIRCursor extends InterProcBlockIRCursor /** Computes the reachability transitive closure of the CFGPositions in initial under the successor relation defined by * walker. diff --git a/src/main/scala/ir/Interpreter.scala b/src/main/scala/ir/Interpreter.scala index f5437015b..de470d5d9 100644 --- a/src/main/scala/ir/Interpreter.scala +++ b/src/main/scala/ir/Interpreter.scala @@ -11,8 +11,8 @@ class Interpreter() { private val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) private val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) private val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) - private var nextBlock: Option[Block] = None - private val returnBlock: mutable.Stack[Block] = mutable.Stack() + private var nextCmd: Option[Command] = None + private val returnCmd: mutable.Stack[Command] = mutable.Stack() def eval(exp: Expr, env: mutable.Map[Variable, BitVecLiteral]): BitVecLiteral = { exp match { @@ -220,23 +220,15 @@ class Interpreter() { // Procedure.Block p.entryBlock match { - case Some(block) => nextBlock = Some(block) - case None => nextBlock = Some(returnBlock.pop()) + case Some(block) => nextCmd = Some(block.statements.headOption.getOrElse(block.jump)) + case None => nextCmd = Some(returnCmd.pop()) } } - private def interpretBlock(b: Block): Unit = { - Logger.debug(s"Block:${b.label} ${b.address}") - // Block.Statement - for ((statement, index) <- b.statements.zipWithIndex) { - Logger.debug(s"statement[$index]:") - interpretStatement(statement) - } - - // Block.Jump + private def interpretJump(j: Jump) : Unit = { + Logger.debug(s"jump:") breakable { - Logger.debug(s"jump:") - b.jump match { + j match { case gt: GoTo => Logger.debug(s"$gt") for (g <- gt.targets) { @@ -244,40 +236,28 @@ class Interpreter() { condition match { case Some(e) => evalBool(e, regs) match { case TrueLiteral => - nextBlock = Some(g) + nextCmd = Some(g.statements.headOption.getOrElse(g.jump)) break case _ => } case None => - nextBlock = Some(g) + nextCmd = Some(g.statements.headOption.getOrElse(g.jump)) break } } - case dc: DirectCall => - Logger.debug(s"$dc") - if (dc.returnTarget.isDefined) { - returnBlock.push(dc.returnTarget.get) - } - interpretProcedure(dc.target) - break - case ic: IndirectCall => - Logger.debug(s"$ic") - if (ic.target == Register("R30", 64) && ic.returnTarget.isEmpty) { - if (returnBlock.nonEmpty) { - nextBlock = Some(returnBlock.pop()) - } else { - //Exit Interpreter - nextBlock = None - } - break - } else { - ??? - } + case r: Return => { + nextCmd = Some(returnCmd.pop()) + } + case h: Halt => { + Logger.debug("Halt") + nextCmd = None + } } } } private def interpretStatement(s: Statement): Unit = { + Logger.debug(s"statement[$s]:") s match { case assign: Assign => Logger.debug(s"LocalAssign ${assign.lhs} = ${assign.rhs}") @@ -300,14 +280,42 @@ class Interpreter() { case BitVecLiteral(value, size) => Logger.debug(s"MemoryAssign ${assign.mem} := 0x${value.toString(16)}[u$size]\n") } - case _ : NOP => + case _ : NOP => () case assert: Assert => - Logger.debug(assert) // TODO - + Logger.debug(assert) + evalBool(assert.body, regs) match { + case TrueLiteral => () + case FalseLiteral => throw Exception(s"Assertion failed ${assert}") + } case assume: Assume => - Logger.debug(assume) // TODO, but already taken into effect if it is a branch condition + Logger.debug(assume) + evalBool(assume.body, regs) match { + case TrueLiteral => () + case FalseLiteral => { + nextCmd = None + Logger.debug(s"Assumption not satisfied: $assume") + } + } + case dc: DirectCall => + Logger.debug(s"$dc") + returnCmd.push(dc.successor) + interpretProcedure(dc.target) + break + case ic: IndirectCall => + Logger.debug(s"$ic") + if (ic.target == Register("R30", 64)) { + if (returnCmd.nonEmpty) { + nextCmd = Some(returnCmd.pop()) + } else { + //Exit Interpreter + nextCmd = None + } + break + } else { + ??? + } } } @@ -334,10 +342,13 @@ class Interpreter() { // Program.Procedure interpretProcedure(IRProgram.mainProcedure) - while (nextBlock.isDefined) { - interpretBlock(nextBlock.get) + while (nextCmd.isDefined) { + nextCmd.get match { + case c: Statement => interpretStatement(c) + case c: Jump => interpretJump(c) + } } regs } -} \ No newline at end of file +} diff --git a/src/main/scala/ir/Program.scala b/src/main/scala/ir/Program.scala index 5668607aa..577666e9f 100644 --- a/src/main/scala/ir/Program.scala +++ b/src/main/scala/ir/Program.scala @@ -141,7 +141,7 @@ class Program(var procedures: ArrayBuffer[Procedure], stack.pushAll(n match { case p: Procedure => p.blocks - case b: Block => Seq() ++ b.statements ++ Seq(b.jump) ++ b.fallthrough.toSet + case b: Block => Seq() ++ b.statements ++ Seq(b.jump) case s: Command => Seq() }) n @@ -289,30 +289,28 @@ class Procedure private ( block } - /** - * Remove blocks with the semantics of replacing them with a noop. The incoming jumps to this are replaced - * with a jump(s) to this blocks jump target(s). If this block ends in a call then only its statements are removed. - * @param blocks the block/blocks to remove - */ - def removeBlocksInline(blocks: Iterable[Block]): Unit = { - for (elem <- blocks) { - elem.jump match { - case g: GoTo => - // rewrite all the jumps to include our jump targets - elem.incomingJumps.foreach(_.removeTarget(elem)) - elem.incomingJumps.foreach(_.addAllTargets(g.targets)) - removeBlocks(elem) - case c: Call => - // just remove statements, keep call - elem.statements.clear() - } - } - } - - - def removeBlocksInline(blocks: Block*): Unit = { - removeBlocksInline(blocks.toSeq) - } +// unused +// /** +// * Remove blocks with the semantics of replacing them with a noop. The incoming jumps to this are replaced +// * with a jump(s) to this blocks jump target(s). If this block ends in a call then only its statements are removed. +// * @param blocks the block/blocks to remove +// */ +// def removeBlocksInline(blocks: Iterable[Block]): Unit = { +// for (elem <- blocks) { +// elem.jump match { +// case g: GoTo => +// // rewrite all the jumps to include our jump targets +// elem.incomingJumps.foreach(_.removeTarget(elem)) +// elem.incomingJumps.foreach(_.addAllTargets(g.targets)) +// removeBlocks(elem) +// } +// } +// } +// +// +// def removeBlocksInline(blocks: Block*): Unit = { +// removeBlocksInline(blocks.toSeq) +// } /** * Remove block(s) and all jumps that target it @@ -380,7 +378,6 @@ class Block private ( val statements: IntrusiveList[Statement], private var _jump: Jump, private val _incomingJumps: mutable.HashSet[GoTo], - var _fallthrough: Option[GoTo], ) extends HasParent[Procedure] { _jump.setParent(this) statements.foreach(_.setParent(this)) @@ -389,23 +386,11 @@ class Block private ( statements.onRemove = x => x.deParent() def this(label: String, address: Option[Int] = None, statements: IterableOnce[Statement] = Set.empty, jump: Jump = GoTo(Set.empty)) = { - this(label, address, IntrusiveList().addAll(statements), jump, mutable.HashSet.empty, None) + this(label, address, IntrusiveList().addAll(statements), jump, mutable.HashSet.empty) } def jump: Jump = _jump - def fallthrough: Option[GoTo] = _fallthrough - - def fallthrough_=(g: Option[GoTo]): Unit = { - /* - * Fallthrough is only set if Jump is a call, this is maintained maintained at the - * linkParent implementation on FallThrough of Call. - */ - _fallthrough.foreach(_.deParent()) - g.foreach(x => x.parent = this) - _fallthrough = g - } - private def jump_=(j: Jump): Unit = { require(!j.hasParent) if (j ne _jump) { @@ -434,7 +419,9 @@ class Block private ( assert(!incomingJumps.contains(g)) } - def calls: Set[Procedure] = _jump.calls + def calls: Set[Procedure] = statements.toSet.collect { + case d: DirectCall => d.target + } def modifies: Set[Global] = statements.flatMap(_.modifies).toSet //def locals: Set[Variable] = statements.flatMap(_.locals).toSet ++ jumps.flatMap(_.locals).toSet @@ -454,10 +441,7 @@ class Block private ( def nextBlocks: Iterable[Block] = { jump match { case c: GoTo => c.targets - case c: Call => fallthrough match { - case Some(x) => x.targets - case _ => Seq() - } + case _ => Seq() } } @@ -504,7 +488,7 @@ class Block private ( object Block { def procedureReturn(from: Procedure): Block = { - Block(from.name + "_basil_return", None, List(), IndirectCall(Register("R30", 64))) + Block(from.name + "_basil_return", None, List(), Return()) } } diff --git a/src/main/scala/ir/Statement.scala b/src/main/scala/ir/Statement.scala index 8d89a9726..74aaa18b5 100644 --- a/src/main/scala/ir/Statement.scala +++ b/src/main/scala/ir/Statement.scala @@ -23,6 +23,9 @@ sealed trait Statement extends Command, IntrusiveListElement[Statement] { def acceptVisit(visitor: Visitor): Statement = throw new Exception( "visitor " + visitor + " unimplemented for: " + this ) + + def successor: Command = parent.statements.nextOption(this).getOrElse(parent.jump) + } // invariant: rhs contains at most one MemoryLoad @@ -76,10 +79,18 @@ object Assume: sealed trait Jump extends Command { def modifies: Set[Global] = Set() //def locals: Set[Variable] = Set() - def calls: Set[Procedure] = Set() def acceptVisit(visitor: Visitor): Jump = throw new Exception("visitor " + visitor + " unimplemented for: " + this) } +class Halt(override val label: Option[String] = None) extends Jump { + override def acceptVisit(visitor: Visitor): Jump = this +} + +class Return(override val label: Option[String] = None) extends Jump { + override def acceptVisit(visitor: Visitor): Jump = this +} + + class GoTo private (private val _targets: mutable.LinkedHashSet[Block], override val label: Option[String]) extends Jump { def this(targets: Iterable[Block], label: Option[String] = None) = this(mutable.LinkedHashSet.from(targets), label) @@ -125,30 +136,18 @@ object GoTo: def unapply(g: GoTo): Option[(Set[Block], Option[String])] = Some(g.targets, g.label) -sealed trait Call extends Jump { - val returnTarget: Option[Block] - - // moving a call between blocks - override def linkParent(p: Block): Unit = { - returnTarget.foreach(t => parent.fallthrough = Some(GoTo(Set(t)))) - } - - override def unlinkParent(): Unit = { - parent.fallthrough = None - } -} +sealed trait Call extends Statement class DirectCall(val target: Procedure, - override val returnTarget: Option[Block] = None, override val label: Option[String] = None ) extends Call { /* override def locals: Set[Variable] = condition match { case Some(c) => c.locals case None => Set() } */ - override def calls: Set[Procedure] = Set(target) - override def toString: String = s"${labelStr}DirectCall(${target.name}, ${returnTarget.map(_.label)})" - override def acceptVisit(visitor: Visitor): Jump = visitor.visitDirectCall(this) + def calls: Set[Procedure] = Set(target) + override def toString: String = s"${labelStr}DirectCall(${target.name})" + override def acceptVisit(visitor: Visitor): Statement = visitor.visitDirectCall(this) override def linkParent(p: Block): Unit = { super.linkParent(p) @@ -163,19 +162,18 @@ class DirectCall(val target: Procedure, } object DirectCall: - def unapply(i: DirectCall): Option[(Procedure, Option[Block], Option[String])] = Some(i.target, i.returnTarget, i.label) + def unapply(i: DirectCall): Option[(Procedure, Option[String])] = Some(i.target, i.label) class IndirectCall(var target: Variable, - override val returnTarget: Option[Block] = None, override val label: Option[String] = None ) extends Call { /* override def locals: Set[Variable] = condition match { case Some(c) => c.locals + target case None => Set(target) } */ - override def toString: String = s"${labelStr}IndirectCall($target, ${returnTarget.map(_.label)})" - override def acceptVisit(visitor: Visitor): Jump = visitor.visitIndirectCall(this) + override def toString: String = s"${labelStr}IndirectCall($target)" + override def acceptVisit(visitor: Visitor): Statement = visitor.visitIndirectCall(this) } object IndirectCall: - def unapply(i: IndirectCall): Option[(Variable, Option[Block], Option[String])] = Some(i.target, i.returnTarget, i.label) \ No newline at end of file + def unapply(i: IndirectCall): Option[(Variable, Option[String])] = Some(i.target, i.label) diff --git a/src/main/scala/ir/Visitor.scala b/src/main/scala/ir/Visitor.scala index b9fd91b3e..1cc9c1b40 100644 --- a/src/main/scala/ir/Visitor.scala +++ b/src/main/scala/ir/Visitor.scala @@ -39,11 +39,11 @@ abstract class Visitor { node } - def visitDirectCall(node: DirectCall): Jump = { + def visitDirectCall(node: DirectCall): Statement = { node } - def visitIndirectCall(node: IndirectCall): Jump = { + def visitIndirectCall(node: IndirectCall): Statement = { node.target = visitVariable(node.target) node } @@ -199,11 +199,11 @@ abstract class ReadOnlyVisitor extends Visitor { node } - override def visitDirectCall(node: DirectCall): Jump = { + override def visitDirectCall(node: DirectCall): Statement = { node } - override def visitIndirectCall(node: IndirectCall): Jump = { + override def visitIndirectCall(node: IndirectCall): Statement = { visitVariable(node.target) node } @@ -281,14 +281,12 @@ abstract class IntraproceduralControlFlowVisitor extends Visitor { node } - override def visitDirectCall(node: DirectCall): Jump = { - node.returnTarget.foreach(visitBlock) + override def visitDirectCall(node: DirectCall): Statement = { node } - override def visitIndirectCall(node: IndirectCall): Jump = { + override def visitIndirectCall(node: IndirectCall): Statement = { node.target = visitVariable(node.target) - node.returnTarget.foreach(visitBlock) node } } @@ -431,20 +429,3 @@ class VariablesWithoutStoresLoads extends ReadOnlyVisitor { } } - -class ConvertToSingleProcedureReturn extends Visitor { - override def visitJump(node: Jump): Jump = { - node match - case c: IndirectCall => - val returnBlock = node.parent.parent.returnBlock match { - case Some(b) => b - case None => - val b = Block.procedureReturn(node.parent.parent) - node.parent.parent.returnBlock = b - b - } - // if we are return outside the return block then replace with a goto to the return block - if c.target.name == "R30" && c.returnTarget.isEmpty && !c.parent.isProcReturn then GoTo(Seq(returnBlock)) else node - case _ => node - } -} diff --git a/src/main/scala/ir/cilvisitor/CILVisitor.scala b/src/main/scala/ir/cilvisitor/CILVisitor.scala index a405d4fa1..5583b12da 100644 --- a/src/main/scala/ir/cilvisitor/CILVisitor.scala +++ b/src/main/scala/ir/cilvisitor/CILVisitor.scala @@ -95,6 +95,11 @@ class CILVisitorImpl(val v: CILVisitor) { def visit_stmt(s: Statement): List[Statement] = { def continue(n: Statement) = n match { + case d: DirectCall => d + case i: IndirectCall => { + i.target = visit_var(i.target) + i + } case m: MemoryAssign => { m.mem = visit_mem(m.mem) m.index = visit_expr(m.index) @@ -131,7 +136,6 @@ class CILVisitorImpl(val v: CILVisitor) { } }) b.replaceJump(visit_jump(b.jump)) - b.fallthrough = visit_fallthrough(b.fallthrough) b } @@ -153,7 +157,7 @@ class CILVisitorImpl(val v: CILVisitor) { doVisitList(v, v.vproc(p), p, continue) } - def visit_proc(p: Program): Program = { + def visit_prog(p: Program): Program = { def continue(p: Program) = { p.procedures = p.procedures.flatMap(visit_proc) p @@ -164,6 +168,7 @@ class CILVisitorImpl(val v: CILVisitor) { def visit_block(v: CILVisitor, b: Block): Block = CILVisitorImpl(v).visit_block(b) def visit_proc(v: CILVisitor, b: Procedure): List[Procedure] = CILVisitorImpl(v).visit_proc(b) +def visit_prog(v: CILVisitor, b: Program): Program = CILVisitorImpl(v).visit_prog(b) def visit_stmt(v: CILVisitor, e: Statement): List[Statement] = CILVisitorImpl(v).visit_stmt(e) def visit_jump(v: CILVisitor, e: Jump): Jump = CILVisitorImpl(v).visit_jump(e) def visit_expr(v: CILVisitor, e: Expr): Expr = CILVisitorImpl(v).visit_expr(e) diff --git a/src/main/scala/ir/dsl/DSL.scala b/src/main/scala/ir/dsl/DSL.scala index 7fdd8bdaa..3c55e2dfc 100644 --- a/src/main/scala/ir/dsl/DSL.scala +++ b/src/main/scala/ir/dsl/DSL.scala @@ -35,24 +35,31 @@ case class DelayNameResolve(ident: String) { } } +trait EventuallyStatement { + def resolve(p: Program): Statement +} + +case class ResolvableStatement(s: Statement) extends EventuallyStatement { + override def resolve(p: Program) = s +} + trait EventuallyJump { def resolve(p: Program): Jump } -case class EventuallyIndirectCall(target: Variable, fallthrough: Option[DelayNameResolve]) extends EventuallyJump { +case class EventuallyIndirectCall(target: Variable, fallthrough: Option[DelayNameResolve]) extends EventuallyStatement { override def resolve(p: Program): IndirectCall = { - IndirectCall(target, fallthrough.flatMap(_.resolveBlock(p))) + IndirectCall(target) } } -case class EventuallyCall(target: DelayNameResolve, fallthrough: Option[DelayNameResolve]) extends EventuallyJump { +case class EventuallyCall(target: DelayNameResolve, fallthrough: Option[DelayNameResolve]) extends EventuallyStatement { override def resolve(p: Program): DirectCall = { val t = target.resolveProc(p) match { case Some(x) => x case None => throw Exception("can't resolve proc " + p) } - val ft = fallthrough.flatMap(_.resolveBlock(p)) - DirectCall(t, ft) + DirectCall(t) } } @@ -79,18 +86,20 @@ def indirectCall(tgt: Variable, fallthrough: Option[String]): EventuallyIndirect // def directcall(tgt: String) = EventuallyCall(DelayNameResolve(tgt), None) -case class EventuallyBlock(label: String, sl: Seq[Statement], j: EventuallyJump) { - val tempBlock: Block = Block(label, None, sl, GoTo(List.empty)) +case class EventuallyBlock(label: String, sl: Seq[EventuallyStatement], j: EventuallyJump) { + val tempBlock: Block = Block(label, None, List(), GoTo(List.empty)) def resolve(prog: Program): Block = { + tempBlock.statements.addAll(sl.map(_.resolve(prog))) tempBlock.replaceJump(j.resolve(prog)) tempBlock } } -def block(label: String, sl: (Statement | EventuallyJump)*): EventuallyBlock = { - val statements = sl.collect { - case s: Statement => s +def block(label: String, sl: (Statement | EventuallyStatement | EventuallyJump)*): EventuallyBlock = { + val statements : Seq[EventuallyStatement] = sl.collect { + case s: Statement => ResolvableStatement(s) + case o: EventuallyStatement => o } val jump = sl.collectFirst { case j: EventuallyJump => j diff --git a/src/main/scala/ir/transforms/ReplaceReturn.scala b/src/main/scala/ir/transforms/ReplaceReturn.scala new file mode 100644 index 000000000..61b276833 --- /dev/null +++ b/src/main/scala/ir/transforms/ReplaceReturn.scala @@ -0,0 +1,52 @@ +package ir.transforms + +import util.Logger +import ir.cilvisitor._ +import ir._ + + +class ReplaceReturns extends CILVisitor { + /** + * Assumes IR with 1 call per block which appears as the last statement. + */ + override def vstmt(j: Statement): VisitAction[List[Statement]] = { + j match { + case IndirectCall(Register("R30", _), _) => { + assert(j.parent.statements.lastOption.contains(j)) + if (j.parent.jump.isInstanceOf[Halt | Return]) { + j.parent.replaceJump(Return()) + ChangeTo(List()) + } else { + SkipChildren() + } + } + case _ => SkipChildren() + } + } + + override def vjump(j: Jump) = SkipChildren() +} + + +def addReturnBlocks(p: Program) = { + p.procedures.foreach(p => { + val containsReturn = p.blocks.map(_.jump).find(_.isInstanceOf[Return]).isDefined + if (containsReturn) { + p.returnBlock = p.addBlocks(Block(label=p.name + "_return",jump=Return())) + } + }) +} + + +class ConvertSingleReturn extends CILVisitor { + /** + * Assumes procedures have defined return blocks if they contain a return statement. + */ + override def vjump(j: Jump) = j match { + case r: Return if !(j.parent.parent.returnBlock.contains(j.parent)) => ChangeTo(GoTo(Seq(j.parent.parent.returnBlock.get))) + case _ => SkipChildren() + } + + override def vstmt(s: Statement) = SkipChildren() +} + diff --git a/src/main/scala/translating/BAPToIR.scala b/src/main/scala/translating/BAPToIR.scala index 85ed82f21..90ba61335 100644 --- a/src/main/scala/translating/BAPToIR.scala +++ b/src/main/scala/translating/BAPToIR.scala @@ -48,8 +48,9 @@ class BAPToIR(var program: BAPProgram, mainAddress: Int) { for (st <- b.statements) { block.statements.append(translate(st)) } - val (jump, newBlocks) = translate(b.jumps, block) + val (call, jump, newBlocks) = translate(b.jumps, block) procedure.addBlocks(newBlocks) + call.foreach(c => block.statements.append(c)) block.replaceJump(jump) assert(jump.hasParent) } @@ -85,7 +86,7 @@ class BAPToIR(var program: BAPProgram, mainAddress: Int) { * Translates a list of jumps from BAP into a single Jump at the IR level by moving any conditions on jumps to * Assume statements in new blocks * */ - private def translate(jumps: List[BAPJump], block: Block): (Jump, ArrayBuffer[Block]) = { + private def translate(jumps: List[BAPJump], block: Block): (Option[Call], Jump, ArrayBuffer[Block]) = { if (jumps.size > 1) { val targets = ArrayBuffer[Block]() val conditions = ArrayBuffer[BAPExpr]() @@ -130,26 +131,28 @@ class BAPToIR(var program: BAPProgram, mainAddress: Int) { case _ => throw Exception("translation error, call where not expected: " + jumps.mkString(", ")) } } - (GoTo(targets, Some(line)), newBlocks) + (None, GoTo(targets, Some(line)), newBlocks) } else { jumps.head match { case b: BAPDirectCall => - val call = DirectCall(nameToProcedure(b.target), b.returnTarget.map(t => labelToBlock(t)), Some(b.line)) - (call, ArrayBuffer()) + val call = Some(DirectCall(nameToProcedure(b.target),Some(b.line))) + val ft = (b.returnTarget.map(t => labelToBlock(t))).map(x => GoTo(Set(x))).getOrElse(Halt()) + (call, ft, ArrayBuffer()) case b: BAPIndirectCall => - val call = IndirectCall(b.target.toIR, b.returnTarget.map(t => labelToBlock(t)), Some(b.line)) - (call, ArrayBuffer()) + val call = IndirectCall(b.target.toIR, Some(b.line)) + val ft = (b.returnTarget.map(t => labelToBlock(t))).map(x => GoTo(Set(x))).getOrElse(Halt()) + (Some(call), ft, ArrayBuffer()) case b: BAPGoTo => val target = labelToBlock(b.target) b.condition match { // condition is true case l: BAPLiteral if l.value > BigInt(0) => - (GoTo(ArrayBuffer(target), Some(b.line)), ArrayBuffer()) + (None, GoTo(ArrayBuffer(target), Some(b.line)), ArrayBuffer()) // non-true condition case _ => val condition = convertConditionBool(b.condition, false) val newBlock = newBlockCondition(block, target, condition) - (GoTo(ArrayBuffer(newBlock), Some(b.line)), ArrayBuffer(newBlock)) + (None, GoTo(ArrayBuffer(newBlock), Some(b.line)), ArrayBuffer(newBlock)) } } } diff --git a/src/main/scala/translating/GTIRBToIR.scala b/src/main/scala/translating/GTIRBToIR.scala index fceb07692..f7589905c 100644 --- a/src/main/scala/translating/GTIRBToIR.scala +++ b/src/main/scala/translating/GTIRBToIR.scala @@ -182,12 +182,13 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ throw Exception(s"block ${block.label} in subroutine ${procedure.name} has no outgoing edges") } - val jump = if (outgoingEdges.size == 1) { + val (calls, jump) = if (outgoingEdges.size == 1) { val edge = outgoingEdges.head handleSingleEdge(block, edge, procedure, procedures) } else { handleMultipleEdges(block, outgoingEdges, procedure) } + calls.foreach(c => block.statements.append(c)) block.replaceJump(jump) if (block.statements.nonEmpty) { @@ -363,8 +364,6 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ // need to copy jump as it can't have multiple parents val jumpCopy = currentBlock.jump match { case GoTo(targets, label) => GoTo(targets, label) - case IndirectCall(target, returnTarget, label) => IndirectCall(target, returnTarget, label) - case DirectCall(target, returnTarget, label) => DirectCall(target, returnTarget, label) case _ => throw Exception("this shouldn't be reachable") } trueBlock.replaceJump(currentBlock.jump) @@ -377,7 +376,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ } // Handles the case where a block has one outgoing edge using gtirb cfg labelling - private def handleSingleEdge(block: Block, edge: Edge, procedure: Procedure, procedures: ArrayBuffer[Procedure]): Jump = { + private def handleSingleEdge(block: Block, edge: Edge, procedure: Procedure, procedures: ArrayBuffer[Procedure]): (Option[Call], Jump) = { edge.getLabel match { case EdgeLabel(false, false, Type_Branch, _) => // indirect jump, possibly to external subroutine, possibly to another block in procedure @@ -391,7 +390,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ case _ => throw Exception(s"no assignment to program counter found before indirect call in block ${block.label}") } block.statements.remove(block.statements.last) // remove _PC assignment - IndirectCall(target, None) + (Some(IndirectCall(target)), Halt()) } else if (proxySymbols.size > 1) { // TODO requires further consideration once encountered throw Exception(s"multiple uuidToSymbol ${proxySymbols.map(_.name).mkString(", ")} associated with proxy block ${byteStringToString(edge.targetUuid)}, target of indirect call from block ${block.label}") @@ -407,14 +406,14 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ proc } removePCAssign(block) - DirectCall(target, None) + (Some(DirectCall(target)), Halt()) } } else if (uuidToBlock.contains(edge.targetUuid)) { // resolved indirect jump // TODO consider possibility this can go to another procedure? val target = uuidToBlock(edge.targetUuid) removePCAssign(block) - GoTo(mutable.Set(target)) + (None, GoTo(mutable.Set(target))) } else { throw Exception(s"edge from ${block.label} to ${byteStringToString(edge.targetUuid)} does not point to a known block or proxy block") } @@ -425,23 +424,23 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ // direct jump to start of own subroutine is treated as GoTo, not DirectCall // should probably investigate recursive cases to determine if this happens/is correct val jump = if (procedure == targetProc) { - GoTo(mutable.Set(uuidToBlock(edge.targetUuid))) + (None, GoTo(mutable.Set(uuidToBlock(edge.targetUuid)))) } else { - DirectCall(targetProc, None) + (Some(DirectCall(targetProc)), Halt()) } removePCAssign(block) jump } else if (uuidToBlock.contains(edge.targetUuid)) { val target = uuidToBlock(edge.targetUuid) removePCAssign(block) - GoTo(mutable.Set(target)) + (None, GoTo(mutable.Set(target))) } else { throw Exception(s"edge from ${block.label} to ${byteStringToString(edge.targetUuid)} does not point to a known block") } case EdgeLabel(false, _, Type_Return, _) => // return statement, value of 'direct' is just whether DDisasm has resolved the return target removePCAssign(block) - IndirectCall(Register("R30", 64), None) + (Some(IndirectCall(Register("R30", 64), None)), Halt()) case EdgeLabel(false, true, Type_Fallthrough, _) => // end of block that doesn't end in a control flow instruction and falls through to next if (entranceUUIDtoProcedure.contains(edge.targetUuid)) { @@ -449,10 +448,10 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ // probably doesn't actually happen in practice since it seems to be after brk instructions? val targetProc = entranceUUIDtoProcedure(edge.targetUuid) // assuming fallthrough won't fall through to start of own procedure - DirectCall(targetProc, None) + (Some(DirectCall(targetProc)), Halt()) } else if (uuidToBlock.contains(edge.targetUuid)) { val target = uuidToBlock(edge.targetUuid) - GoTo(mutable.Set(target)) + (None, GoTo(mutable.Set(target))) } else { throw Exception(s"edge from ${block.label} to ${byteStringToString(edge.targetUuid)} does not point to a known block") } @@ -462,7 +461,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ if (entranceUUIDtoProcedure.contains(edge.targetUuid)) { val target = entranceUUIDtoProcedure(edge.targetUuid) removePCAssign(block) - DirectCall(target, None) + (Some(DirectCall(target)), Halt()) } else { throw Exception(s"edge from ${block.label} to ${byteStringToString(edge.targetUuid)} does not point to a known procedure entrance") } @@ -473,14 +472,13 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ } } - def handleMultipleEdges(block: Block, outgoingEdges: mutable.Set[Edge], procedure: Procedure): Jump = { + def handleMultipleEdges(block: Block, outgoingEdges: mutable.Set[Edge], procedure: Procedure): (Option[Call], Jump) = { val edgeLabels = outgoingEdges.map(_.getLabel) if (edgeLabels.forall { (e: EdgeLabel) => !e.conditional && e.direct && e.`type` == Type_Return }) { // multiple resolved returns, translate as single return removePCAssign(block) - IndirectCall(Register("R30", 64), None) - + (None, Return()) } else if (edgeLabels.forall { (e: EdgeLabel) => !e.conditional && !e.direct && e.`type` == Type_Branch }) { // resolved indirect call with multiple blocks as targets val targets = mutable.Set[Block]() @@ -495,7 +493,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ } // TODO add assertion that target register is low removePCAssign(block) - GoTo(targets) + (None, GoTo(targets)) // TODO possibility not yet encountered: resolved indirect call that goes to multiple procedures? } else if (outgoingEdges.size == 2) { @@ -519,9 +517,9 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ handleIndirectCallWithReturn(edge1, edge0, block) // conditional branch case (EdgeLabel(true, true, Type_Fallthrough, _), EdgeLabel(true, true, Type_Branch, _)) => - handleConditionalBranch(edge0, edge1, block, procedure) + (None, handleConditionalBranch(edge0, edge1, block, procedure)) case (EdgeLabel(true, true, Type_Branch, _), EdgeLabel(true, true, Type_Fallthrough, _)) => - handleConditionalBranch(edge1, edge0, block, procedure) + (None, handleConditionalBranch(edge1, edge0, block, procedure)) case _ => throw Exception(s"cannot resolve outgoing edges from block ${block.label}") } @@ -542,7 +540,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ if (fallthroughs.size != 1 || indirectCallTargets.isEmpty) { throw Exception(s"cannot resolve outgoing edges from block ${block.label}") } - handleIndirectCallMultipleResolvedTargets(fallthroughs.head, indirectCallTargets, block, procedure) + (None, handleIndirectCallMultipleResolvedTargets(fallthroughs.head, indirectCallTargets, block, procedure)) } else { throw Exception(s"cannot resolve outgoing edges from block ${block.label}") } @@ -564,18 +562,18 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ } val target = entranceUUIDtoProcedure(call.targetUuid) - val resolvedCall = DirectCall(target, Some(returnTarget)) + val resolvedCall = DirectCall(target) val assume = Assume(BinaryExpr(BVEQ, targetRegister, BitVecLiteral(target.address.get, 64))) val label = block.label + "$" + target.name - newBlocks.append(Block(label, None, ArrayBuffer(assume), resolvedCall)) + newBlocks.append(Block(label, None, ArrayBuffer(assume, resolvedCall), GoTo(returnTarget))) } removePCAssign(block) procedure.addBlocks(newBlocks) GoTo(newBlocks) } - private def handleIndirectCallWithReturn(fallthrough: Edge, call: Edge, block: Block): Call = { + private def handleIndirectCallWithReturn(fallthrough: Edge, call: Edge, block: Block): (Option[Call], GoTo) = { if (!uuidToBlock.contains(fallthrough.targetUuid)) { throw Exception(s"block ${block.label} has fallthrough edge to ${byteStringToString(fallthrough.targetUuid)} that does not point to a known block") } @@ -586,16 +584,16 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ val target = getPCTarget(block) removePCAssign(block) - IndirectCall(target, Some(returnTarget)) + (Some(IndirectCall(target)), GoTo(Set(returnTarget))) } else { // resolved indirect call val target = entranceUUIDtoProcedure(call.targetUuid) removePCAssign(block) - DirectCall(target, Some(returnTarget)) + (Some(DirectCall(target)), GoTo(Set(returnTarget))) } } - private def handleDirectCallWithReturn(fallthrough: Edge, call: Edge, block: Block): DirectCall = { + private def handleDirectCallWithReturn(fallthrough: Edge, call: Edge, block: Block): (Option[Call], GoTo) = { if (!entranceUUIDtoProcedure.contains(call.targetUuid)) { throw Exception(s"block ${block.label} has direct call edge to ${byteStringToString(call.targetUuid)} that does not point to a known procedure") } @@ -607,7 +605,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ val target = entranceUUIDtoProcedure(call.targetUuid) val returnTarget = uuidToBlock(fallthrough.targetUuid) removePCAssign(block) - DirectCall(target, Some(returnTarget)) + (Some(DirectCall(target)), GoTo(Set(returnTarget))) } private def handleConditionalBranch(fallthrough: Edge, branch: Edge, block: Block, procedure: Procedure): GoTo = { diff --git a/src/main/scala/translating/ILtoIL.scala b/src/main/scala/translating/ILtoIL.scala index b34100704..99b64bc6e 100644 --- a/src/main/scala/translating/ILtoIL.scala +++ b/src/main/scala/translating/ILtoIL.scala @@ -74,7 +74,7 @@ private class ILSerialiser extends ReadOnlyVisitor { } - override def visitDirectCall(node: DirectCall): Jump = { + override def visitDirectCall(node: DirectCall): Statement = { program ++= "DirectCall(" program ++= procedureIdentifier(node.target) program ++= ", " @@ -82,7 +82,7 @@ private class ILSerialiser extends ReadOnlyVisitor { node } - override def visitIndirectCall(node: IndirectCall): Jump = { + override def visitIndirectCall(node: IndirectCall): Statement = { program ++= "IndirectCall(" visitVariable(node.target) program ++= ", " diff --git a/src/main/scala/translating/IRToBoogie.scala b/src/main/scala/translating/IRToBoogie.scala index 9f45918a4..b654fa031 100644 --- a/src/main/scala/translating/IRToBoogie.scala +++ b/src/main/scala/translating/IRToBoogie.scala @@ -632,39 +632,7 @@ class IRToBoogie(var program: Program, var spec: Specification, var thread: Opti } ) } - def translate(j: Jump): List[BCmd] = j match { - case d: DirectCall => - val call = BProcedureCall(d.target.name) - val returnTarget = d.returnTarget match { - case Some(r) => GoToCmd(Seq(r.label)) - case None => BAssume(FalseBLiteral, Some("no return target")) - } - - (config.procedureRely match { - case Some(ProcRelyVersion.Function) => - if (libRelies.contains(d.target.name) && libGuarantees.contains(d.target.name) && libRelies(d.target.name).nonEmpty && libGuarantees(d.target.name).nonEmpty) { - val invCall1 = BProcedureCall(d.target.name + "$inv", List(mem_inv1, Gamma_mem_inv1), List(mem, Gamma_mem)) - val invCall2 = BProcedureCall("rely$inv", List(mem_inv2, Gamma_mem_inv2), List(mem_inv1, Gamma_mem_inv1)) - val libRGAssert = libRelies(d.target.name).map(r => BAssert(r.resolveSpecInv)) - List(invCall1, invCall2) ++ libRGAssert - } else { - List() - } - case Some(ProcRelyVersion.IfCommandContradiction) => relyfun(d.target.name).toList - case None => List() - }) ++ List(call, returnTarget) - case i: IndirectCall => - // TODO put this elsewhere - if (i.target.name == "R30") { - List(ReturnCmd) - } else { - val unresolved: List[BCmd] = List(Comment(s"UNRESOLVED: call ${i.target.name}"), BAssert(FalseBLiteral)) - i.returnTarget match { - case Some(r) => unresolved :+ GoToCmd(Seq(r.label)) - case None => unresolved ++ List(Comment("no return target"), BAssume(FalseBLiteral)) - } - } case g: GoTo => // collects all targets of the goto with a branch condition that we need to check the security level for // and collects the variables for that @@ -681,9 +649,32 @@ class IRToBoogie(var program: Program, var spec: Specification, var thread: Opti } val jump = GoToCmd(g.targets.map(_.label).toSeq) conditionAssert :+ jump + case r: Return => List(ReturnCmd) + case r: Halt => List(BAssert(FalseBLiteral)) + } + + def translate(j: Call): List[BCmd] = j match { + case d: DirectCall => + val call = BProcedureCall(d.target.name) + + (config.procedureRely match { + case Some(ProcRelyVersion.Function) => + if (libRelies.contains(d.target.name) && libGuarantees.contains(d.target.name) && libRelies(d.target.name).nonEmpty && libGuarantees(d.target.name).nonEmpty) { + val invCall1 = BProcedureCall(d.target.name + "$inv", List(mem_inv1, Gamma_mem_inv1), List(mem, Gamma_mem)) + val invCall2 = BProcedureCall("rely$inv", List(mem_inv2, Gamma_mem_inv2), List(mem_inv1, Gamma_mem_inv1)) + val libRGAssert = libRelies(d.target.name).map(r => BAssert(r.resolveSpecInv)) + List(invCall1, invCall2) ++ libRGAssert + } else { + List() + } + case Some(ProcRelyVersion.IfCommandContradiction) => relyfun(d.target.name).toList + case None => List() + }) ++ List(call) + case i: IndirectCall => List(Comment(s"UNRESOLVED: call ${i.target.name}"), BAssert(FalseBLiteral)) } def translate(s: Statement): List[BCmd] = s match { + case d: Call => translate(d) case m: NOP => List.empty case m: MemoryAssign => val lhs = m.mem.toBoogie diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index 6badc3cfe..b7fb1d7b1 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -29,6 +29,7 @@ import java.util.Base64 import spray.json.DefaultJsonProtocol.* import util.intrusive_list.IntrusiveList import analysis.CfgCommandNode +import cilvisitor._ import scala.annotation.tailrec import scala.collection.mutable @@ -197,11 +198,13 @@ object IRTransform { } val externalRemover = ExternalRemover(externalNamesLibRemoved.toSet) val renamer = Renamer(boogieReserved) - val returnUnifier = ConvertToSingleProcedureReturn() + + cilvisitor.visit_prog(transforms.ReplaceReturns(), ctx.program) + transforms.addReturnBlocks(ctx.program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), ctx.program) externalRemover.visitProgram(ctx.program) renamer.visitProgram(ctx.program) - returnUnifier.visitProgram(ctx.program) ctx } @@ -274,8 +277,8 @@ object IRTransform { modified = true // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) - val newCall = DirectCall(targets.head, indirectCall.returnTarget, indirectCall.label) - block.replaceJump(newCall) + val newCall = DirectCall(targets.head, indirectCall.label) + block.statements.replace(indirectCall, newCall) } else if (targets.size > 1) { modified = true val procedure = c.parent.data @@ -283,10 +286,14 @@ object IRTransform { for (t <- targets) { val assume = Assume(BinaryExpr(BVEQ, indirectCall.target, BitVecLiteral(t.address.get, 64))) val newLabel: String = block.label + t.name - val directCall = DirectCall(t, indirectCall.returnTarget) + val directCall = DirectCall(t) directCall.parent = indirectCall.parent - newBlocks.append(Block(newLabel, None, ArrayBuffer(assume), directCall)) + // assume indircall is the last statement in block + assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) + val fallthrough = indirectCall.parent.jump + + newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) } procedure.addBlocks(newBlocks) val newCall = GoTo(newBlocks, indirectCall.label) @@ -429,8 +436,8 @@ object IRTransform { modified = true // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) - val newCall = DirectCall(targets.head, indirectCall.returnTarget, indirectCall.label) - block.replaceJump(newCall) + val newCall = DirectCall(targets.head, indirectCall.label) + block.statements.replace(indirectCall, newCall) } else if (targets.size > 1) { modified = true val procedure = c.parent.data @@ -448,17 +455,20 @@ object IRTransform { addressExprs ::= addressExpr val assume = Assume(addressExpr) val newLabel: String = block.label + t.name - val directCall = DirectCall(t, indirectCall.returnTarget) + val directCall = DirectCall(t) directCall.parent = indirectCall.parent - newBlocks.append(Block(newLabel, None, ArrayBuffer(assume), directCall)) + + assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) + val fallthrough = indirectCall.parent.jump + newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) } procedure.addBlocks(newBlocks) val newCall = GoTo(newBlocks, indirectCall.label) val addressExprOr = addressExprs.tail.foldLeft(addressExprs.head) { (a: BinaryExpr, b: BinaryExpr) => BinaryExpr(BoolOR, a, b) } - val assert = Assert(addressExprOr, Some("check indirect call underapproximation")) - block.statements.append(assert) + val assertion = Assert(addressExprOr, Some("check indirect call underapproximation")) + block.statements.append(assertion) block.replaceJump(newCall) } case _ => @@ -500,42 +510,40 @@ object IRTransform { ): Unit = { // iterate over all commands - if call is to pthread_create, look up? - for (p <- program.procedures) { - for (b <- p.blocks) { - b.jump match { - case d: DirectCall if d.target.name == "pthread_create" => - - // R2 should hold the function pointer of the function that begins the thread - // look up R2 value using points to results - val R2 = Register("R2", 64) - val b = reachingDefs(d) - val R2Wrapper = RegisterVariableWrapper(R2, getDefinition(R2, d, reachingDefs)) - val threadTargets = pointsTo(R2Wrapper) - - if (threadTargets.size > 1) { - // currently can't handle case where the thread created is ambiguous - throw Exception("can't handle thread creation with more than one possible target") - } + program.foreach(c => + c match { + case d: DirectCall if d.target.name == "pthread_create" => + + // R2 should hold the function pointer of the function that begins the thread + // look up R2 value using points to results + val R2 = Register("R2", 64) + val b = reachingDefs(d) + val R2Wrapper = RegisterVariableWrapper(R2, getDefinition(R2, d, reachingDefs)) + val threadTargets = pointsTo(R2Wrapper) + + if (threadTargets.size > 1) { + // currently can't handle case where the thread created is ambiguous + throw Exception("can't handle thread creation with more than one possible target") + } - if (threadTargets.size == 1) { + if (threadTargets.size == 1) { - // not trying to untangle the very messy region resolution at present, just dealing with simplest case - threadTargets.head match { - case data: DataRegion => - val threadEntrance = program.procedures.find(_.name == data.regionIdentifier) match { - case Some(proc) => proc - case None => throw Exception("could not find procedure with name " + data.regionIdentifier) - } - val thread = ProgramThread(threadEntrance, mutable.LinkedHashSet(threadEntrance), Some(d)) - program.threads.addOne(thread) - case _ => - throw Exception("unexpected non-data region " + threadTargets.head + " as PointsTo result for R2 at " + d) - } + // not trying to untangle the very messy region resolution at present, just dealing with simplest case + threadTargets.head match { + case data: DataRegion => + val threadEntrance = program.procedures.find(_.name == data.regionIdentifier) match { + case Some(proc) => proc + case None => throw Exception("could not find procedure with name " + data.regionIdentifier) + } + val thread = ProgramThread(threadEntrance, mutable.LinkedHashSet(threadEntrance), Some(d)) + program.threads.addOne(thread) + case _ => + throw Exception("unexpected non-data region " + threadTargets.head + " as PointsTo result for R2 at " + d) } - case _ => - } - } - } + } + case _ => + }) + if (program.threads.nonEmpty) { val mainThread = ProgramThread(program.mainProcedure, mutable.LinkedHashSet(program.mainProcedure), None) From 3711e4d4de9c6f43227ba984990f26ae82a7bfb9 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Mon, 19 Aug 2024 11:56:49 +1000 Subject: [PATCH 02/62] move transforms out of RunUtils.scala --- src/main/scala/analysis/Cfg.scala | 4 +- src/main/scala/analysis/IDEAnalysis.scala | 6 +- .../analysis/InterLiveVarsAnalysis.scala | 10 +- .../analysis/IntraLiveVarsAnalysis.scala | 4 +- src/main/scala/analysis/VSA.scala | 2 +- .../scala/analysis/solvers/IDESolver.scala | 40 +- src/main/scala/ir/IRCursor.scala | 171 ++++----- src/main/scala/ir/Program.scala | 38 +- src/main/scala/ir/Statement.scala | 8 +- src/main/scala/ir/dsl/DSL.scala | 39 +- .../ir/invariant/EarlyCallStatement.scala | 15 + .../transforms/IndirectCallResolution.scala | 295 +++++++++++++++ .../scala/ir/transforms/ReplaceReturn.scala | 10 +- .../scala/ir/transforms/SplitThreads.scala | 84 +++++ src/main/scala/translating/GTIRBToIR.scala | 4 +- src/main/scala/translating/IRToBoogie.scala | 2 +- src/main/scala/util/RunUtils.scala | 356 +----------------- src/test/scala/IndirectCallsTests.scala | 87 +++-- src/test/scala/LiveVarsAnalysisTests.scala | 129 ++++--- src/test/scala/PointsToTest.scala | 24 +- src/test/scala/ir/IRTest.scala | 90 ++--- src/test/scala/ir/SingleCallInvariant.scala | 83 ++++ 22 files changed, 829 insertions(+), 672 deletions(-) create mode 100644 src/main/scala/ir/invariant/EarlyCallStatement.scala create mode 100644 src/main/scala/ir/transforms/IndirectCallResolution.scala create mode 100644 src/main/scala/ir/transforms/SplitThreads.scala create mode 100644 src/test/scala/ir/SingleCallInvariant.scala diff --git a/src/main/scala/analysis/Cfg.scala b/src/main/scala/analysis/Cfg.scala index 59d512cd2..8807f2d2a 100644 --- a/src/main/scala/analysis/Cfg.scala +++ b/src/main/scala/analysis/Cfg.scala @@ -502,7 +502,7 @@ class ProgramCfgFactory: val targetProc: Procedure = dCall.target funcEntryNode.callers.add(procToCfg(targetProc)._1) - val callNode = CfgJumpNode(dCall, block, funcEntryNode) + val callNode : CfgJumpNode = s.asInstanceOf[CfgJumpNode] // Branch to this call cfg.addEdge(precNode, callNode) @@ -523,7 +523,7 @@ class ProgramCfgFactory: Logger.debug(s"Indirect call found: $iCall in ${proc.name}") var precNode = prevNode - val jmpNode = CfgJumpNode(iCall, block, funcEntryNode) + val jmpNode = s.asInstanceOf[CfgJumpNode] // Branch to this call cfg.addEdge(precNode, jmpNode) diff --git a/src/main/scala/analysis/IDEAnalysis.scala b/src/main/scala/analysis/IDEAnalysis.scala index e77627717..c7ce74559 100644 --- a/src/main/scala/analysis/IDEAnalysis.scala +++ b/src/main/scala/analysis/IDEAnalysis.scala @@ -1,6 +1,6 @@ package analysis -import ir.{CFGPosition, Command, DirectCall, GoTo, IndirectCall, Procedure, Program} +import ir.{CFGPosition, Command, DirectCall, GoTo, Return, IndirectCall, Procedure, Program} final case class Lambda() @@ -55,6 +55,6 @@ trait IDEAnalysis[E, EE, C, R, D, T, L <: Lattice[T]] { } // IndirectCall in these is because they are returns so that can be further tightened in future -trait ForwardIDEAnalysis[D, T, L <: Lattice[T]] extends IDEAnalysis[Procedure, IndirectCall, DirectCall, Command, D, T, L] +trait ForwardIDEAnalysis[D, T, L <: Lattice[T]] extends IDEAnalysis[Procedure, Return, DirectCall, Command, D, T, L] -trait BackwardIDEAnalysis[D, T, L <: Lattice[T]] extends IDEAnalysis[IndirectCall, Procedure, Command, DirectCall, D, T, L] +trait BackwardIDEAnalysis[D, T, L <: Lattice[T]] extends IDEAnalysis[Return, Procedure, Command, DirectCall, D, T, L] diff --git a/src/main/scala/analysis/InterLiveVarsAnalysis.scala b/src/main/scala/analysis/InterLiveVarsAnalysis.scala index b20edf478..7a93266e8 100644 --- a/src/main/scala/analysis/InterLiveVarsAnalysis.scala +++ b/src/main/scala/analysis/InterLiveVarsAnalysis.scala @@ -1,7 +1,7 @@ package analysis import analysis.solvers.BackwardIDESolver -import ir.{Assert, Assume, GoTo, CFGPosition, Command, DirectCall, IndirectCall, Assign, MemoryAssign, Procedure, Program, Variable, toShortString} +import ir.{Assert, Assume, Block, GoTo, CFGPosition, Command, DirectCall, IndirectCall, Assign, MemoryAssign, Halt, Return, Procedure, Program, Variable, toShortString} /** * Micro-transfer-functions for LiveVar analysis @@ -19,7 +19,7 @@ trait LiveVarsAnalysisFunctions extends BackwardIDEAnalysis[Variable, TwoElement val edgelattice: EdgeFunctionLattice[TwoElement, TwoElementLattice] = EdgeFunctionLattice(valuelattice) import edgelattice.{IdEdge, ConstEdge} - def edgesCallToEntry(call: Command, entry: IndirectCall)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { + def edgesCallToEntry(call: Command, entry: Return)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { Map(d -> IdEdge()) } @@ -74,7 +74,11 @@ trait LiveVarsAnalysisFunctions extends BackwardIDEAnalysis[Variable, TwoElement d match case Left(value) => if value != variable then Map(d -> IdEdge()) else Map() case Right(_) => Map(d -> IdEdge(), Left(variable) -> ConstEdge(TwoElementTop)) - case _ => Map(d -> IdEdge()) + case r: Return => Map(d -> IdEdge()) + case h: Halt => Map(d -> IdEdge()) + case c: DirectCall => Map(d -> IdEdge()) + case c: Block => Map(d -> IdEdge()) + case c: GoTo => Map(d -> IdEdge()) } } diff --git a/src/main/scala/analysis/IntraLiveVarsAnalysis.scala b/src/main/scala/analysis/IntraLiveVarsAnalysis.scala index f3f322321..75fa1dbb0 100644 --- a/src/main/scala/analysis/IntraLiveVarsAnalysis.scala +++ b/src/main/scala/analysis/IntraLiveVarsAnalysis.scala @@ -1,7 +1,7 @@ package analysis import analysis.solvers.SimpleWorklistFixpointSolver -import ir.{Assert, Assume, Block, CFGPosition, Call, DirectCall, GoTo, IndirectCall, Jump, Assign, MemoryAssign, NOP, Procedure, Program, Statement, Variable} +import ir.{Assert, Assume, Block, CFGPosition, Call, DirectCall, GoTo, IndirectCall, Jump, Assign, MemoryAssign, NOP, Procedure, Program, Statement, Variable, Return, Halt} abstract class LivenessAnalysis(program: Program) extends Analysis[Any]: val lattice: MapLattice[CFGPosition, Set[Variable], PowersetLattice[Variable]] = MapLattice(PowersetLattice()) @@ -18,6 +18,8 @@ abstract class LivenessAnalysis(program: Program) extends Analysis[Any]: case IndirectCall(variable, _) => s + variable case c: DirectCall => s case g: GoTo => s + case r: Return => s + case r: Halt => s case _ => ??? } } diff --git a/src/main/scala/analysis/VSA.scala b/src/main/scala/analysis/VSA.scala index f7d9a55e7..734f9bf11 100644 --- a/src/main/scala/analysis/VSA.scala +++ b/src/main/scala/analysis/VSA.scala @@ -172,7 +172,7 @@ trait ValueSetAnalysis(program: Program, if (IRWalk.procedure(n) == n) { mmm.pushContext(n.asInstanceOf[Procedure].name) s - } else if (IRWalk.procedure(n).end == n) { + } else if (IRWalk.lastInProc(IRWalk.procedure(n)) == n) { mmm.popContext() s } else n match diff --git a/src/main/scala/analysis/solvers/IDESolver.scala b/src/main/scala/analysis/solvers/IDESolver.scala index 0dd981b30..7a581dbe4 100644 --- a/src/main/scala/analysis/solvers/IDESolver.scala +++ b/src/main/scala/analysis/solvers/IDESolver.scala @@ -1,7 +1,7 @@ package analysis.solvers import analysis.{BackwardIDEAnalysis, Dependencies, EdgeFunction, EdgeFunctionLattice, ForwardIDEAnalysis, IDEAnalysis, IRInterproceduralBackwardDependencies, IRInterproceduralForwardDependencies, Lambda, Lattice, MapLattice} -import ir.{CFGPosition, Command, DirectCall, GoTo, IRWalk, IndirectCall, InterProcIRCursor, Procedure, Program, end, isAfterCall, Halt, Statement, Jump} +import ir.{CFGPosition, Command, DirectCall, GoTo, IRWalk, IndirectCall, Return, InterProcIRCursor, Procedure, Program, isAfterCall, Halt, Statement, Jump} import util.Logger import scala.collection.immutable.Map @@ -204,10 +204,10 @@ abstract class IDESolver[E <: Procedure | Command, EE <: Procedure | Command, C abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) - extends IDESolver[Procedure, IndirectCall, DirectCall, Command, D, T, L](program, program.mainProcedure), + extends IDESolver[Procedure, Return, DirectCall, Command, D, T, L](program, program.mainProcedure), ForwardIDEAnalysis[D, T, L], IRInterproceduralForwardDependencies { - protected def entryToExit(entry: Procedure): IndirectCall = entry.end.asInstanceOf[IndirectCall] + protected def entryToExit(entry: Procedure): Return = IRWalk.lastInProc(entry).asInstanceOf[Return] protected def exitToEntry(exit: IndirectCall): Procedure = IRWalk.procedure(exit) @@ -218,7 +218,10 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) case r: Jump => ret.parent.statements.last.asInstanceOf[DirectCall] } - protected def getCallee(call: DirectCall): Procedure = call.target + protected def getCallee(call: DirectCall): Procedure = { + require(isCall(call)) + call.target + } protected def isCall(call: CFGPosition): Boolean = call match @@ -228,41 +231,46 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) protected def isExit(exit: CFGPosition): Boolean = exit match // only looking at functions with statements - case command: Command => IRWalk.procedure(command).end == command + case command: Command => IRWalk.lastInProc(IRWalk.procedure(command)) == command case _ => false protected def getAfterCalls(exit: IndirectCall): Set[Command] = InterProcIRCursor.succ(exit).filter(_.isInstanceOf[Command]).map(_.asInstanceOf[Command]) - } abstract class BackwardIDESolver[D, T, L <: Lattice[T]](program: Program) - extends IDESolver[IndirectCall, Procedure, Command, DirectCall, D, T, L](program, program.mainProcedure.end), + extends IDESolver[Return, Procedure, Command, DirectCall, D, T, L](program, IRWalk.lastInProc(program.mainProcedure)), BackwardIDEAnalysis[D, T, L], IRInterproceduralBackwardDependencies { - protected def entryToExit(entry: IndirectCall): Procedure = IRWalk.procedure(entry) + protected def entryToExit(entry: Return): Procedure = IRWalk.procedure(entry) - protected def exitToEntry(exit: Procedure): IndirectCall = exit.end.asInstanceOf[IndirectCall] + protected def exitToEntry(exit: Procedure): Return = exit.returnBlock.get.jump.asInstanceOf[Return] - protected def callToReturn(call: Command): DirectCall = call match { - case ret: Statement => ret.parent.statements.getPrev(ret).asInstanceOf[DirectCall] - case r: Jump => r.parent.statements.last.asInstanceOf[DirectCall] + protected def callToReturn(call: Command): DirectCall = { + IRWalk.prevCommandInBlock(call) match { + case Some(x : DirectCall) => x + case p => throw Exception(s"Not a return/aftercall node $call .... prev = $p") + } } protected def returnToCall(ret: DirectCall): Command = ret.successor - protected def getCallee(call: Command): IndirectCall = callToReturn(call: Command).target.end.asInstanceOf[IndirectCall] + protected def getCallee(call: Command): Return = { + require(isCall(call)) + val procCalled = callToReturn(call).target + procCalled.returnBlock.getOrElse(throw Exception(s"No return node for procedure ${procCalled}")).jump.asInstanceOf[Return] + } protected def isCall(call: CFGPosition): Boolean = call match - case directCall: DirectCall => (!directCall.successor.isInstanceOf[Halt]) + case c : Command => isAfterCall(c) && IRWalk.prevCommandInBlock(c).map(_.isInstanceOf[DirectCall]).getOrElse(false) case _ => false protected def isExit(exit: CFGPosition): Boolean = exit match - case procedure: Procedure => procedure.blocks.nonEmpty + case procedure: Procedure => true case _ => false - protected def getAfterCalls(exit: Procedure): Set[DirectCall] = InterProcIRCursor.pred(exit).filter(_.isInstanceOf[DirectCall]).map(_.asInstanceOf[DirectCall]) + protected def getAfterCalls(exit: Procedure): Set[DirectCall] = exit.incomingCalls().toSet } diff --git a/src/main/scala/ir/IRCursor.scala b/src/main/scala/ir/IRCursor.scala index 0e1dc9d93..690b57dea 100644 --- a/src/main/scala/ir/IRCursor.scala +++ b/src/main/scala/ir/IRCursor.scala @@ -14,12 +14,19 @@ import scala.annotation.tailrec */ type CFGPosition = Procedure | Block | Command +def isAfterCall(c: Command) = { + (IRWalk.prevCommandInBlock(c)) match { + case Some(c: Call) => true + case _ => false + } +} + extension (p: CFGPosition) def toShortString: String = p match case procedure: Procedure => procedure.toString - case block: Block => s"Block ${block.label}" - case command: Command => command.toString + case block: Block => s"Block ${block.label}" + case command: Command => command.toString // todo: we could just use the dependencies trait directly instead to avoid the instantiation issue trait IRWalk[IN <: CFGPosition, NT <: CFGPosition & IN] { @@ -28,75 +35,79 @@ trait IRWalk[IN <: CFGPosition, NT <: CFGPosition & IN] { } object IRWalk: - def procedure(pos: CFGPosition) : Procedure = { + + def prevCommandInBlock(c: Command): Option[Command] = c match { + case s: Statement => c.parent.statements.prevOption(s) + case j: Jump => c.parent.statements.lastOption + } + + def nextCommandInBlock(c: Command): Option[Command] = c match { + case s: Statement => Some(s.successor) + case j: Jump => None + } + + def procedure(pos: CFGPosition): Procedure = { pos match { case p: Procedure => p - case b: Block => b.parent - case c: Command => c.parent.parent + case b: Block => b.parent + case c: Command => c.parent.parent } } - def blockBegin(pos: CFGPosition) : Option[Block] = { + def blockBegin(pos: CFGPosition): Option[Block] = { pos match { case p: Procedure => p.entryBlock - case b: Block => Some(b) - case c: Command => Some(c.parent) + case b: Block => Some(b) + case c: Command => Some(c.parent) } } - def commandBegin(pos: CFGPosition) : Option[Command] = { + def commandBegin(pos: CFGPosition): Option[Command] = { pos match { case p: Procedure => p.entryBlock.map(b => b.statements.headOption.getOrElse(b.jump)) - case b: Block => Some(b.statements.headOption.getOrElse(b.jump)) - case c: Command => Some(c) - } - } - -extension (p: Command) - def isAfterCall : Boolean = { - p match { - case g: Jump => g.parent.statements.lastOption.map(_.isInstanceOf[Call]).getOrElse(false) - case g: Statement => g.parent.statements.prevOption(g).map(_.isInstanceOf[Call]).getOrElse(false) + case b: Block => Some(b.statements.headOption.getOrElse(b.jump)) + case c: Command => Some(c) } } -extension (p: Block) - def isProcEntry : Boolean = p.parent.entryBlock.contains(p) - def isProcReturn : Boolean = p.parent.returnBlock.contains(p) - // TODO: this method doesn't require aftercall blocks only have 1 incoming jump - def isAfterCall : Boolean = p.incomingJumps.nonEmpty && p.incomingJumps.forall(_.isAfterCall) + def lastInBlock(p: Block): Command = p.jump + def firstInBlock(p: Block): Command = p.statements.headOption.getOrElse(p.jump) - def begin: CFGPosition = p - def end: CFGPosition = p.jump + def firstInProc(p: Procedure): Command = firstInBlock(p.entryBlock.get) + def lastInProc(p: Procedure): Command = lastInBlock(p.returnBlock.get) -extension (p: Procedure) - def begin: CFGPosition = p - def end: CFGPosition = p.returnBlock.map(_.end).getOrElse(p) +// extension (p: Block) +// def isProcEntry: Boolean = p.parent.entryBlock.contains(p) +// def isProcReturn: Boolean = p.parent.returnBlock.contains(p) +// +// def begin: CFGPosition = p +// def end: CFGPosition = p.jump +// +// extension (p: Procedure) +// def begin: CFGPosition = p +// def end: CFGPosition = p.returnBlock.map(_.end).getOrElse(p) -/** - * Does not include edges between procedures. - */ +/** Does not include edges between procedures. + */ trait IntraProcIRCursor extends IRWalk[CFGPosition, CFGPosition] { def succ(pos: CFGPosition): Set[CFGPosition] = { pos match { case proc: Procedure => proc.entryBlock.toSet - case b: Block => Set(b.statements.headOption.getOrElse(b.jump)) - case s: Statement => Set(s.successor) + case b: Block => b.statements.headOption.orElse(Some(b.jump)).toSet case n: GoTo => n.targets.asInstanceOf[Set[CFGPosition]] case h: Halt => Set() case h: Return => Set() + case c: Statement => IRWalk.nextCommandInBlock(c).toSet } } def pred(pos: CFGPosition): Set[CFGPosition] = { pos match { - case s: Statement => Set(s.pred().getOrElse(s.parent)) - case j: GoTo if j.isAfterCall => Set(j.parent.jump) - case j: Jump => Set(j.parent.statements.lastOption.getOrElse(j.parent)) - case b: Block if b.isProcEntry => Set(b.parent) - case b: Block => b.incomingJumps.asInstanceOf[Set[CFGPosition]] - case proc: Procedure => Set() // intraproc + case c: Command => Set(IRWalk.prevCommandInBlock(c).getOrElse(c.parent)) + case b: Block if b.isEntry => Set(b.parent) + case b: Block => b.incomingJumps.asInstanceOf[Set[CFGPosition]] + case proc: Procedure => Set() // intraproc } } } @@ -117,70 +128,44 @@ trait IntraProcBlockIRCursor extends IRWalk[CFGPosition, Block] { @tailrec final def pred(pos: CFGPosition): Set[Block] = { pos match { - case b: Block if b.isProcEntry => Set.empty - case b: Block => b.incomingJumps.map(_.parent).toSet - case j: Command => pred(j.parent) - case s: Procedure => Set.empty + case b: Block if b.isEntry => Set.empty + case b: Block => b.incomingJumps.map(_.parent).toSet + case j: Command => pred(j.parent) + case s: Procedure => Set.empty } } } object IntraProcBlockIRCursor extends IntraProcBlockIRCursor -/** - * Includes all intraproc edges as well as edges between procedures. - * - * forwards: - * Direct call -> target - * return indirect Call -> the procedure return block for all possible direct-call sites - * - * backwards: - * Procedure -> all possible direct-call sites - * Call-return block -> return Call of the procedure called - * - */ +/** Includes all intraproc edges as well as edges between procedures. + * + * forwards: Direct call -> target return indirect Call -> the procedure return block for all possible direct-call + * sites + * + * backwards: Procedure -> all possible direct-call sites Call-return block -> return Call of the procedure called + */ trait InterProcIRCursor extends IRWalk[CFGPosition, CFGPosition] { final def succ(pos: CFGPosition): Set[CFGPosition] = { - IntraProcIRCursor.succ(pos) ++ - (pos match - case c: DirectCall if c.target.blocks.nonEmpty => Set(c.target) - case c: IndirectCall if c.parent.isProcReturn => c.parent.parent.incomingCalls().map(_.successor).toSet - case _ => Set.empty) + IntraProcIRCursor.succ(pos) ++ + (pos match + case c: DirectCall if c.target.blocks.nonEmpty => Set(c.target) + // case c: IndirectCall if c.parent.isProcReturn => c.parent.parent.incomingCalls().map(_.successor).toSet + case c: Return => c.parent.parent.incomingCalls().map(_.successor).toSet + case _ => Set.empty + ) } final def pred(pos: CFGPosition): Set[CFGPosition] = { IntraProcIRCursor.pred(pos) ++ - (pos match - case d: DirectCall if d.target.blocks.nonEmpty => d.target.returnBlock.toSet - case c: Procedure => c.incomingCalls().toSet.asInstanceOf[Set[CFGPosition]] - case _ => Set.empty) + (pos match + case d: DirectCall if d.target.blocks.nonEmpty => d.target.returnBlock.toSet + case c: Procedure => c.incomingCalls().toSet.asInstanceOf[Set[CFGPosition]] + case _ => Set.empty + ) } } -// less meaningful with call statements - -// trait InterProcBlockIRCursor extends IRWalk[CFGPosition, Block] { -// -// final def succ(pos: CFGPosition): Set[Block] = { -// IntraProcBlockIRCursor.succ(pos) ++ -// (pos match { -// case s: DirectCall if s.target.blocks.nonEmpty => s.target.entryBlock.toSet -// case b: Block if b.isProcReturn => b.parent.incomingCalls().map(_.parent).toSet -// case _ => Set.empty -// }) -// } -// -// final def pred(pos: CFGPosition): Set[Block] = { -// IntraProcBlockIRCursor.pred(pos) ++ -// (pos match { -// case b: Block if b.isAfterCall => b.incomingJumps.collect {_.parent.jump match -// case d: DirectCall => d.target }.flatMap(_.returnBlock).toSet -// case b: Block if b.isProcEntry => b.parent.incomingCalls().map(_.parent).toSet -// case _ => Set.empty -// }) -// } -// } - object InterProcIRCursor extends InterProcIRCursor trait CallGraph extends IRWalk[Procedure, Procedure] { @@ -268,17 +253,17 @@ def toDot[T <: CFGPosition]( def getArrow(s: CFGPosition, n: CFGPosition) = { if (IRWalk.procedure(n) eq IRWalk.procedure(s)) { - DotRegularArrow(dotNodes(s),dotNodes(n)) + DotRegularArrow(dotNodes(s), dotNodes(n)) } else { - DotInterArrow(dotNodes(s),dotNodes(n)) + DotInterArrow(dotNodes(s), dotNodes(n)) } } for (node <- domain) { node match { case s => - iterator.succ(s).foreach(n => dotArrows.addOne(getArrow(s,n))) - // iterator.pred(s).foreach(n => dotArrows.addOne(getArrow(s,n))) + iterator.succ(s).foreach(n => dotArrows.addOne(getArrow(s, n))) + // iterator.pred(s).foreach(n => dotArrows.addOne(getArrow(s,n))) } } diff --git a/src/main/scala/ir/Program.scala b/src/main/scala/ir/Program.scala index 577666e9f..ed61cba4f 100644 --- a/src/main/scala/ir/Program.scala +++ b/src/main/scala/ir/Program.scala @@ -5,6 +5,7 @@ import scala.collection.{IterableOnceExtensionMethods, View, immutable, mutable} import boogie.* import analysis.BitVectorEval import util.intrusive_list.* +import translating.serialiseIL class Program(var procedures: ArrayBuffer[Procedure], var mainProcedure: Procedure, @@ -13,6 +14,10 @@ class Program(var procedures: ArrayBuffer[Procedure], val threads: ArrayBuffer[ProgramThread] = ArrayBuffer() + override def toString(): String = { + serialiseIL(this) + } + // This shouldn't be run before indirect calls are resolved def stripUnreachableFunctions(depth: Int = Int.MaxValue): Unit = { val procedureCalleeNames = procedures.map(f => f.name -> f.calls.map(_.name)).toMap @@ -141,7 +146,7 @@ class Program(var procedures: ArrayBuffer[Procedure], stack.pushAll(n match { case p: Procedure => p.blocks - case b: Block => Seq() ++ b.statements ++ Seq(b.jump) + case b: Block => Seq() ++ b.statements.toSeq ++ Seq(b.jump) case s: Command => Seq() }) n @@ -209,7 +214,7 @@ class Procedure private ( def returnBlock_=(value: Block): Unit = { if (!returnBlock.contains(value)) { - removeBlocks(_returnBlock) + _returnBlock.foreach(removeBlocks(_)) _returnBlock = Some(addBlocks(value)) } } @@ -218,7 +223,7 @@ class Procedure private ( def entryBlock_=(value: Block): Unit = { if (!entryBlock.contains(value)) { - removeBlocks(_entryBlock) + _entryBlock.foreach(removeBlocks(_)) _entryBlock = Some(addBlocks(value)) } } @@ -228,9 +233,6 @@ class Procedure private ( if (!_blocks.contains(block)) { block.parent = this _blocks.add(block) - if (entryBlock.isEmpty) { - entryBlock = block - } } block } @@ -289,28 +291,6 @@ class Procedure private ( block } -// unused -// /** -// * Remove blocks with the semantics of replacing them with a noop. The incoming jumps to this are replaced -// * with a jump(s) to this blocks jump target(s). If this block ends in a call then only its statements are removed. -// * @param blocks the block/blocks to remove -// */ -// def removeBlocksInline(blocks: Iterable[Block]): Unit = { -// for (elem <- blocks) { -// elem.jump match { -// case g: GoTo => -// // rewrite all the jumps to include our jump targets -// elem.incomingJumps.foreach(_.removeTarget(elem)) -// elem.incomingJumps.foreach(_.addAllTargets(g.targets)) -// removeBlocks(elem) -// } -// } -// } -// -// -// def removeBlocksInline(blocks: Block*): Unit = { -// removeBlocksInline(blocks.toSeq) -// } /** * Remove block(s) and all jumps that target it @@ -389,6 +369,8 @@ class Block private ( this(label, address, IntrusiveList().addAll(statements), jump, mutable.HashSet.empty) } + def isEntry: Boolean = parent.entryBlock.contains(this) + def jump: Jump = _jump private def jump_=(j: Jump): Unit = { diff --git a/src/main/scala/ir/Statement.scala b/src/main/scala/ir/Statement.scala index 74aaa18b5..2dea68f46 100644 --- a/src/main/scala/ir/Statement.scala +++ b/src/main/scala/ir/Statement.scala @@ -83,6 +83,7 @@ sealed trait Jump extends Command { } class Halt(override val label: Option[String] = None) extends Jump { + /* Terminate / No successors / assume false */ override def acceptVisit(visitor: Visitor): Jump = this } @@ -136,7 +137,12 @@ object GoTo: def unapply(g: GoTo): Option[(Set[Block], Option[String])] = Some(g.targets, g.label) -sealed trait Call extends Statement +sealed trait Call extends Statement { + def returnTarget: Option[Command] = successor match { + case h: Halt => None + case o => Some(o) + } +} class DirectCall(val target: Procedure, override val label: Option[String] = None diff --git a/src/main/scala/ir/dsl/DSL.scala b/src/main/scala/ir/dsl/DSL.scala index 3c55e2dfc..6a1b96742 100644 --- a/src/main/scala/ir/dsl/DSL.scala +++ b/src/main/scala/ir/dsl/DSL.scala @@ -14,7 +14,7 @@ val R7: Register = Register("R7", 64) val R29: Register = Register("R29", 64) val R30: Register = Register("R30", 64) val R31: Register = Register("R31", 64) -val ret: EventuallyIndirectCall = EventuallyIndirectCall(Register("R30", 64), None) + def bv32(i: Int): BitVecLiteral = BitVecLiteral(i, 32) @@ -40,21 +40,21 @@ trait EventuallyStatement { } case class ResolvableStatement(s: Statement) extends EventuallyStatement { - override def resolve(p: Program) = s + override def resolve(p: Program) : Statement = s } trait EventuallyJump { def resolve(p: Program): Jump } -case class EventuallyIndirectCall(target: Variable, fallthrough: Option[DelayNameResolve]) extends EventuallyStatement { - override def resolve(p: Program): IndirectCall = { +case class EventuallyIndirectCall(target: Variable) extends EventuallyStatement { + override def resolve(p: Program): Statement = { IndirectCall(target) } } -case class EventuallyCall(target: DelayNameResolve, fallthrough: Option[DelayNameResolve]) extends EventuallyStatement { - override def resolve(p: Program): DirectCall = { +case class EventuallyCall(target: DelayNameResolve) extends EventuallyStatement { + override def resolve(p: Program): Statement = { val t = target.resolveProc(p) match { case Some(x) => x case None => throw Exception("can't resolve proc " + p) @@ -63,12 +63,19 @@ case class EventuallyCall(target: DelayNameResolve, fallthrough: Option[DelayNam } } + case class EventuallyGoto(targets: List[DelayNameResolve]) extends EventuallyJump { override def resolve(p: Program): GoTo = { val tgs = targets.flatMap(tn => tn.resolveBlock(p)) GoTo(tgs) } } +case class EventuallyReturn() extends EventuallyJump { + override def resolve(p: Program) = Return() +} +case class EventuallyHalt() extends EventuallyJump { + override def resolve(p: Program) = Halt() +} def goto(): EventuallyGoto = EventuallyGoto(List.empty) @@ -76,13 +83,16 @@ def goto(targets: String*): EventuallyGoto = { EventuallyGoto(targets.map(p => DelayNameResolve(p)).toList) } +def ret: EventuallyReturn = EventuallyReturn() +def halt: EventuallyHalt= EventuallyHalt() + def goto(targets: List[String]): EventuallyGoto = { EventuallyGoto(targets.map(p => DelayNameResolve(p))) } -def directCall(tgt: String, fallthrough: Option[String]): EventuallyCall = EventuallyCall(DelayNameResolve(tgt), fallthrough.map(x => DelayNameResolve(x))) +def directCall(tgt: String): EventuallyCall = EventuallyCall(DelayNameResolve(tgt)) -def indirectCall(tgt: Variable, fallthrough: Option[String]): EventuallyIndirectCall = EventuallyIndirectCall(tgt, fallthrough.map(x => DelayNameResolve(x))) +def indirectCall(tgt: Variable): EventuallyIndirectCall = EventuallyIndirectCall(tgt) // def directcall(tgt: String) = EventuallyCall(DelayNameResolve(tgt), None) @@ -90,16 +100,20 @@ case class EventuallyBlock(label: String, sl: Seq[EventuallyStatement], j: Event val tempBlock: Block = Block(label, None, List(), GoTo(List.empty)) def resolve(prog: Program): Block = { - tempBlock.statements.addAll(sl.map(_.resolve(prog))) + val resolved = sl.map(_.resolve(prog)) + tempBlock.statements.addAll(resolved) tempBlock.replaceJump(j.resolve(prog)) tempBlock } } def block(label: String, sl: (Statement | EventuallyStatement | EventuallyJump)*): EventuallyBlock = { - val statements : Seq[EventuallyStatement] = sl.collect { - case s: Statement => ResolvableStatement(s) - case o: EventuallyStatement => o + val statements : Seq[EventuallyStatement] = sl.flatMap { + case s: Statement => Some(ResolvableStatement(s)) + case o: EventuallyStatement => Some(o) + case o: EventuallyCall => Some(o) + case o: EventuallyIndirectCall => Some(o) + case g: EventuallyJump => None } val jump = sl.collectFirst { case j: EventuallyJump => j @@ -113,6 +127,7 @@ case class EventuallyProcedure(label: String, blocks: Seq[EventuallyBlock]) { val jumps: Map[Block, EventuallyJump] = blocks.map(b => b.tempBlock -> b.j).toMap def resolve(prog: Program): Procedure = { + blocks.foreach(b => b.resolve(prog)) jumps.map((b, j) => b.replaceJump(j.resolve(prog))) tempProc } diff --git a/src/main/scala/ir/invariant/EarlyCallStatement.scala b/src/main/scala/ir/invariant/EarlyCallStatement.scala new file mode 100644 index 000000000..bb4504343 --- /dev/null +++ b/src/main/scala/ir/invariant/EarlyCallStatement.scala @@ -0,0 +1,15 @@ +package ir.invariant +import ir._ + + +def singleCallBlockEnd(p: Program) : Boolean = { + p.forall { + case b: Block => { + val calls = (b.statements.collect { + case c: Call => b.statements.lastOption.contains(c) + }) + (calls.size <= 1) && calls.headOption.getOrElse(true) + } + case _ => true + } +} diff --git a/src/main/scala/ir/transforms/IndirectCallResolution.scala b/src/main/scala/ir/transforms/IndirectCallResolution.scala new file mode 100644 index 000000000..e9c9fddf7 --- /dev/null +++ b/src/main/scala/ir/transforms/IndirectCallResolution.scala @@ -0,0 +1,295 @@ +package ir.transforms + + + +import scala.collection.mutable.ListBuffer +import scala.collection.mutable.ArrayBuffer +import analysis.solvers.* +import analysis.* +import bap.* +import ir.* +import translating.* +import util.Logger +import util.intrusive_list.IntrusiveList +import analysis.CfgCommandNode +import scala.collection.mutable +import cilvisitor._ + + +/** Resolve indirect calls to an address-conditional choice between direct calls using the Value Set Analysis results. + * Dead code, and currently broken by statement calls + * +def resolveIndirectCalls( + cfg: ProgramCfg, + valueSets: Map[CfgNode, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]], + IRProgram: Program +): Boolean = { + var modified: Boolean = false + val worklist = ListBuffer[CfgNode]() + cfg.startNode.succIntra.union(cfg.startNode.succInter).foreach(node => worklist.addOne(node)) + + val visited = mutable.Set[CfgNode]() + while (worklist.nonEmpty) { + val node = worklist.remove(0) + if (!visited.contains(node)) { + process(node) + node.succIntra.union(node.succInter).foreach(node => worklist.addOne(node)) + visited.add(node) + } + } + + def process(n: CfgNode): Unit = n match { + /* + case c: CfgStatementNode => + c.data match + + //We do not want to insert the VSA results into the IR like this + case localAssign: Assign => + localAssign.rhs match + case _: MemoryLoad => + if (valueSets(n).contains(localAssign.lhs) && valueSets(n).get(localAssign.lhs).head.size == 1) { + val extractedValue = extractExprFromValue(valueSets(n).get(localAssign.lhs).head.head) + localAssign.rhs = extractedValue + Logger.info(s"RESOLVED: Memory load ${localAssign.lhs} resolved to ${extractedValue}") + } else if (valueSets(n).contains(localAssign.lhs) && valueSets(n).get(localAssign.lhs).head.size > 1) { + Logger.info(s"RESOLVED: WARN Memory load ${localAssign.lhs} resolved to multiple values, cannot replace") + + /* + // must merge into a single memory variable to represent the possible values + // Make a binary OR of all the possible values takes two at a time (incorrect to do BVOR) + val values = valueSets(n).get(localAssign.lhs).head + val exprValues = values.map(extractExprFromValue) + val result = exprValues.reduce((a, b) => BinaryExpr(BVOR, a, b)) // need to express nondeterministic + // choice between these specific options + localAssign.rhs = result + */ + } + case _ => + */ + case c: CfgJumpNode => + val block = c.block + c.data match + case indirectCall: IndirectCall => + if (block.jump != indirectCall) { + // We only replace the calls with DirectCalls in the IR, and don't replace the CommandNode.data + // Hence if we have already processed this CFG node there will be no corresponding IndirectCall in the IR + // to replace. + // We want to replace all possible indirect calls based on this CFG, before regenerating it from the IR + return + } + valueSets(n) match { + case Lift(valueSet) => + val targetNames = resolveAddresses(valueSet(indirectCall.target)).map(_.name).toList.sorted + val targets = targetNames.map(name => IRProgram.procedures.filter(_.name.equals(name)).head) + + if (targets.size == 1) { + modified = true + + // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) + val newCall = DirectCall(targets.head, indirectCall.label) + block.statements.replace(indirectCall, newCall) + } else if (targets.size > 1) { + modified = true + val procedure = c.parent.data + val newBlocks = ArrayBuffer[Block]() + for (t <- targets) { + val assume = Assume(BinaryExpr(BVEQ, indirectCall.target, BitVecLiteral(t.address.get, 64))) + val newLabel: String = block.label + t.name + val directCall = DirectCall(t) + directCall.parent = indirectCall.parent + + // assume indircall is the last statement in block + assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) + val fallthrough = indirectCall.parent.jump + + newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) + } + procedure.addBlocks(newBlocks) + val newCall = GoTo(newBlocks, indirectCall.label) + block.replaceJump(newCall) + } + case LiftedBottom => + } + case _ => + case _ => + } + + def nameExists(name: String): Boolean = { + IRProgram.procedures.exists(_.name.equals(name)) + } + + def addFakeProcedure(name: String): Unit = { + IRProgram.procedures += Procedure(name) + } + + def resolveAddresses(valueSet: Set[Value]): Set[AddressValue] = { + var functionNames: Set[AddressValue] = Set() + valueSet.foreach { + case globalAddress: GlobalAddress => + if (nameExists(globalAddress.name)) { + functionNames += globalAddress + Logger.info(s"RESOLVED: Call to Global address ${globalAddress.name} rt statuesolved.") + } else { + addFakeProcedure(globalAddress.name) + functionNames += globalAddress + Logger.info(s"Global address ${globalAddress.name} does not exist in the program. Added a fake function.") + } + case localAddress: LocalAddress => + if (nameExists(localAddress.name)) { + functionNames += localAddress + Logger.info(s"RESOLVED: Call to Local address ${localAddress.name}") + } else { + addFakeProcedure(localAddress.name) + functionNames += localAddress + Logger.info(s"Local address ${localAddress.name} does not exist in the program. Added a fake function.") + } + case _ => + } + functionNames + } + + modified +} + + */ + +def resolveIndirectCallsUsingPointsTo( + cfg: ProgramCfg, + pointsTos: Map[RegisterVariableWrapper, Set[RegisterVariableWrapper | MemoryRegion]], + regionContents: Map[MemoryRegion, Set[BitVecLiteral | MemoryRegion]], + reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], + IRProgram: Program + ): Boolean = { + var modified: Boolean = false + val worklist = ListBuffer[CfgNode]() + cfg.startNode.succIntra.union(cfg.startNode.succInter).foreach(node => worklist.addOne(node)) + + val visited = mutable.Set[CfgNode]() + while (worklist.nonEmpty) { + val node = worklist.remove(0) + if (!visited.contains(node)) { + process(node) + node.succIntra.union(node.succInter).foreach(node => worklist.addOne(node)) + visited.add(node) + } + } + + def searchRegion(region: MemoryRegion): mutable.Set[String] = { + val result = mutable.Set[String]() + region match { + case stackRegion: StackRegion => + if (regionContents.contains(stackRegion)) { + for (c <- regionContents(stackRegion)) { + c match { + case bitVecLiteral: BitVecLiteral => Logger.debug("hi: " + bitVecLiteral)//??? + case memoryRegion: MemoryRegion => + result.addAll(searchRegion(memoryRegion)) + } + } + } + result + case dataRegion: DataRegion => + if (!regionContents.contains(dataRegion) || regionContents(dataRegion).isEmpty) { + result.add(dataRegion.regionIdentifier) + } else { + result.add(dataRegion.regionIdentifier) // TODO: may need to investigate if we should add the parent region + for (c <- regionContents(dataRegion)) { + c match { + case bitVecLiteral: BitVecLiteral => Logger.debug("hi: " + bitVecLiteral)//??? + case memoryRegion: MemoryRegion => + result.addAll(searchRegion(memoryRegion)) + } + } + } + result + } + } + + def addFakeProcedure(name: String): Procedure = { + val newProcedure = Procedure(name) + IRProgram.procedures += newProcedure + newProcedure + } + + def resolveAddresses(variable: Variable, i: IndirectCall): mutable.Set[String] = { + val names = mutable.Set[String]() + val variableWrapper = RegisterVariableWrapper(variable, getUse(variable, i, reachingDefs)) + pointsTos.get(variableWrapper) match { + case Some(value) => + value.map { + case v: RegisterVariableWrapper => names.addAll(resolveAddresses(v.variable, i)) + case m: MemoryRegion => names.addAll(searchRegion(m)) + } + names + case None => names + } + } + + def process(n: CfgNode): Unit = n match { + case c: CfgJumpNode => + val block = c.block + c.data match + // don't try to resolve returns + case indirectCall: IndirectCall if indirectCall.target != Register("R30", 64) => + if (!indirectCall.hasParent) { + // We only replace the calls with DirectCalls in the IR, and don't replace the CommandNode.data + // Hence if we have already processed this CFG node there will be no corresponding IndirectCall in the IR + // to replace. + // We want to replace all possible indirect calls based on this CFG, before regenerating it from the IR + return + } + assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) + + val targetNames = resolveAddresses(indirectCall.target, indirectCall) + Logger.debug(s"Points-To approximated call ${indirectCall.target} with $targetNames") + Logger.debug(IRProgram.procedures) + val targets: mutable.Set[Procedure] = targetNames.map(name => IRProgram.procedures.find(_.name == name).getOrElse(addFakeProcedure(name))) + + if (targets.size > 1) { + Logger.info(s"Resolved indirect call $indirectCall") + } + + + if (targets.size == 1) { + modified = true + + // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) + val newCall = DirectCall(targets.head, indirectCall.label) + block.statements.replace(indirectCall, newCall) + } else if (targets.size > 1) { + + val oft = indirectCall.parent.jump + + modified = true + val procedure = c.parent.data + val newBlocks = ArrayBuffer[Block]() + // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) + for (t <- targets) { + Logger.debug(targets) + val address = t.address.match { + case Some(a) => a + case None => throw Exception(s"resolved indirect call $indirectCall to procedure which does not have address: $t") + } + val assume = Assume(BinaryExpr(BVEQ, indirectCall.target, BitVecLiteral(address, 64))) + val newLabel: String = block.label + t.name + val directCall = DirectCall(t) + + /* copy the goto node resulting */ + val fallthrough = oft match { + case g: GoTo => GoTo(g.targets, g.label) + case h: Halt => Halt() + case r: Return => Return() + } + newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) + } + block.statements.remove(indirectCall) + procedure.addBlocks(newBlocks) + val newCall = GoTo(newBlocks, indirectCall.label) + block.replaceJump(newCall) + } + case _ => + case _ => + } + + modified +} diff --git a/src/main/scala/ir/transforms/ReplaceReturn.scala b/src/main/scala/ir/transforms/ReplaceReturn.scala index 61b276833..91681ab8e 100644 --- a/src/main/scala/ir/transforms/ReplaceReturn.scala +++ b/src/main/scala/ir/transforms/ReplaceReturn.scala @@ -28,11 +28,15 @@ class ReplaceReturns extends CILVisitor { } -def addReturnBlocks(p: Program) = { +def addReturnBlocks(p: Program, toAll: Boolean = false) = { p.procedures.foreach(p => { val containsReturn = p.blocks.map(_.jump).find(_.isInstanceOf[Return]).isDefined - if (containsReturn) { - p.returnBlock = p.addBlocks(Block(label=p.name + "_return",jump=Return())) + if (toAll && p.blocks.isEmpty && p.entryBlock.isEmpty && p.returnBlock.isEmpty) { + Logger.info(s"proc ${p.name} ${p.entryBlock}, ${p.returnBlock}") + p.returnBlock = (Block(label=p.name + "_basil_return",jump=Return())) + p.entryBlock = (Block(label=p.name + "_basil_entry",jump=GoTo(p.returnBlock.get))) + } else if (p.returnBlock.isEmpty && (toAll || containsReturn)) { + p.returnBlock = p.addBlocks(Block(label=p.name + "_basil_return",jump=Return())) } }) } diff --git a/src/main/scala/ir/transforms/SplitThreads.scala b/src/main/scala/ir/transforms/SplitThreads.scala new file mode 100644 index 000000000..7f36fdb3c --- /dev/null +++ b/src/main/scala/ir/transforms/SplitThreads.scala @@ -0,0 +1,84 @@ +package ir.transforms + +import scala.collection.mutable.ListBuffer +import scala.collection.mutable.ArrayBuffer +import analysis.solvers.* +import analysis.* +import bap.* +import ir.* +import translating.* +import util.Logger +import java.util.Base64 +import spray.json.DefaultJsonProtocol.* +import util.intrusive_list.IntrusiveList +import analysis.CfgCommandNode +import scala.collection.mutable +import cilvisitor._ + +// identify calls to pthread_create +// use analysis result to determine the third parameter's value (the function pointer) +// split off that procedure into new thread +// do reachability analysis +// also need a bit in the IR where it creates separate files +def splitThreads(program: Program, + pointsTo: Map[RegisterVariableWrapper, Set[RegisterVariableWrapper | MemoryRegion]], + regionContents: Map[MemoryRegion, Set[BitVecLiteral | MemoryRegion]], + reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])] + ): Unit = { + + // iterate over all commands - if call is to pthread_create, look up? + program.foreach(c => + c match { + case d: DirectCall if d.target.name == "pthread_create" => + + // R2 should hold the function pointer of the function that begins the thread + // look up R2 value using points to results + val R2 = Register("R2", 64) + val b = reachingDefs(d) + val R2Wrapper = RegisterVariableWrapper(R2, getDefinition(R2, d, reachingDefs)) + val threadTargets = pointsTo(R2Wrapper) + + if (threadTargets.size > 1) { + // currently can't handle case where the thread created is ambiguous + throw Exception("can't handle thread creation with more than one possible target") + } + + if (threadTargets.size == 1) { + + // not trying to untangle the very messy region resolution at present, just dealing with simplest case + threadTargets.head match { + case data: DataRegion => + val threadEntrance = program.procedures.find(_.name == data.regionIdentifier) match { + case Some(proc) => proc + case None => throw Exception("could not find procedure with name " + data.regionIdentifier) + } + val thread = ProgramThread(threadEntrance, mutable.LinkedHashSet(threadEntrance), Some(d)) + program.threads.addOne(thread) + case _ => + throw Exception("unexpected non-data region " + threadTargets.head + " as PointsTo result for R2 at " + d) + } + } + case _ => + }) + + + if (program.threads.nonEmpty) { + val mainThread = ProgramThread(program.mainProcedure, mutable.LinkedHashSet(program.mainProcedure), None) + program.threads.addOne(mainThread) + + val programProcs = program.procedures + + // do reachability for all threads + for (thread <- program.threads) { + val reachable = thread.entry.reachableFrom + + // add procedures to thread in way that maintains original ordering + for (p <- programProcs) { + if (reachable.contains(p)) { + thread.procedures.add(p) + } + } + + } + } +} diff --git a/src/main/scala/translating/GTIRBToIR.scala b/src/main/scala/translating/GTIRBToIR.scala index f7589905c..ad090eb90 100644 --- a/src/main/scala/translating/GTIRBToIR.scala +++ b/src/main/scala/translating/GTIRBToIR.scala @@ -364,6 +364,8 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ // need to copy jump as it can't have multiple parents val jumpCopy = currentBlock.jump match { case GoTo(targets, label) => GoTo(targets, label) + case h: Halt => Halt() + case r: Return => Return() case _ => throw Exception("this shouldn't be reachable") } trueBlock.replaceJump(currentBlock.jump) @@ -440,7 +442,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ case EdgeLabel(false, _, Type_Return, _) => // return statement, value of 'direct' is just whether DDisasm has resolved the return target removePCAssign(block) - (Some(IndirectCall(Register("R30", 64), None)), Halt()) + (None, Return()) case EdgeLabel(false, true, Type_Fallthrough, _) => // end of block that doesn't end in a control flow instruction and falls through to next if (entranceUUIDtoProcedure.contains(edge.targetUuid)) { diff --git a/src/main/scala/translating/IRToBoogie.scala b/src/main/scala/translating/IRToBoogie.scala index b654fa031..e37f9ce92 100644 --- a/src/main/scala/translating/IRToBoogie.scala +++ b/src/main/scala/translating/IRToBoogie.scala @@ -650,7 +650,7 @@ class IRToBoogie(var program: Program, var spec: Specification, var thread: Opti val jump = GoToCmd(g.targets.map(_.label).toSeq) conditionAssert :+ jump case r: Return => List(ReturnCmd) - case r: Halt => List(BAssert(FalseBLiteral)) + case r: Halt => List(BAssume(FalseBLiteral)) } def translate(j: Call): List[BCmd] = j match { diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index b7fb1d7b1..059701ab2 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -200,7 +200,7 @@ object IRTransform { val renamer = Renamer(boogieReserved) cilvisitor.visit_prog(transforms.ReplaceReturns(), ctx.program) - transforms.addReturnBlocks(ctx.program) + transforms.addReturnBlocks(ctx.program, true) // add return to all blocks because IDE solver expects it cilvisitor.visit_prog(transforms.ConvertSingleReturn(), ctx.program) externalRemover.visitProgram(ctx.program) @@ -208,276 +208,6 @@ object IRTransform { ctx } - /** Resolve indirect calls to an address-conditional choice between direct calls using the Value Set Analysis results. - */ - def resolveIndirectCalls( - cfg: ProgramCfg, - valueSets: Map[CfgNode, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]], - IRProgram: Program - ): Boolean = { - var modified: Boolean = false - val worklist = ListBuffer[CfgNode]() - cfg.startNode.succIntra.union(cfg.startNode.succInter).foreach(node => worklist.addOne(node)) - - val visited = mutable.Set[CfgNode]() - while (worklist.nonEmpty) { - val node = worklist.remove(0) - if (!visited.contains(node)) { - process(node) - node.succIntra.union(node.succInter).foreach(node => worklist.addOne(node)) - visited.add(node) - } - } - - def process(n: CfgNode): Unit = n match { - /* - case c: CfgStatementNode => - c.data match - - //We do not want to insert the VSA results into the IR like this - case localAssign: Assign => - localAssign.rhs match - case _: MemoryLoad => - if (valueSets(n).contains(localAssign.lhs) && valueSets(n).get(localAssign.lhs).head.size == 1) { - val extractedValue = extractExprFromValue(valueSets(n).get(localAssign.lhs).head.head) - localAssign.rhs = extractedValue - Logger.info(s"RESOLVED: Memory load ${localAssign.lhs} resolved to ${extractedValue}") - } else if (valueSets(n).contains(localAssign.lhs) && valueSets(n).get(localAssign.lhs).head.size > 1) { - Logger.info(s"RESOLVED: WARN Memory load ${localAssign.lhs} resolved to multiple values, cannot replace") - - /* - // must merge into a single memory variable to represent the possible values - // Make a binary OR of all the possible values takes two at a time (incorrect to do BVOR) - val values = valueSets(n).get(localAssign.lhs).head - val exprValues = values.map(extractExprFromValue) - val result = exprValues.reduce((a, b) => BinaryExpr(BVOR, a, b)) // need to express nondeterministic - // choice between these specific options - localAssign.rhs = result - */ - } - case _ => - */ - case c: CfgJumpNode => - val block = c.block - c.data match - case indirectCall: IndirectCall => - if (block.jump != indirectCall) { - // We only replace the calls with DirectCalls in the IR, and don't replace the CommandNode.data - // Hence if we have already processed this CFG node there will be no corresponding IndirectCall in the IR - // to replace. - // We want to replace all possible indirect calls based on this CFG, before regenerating it from the IR - return - } - valueSets(n) match { - case Lift(valueSet) => - val targetNames = resolveAddresses(valueSet(indirectCall.target)).map(_.name).toList.sorted - val targets = targetNames.map(name => IRProgram.procedures.filter(_.name.equals(name)).head) - - if (targets.size == 1) { - modified = true - - // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) - val newCall = DirectCall(targets.head, indirectCall.label) - block.statements.replace(indirectCall, newCall) - } else if (targets.size > 1) { - modified = true - val procedure = c.parent.data - val newBlocks = ArrayBuffer[Block]() - for (t <- targets) { - val assume = Assume(BinaryExpr(BVEQ, indirectCall.target, BitVecLiteral(t.address.get, 64))) - val newLabel: String = block.label + t.name - val directCall = DirectCall(t) - directCall.parent = indirectCall.parent - - // assume indircall is the last statement in block - assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) - val fallthrough = indirectCall.parent.jump - - newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) - } - procedure.addBlocks(newBlocks) - val newCall = GoTo(newBlocks, indirectCall.label) - block.replaceJump(newCall) - } - case LiftedBottom => - } - case _ => - case _ => - } - - def nameExists(name: String): Boolean = { - IRProgram.procedures.exists(_.name.equals(name)) - } - - def addFakeProcedure(name: String): Unit = { - IRProgram.procedures += Procedure(name) - } - - def resolveAddresses(valueSet: Set[Value]): Set[AddressValue] = { - var functionNames: Set[AddressValue] = Set() - valueSet.foreach { - case globalAddress: GlobalAddress => - if (nameExists(globalAddress.name)) { - functionNames += globalAddress - Logger.info(s"RESOLVED: Call to Global address ${globalAddress.name} rt statuesolved.") - } else { - addFakeProcedure(globalAddress.name) - functionNames += globalAddress - Logger.info(s"Global address ${globalAddress.name} does not exist in the program. Added a fake function.") - } - case localAddress: LocalAddress => - if (nameExists(localAddress.name)) { - functionNames += localAddress - Logger.info(s"RESOLVED: Call to Local address ${localAddress.name}") - } else { - addFakeProcedure(localAddress.name) - functionNames += localAddress - Logger.info(s"Local address ${localAddress.name} does not exist in the program. Added a fake function.") - } - case _ => - } - functionNames - } - - modified - } - - def resolveIndirectCallsUsingPointsTo( - cfg: ProgramCfg, - pointsTos: Map[RegisterVariableWrapper, Set[RegisterVariableWrapper | MemoryRegion]], - regionContents: Map[MemoryRegion, Set[BitVecLiteral | MemoryRegion]], - reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], - IRProgram: Program - ): Boolean = { - var modified: Boolean = false - val worklist = ListBuffer[CfgNode]() - cfg.startNode.succIntra.union(cfg.startNode.succInter).foreach(node => worklist.addOne(node)) - - val visited = mutable.Set[CfgNode]() - while (worklist.nonEmpty) { - val node = worklist.remove(0) - if (!visited.contains(node)) { - process(node) - node.succIntra.union(node.succInter).foreach(node => worklist.addOne(node)) - visited.add(node) - } - } - - def searchRegion(region: MemoryRegion): mutable.Set[String] = { - val result = mutable.Set[String]() - region match { - case stackRegion: StackRegion => - if (regionContents.contains(stackRegion)) { - for (c <- regionContents(stackRegion)) { - c match { - case bitVecLiteral: BitVecLiteral => Logger.debug("hi: " + bitVecLiteral)//??? - case memoryRegion: MemoryRegion => - result.addAll(searchRegion(memoryRegion)) - } - } - } - result - case dataRegion: DataRegion => - if (!regionContents.contains(dataRegion) || regionContents(dataRegion).isEmpty) { - result.add(dataRegion.regionIdentifier) - } else { - result.add(dataRegion.regionIdentifier) // TODO: may need to investigate if we should add the parent region - for (c <- regionContents(dataRegion)) { - c match { - case bitVecLiteral: BitVecLiteral => Logger.debug("hi: " + bitVecLiteral)//??? - case memoryRegion: MemoryRegion => - result.addAll(searchRegion(memoryRegion)) - } - } - } - result - } - } - - def addFakeProcedure(name: String): Procedure = { - val newProcedure = Procedure(name) - IRProgram.procedures += newProcedure - newProcedure - } - - def resolveAddresses(variable: Variable, i: IndirectCall): mutable.Set[String] = { - val names = mutable.Set[String]() - val variableWrapper = RegisterVariableWrapper(variable, getUse(variable, i, reachingDefs)) - pointsTos.get(variableWrapper) match { - case Some(value) => - value.map { - case v: RegisterVariableWrapper => names.addAll(resolveAddresses(v.variable, i)) - case m: MemoryRegion => names.addAll(searchRegion(m)) - } - names - case None => names - } - } - - def process(n: CfgNode): Unit = n match { - case c: CfgJumpNode => - val block = c.block - c.data match - // don't try to resolve returns - case indirectCall: IndirectCall if indirectCall.target != Register("R30", 64) => - if (block.jump != indirectCall) { - // We only replace the calls with DirectCalls in the IR, and don't replace the CommandNode.data - // Hence if we have already processed this CFG node there will be no corresponding IndirectCall in the IR - // to replace. - // We want to replace all possible indirect calls based on this CFG, before regenerating it from the IR - return - } - val targetNames = resolveAddresses(indirectCall.target, indirectCall) - Logger.debug(s"Points-To approximated call ${indirectCall.target} with $targetNames") - Logger.debug(IRProgram.procedures) - val targets: mutable.Set[Procedure] = targetNames.map(name => IRProgram.procedures.find(_.name == name).getOrElse(addFakeProcedure(name))) - - if (targets.size == 1) { - modified = true - - // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) - val newCall = DirectCall(targets.head, indirectCall.label) - block.statements.replace(indirectCall, newCall) - } else if (targets.size > 1) { - modified = true - val procedure = c.parent.data - val newBlocks = ArrayBuffer[Block]() - // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) - var addressExprs = List[BinaryExpr]() - for (t <- targets) { - Logger.debug(targets) - // TODO handle external procedures without a set address but this requires more information than the analysis gives at present - val address = t.address.match { - case Some(a) => a - case None => throw Exception(s"resolved indirect call $indirectCall to procedure which does not have address: $t") - } - val addressExpr = BinaryExpr(BVEQ, indirectCall.target, BitVecLiteral(address, 64)) - addressExprs ::= addressExpr - val assume = Assume(addressExpr) - val newLabel: String = block.label + t.name - val directCall = DirectCall(t) - directCall.parent = indirectCall.parent - - assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) - val fallthrough = indirectCall.parent.jump - newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) - } - procedure.addBlocks(newBlocks) - val newCall = GoTo(newBlocks, indirectCall.label) - val addressExprOr = addressExprs.tail.foldLeft(addressExprs.head) { - (a: BinaryExpr, b: BinaryExpr) => BinaryExpr(BoolOR, a, b) - } - val assertion = Assert(addressExprOr, Some("check indirect call underapproximation")) - block.statements.append(assertion) - block.replaceJump(newCall) - } - case _ => - case _ => - } - - modified - } - /** Cull unneccessary information that does not need to be included in the translation, and infer stack regions, and * add in modifies from the spec. */ @@ -496,75 +226,9 @@ object IRTransform { val specModifies = ctx.specification.subroutines.map(s => s.name -> s.modifies).toMap ctx.program.setModifies(specModifies) + assert(invariant.singleCallBlockEnd(ctx.program)) } - // identify calls to pthread_create - // use analysis result to determine the third parameter's value (the function pointer) - // split off that procedure into new thread - // do reachability analysis - // also need a bit in the IR where it creates separate files - def splitThreads(program: Program, - pointsTo: Map[RegisterVariableWrapper, Set[RegisterVariableWrapper | MemoryRegion]], - regionContents: Map[MemoryRegion, Set[BitVecLiteral | MemoryRegion]], - reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])] - ): Unit = { - - // iterate over all commands - if call is to pthread_create, look up? - program.foreach(c => - c match { - case d: DirectCall if d.target.name == "pthread_create" => - - // R2 should hold the function pointer of the function that begins the thread - // look up R2 value using points to results - val R2 = Register("R2", 64) - val b = reachingDefs(d) - val R2Wrapper = RegisterVariableWrapper(R2, getDefinition(R2, d, reachingDefs)) - val threadTargets = pointsTo(R2Wrapper) - - if (threadTargets.size > 1) { - // currently can't handle case where the thread created is ambiguous - throw Exception("can't handle thread creation with more than one possible target") - } - - if (threadTargets.size == 1) { - - // not trying to untangle the very messy region resolution at present, just dealing with simplest case - threadTargets.head match { - case data: DataRegion => - val threadEntrance = program.procedures.find(_.name == data.regionIdentifier) match { - case Some(proc) => proc - case None => throw Exception("could not find procedure with name " + data.regionIdentifier) - } - val thread = ProgramThread(threadEntrance, mutable.LinkedHashSet(threadEntrance), Some(d)) - program.threads.addOne(thread) - case _ => - throw Exception("unexpected non-data region " + threadTargets.head + " as PointsTo result for R2 at " + d) - } - } - case _ => - }) - - - if (program.threads.nonEmpty) { - val mainThread = ProgramThread(program.mainProcedure, mutable.LinkedHashSet(program.mainProcedure), None) - program.threads.addOne(mainThread) - - val programProcs = program.procedures - - // do reachability for all threads - for (thread <- program.threads) { - val reachable = thread.entry.reachableFrom - - // add procedures to thread in way that maintains original ordering - for (p <- programProcs) { - if (reachable.contains(p)) { - thread.procedures.add(p) - } - } - - } - } - } } @@ -700,18 +364,20 @@ object StaticAnalysis { val memoryRegionContents = steensgaardSolver.getMemoryRegionContents mmm.logRegions(memoryRegionContents) + // turn fake procedures into diamonds + transforms.addReturnBlocks(ctx.program, true) // add return to all blocks because IDE solver expects it Logger.info("[!] Running VSA") val vsaSolver = ValueSetAnalysisSolver(IRProgram, globalAddresses, externalAddresses, globalOffsets, subroutines, mmm, constPropResult) val vsaResult: Map[CFGPosition, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]] = vsaSolver.analyze() Logger.info("[!] Running Interprocedural Live Variables Analysis") - //val interLiveVarsResults = InterLiveVarsAnalysis(IRProgram).analyze() - val interLiveVarsResults = Map[CFGPosition, Map[Variable, TwoElement]]() + val interLiveVarsResults = InterLiveVarsAnalysis(IRProgram).analyze() + // val interLiveVarsResults = Map[CFGPosition, Map[Variable, TwoElement]]() Logger.info("[!] Running Parameter Analysis") - //val paramResults = ParamAnalysis(IRProgram).analyze() - val paramResults = Map[Procedure, Set[Variable]]() + val paramResults = ParamAnalysis(IRProgram).analyze() + // val paramResults = Map[Procedure, Set[Variable]]() StaticAnalysisContext( cfg = cfg, @@ -926,6 +592,7 @@ object RunUtils { val boogieTranslator = IRToBoogie(ctx.program, ctx.specification, None, q.outputPrefix) ArrayBuffer(boogieTranslator.translate(q.boogieTranslation)) } + assert(invariant.singleCallBlockEnd(ctx.program)) BASILResult(ctx, analysis, boogiePrograms) } @@ -941,7 +608,7 @@ object RunUtils { val result = StaticAnalysis.analyse(ctx, config, iteration) analysisResult.append(result) Logger.info("[!] Replacing Indirect Calls") - modified = IRTransform.resolveIndirectCallsUsingPointsTo(result.cfg, + modified = transforms.resolveIndirectCallsUsingPointsTo(result.cfg, result.steensgaardResults, result.memoryRegionContents, result.reachingDefs, @@ -956,7 +623,7 @@ object RunUtils { // should later move this to be inside while (modified) loop and have splitting threads cause further iterations if (config.threadSplit) { - IRTransform.splitThreads(ctx.program, analysisResult.last.steensgaardResults, analysisResult.last.memoryRegionContents, analysisResult.last.reachingDefs) + transforms.splitThreads(ctx.program, analysisResult.last.steensgaardResults, analysisResult.last.memoryRegionContents, analysisResult.last.reachingDefs) } config.analysisDotPath.foreach { s => @@ -964,6 +631,7 @@ object RunUtils { writeToFile(newCFG.toDot(x => x.toString, Output.dotIder), s"${s}_resolvedCFG.dot") } + assert(invariant.singleCallBlockEnd(ctx.program)) Logger.info(s"[!] Finished indirect call resolution after $iteration iterations") analysisResult.last } diff --git a/src/test/scala/IndirectCallsTests.scala b/src/test/scala/IndirectCallsTests.scala index 60f7fdd67..15c635fc9 100644 --- a/src/test/scala/IndirectCallsTests.scala +++ b/src/test/scala/IndirectCallsTests.scala @@ -3,7 +3,7 @@ import ir.Endian.LittleEndian import org.scalatest.* import org.scalatest.funsuite.* import specification.* -import util.{BASILConfig, ILLoadingConfig, IRContext, RunUtils, StaticAnalysis, StaticAnalysisConfig, StaticAnalysisContext} +import util.{BASILConfig, ILLoadingConfig, IRContext, RunUtils, StaticAnalysis, StaticAnalysisConfig, StaticAnalysisContext, BASILResult} import java.io.IOException import java.nio.file.* @@ -76,14 +76,15 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) - case _ => + case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -113,14 +114,15 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -150,14 +152,15 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -193,14 +196,15 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -236,14 +240,15 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -278,14 +283,15 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -321,14 +327,15 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -341,7 +348,7 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before dumpIL = Some(tempPath + testName), ), outputPrefix = tempPath + testName, - staticAnalysis = Some(StaticAnalysisConfig(None, None, None)), + staticAnalysis = Some(StaticAnalysisConfig(Some("functionpointer"), None, None)), ) val result = loadAndTranslate(basilConfig) /* in this example we must find: @@ -356,17 +363,19 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before "l000004f3set_seven" -> ("set_seven", "R0") ) + println("prev " + result.ir.program) // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(block.label) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(block.label) => val callTransform = expectedCallTransform(block.label) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(block.label) case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -397,14 +406,15 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(block.label) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(block.label) => val callTransform = expectedCallTransform(block.label) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(block.label) case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -431,18 +441,19 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before "l0000044dset_two" -> ("set_two", "R0"), "l0000044dset_seven" -> ("set_seven", "R0") ) + result.ir.program.mainProcedure.blocks.foreach { + block => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(block.label) => + val callTransform = expectedCallTransform(block.label) + assert(callTransform._1 == directCall.target.name) + expectedCallTransform.remove(block.label) + case _ => + } + } // Traverse the statements in the main function - result.ir.program.mainProcedure.blocks.foreach { - block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(block.label) => - val callTransform = expectedCallTransform(block.label) - assert(callTransform._1 == directCall.target.name) - expectedCallTransform.remove(block.label) - case _ => - } - } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -470,8 +481,8 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) @@ -505,8 +516,8 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) @@ -540,8 +551,8 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) @@ -550,4 +561,4 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before } assert(expectedCallTransform.isEmpty) } -} \ No newline at end of file +} diff --git a/src/test/scala/LiveVarsAnalysisTests.scala b/src/test/scala/LiveVarsAnalysisTests.scala index 443dc011f..e1b142001 100644 --- a/src/test/scala/LiveVarsAnalysisTests.scala +++ b/src/test/scala/LiveVarsAnalysisTests.scala @@ -1,12 +1,14 @@ import analysis.{InterLiveVarsAnalysis, TwoElementTop} import ir.dsl.* -import ir.{BitVecLiteral, BitVecType, ConvertToSingleProcedureReturn, dsl, Assign, LocalVar, Program, Register, Statement, Variable} +import ir.{BitVecLiteral, BitVecType, dsl, Assign, LocalVar, Program, Register, Statement, Variable, transforms, cilvisitor, Procedure} +import util.{Logger, LogLevel} import org.scalatest.funsuite.AnyFunSuite import test_util.TestUtil import util.BASILResult class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { + Logger.setLevel(LogLevel.ERROR) def createSimpleProc(name: String, statements: Seq[Statement | EventuallyJump]): EventuallyProcedure = { proc(name, @@ -31,10 +33,12 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { block("first_call", r0ConstantAssign, r1ConstantAssign, - directCall("callee1", Some("second_call")) + directCall("callee1"), + goto("second_call") ), block("second_call", - directCall("callee2", Some("returnBlock")) + directCall("callee2"), + goto("returnBlock") ), block("returnBlock", ret @@ -44,15 +48,21 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { createSimpleProc("callee2", Seq(r2r1Assign)) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val liveVarAnalysisResults = InterLiveVarsAnalysis(program).analyze() + // fix for DSA pairs of results? val procs = program.procs - assert(liveVarAnalysisResults(procs("main")) == Map(R30 -> TwoElementTop)) - assert(liveVarAnalysisResults(procs("callee1")) == Map(R0 -> TwoElementTop, R1 -> TwoElementTop, R30 -> TwoElementTop)) - assert(liveVarAnalysisResults(procs("callee2")) == Map(R1 -> TwoElementTop, R30 -> TwoElementTop)) + println(liveVarAnalysisResults.filter((k,n) => k match { + case p => true + case _ => false + })) + // assert(liveVarAnalysisResults(procs("main")) == Map(R30 -> TwoElementTop)) + assert(liveVarAnalysisResults(procs("callee1")) == Map(R0 -> TwoElementTop, R1 -> TwoElementTop)) + assert(liveVarAnalysisResults(procs("callee2")) == Map(R1 -> TwoElementTop)) } @@ -69,10 +79,10 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { block("first_call", r0ConstantAssign, r1ConstantAssign, - directCall("callee1", Some("second_call")) + directCall("callee1"), goto("second_call") ), block("second_call", - directCall("callee2", Some("returnBlock")) + directCall("callee2"), goto("returnBlock") ), block("returnBlock", ret @@ -82,15 +92,16 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { createSimpleProc("callee2", Seq(r2r1Assign)) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val liveVarAnalysisResults = InterLiveVarsAnalysis(program).analyze() val procs = program.procs - assert(liveVarAnalysisResults(procs("main")) == Map(R30 -> TwoElementTop)) - assert(liveVarAnalysisResults(procs("callee1")) == Map(R0 -> TwoElementTop, R30 -> TwoElementTop)) - assert(liveVarAnalysisResults(procs("callee2")) == Map(R1 -> TwoElementTop, R30 -> TwoElementTop)) + // assert(liveVarAnalysisResults(procs("main")) == Map()) + assert(liveVarAnalysisResults(procs("callee1")) == Map(R0 -> TwoElementTop)) + assert(liveVarAnalysisResults(procs("callee2")) == Map(R1 -> TwoElementTop)) } def twoCallers(): Unit = { @@ -104,10 +115,10 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { var program = prog( proc("main", block("main_first_call", - directCall("wrapper1", Some("main_second_call")) + directCall("wrapper1"), goto("main_second_call") ), block("main_second_call", - directCall("wrapper2", Some("main_return")) + directCall("wrapper2"), goto("main_return") ), block("main_return", ret) ), @@ -117,30 +128,31 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { proc("wrapper1", block("wrapper1_first_call", Assign(R1, constant1), - directCall("callee", Some("wrapper1_second_call")) + directCall("callee"), goto("wrapper1_second_call") ), block("wrapper1_second_call", - directCall("callee2", Some("wrapper1_return"))), + directCall("callee2"), goto("wrapper1_return")), block("wrapper1_return", ret) ), proc("wrapper2", block("wrapper2_first_call", Assign(R2, constant1), - directCall("callee", Some("wrapper2_second_call")) + directCall("callee"), goto("wrapper2_second_call") ), block("wrapper2_second_call", - directCall("callee3", Some("wrapper2_return"))), + directCall("callee3"), goto("wrapper2_return")), block("wrapper2_return", ret) ) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val liveVarAnalysisResults = InterLiveVarsAnalysis(program).analyze() val blocks = program.blocks - assert(liveVarAnalysisResults(blocks("wrapper1_first_call").jump) == Map(R1 -> TwoElementTop, R30 -> TwoElementTop)) - assert(liveVarAnalysisResults(blocks("wrapper2_first_call").jump) == Map(R2 -> TwoElementTop, R30 -> TwoElementTop)) + assert(liveVarAnalysisResults(blocks("wrapper1_first_call").jump) == Map(R1 -> TwoElementTop)) + assert(liveVarAnalysisResults(blocks("wrapper2_first_call").jump) == Map(R2 -> TwoElementTop)) } @@ -148,7 +160,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { var program = prog( proc("main", block("lmain", - directCall("killer", Some("aftercall")) + directCall("killer"), goto("aftercall") ), block("aftercall", Assign(R0, R1), @@ -158,14 +170,15 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { createSimpleProc("killer", Seq(Assign(R1, bv64(1)))) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val liveVarAnalysisResults = InterLiveVarsAnalysis(program).analyze() val blocks = program.blocks - assert(liveVarAnalysisResults(blocks("aftercall")) == Map(R1 -> TwoElementTop, R30 -> TwoElementTop)) - assert(liveVarAnalysisResults(blocks("lmain")) == Map(R30 -> TwoElementTop)) + assert(liveVarAnalysisResults(blocks("aftercall")) == Map(R1 -> TwoElementTop)) + // assert(liveVarAnalysisResults(blocks("lmain")) == Map()) } def simpleBranch(): Unit = { @@ -193,15 +206,16 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { ) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val blocks = program.blocks val liveVarAnalysisResults = InterLiveVarsAnalysis(program).analyze() - assert(liveVarAnalysisResults(blocks("branch1")) == Map(R1 -> TwoElementTop, R30 -> TwoElementTop)) - assert(liveVarAnalysisResults(blocks("branch2")) == Map(R2 -> TwoElementTop, R30 -> TwoElementTop)) - assert(liveVarAnalysisResults(blocks("lmain")) == Map(R1 -> TwoElementTop, R2 -> TwoElementTop, R30 -> TwoElementTop)) + assert(liveVarAnalysisResults(blocks("branch1")) == Map(R1 -> TwoElementTop)) + assert(liveVarAnalysisResults(blocks("branch2")) == Map(R2 -> TwoElementTop)) + assert(liveVarAnalysisResults(blocks("lmain")) == Map(R1 -> TwoElementTop, R2 -> TwoElementTop)) } @@ -212,7 +226,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { block( "lmain", Assign(R0, R1), - directCall("main", Some("return")) + directCall("main"), goto("return") ), block("return", Assign(R0, R2), @@ -221,13 +235,14 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { ) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val liveVarAnalysisResults = InterLiveVarsAnalysis(program).analyze() val blocks = program.blocks - assert(liveVarAnalysisResults(program.mainProcedure) == Map(R1 -> TwoElementTop, R2 -> TwoElementTop, R30 -> TwoElementTop)) + assert(liveVarAnalysisResults(program.mainProcedure) == Map(R1 -> TwoElementTop, R2 -> TwoElementTop)) } def recursionBaseCase(): Unit = { @@ -240,7 +255,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { ), block( "recursion", - directCall("main", Some("assign")) + directCall("main"), goto("assign") ), block("assign", Assign(R0, R2), @@ -256,13 +271,14 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { ) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val liveVarAnalysisResults = InterLiveVarsAnalysis(program).analyze() val blocks = program.blocks - assert(liveVarAnalysisResults(program.mainProcedure) == Map(R1 -> TwoElementTop, R2 -> TwoElementTop, R30 -> TwoElementTop)) + assert(liveVarAnalysisResults(program.mainProcedure) == Map(R1 -> TwoElementTop, R2 -> TwoElementTop)) } test("differentCalleesBothAlive") { @@ -299,7 +315,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { val blocks = result.ir.program.blocks // main has a parameter, R0 should be alive - assert(analysisResults(blocks("lmain")) == Map(R0 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("lmain")) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) } test("function") { @@ -309,9 +325,8 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { // checks function call blocks assert(analysisResults(blocks("lmain")) == Map(R29 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) - assert(analysisResults(blocks("lget_two")) == Map(R30 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("lget_two")) == Map(R31 -> TwoElementTop)) assert(analysisResults(blocks("l00000946")) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) // aftercall block - assert(analysisResults(blocks("main_basil_return")) == Map(R30 -> TwoElementTop)) } @@ -323,9 +338,9 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { // main has parameter, callee (zero) has return and no parameter assert(analysisResults(blocks("lmain")) == Map(R0 -> TwoElementTop, R29 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) - assert(analysisResults(blocks("lzero")) == Map(R30 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("lzero")) == Map(R31 -> TwoElementTop)) assert(analysisResults(blocks("l00000323")) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) // aftercall block - assert(analysisResults(blocks("zero_basil_return")) == Map(R0 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("zero_basil_return")) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) } test("function1") { @@ -334,12 +349,12 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { val blocks = result.ir.program.blocks // main has no parameters, get_two has three and a return - assert(analysisResults(blocks("lmain")) == Map(R29 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) - assert(analysisResults(blocks("l000003ec")) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) // get_two aftercall - assert(analysisResults(blocks("l00000430")) == Map(R31 -> TwoElementTop)) // printf aftercall - assert(analysisResults(blocks("main_basil_return")) == Map(R30 -> TwoElementTop)) - assert(analysisResults(blocks("lget_two")) == Map(R0 -> TwoElementTop, R1 -> TwoElementTop, R2 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) - assert(analysisResults(blocks("get_two_basil_return")) == Map(R0 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("lmain").jump) == Map(R29 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("l000003ec").jump) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) // get_two aftercall + assert(analysisResults(blocks("l00000430").jump) == Map(R31 -> TwoElementTop)) // printf aftercall + assert(analysisResults(blocks("main_basil_return").jump) == Map(R30 -> TwoElementTop)) + assert(analysisResults(blocks("lget_two").jump) == Map(R0 -> TwoElementTop, R1 -> TwoElementTop, R2 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("get_two_basil_return").jump) == Map(R0 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) } test("ifbranches") { @@ -348,11 +363,11 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { val blocks = result.ir.program.blocks // block after branch - assert(analysisResults(blocks("l00000342")) == Map(R30 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("l00000342")) == Map(R31 -> TwoElementTop)) // branch blocks assert(analysisResults(blocks("lmain_goto_l00000330")) == Map(Register("ZF", 1) -> TwoElementTop, - R30 -> TwoElementTop, R31 -> TwoElementTop)) + R31 -> TwoElementTop)) assert(analysisResults(blocks("lmain_goto_l00000369")) == Map(Register("ZF", 1) -> TwoElementTop, - R30 -> TwoElementTop, R31 -> TwoElementTop)) + R31 -> TwoElementTop)) } } diff --git a/src/test/scala/PointsToTest.scala b/src/test/scala/PointsToTest.scala index 32131ed46..9534053a3 100644 --- a/src/test/scala/PointsToTest.scala +++ b/src/test/scala/PointsToTest.scala @@ -70,8 +70,8 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft ) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val results = runAnalyses(program) results.mmmResults.pushContext("main") @@ -99,8 +99,8 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft ) ) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val results = runAnalyses(program) results.mmmResults.pushContext("main") @@ -168,7 +168,7 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft goto("0x1") ), block("0x1", - directCall("p2", Some("returntarget")) + directCall("p2"), goto("returntarget") ), block("returntarget", ret @@ -186,8 +186,8 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft ) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val results = runAnalyses(program) results.mmmResults.pushContext("main") @@ -217,7 +217,7 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft goto("0x1") ), block("0x1", - directCall("p2", Some("returntarget")) + directCall("p2"), goto("returntarget") ), block("returntarget", ret @@ -227,7 +227,7 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft block("l_foo", Assign(getRegister("R0"), MemoryLoad(mem, BinaryExpr(BVADD, getRegister("R31"), bv64(6)), LittleEndian, 64)), Assign(getRegister("R1"), BinaryExpr(BVADD, getRegister("R31"), bv64(10))), - directCall("p2", Some("l_foo_1")) + directCall("p2"), goto("l_foo_1") ), block("l_foo_1", ret, @@ -245,8 +245,8 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft ) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val results = runAnalyses(program) results.mmmResults.pushContext("main") @@ -303,4 +303,4 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft // // runSteensgaardAnalysis(program, globals = globals, globalOffsets = globalOffsets) // } -} \ No newline at end of file +} diff --git a/src/test/scala/ir/IRTest.scala b/src/test/scala/ir/IRTest.scala index 7c06315d3..7421c2a28 100644 --- a/src/test/scala/ir/IRTest.scala +++ b/src/test/scala/ir/IRTest.scala @@ -4,7 +4,9 @@ import scala.collection.mutable import scala.collection.immutable.* import org.scalatest.funsuite.AnyFunSuite import util.intrusive_list.* +import translating.serialiseIL import ir.dsl.* +import ir._ class IRTest extends AnyFunSuite { @@ -57,31 +59,6 @@ class IRTest extends AnyFunSuite { } - test("removeblockinline") { - - val p = prog( - proc("main", - block("lmain", - goto("lmain1") - ), - block("lmain1", - goto("lmain2")), - block("lmain2", - ret) - ) - ) - - val blocks = p.collect { - case b: Block => b.label -> b - }.toMap - - p.procedures.head.removeBlocksInline(blocks("lmain1")) - - blocks("lmain").singleSuccessor.contains(blocks("lmain2")) - blocks("lmain2").singlePredecessor.contains(blocks("lmain")) - - } - test("simple replace jump") { val p = prog( @@ -142,7 +119,8 @@ class IRTest extends AnyFunSuite { ), block("l_main_1", Assign(R0, bv64(22)), - directCall("p2", Some("returntarget")) + directCall("p2"), + goto("returntarget") ), block("returntarget", ret @@ -154,34 +132,32 @@ class IRTest extends AnyFunSuite { ) ) + val blocks = p.collect { case b: Block => b.label -> b }.toMap - val directcalls = p.collect { case c: DirectCall => c } - assert(blocks("l_main_1").fallthrough.nonEmpty) - assert(p.toSet.contains(blocks("l_main_1").fallthrough.get)) - assert(directcalls.forall(c => IntraProcIRCursor.succ(c).count(_.asInstanceOf[GoTo].isAfterCall) == 1)) - assert(directcalls.forall(c => IntraProcBlockIRCursor.succ(c).count(_.isAfterCall) == 1)) + assert(p.toSet.contains(blocks("l_main_1").jump)) + assert(directcalls.forall(c => IntraProcIRCursor.succ(c).count(c => isAfterCall(c.asInstanceOf[Command])) == 1)) val afterCalls = p.collect { - case b: Block if b.isAfterCall => b + case b: Command if isAfterCall(b) => b }.toSet - assert(afterCalls.toSet == Set(blocks("returntarget"))) + assert(afterCalls.toSet == Set(blocks("l_main_1").jump)) val aftercallGotos = p.collect { - case c: Jump if c.isAfterCall => c + case c: Command if isAfterCall(c) => c }.toSet - assert(aftercallGotos == Set(blocks("l_main_1").fallthrough.get)) + // assert(aftercallGotos == Set(blocks("l_main_1").fallthrough.get)) assert(1 == aftercallGotos.count(b => IntraProcIRCursor.pred(b).contains(blocks("l_main_1").jump))) - assert(1 == aftercallGotos.count(b => IntraProcIRCursor.succ(b).contains(blocks("l_main_1").fallthrough.map(_.targets.head).head))) - - assert(afterCalls.forall(b => IntraProcBlockIRCursor.pred(b).contains(blocks("l_main_1")))) + assert(1 == aftercallGotos.count(b => IntraProcIRCursor.succ(b).contains(blocks("l_main_1").jump match { + case GoTo(targets, _) => targets.head + }))) } @@ -246,7 +222,8 @@ class IRTest extends AnyFunSuite { Assign(R0, bv64(22)), Assign(R0, bv64(22)), Assign(R0, bv64(22)), - directCall("main", None) + directCall("main"), + halt ).resolve(p) val b2 = block("newblock1", Assign(R0, bv64(22)), @@ -271,7 +248,8 @@ class IRTest extends AnyFunSuite { assert(called.incomingCalls().isEmpty) val b3 = block("newblock3", Assign(R0, bv64(22)), - directCall("called", None) + directCall("called"), + halt ).resolve(p) assert(b3.calls.toSet == Set(p.procs("called"))) @@ -283,11 +261,11 @@ class IRTest extends AnyFunSuite { assert(!oldb.hasParent) assert(oldb.incomingJumps.isEmpty) assert(!blocks("lmain").jump.asInstanceOf[GoTo].targets.contains(oldb)) - assert(called.incomingCalls().toSet == Set(b3.jump)) + assert(called.incomingCalls().toSet == Set(b3.statements.last)) assert(called.incomingCalls().map(_.parent.parent).toSet == called.callers().toSet) val olds = blocks.size p.mainProcedure.replaceBlock(b3, b3) - assert(called.incomingCalls().toSet == Set(b3.jump)) + assert(called.incomingCalls().toSet == Set(b3.statements.last)) assert(olds == blocks.size) p.mainProcedure.addBlocks(block("test", ret).resolve(p)) assert(olds != blocks.size) @@ -333,35 +311,35 @@ class IRTest extends AnyFunSuite { proc("main", block("l_main", Assign(R0, bv64(10)), - directCall("p1", Some("returntarget")) + directCall("p1"), goto("returntarget") ), block("returntarget", ret ) ), ) - val returnUnifier = ConvertToSingleProcedureReturn() - returnUnifier.visitProgram(p) + + cilvisitor.visit_prog(transforms.ReplaceReturns(), p) + transforms.addReturnBlocks(p) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), p) val next = InterProcIRCursor.succ(p.blocks("l_main").jump) val prev = InterProcIRCursor.pred(p.blocks("returntarget")) assert(prev.size == 1 && prev.collect { - case c : GoTo => (c.parent == p.blocks("l_main")) && c.isAfterCall + case c : GoTo => (c.parent == p.blocks("l_main")) }.contains(true)) - assert(next == Set(p.procs("p1"), p.blocks("l_main").fallthrough.get)) + // assert(next == Set(p.procs("p1"), p.blocks("l_main").fallthrough.get)) - val prevB: Block = (p.blocks("l_main").jump match - case c: IndirectCall => c.returnTarget - case c: DirectCall => c.returnTarget - case _ => None + val prevB: Command = (p.blocks("l_main").statements.lastOption match + case Some(c: IndirectCall) => c.returnTarget + case Some(c: DirectCall) => c.returnTarget + case o => None ).get - assert(prevB.isAfterCall) + assert(isAfterCall(prevB)) assert(InterProcIRCursor.pred(prevB).size == 1) - assert(InterProcIRCursor.pred(prevB).head == p.blocks("l_main").fallthrough.get) - assert(InterProcBlockIRCursor.pred(prevB).head == p.blocks("l_main"), p.procs("p1").returnBlock.get) } @@ -374,10 +352,10 @@ class IRTest extends AnyFunSuite { ), proc("main", block("l_main", - indirectCall(R1, Some("returntarget")) + indirectCall(R1), goto("returntarget") ), block("block2", - directCall("p1", Some("returntarget")) + directCall("p1"), goto("returntarget") ), block("returntarget", ret diff --git a/src/test/scala/ir/SingleCallInvariant.scala b/src/test/scala/ir/SingleCallInvariant.scala new file mode 100644 index 000000000..d8efb6fc2 --- /dev/null +++ b/src/test/scala/ir/SingleCallInvariant.scala @@ -0,0 +1,83 @@ +package ir + + +import ir.dsl._ + +import org.scalatest.funsuite.AnyFunSuite +class InvariantTest extends AnyFunSuite { + + test("sat singleCallBlockEnd case") { + var program: Program = prog( + proc("main", + block("first_call", + Assign(R0, bv64(10)), + Assign(R1, bv64(10)), + directCall("callee1"), + ret + ), + block("second_call", + Assign(R0, bv64(10)), + directCall("callee2"), + ret + ), + block("returnBlock", + ret + ) + ), + proc("callee1", block("bye1", ret)), + proc("callee2", block("bye2", ret)), + ) + + assert(invariant.singleCallBlockEnd(program)) + } + + test("unsat singleCallBlockEnd 1 (two calls)") { + var program: Program = prog( + proc("main", + block("first_call", + Assign(R0, bv64(10)), + directCall("callee2"), + Assign(R1, bv64(10)), + directCall("callee1"), + ret + ), + block("second_call", + Assign(R0, bv64(10)), + ret + ), + block("returnBlock", + ret + ) + ), + proc("callee1", block("bye1", ret)), + proc("callee2", block("bye2", ret)), + ) + + assert(!invariant.singleCallBlockEnd(program)) + } + + test("unsat singleCallBlockEnd 2 (not at end)") { + var program: Program = prog( + proc("main", + block("first_call", + Assign(R0, bv64(10)), + Assign(R1, bv64(10)), + ret + ), + block("second_call", + directCall("callee2"), + Assign(R0, bv64(10)), + ret + ), + block("returnBlock", + ret + ) + ), + proc("callee1", block("bye1", ret)), + proc("callee2", block("bye2", ret)), + ) + + assert(!invariant.singleCallBlockEnd(program)) + } + +} From fd56c9a1a7485eb039861a48ba7986ae3bc20d1a Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Fri, 9 Aug 2024 16:43:58 +1000 Subject: [PATCH 03/62] remove old cfg --- src/main/scala/analysis/Cfg.scala | 829 ------------------ src/main/scala/analysis/Dependencies.scala | 20 - .../analysis/InterLiveVarsAnalysis.scala | 8 +- .../InterprocSteensgaardAnalysis.scala | 4 +- .../analysis/IntraLiveVarsAnalysis.scala | 4 +- .../scala/analysis/MemoryRegionAnalysis.scala | 6 +- .../scala/analysis/RegToMemAnalysis.scala | 38 +- .../scala/analysis/solvers/IDESolver.scala | 19 +- src/main/scala/cfg_visualiser/Output.scala | 37 - src/main/scala/ir/IRCursor.scala | 11 +- src/main/scala/ir/Interpreter.scala | 4 +- src/main/scala/ir/Program.scala | 39 +- src/main/scala/ir/Statement.scala | 4 +- src/main/scala/ir/dsl/DSL.scala | 8 +- .../transforms/IndirectCallResolution.scala | 280 ++---- .../scala/ir/transforms/ReplaceReturn.scala | 3 +- .../scala/ir/transforms/SplitThreads.scala | 1 - .../StripUnreachableFunctions.scala | 38 + src/main/scala/translating/BAPToIR.scala | 4 +- src/main/scala/translating/GTIRBToIR.scala | 12 +- src/main/scala/translating/ILtoIL.scala | 16 +- src/main/scala/translating/IRToBoogie.scala | 2 +- src/main/scala/util/RunUtils.scala | 142 +-- src/test/scala/IndirectCallsTests.scala | 11 - src/test/scala/LiveVarsAnalysisTests.scala | 23 +- .../scala/MemoryRegionAnalysisMiscTest.scala | 2 +- src/test/scala/ir/IRTest.scala | 4 +- src/test/scala/ir/InterpreterTests.scala | 2 +- 28 files changed, 215 insertions(+), 1356 deletions(-) delete mode 100644 src/main/scala/analysis/Cfg.scala delete mode 100644 src/main/scala/cfg_visualiser/Output.scala create mode 100644 src/main/scala/ir/transforms/StripUnreachableFunctions.scala diff --git a/src/main/scala/analysis/Cfg.scala b/src/main/scala/analysis/Cfg.scala deleted file mode 100644 index 8807f2d2a..000000000 --- a/src/main/scala/analysis/Cfg.scala +++ /dev/null @@ -1,829 +0,0 @@ -package analysis - -import scala.collection.mutable -import ir.* -import cfg_visualiser.{DotArrow, DotGraph, DotInlineArrow, DotInterArrow, DotIntraArrow, DotNode, DotRegularArrow} - -import scala.collection.mutable.{ArrayBuffer, ListBuffer} -import scala.util.control.Breaks.break -import util.Logger - -import scala.annotation.tailrec - -/** Node in the control-flow graph. - */ -object CfgNode: - - var id: Int = 0 - - def nextId(): Int = - id += 1 - id - -/** Node in the control-flow graph. Each node has a (simple incremental) unique identifier used to distinguish it from - * other nodes in the cfg - this is mainly used for copying procedure cfgs when inlining them. - * - * Each node will store four separate sets: ingoing/outgoing of both inter-/intra-procedural CFG edges. Both intra and - * inter will also store regular edges in the cfg. This is duplication of storage, however is done so knowingly. - * - * By separating sets into inter/intra we are able to return these directly without doing any processing. Alternative - * means of achieving this same behaviour would involve some form of set operations, or a filter operation, both of - * which can be expensive, especially as the successors/predecessors will be accessed frequently by analyses. - * Additionally, inspecting the space complexity of these sets, we note that their sizes should be relatively limited: - * a. #(outgoing edges) <= 2 b. #(incoming edges) ~ N Thus in `succIntra` + `succInter` we have at most 4 elements. - * For `predIntra` + `predInter` we have a maximum of 2N, resulting from the case that this node is a block entry - * that is jumped to by N other nodes. It should be noted that in the majority of cases (statements which are - * neither the start of blocks nor jumps), both sets will be of size 1, making the storage complexity negligible. - * This is a point which can be optimised upon however. - * - * A node can have three main types of connected edges: - * a. A regular edge A regular edge connects two statements which only relate to the current procedure's context. b. - * Intra-procedural edge Intra-procedural edges connect a call node with the subsequent cfg node in a way that - * bypasses dealing with the semantics of the callee - it is up to analyses to determine how to treat such a call. - * c. Inter-procedural edge These are split into two cases: - * i. Inline edge These connect call nodes with an inlined copy of the target's procedure body. The exit of the - * target procedure's clone is also linked back to the caller via an inline edge. For an inline limit of `n`, - * these are the inter-procedural edges for depth 0 <= i < n ii. Call edge These connect leaf call nodes (calls - * which are not inlined) to the start of the independent cfg of the call's target. For an inline limit of `n`, - * these are the inter-procedural edges at depth i == n. - */ -trait CfgNode: - - /** Edges to this node from regular statements or ignored procedure calls. - * - */ - val predIntra: mutable.Set[CfgNode] = mutable.Set() - - /** Edges to this node from procedure calls. Likely empty unless this node is a [[CfgFunctionEntryNode]] - * - */ - val predInter: mutable.Set[CfgNode] = mutable.Set() - - /** Edges to successor nodes, either regular or ignored procedure calls - * - */ - val succIntra: mutable.Set[CfgNode] = mutable.Set() - - /** Edges to successor procedure calls. Used when walking inter-proc cfg. - * - */ - val succInter: mutable.Set[CfgNode] = mutable.Set() - - /** Unique identifier. */ - val id: Int = CfgNode.nextId() - def copyNode(): CfgNode - - override def equals(obj: scala.Any): Boolean = - obj match - case o: CfgNode => o.id == this.id - case _ => false - - override def hashCode(): Int = id.hashCode() - -/** Control-flow graph node that additionally stores an AST node. - */ -trait CfgNodeWithData[T] extends CfgNode { - val data: T -} - -/** Control-flow graph node for the entry of a function. - */ -class CfgFunctionEntryNode(val data: Procedure) extends CfgNodeWithData[Procedure]: - val callers: mutable.Set[CfgFunctionEntryNode] = mutable.Set[CfgFunctionEntryNode]() - override def toString: String = s"[FunctionEntry] $data" - - /** Copy this node, but give unique ID and reset edges */ - override def copyNode(): CfgFunctionEntryNode = CfgFunctionEntryNode(data) - -/** Control-flow graph node for the exit of a function. - */ -class CfgFunctionExitNode(val data: Procedure) extends CfgNodeWithData[Procedure]: - override def toString: String = s"[FunctionExit] $data" - - /** Copy this node, but give unique ID and reset edges */ - override def copyNode(): CfgFunctionExitNode = CfgFunctionExitNode(data) - -/** CFG node immediately proceeding a indirect call. This signifies that the call is a return from the current context - * (i.e., likely an indirect call to R30). Its purpose is to provide a way for analyses to identify whether they should - * return to the previous function context, if it is a context dependent analyses, and otherwise can be ignored. - * - * In the cfg we treat this as a stepping stone to `CfgFunctionExitNode`, as a way to emphasise that the current - * procedure has no functionality past this point. - */ -class CfgProcedureReturnNode() extends CfgNode: - override def toString: String = s"[ProcedureReturn]" - - /** Copy this node, but give unique ID and reset edges */ - override def copyNode(): CfgProcedureReturnNode = CfgProcedureReturnNode() - -/** CFG node immediately proceeding a direct/indirect call, if that call has no specified return block. There are a few - * reasons this can occur: - * a. It is not expected that the program will return from the callee b. The lifter has erroneously labelled a call - * as a jump / mislabelled a function name, which was then not associated with a block in some later stage. This - * happened with a call to `__gmon_start_`, which was optimised from a call to a jump which was incorrectly - * interpreted by the lifter. c. The indirect call is some other form of return-to-caller (which does not use - * R30). These are currently unhandled, and could potentially be integrated - e.g., sometimes R17 and R16 can be - * used in a similar way to R30. - * https://blog.tomzhao.me/wp-content/uploads/2021/08/Procedure_Call_Standard_in_Armv8_54f88cbfe905409aaff956ac2d1ad059.pdf - * - * In the cfg this is similarly used as a stepping stone to `CfgFunctionExitNode`. - */ -class CfgCallNoReturnNode() extends CfgNode: - override def toString: String = s"[Call NoReturn]" - - /** Copy this node, but give unique ID and reset edges */ - override def copyNode(): CfgCallNoReturnNode = CfgCallNoReturnNode() - -/** CFG node immediately proceeding a direct/indirect call, if that call has a return location specified. This serves as - * a point for analysis to stop and process update their states after handling a procedure call before continuing - * within the current contex. For example, a context sensitive analysis will return to this node after reaching a - * procedure return within the caller. It will then restore and update its context from before the call, before - * continuing on within the original procedure. - * - * Effectively, this just splits a procedure call from a single `Jmp` node into two - the call, and the return point. - * Incoming edges to the `Jmp` are then incoming edges to the respective `CfgJumpnode`, and outgoing edges from the - * `Jmp` are then outgoing edges of the `CfgCallReturnNode`. It is functionally in the same spirit as - * `CfgCallNoReturnNode`, though handles the case that this procedure still has functionality to be explored. - */ -class CfgCallReturnNode() extends CfgNode: - override def toString: String = s"[Call Return]" - - /** Copy this node, but give unique ID and reset edges */ - override def copyNode(): CfgCallReturnNode = CfgCallReturnNode() - -/** Control-flow graph node for a command (statement or jump). - */ -trait CfgCommandNode extends CfgNodeWithData[Command] { - override def copyNode(): CfgCommandNode - val block: Block - val parent: CfgFunctionEntryNode -} - -/** CFG's representation of a single statement. - */ -class CfgStatementNode( - val data: Statement, - val block: Block, - val parent: CfgFunctionEntryNode -) extends CfgCommandNode: - override def toString: String = s"[Stmt] $data" - - /** Copy this node, but give unique ID and reset edges */ - override def copyNode(): CfgStatementNode = CfgStatementNode(data, block, parent) - -/** CFG's representation of a jump. This is used as a general jump node, for both indirect and direct calls. - */ -class CfgJumpNode( - val data: Jump | DirectCall | IndirectCall, - val block: Block, - val parent: CfgFunctionEntryNode -) extends CfgCommandNode: - override def toString: String = s"[Jmp] $data" - - /** Copy this node, but give unique ID and reset edges */ - override def copyNode(): CfgJumpNode = CfgJumpNode(data, block, parent) - -/** A general purpose node which in terms of the IR has no functionality, but can have purpose in the CFG. As example, - * this is used as a "block" start node for the case that a block contains no statements, but has a `GoTo` as its jump. - * In this case we introduce a ghost node as the start of the block for the case that some part of the program jumps - * back to this conditional jump (e.g. in the case of loops). - */ -class CfgGhostNode( - val block: Block, - val parent: CfgFunctionEntryNode, - val data: NOP -) extends CfgCommandNode: - override def toString: String = s"[NOP] $data" - - /** Copy this node, but give unique ID and reset edges */ - override def copyNode(): CfgGhostNode = CfgGhostNode(block, parent, data) - -/** A control-flow graph. Nodes provide the ability to walk it as both an intra and inter procedural CFG. - */ -class ProgramCfg: - - var startNode: CfgFunctionEntryNode = _ - var nodes: mutable.Set[CfgNode] = mutable.Set() - var funEntries: mutable.Set[CfgFunctionEntryNode] = mutable.Set() - - /** Inline edges are for connecting an intraprocedural cfg with a copy of another procedure's intraprocedural cfg - * which is placed inside this one. They are considered interprocedural edges, and will not be followed if the caller - * requests an intraprocedural cfg. - */ - def addInlineEdge(from: CfgNode, to: CfgNode): Unit = { - from.succInter += to - to.predInter += from - } - - /** Interprocedural call edges connect an intraprocedural cfg with another procedure's intraprocedural cfg that it is - * calling. - */ - def addInterprocCallEdge(from: CfgNode, to: CfgNode): Unit = { - from.succInter += to - to.predInter += from - } - - /** Intraprocedural edges are for connecting call nodes to the call's return node, without following the call itself - * (stepping over the call). - */ - def addIntraprocEdge(from: CfgNode, to: CfgNode): Unit = { - from.succIntra += to - to.predIntra += from - } - - /** Regular edges are normal control flow - used in both inter-/intra-procedural cfgs. - */ - def addRegularEdge(from: CfgNode, to: CfgNode): Unit = { - from.succInter += to - from.succIntra += to - to.predInter += from - to.predIntra += from - } - - /** Add an outgoing edge from the current node, taking into account any conditionals on this jump. Note that we have - * some duplication of storage here - this is a performance consideration. We don't expect too many edges for any - * given node, and so the increased storage is relatively minimal. This saves having to filter / union sets when - * trying to retrieve only an intra/inter cfg, hopefully improving computation time. - * - * NOTE: this function attempts to "smartly" identify how to connect two edges. Perhaps as the CFG changes however - * different requirements will be made of nodes, and so the conditions on edges below may change. In that case, - * either update the below, or explicitly specify the edge to be added between two nodes. - * - * @param from - * The originating node - * @param to - * The destination node - */ - def addEdge(from: CfgNode, to: CfgNode): Unit = { - - (from, to) match { - // Ignored procedure (e.g. library calls such as @printf) - case (from: CfgFunctionEntryNode, to: CfgFunctionExitNode) => addRegularEdge(from, to) - // Calling procedure (follow as inline) - // This to be used if inlining skips the call node and links the most recent statement to the first statement of the target - case (from: CfgCommandNode, to: CfgFunctionEntryNode) => addInlineEdge(from, to) - // Returning from procedure (follow as inline - see above) - case (from: CfgFunctionExitNode, to: CfgNode) => addInlineEdge(from, to) - // First instruction of procedure - case (from: CfgFunctionEntryNode, to: CfgNode) => addRegularEdge(from, to) - // Function call which returns to the previous context - case (from: CfgJumpNode, to: CfgProcedureReturnNode) => addRegularEdge(from, to) - // Edge to intermediary return node (no semantic meaning, a cfg convenience edge) - case (from: CfgJumpNode, to: (CfgCallReturnNode | CfgCallNoReturnNode)) => addIntraprocEdge(from, to) - // Pre-exit nodes - case (from: (CfgProcedureReturnNode | CfgCallNoReturnNode | CfgCallReturnNode), to: CfgFunctionExitNode) => - addRegularEdge(from, to) - // Regular continuation of execution - case (from: CfgCallReturnNode, to: CfgCommandNode) => addRegularEdge(from, to) - // Regular flow of instructions - case (from: CfgCommandNode, to: (CfgCommandNode | CfgFunctionExitNode)) => addRegularEdge(from, to) - case _ => throw new Exception(s"[!] Unexpected edge combination when adding cfg edge between $from -> $to.") - } - - nodes += from - nodes += to - } - - /** Returns a Graphviz dot representation of the CFG. Each node is labeled using the given function labeler. - */ - def toDot(labeler: CfgNode => String, idGen: (CfgNode, Int) => String): String = { - val dotNodes = mutable.Map[CfgNode, DotNode]() - var dotArrows = mutable.ListBuffer[DotArrow]() - var uniqueId = 0 - nodes.foreach { n => - dotNodes += (n -> DotNode(s"${idGen(n, uniqueId)}", labeler(n))) - uniqueId += 1 - } - nodes.foreach { n => - - val successors = n.succIntra.toSet.union(n.succInter) - - successors.foreach { s => - (n, s) match { - case (from: CfgFunctionEntryNode, to: CfgNode) => - dotArrows += DotRegularArrow(dotNodes(n), dotNodes(to)) - case (from: CfgJumpNode, to: CfgProcedureReturnNode) => - dotArrows += DotRegularArrow(dotNodes(n), dotNodes(to)) - case (from: (CfgProcedureReturnNode | CfgCallNoReturnNode | CfgCallReturnNode), to: CfgFunctionExitNode) => - dotArrows += DotRegularArrow(dotNodes(n), dotNodes(to)) - case (from: CfgCallReturnNode, to: CfgCommandNode) => - dotArrows += DotRegularArrow(dotNodes(n), dotNodes(to)) - case (from: CfgCommandNode, to: (CfgCommandNode | CfgFunctionExitNode)) => - dotArrows += DotRegularArrow(dotNodes(n), dotNodes(to)) - case (from: CfgCommandNode, to: CfgFunctionEntryNode) => - DotInlineArrow(dotNodes(n), dotNodes(to)) - case (from: CfgFunctionExitNode, to: CfgNode) => - DotInlineArrow(dotNodes(n), dotNodes(to)) - case (from: CfgJumpNode, to: (CfgCallReturnNode | CfgCallNoReturnNode)) => - dotArrows += DotIntraArrow(dotNodes(n), dotNodes(to)) - /* - Displaying the below in the CFG is mostly for debugging purposes. With it included the CFG becomes a little unreadable, but - will emphasise that the leaf-call nodes are linked to the start of the procedures they're calling (as green inter-procedural edges). - To verify this is still happening, simply uncomment the below and it will add these edges. - case (from: CfgCommandNode, to: CfgFunctionEntry) => - dotArrows += DotInterArrow(dotNodes(n), dotNodes(to)) - */ - - case _ => - } - } - } - dotArrows = dotArrows.sortBy(arr => arr.fromNode.id + "-" + arr.toNode.id) - val allNodes = dotNodes.values.toList.sortBy(n => n.id) - DotGraph("CFG", allNodes, dotArrows).toDotString - } - - override def toString: String = { - val sb = StringBuilder() - sb.append("CFG {") - sb.append(" nodes: ") - sb.append(nodes) - sb.append("}") - sb.toString() - } - -/** Control-flow graph for an entire program. We have a more granular approach, storing commands as nodes instead of - * basic blocks. - */ -class ProgramCfgFactory: - val cfg: ProgramCfg = ProgramCfg() - - // Mapping from procedures to the start of their individual (intra) cfgs - val procToCfg: mutable.Map[Procedure, (CfgFunctionEntryNode, CfgFunctionExitNode)] = mutable.Map() - // Mapping from procedures to procedure call nodes (all the calls made within this procedure, including inlined functions) - val procToCalls: mutable.Map[Procedure, mutable.Set[CfgJumpNode]] = mutable.Map() - // Mapping from procedure entry instances to procedure call nodes within that procedure's instance (`CfgCommandNode.data <: DirectCall`) - // Updated on first creation of a new procedure (e.g. in initial creation, or in cloning of a procedure's cfg) - val callToNodes: mutable.Map[CfgFunctionEntryNode, mutable.Set[CfgJumpNode]] = mutable.Map() - // Mapping from procedures to nodes in any node in the cfg which has a call to that procedure - val procToCallers: mutable.Map[Procedure, mutable.Set[CfgJumpNode]] = mutable.Map() - - /** Generate the cfg for each function of the program. NOTE: is this functionally different to a constructor? Do we - * ever expect to generate a CFG from any other data structure? If not then the `class` could probably be absorbed - * into this object. - * - * @param program - * Basil IR of the program - * @param inlineLimit - * How many levels deep to inline function calls. Default is 3 - */ - def fromIR(program: Program, unify: Boolean = true, inlineLimit: Int = 0): ProgramCfg = { - CfgNode.id = 0 - require(inlineLimit >= 0, "Can't inline procedures to negative depth...") - Logger.info("[+] Generating CFG...") - - // Have to initialise the map entries manually. Scala maps have a `.withDefaulValue`, but this is buggy and doesn't - // behave as you would expect: https://github.com/scala/bug/issues/8099 - thus the manual approach. - // We don't initialise `procToCfg` here, because it will never be accessed before `cfgForProcedure`, - // and because it relies on the entry/exit nodes be initialised. It is initialised in `cfgForProcedure`. - program.procedures.foreach(proc => - procToCalls += (proc -> mutable.Set()) - procToCallers += (proc -> mutable.Set()) - ) - - // Create CFG for individual procedures - program.procedures.foreach( - proc => { - val funcEntryNode: CfgFunctionEntryNode = CfgFunctionEntryNode(proc) - val funcExitNode: CfgFunctionExitNode = CfgFunctionExitNode(proc) - cfg.nodes += funcEntryNode - cfg.nodes += funcExitNode - cfg.funEntries += funcEntryNode - - procToCfg += (proc -> (funcEntryNode, funcExitNode)) - callToNodes += (funcEntryNode -> mutable.Set()) - } - ) - program.procedures.foreach(proc => cfgForProcedure(proc)) - - // Inline functions up to `inlineLimit` level - // EXTENSION; one way to improve this would be to specify inline depths for specific functions / situations. - // i.e. we may not want to inline self-recursive functions too much. - // Of note is whether we want this at all or note. If not, then we can simply remove the below and pass `procCallNodes` to - // `addInterprocEdges`. - val procCallNodes: Set[CfgJumpNode] = procToCalls.values.flatten.toSet - val leafCallNodes: Set[CfgJumpNode] = - if !unify then inlineProcedureCalls(procCallNodes, inlineLimit) else procCallNodes - - // Add inter-proc edges to leaf call nodes - if (leafCallNodes.nonEmpty) { - addInterprocEdges(leafCallNodes) - } - - cfg.startNode = procToCfg(program.mainProcedure)._1 - - cfg - } - - /** Create an intraprocedural CFG for the given IR procedure. The start of the CFG for a procedure is identified by - * its `CfgFunctionEntryNode`, and its closure is identified by the `CfgFunctionExitNode`. - * - * @param proc - * Procedure for which to generate the intraprocedural cfg - */ - private def cfgForProcedure(proc: Procedure): Unit = { - val funcEntryNode: CfgFunctionEntryNode = procToCfg(proc)._1 - val funcExitNode: CfgFunctionExitNode = procToCfg(proc)._2 - - // Track blocks we've already processed so we don't double up - val visitedBlocks: mutable.Map[Block, CfgCommandNode] = mutable.Map() - - // Procedure has no content (in our case this probably means it's an ignored procedure, e.g., an external function such as @printf) - if (proc.blocks.isEmpty) { - cfg.addEdge(funcEntryNode, funcExitNode) - } else { - // Recurse through blocks - visitBlock(proc.entryBlock.get, funcEntryNode) - } - - /** Add a block to the CFG. A block in this case is a basic block, so it contains a list of consecutive statements - * followed by a jump at the end to another block. We process statements in this block (if they exist), and then - * follow the jump to recurse through all other blocks. - * - * This recursive approach is effectively a "reaches" approach, and will miss cases that we encounter a jump we - * can't resolve, or cases where the lifter has not identified a section of code. In each case: - * a. The only jumps we can't resolve are indirect calls. It's the intent of the tool to attempt to resolve these - * through analysis however. The CFG can then be updated as these are resolved to incorporate their jumps. In - * construction we do a simple check for register R30 to identify if an indirect call is a return, but - * otherwise consider it as unresolved. b. If the lifter has failed to identify a region of code, then the - * problem exists at the lifter level. In that case we need a way to coerce the lifter into identifying it, or - * to use a new lifter. - * - * These visitations will also only produce the intra-procedural CFG - the burden of "creating" the - * inter-procedural CFG is left to processes later during CFG construction. The benefit of doing this is that we - * can completely resolve a procedure's CFG without jumping to other procedures mid-way through processing, which - * assures we don't have any issues with referencing nodes before they exist. Essentially this is a depth-first - * approach to CFG construction, as opposed to a breadth-first. - * - * @param block - * The block being added to the CFG. - * @param prevBlockEnd - * Preceding block's end node (jump) - */ - def visitBlock(block: Block, prevBlockEnd: CfgNode): Unit = { - - if (block.statements.nonEmpty) { - val endStmt = visitStmts(block.statements, prevBlockEnd) - visitJump(block.jump, endStmt, false) - } else { - // Only jumps in this block - visitJump(block.jump, prevBlockEnd, true) - } - - /** If a block has statements, we add them to the CFG. Blocks in this case are basic blocks, so we know - * consecutive statements will be linked by an unconditional, regular edge. - * - * @param stmts - * Statements in this block - * @param prevNode - * Preceding block's end node (jump) - * @return - * The last statement's CFG node - */ - def visitStmts(stmts: Iterable[Statement], prevNode: CfgNode): CfgCommandNode = { - - val firstNode = CfgStatementNode(stmts.head, block, funcEntryNode) - cfg.addEdge(prevNode, firstNode) - visitedBlocks += (block -> firstNode) // This is guaranteed to be entrance to block if we are here - - val statements = List.from(stmts).map(s => s match { - case d: DirectCall => CfgJumpNode(d, block, funcEntryNode) - case d: IndirectCall => CfgJumpNode(d, block, funcEntryNode) - case o => CfgStatementNode(o, block, funcEntryNode) - }) - val succs = if (statements.nonEmpty) then statements.zip(statements.tail ++ List(CfgJumpNode(statements.head.data.parent.jump, block, funcEntryNode))) else List() - - for ((s,nexts) <- succs) { - s.data match { - case dCall: DirectCall => - - var precNode = prevNode - - val targetProc: Procedure = dCall.target - funcEntryNode.callers.add(procToCfg(targetProc)._1) - - val callNode : CfgJumpNode = s.asInstanceOf[CfgJumpNode] - - // Branch to this call - cfg.addEdge(precNode, callNode) - - procToCalls(proc) += callNode - procToCallers(targetProc) += callNode - callToNodes(funcEntryNode) += callNode - - // Record call association - - // Jump to return location - val returnTarget = nexts - // Add intermediary return node (split call into call and return) - val callRet = CfgCallReturnNode() - cfg.addEdge(callNode, callRet) - cfg.addEdge(callRet, returnTarget) - case iCall: IndirectCall => - Logger.debug(s"Indirect call found: $iCall in ${proc.name}") - var precNode = prevNode - - val jmpNode = s.asInstanceOf[CfgJumpNode] - // Branch to this call - cfg.addEdge(precNode, jmpNode) - - // Record call association - procToCalls(proc) += jmpNode - callToNodes(funcEntryNode) += jmpNode - - // R30 is the link register - this stores the address to return to. - // For now just add a node expressing that we are to return to the previous context. - if (iCall.target == Register("R30", 64)) { - val returnNode = CfgProcedureReturnNode() - cfg.addEdge(jmpNode, returnNode) - cfg.addEdge(returnNode, funcExitNode) - } - - val callRet = CfgCallReturnNode() - cfg.addEdge(jmpNode, callRet) - val returnTarget = nexts - cfg.addEdge(callRet, jmpNode) - case h: Halt => { - assert(false); - // not possible since s is only Statement. - } - case _ => () - } - } - - - if (stmts.size == 1) { - return firstNode - } - - var prevStmtNode: CfgStatementNode = firstNode - - stmts.tail.foreach(stmt => - val stmtNode = CfgStatementNode(stmt, block, funcEntryNode) - cfg.addEdge(prevStmtNode, stmtNode) - prevStmtNode = stmtNode - ) - - prevStmtNode - } - - /** All blocks end with jump(s), whereas some also start with a jump (in the case of no statements). Add these to - * the CFG and visit their target blocks for processing. - * - * @param jmps - * Jumps in the current block being processed - * @param prevNode - * Either the previous statement in the block, or the previous block's end node (in the case that this block - * contains no statements) - * @param solitary - * `True` if this block contains no statements, `False` otherwise - */ - def visitJump(jmp: Jump, prevNode: CfgNode, solitary: Boolean): Unit = { - val jmpNode = CfgJumpNode(jmp, block, funcEntryNode) - var precNode = prevNode - - if (solitary) { - /* If the block contains only jumps (no statements), then the "start" of the block is a jump. - If this is a direct call, then we simply use that call node as the start of the block. - However, GoTos in the CFG are resolved as edges, and so there doesn't exist a node to use as - the start. Thus we introduce a "ghost" node to act as that jump point - it has no functionality - and will simply be skipped by analyses. - - Currently we display these nodes in the DOT view of the CFG, however these could be hidden if desired. - */ - jmp match { - case jmp: GoTo => - // `GoTo`s are just edges, so introduce a fake `start of block` that can be jmp'd to - val ghostNode = CfgGhostNode(block, funcEntryNode, NOP(jmp.label)) - cfg.addEdge(prevNode, ghostNode) - precNode = ghostNode - visitedBlocks += (block -> ghostNode) - case _ => - // (In)direct call - use this as entrance to block - visitedBlocks += (block -> jmpNode) - } - } - - jmp match { - case n: GoTo => - for (targetBlock <- n.targets) { - if (visitedBlocks.contains(targetBlock)) { - val targetBlockEntry: CfgCommandNode = visitedBlocks(targetBlock) - cfg.addEdge(precNode, targetBlockEntry) - } else { - visitBlock(targetBlock, precNode) - } - } - case h: Halt => { - cfg.addEdge(jmpNode, funcExitNode) - } - case r: Return => - // Branch to this call - cfg.addEdge(precNode, jmpNode) - - // Record call association - procToCalls(proc) += jmpNode - callToNodes(funcEntryNode) += jmpNode - - val returnNode = CfgProcedureReturnNode() - cfg.addEdge(jmpNode, returnNode) - cfg.addEdge(returnNode, funcExitNode) - } // `jmps.head` match - } // `visitJumps` function - } // `visitBlocks` function - } // `cfgForProcedure` function - - /** This takes an expression used in a conditional (jump) and tries to negate it in a (hopefully) nice way. Most - * conditional jumps are just bitvector comparisons. - * - * @param expr - * The expression to negate - * @return - * The negated expression - */ - private def negateConditional(expr: Expr): Expr = expr match { - case binop: BinaryExpr => - binop.op match { - case BVNEQ => - BinaryExpr( - BVEQ, - binop.arg1, - binop.arg2 - ) - case BVEQ => - BinaryExpr( - BVNEQ, - binop.arg1, - binop.arg2 - ) - case _ => - // Worst case scenario we just take the logical not of everything - UnaryExpr( - BoolNOT, - binop - ) - } - case unop: UnaryExpr => - unop.op match { - case BVNOT | BoolNOT => - unop.arg - case _ => - UnaryExpr( - BoolNOT, - unop - ) - } - case _ => - UnaryExpr( - BoolNOT, - expr - ) - } - - /** Recursively inline procedures. This has a dumb/flat approach - we simply continue inlining each all direct calls - * until we either run out of direct calls, or we are at our max inline depth. - * - * For each direct call to be inlined we make a copy of the target's intraprocedural cfg, which is then linked to the - * calling procedure's cfg. We keep track of newly found direct calls that come from inlined functions, which is what - * we pass to the next recursive call. At the end of recursion this set stores the leaf nodes of the cfg - this is - * then used later to link interprocedural calls. - * - * @param procNodes - * The call nodes to inline - * @param inlineAmount - * Maximum amount of inlining from this depth allowed - * @return - * Tthe next leaf call nodes - */ - @tailrec - private def inlineProcedureCalls(procNodes: Set[CfgJumpNode], inlineAmount: Int): Set[CfgJumpNode] = { - assert(inlineAmount >= 0) - Logger.info(s"[+] Inlining ${procNodes.size} leaf call nodes with $inlineAmount level(s) left") - - if (inlineAmount == 0 || procNodes.isEmpty) { - return procNodes - } - - // Set of procedure calls to be discovered by inlining the ones in `procNodes` - val nextProcNodes: mutable.Set[CfgJumpNode] = mutable.Set() - - procNodes.foreach { procNode => - procNode.data match { - case targetCall: DirectCall => - // Retrieve information about the call to the target procedure - val targetProc = targetCall.target - val (procEntry, procExit) = cloneProcedureCFG(targetProc) - - // Add link between call node and the procedure's `Entry`. - cfg.addInlineEdge(procNode, procEntry) - - // Link the procedure's `Exit` to the return point. There should only be one. - assert( - procNode.succIntra.size == 1, - s"More than 1 return node... $procNode has ${procNode.succIntra}" - ) - val returnNode = procNode.succIntra.head - cfg.addInlineEdge(procExit, returnNode) - - // Add new (un-inlined) function calls to be inlined - nextProcNodes ++= callToNodes(procEntry) - case _ => - } - } - - inlineProcedureCalls(nextProcNodes.toSet, inlineAmount - 1) - } - - /** Clones the intraproc-cfg of the given procedure, with unique CfgNode ids. Adds the new nodes to the cfg, and - * returns the start/end nodes of the new procedure cfg. - * - * @param proc - * The procedure to clone (used to index the pre-computed cfgs) - * @return - * (CfgFunctionEntryNode, CfgFunctionExitNode) of the cloned cfg - */ - private def cloneProcedureCFG(proc: Procedure): (CfgFunctionEntryNode, CfgFunctionExitNode) = { - - val (entryNode: CfgFunctionEntryNode, exitNode: CfgFunctionExitNode) = procToCfg(proc) - val (newEntry: CfgFunctionEntryNode, newExit: CfgFunctionExitNode) = (entryNode.copyNode(), exitNode.copyNode()) - - callToNodes += (newEntry -> mutable.Set()) - - // Entry is guaranteed to only have one successor (by our cfg design) - val currNode: CfgNode = entryNode.succIntra.head - visitNode(currNode, newEntry) - - /** Walk this proc's cfg until we reach the exit node on each branch. We do this recursively, tracking the previous - * node, to account for branches and loops. - * - * We can't represent the parameters as an edge as one node comes from the old cfg, and the other from the new cfg. - * - * @param node - * Node in the original procedure's cfg we're up to cloning - * @param prevNewNode - * The originating node in the new clone's cfg - */ - def visitNode(node: CfgNode, prevNewNode: CfgNode): Unit = { - - if (node == exitNode) { - cfg.addEdge(prevNewNode, newExit) - return - } - - node match { - case n: CfgJumpNode => - val newNode = n.copyNode() - - // Link this node with predecessor in the new cfg - cfg.addEdge(prevNewNode, newNode) - - n.data match { - case d: DirectCall => - procToCalls(proc) += newNode - callToNodes(newEntry) += newNode - procToCallers(d.target) += newNode - case i: IndirectCall => - procToCalls(proc) += newNode - callToNodes(newEntry) += newNode - case _ => - } - - // Get intra-cfg successors - val outNodes = node.succIntra - outNodes.foreach(node => visitNode(node, newNode)) - - // For other node types, link with predecessor and continue traversal - case _ => - val newNode = node.copyNode() - cfg.addEdge(prevNewNode, newNode) - - val outNodes = node.succIntra - outNodes.foreach(node => visitNode(node, newNode)) - } - } - - (newEntry, newExit) - } - - /** After inlining has been done, we link all residual direct calls (leaf nodes) to the start of the intraprocedural - * that are the target of the call. - * - * @param leaves - * The call nodes at edge of intraprocedural cfgs to be linked to their targets - */ - private def addInterprocEdges(leaves: Set[CfgJumpNode]): Unit = { - - leaves.foreach { callNode => - callNode.data match { - case targetCall: DirectCall => - val targetProc: Procedure = targetCall.target - - // this does not add returns for any of the calls, so the interprocedural analysis will not work if any - // calls are not in-lined - val (targetEntry: CfgFunctionEntryNode, _) = procToCfg(targetProc) - - cfg.addInterprocCallEdge(callNode, targetEntry) - case _ => - } - } - } diff --git a/src/main/scala/analysis/Dependencies.scala b/src/main/scala/analysis/Dependencies.scala index 040861803..4b5e15106 100644 --- a/src/main/scala/analysis/Dependencies.scala +++ b/src/main/scala/analysis/Dependencies.scala @@ -22,26 +22,6 @@ trait Dependencies[N]: */ def indep(n: N): Set[N] -trait InterproceduralForwardDependencies extends Dependencies[CfgNode] { - override def outdep(n: CfgNode): Set[CfgNode] = n.succInter.toSet - override def indep(n: CfgNode): Set[CfgNode] = n.predInter.toSet -} - -trait IntraproceduralForwardDependencies extends Dependencies[CfgNode] { - override def outdep(n: CfgNode): Set[CfgNode] = n.succIntra.toSet - override def indep(n: CfgNode): Set[CfgNode] = n.predIntra.toSet -} - -trait InterproceduralBackwardDependencies extends Dependencies[CfgNode] { - override def outdep(n: CfgNode): Set[CfgNode] = n.predInter.toSet - override def indep(n: CfgNode): Set[CfgNode] = n.succInter.toSet -} - -trait IntraproceduralBackwardDependencies extends Dependencies[CfgNode] { - override def outdep(n: CfgNode): Set[CfgNode] = n.predIntra.toSet - override def indep(n: CfgNode): Set[CfgNode] = n.succIntra.toSet -} - trait IRInterproceduralForwardDependencies extends Dependencies[CFGPosition] { override def outdep(n: CFGPosition): Set[CFGPosition] = InterProcIRCursor.succ(n) override def indep(n: CFGPosition): Set[CFGPosition] = InterProcIRCursor.pred(n) diff --git a/src/main/scala/analysis/InterLiveVarsAnalysis.scala b/src/main/scala/analysis/InterLiveVarsAnalysis.scala index 7a93266e8..cbe3076c6 100644 --- a/src/main/scala/analysis/InterLiveVarsAnalysis.scala +++ b/src/main/scala/analysis/InterLiveVarsAnalysis.scala @@ -1,7 +1,7 @@ package analysis import analysis.solvers.BackwardIDESolver -import ir.{Assert, Assume, Block, GoTo, CFGPosition, Command, DirectCall, IndirectCall, Assign, MemoryAssign, Halt, Return, Procedure, Program, Variable, toShortString} +import ir.{Assert, Assume, Block, GoTo, CFGPosition, Command, DirectCall, IndirectCall, Assign, MemoryAssign, Unreachable, Return, Procedure, Program, Variable, toShortString} /** * Micro-transfer-functions for LiveVar analysis @@ -74,11 +74,7 @@ trait LiveVarsAnalysisFunctions extends BackwardIDEAnalysis[Variable, TwoElement d match case Left(value) => if value != variable then Map(d -> IdEdge()) else Map() case Right(_) => Map(d -> IdEdge(), Left(variable) -> ConstEdge(TwoElementTop)) - case r: Return => Map(d -> IdEdge()) - case h: Halt => Map(d -> IdEdge()) - case c: DirectCall => Map(d -> IdEdge()) - case c: Block => Map(d -> IdEdge()) - case c: GoTo => Map(d -> IdEdge()) + case _ => Map(d -> IdEdge()) } } diff --git a/src/main/scala/analysis/InterprocSteensgaardAnalysis.scala b/src/main/scala/analysis/InterprocSteensgaardAnalysis.scala index 38153e277..4e7ba81f1 100644 --- a/src/main/scala/analysis/InterprocSteensgaardAnalysis.scala +++ b/src/main/scala/analysis/InterprocSteensgaardAnalysis.scala @@ -39,7 +39,7 @@ case class RegisterWrapperEqualSets(variable: Variable, assigns: Set[Assign]) { class InterprocSteensgaardAnalysis( program: Program, constantProp: Map[CFGPosition, Map[RegisterWrapperEqualSets, Set[BitVecLiteral]]], - regionAccesses: Map[CfgNode, Map[RegisterVariableWrapper, FlatElement[Expr]]], + regionAccesses: Map[CFGPosition, Map[RegisterVariableWrapper, FlatElement[Expr]]], mmm: MemoryModelMap, reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], globalOffsets: Map[BigInt, BigInt]) extends Analysis[Any] { @@ -431,4 +431,4 @@ object Fresh { n += 1 n } -} \ No newline at end of file +} diff --git a/src/main/scala/analysis/IntraLiveVarsAnalysis.scala b/src/main/scala/analysis/IntraLiveVarsAnalysis.scala index 75fa1dbb0..a576b27fb 100644 --- a/src/main/scala/analysis/IntraLiveVarsAnalysis.scala +++ b/src/main/scala/analysis/IntraLiveVarsAnalysis.scala @@ -1,7 +1,7 @@ package analysis import analysis.solvers.SimpleWorklistFixpointSolver -import ir.{Assert, Assume, Block, CFGPosition, Call, DirectCall, GoTo, IndirectCall, Jump, Assign, MemoryAssign, NOP, Procedure, Program, Statement, Variable, Return, Halt} +import ir.{Assert, Assume, Block, CFGPosition, Call, DirectCall, GoTo, IndirectCall, Jump, Assign, MemoryAssign, NOP, Procedure, Program, Statement, Variable, Return, Unreachable} abstract class LivenessAnalysis(program: Program) extends Analysis[Any]: val lattice: MapLattice[CFGPosition, Set[Variable], PowersetLattice[Variable]] = MapLattice(PowersetLattice()) @@ -19,7 +19,7 @@ abstract class LivenessAnalysis(program: Program) extends Analysis[Any]: case c: DirectCall => s case g: GoTo => s case r: Return => s - case r: Halt => s + case r: Unreachable => s case _ => ??? } } diff --git a/src/main/scala/analysis/MemoryRegionAnalysis.scala b/src/main/scala/analysis/MemoryRegionAnalysis.scala index 65aa1cc20..c7d888a43 100644 --- a/src/main/scala/analysis/MemoryRegionAnalysis.scala +++ b/src/main/scala/analysis/MemoryRegionAnalysis.scala @@ -14,7 +14,7 @@ trait MemoryRegionAnalysis(val program: Program, val constantProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], val ANRResult: Map[CFGPosition, Set[Variable]], val RNAResult: Map[CFGPosition, Set[Variable]], - val regionAccesses: Map[CfgNode, Map[RegisterVariableWrapper, FlatElement[Expr]]], + val regionAccesses: Map[CFGPosition, Map[RegisterVariableWrapper, FlatElement[Expr]]], reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])]) { var mallocCount: Int = 0 @@ -234,7 +234,7 @@ class MemoryRegionAnalysisSolver( constantProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], ANRResult: Map[CFGPosition, Set[Variable]], RNAResult: Map[CFGPosition, Set[Variable]], - regionAccesses: Map[CfgNode, Map[RegisterVariableWrapper, FlatElement[Expr]]], + regionAccesses: Map[CFGPosition, Map[RegisterVariableWrapper, FlatElement[Expr]]], reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])] ) extends MemoryRegionAnalysis(program, globals, globalOffsets, subroutines, constantProp, ANRResult, RNAResult, regionAccesses, reachingDefs) with IRIntraproceduralForwardDependencies @@ -249,4 +249,4 @@ class MemoryRegionAnalysisSolver( case _ => super.funsub(n, x) } } -} \ No newline at end of file +} diff --git a/src/main/scala/analysis/RegToMemAnalysis.scala b/src/main/scala/analysis/RegToMemAnalysis.scala index df7217f75..1afde6f4d 100644 --- a/src/main/scala/analysis/RegToMemAnalysis.scala +++ b/src/main/scala/analysis/RegToMemAnalysis.scala @@ -15,29 +15,29 @@ import scala.collection.immutable * * Both in which constant propagation mark as TOP which is not useful. */ -trait RegionAccessesAnalysis(cfg: ProgramCfg, constantProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])]) { +trait RegionAccessesAnalysis(program: Program, constantProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])]) { val mapLattice: MapLattice[RegisterVariableWrapper, FlatElement[Expr], FlatLattice[Expr]] = MapLattice(FlatLattice[_root_.ir.Expr]()) - val lattice: MapLattice[CfgNode, Map[RegisterVariableWrapper, FlatElement[Expr]], MapLattice[RegisterVariableWrapper, FlatElement[Expr], FlatLattice[Expr]]] = MapLattice(mapLattice) + val lattice: MapLattice[CFGPosition, Map[RegisterVariableWrapper, FlatElement[Expr]], MapLattice[RegisterVariableWrapper, FlatElement[Expr], FlatLattice[Expr]]] = MapLattice(mapLattice) - val domain: Set[CfgNode] = cfg.nodes.toSet + val domain: Set[CFGPosition] = program.toSet - val first: Set[CfgNode] = Set(cfg.startNode) + val first: Set[CFGPosition] = program.procedures.toSet /** Default implementation of eval. */ - def eval(cmd: CfgCommandNode, constants: Map[Variable, FlatElement[BitVecLiteral]], s: Map[RegisterVariableWrapper, FlatElement[Expr]]): Map[RegisterVariableWrapper, FlatElement[Expr]] = { - cmd.data match { + def eval(cmd: Statement, constants: Map[Variable, FlatElement[BitVecLiteral]], s: Map[RegisterVariableWrapper, FlatElement[Expr]]): Map[RegisterVariableWrapper, FlatElement[Expr]] = { + cmd match { case assign: Assign => assign.rhs match { case memoryLoad: MemoryLoad => - s + (RegisterVariableWrapper(assign.lhs, getDefinition(assign.lhs, cmd.data, reachingDefs)) -> FlatEl(memoryLoad)) + s + (RegisterVariableWrapper(assign.lhs, getDefinition(assign.lhs, cmd, reachingDefs)) -> FlatEl(memoryLoad)) case binaryExpr: BinaryExpr => if (evaluateExpression(binaryExpr.arg1, constants).isEmpty) { // approximates Base + Offset Logger.debug(s"Approximating $assign in $binaryExpr") - Logger.debug(s"Reaching defs: ${reachingDefs(cmd.data)}") - s + (RegisterVariableWrapper(assign.lhs, getDefinition(assign.lhs, cmd.data, reachingDefs)) -> FlatEl(binaryExpr)) + Logger.debug(s"Reaching defs: ${reachingDefs(cmd)}") + s + (RegisterVariableWrapper(assign.lhs, getDefinition(assign.lhs, cmd, reachingDefs)) -> FlatEl(binaryExpr)) } else { s } @@ -50,23 +50,23 @@ trait RegionAccessesAnalysis(cfg: ProgramCfg, constantProp: Map[CFGPosition, Map /** Transfer function for state lattice elements. */ - def localTransfer(n: CfgNode, s: Map[RegisterVariableWrapper, FlatElement[Expr]]): Map[RegisterVariableWrapper, FlatElement[Expr]] = n match { - case cmd: CfgCommandNode => - eval(cmd, constantProp(cmd.data), s) + def localTransfer(n: CFGPosition, s: Map[RegisterVariableWrapper, FlatElement[Expr]]): Map[RegisterVariableWrapper, FlatElement[Expr]] = n match { + case cmd: Statement => + eval(cmd, constantProp(cmd), s) case _ => s // ignore other kinds of nodes } /** Transfer function for state lattice elements. */ - def transfer(n: CfgNode, s: Map[RegisterVariableWrapper, FlatElement[Expr]]): Map[RegisterVariableWrapper, FlatElement[Expr]] = localTransfer(n, s) + def transfer(n: CFGPosition, s: Map[RegisterVariableWrapper, FlatElement[Expr]]): Map[RegisterVariableWrapper, FlatElement[Expr]] = localTransfer(n, s) } class RegionAccessesAnalysisSolver( - cfg: ProgramCfg, + program: Program, constantProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], - ) extends RegionAccessesAnalysis(cfg, constantProp, reachingDefs) - with InterproceduralForwardDependencies - with Analysis[Map[CfgNode, Map[RegisterVariableWrapper, FlatElement[Expr]]]] - with SimpleWorklistFixpointSolver[CfgNode, Map[RegisterVariableWrapper, FlatElement[Expr]], MapLattice[RegisterVariableWrapper, FlatElement[Expr], FlatLattice[Expr]]] { -} \ No newline at end of file + ) extends RegionAccessesAnalysis(program, constantProp, reachingDefs) + with IRInterproceduralForwardDependencies + with Analysis[Map[CFGPosition, Map[RegisterVariableWrapper, FlatElement[Expr]]]] + with SimpleWorklistFixpointSolver[CFGPosition, Map[RegisterVariableWrapper, FlatElement[Expr]], MapLattice[RegisterVariableWrapper, FlatElement[Expr], FlatLattice[Expr]]] { +} diff --git a/src/main/scala/analysis/solvers/IDESolver.scala b/src/main/scala/analysis/solvers/IDESolver.scala index 7a581dbe4..030017434 100644 --- a/src/main/scala/analysis/solvers/IDESolver.scala +++ b/src/main/scala/analysis/solvers/IDESolver.scala @@ -1,7 +1,7 @@ package analysis.solvers import analysis.{BackwardIDEAnalysis, Dependencies, EdgeFunction, EdgeFunctionLattice, ForwardIDEAnalysis, IDEAnalysis, IRInterproceduralBackwardDependencies, IRInterproceduralForwardDependencies, Lambda, Lattice, MapLattice} -import ir.{CFGPosition, Command, DirectCall, GoTo, IRWalk, IndirectCall, Return, InterProcIRCursor, Procedure, Program, isAfterCall, Halt, Statement, Jump} +import ir.{CFGPosition, Command, DirectCall, GoTo, IRWalk, IndirectCall, Return, InterProcIRCursor, Procedure, Program, isAfterCall, Unreachable, Statement, Jump} import util.Logger import scala.collection.immutable.Map @@ -209,7 +209,7 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) protected def entryToExit(entry: Procedure): Return = IRWalk.lastInProc(entry).asInstanceOf[Return] - protected def exitToEntry(exit: IndirectCall): Procedure = IRWalk.procedure(exit) + protected def exitToEntry(exit: Return): Procedure = IRWalk.procedure(exit) protected def callToReturn(call: DirectCall): Command = call.successor @@ -225,13 +225,13 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) protected def isCall(call: CFGPosition): Boolean = call match - case directCall: DirectCall if (!directCall.successor.isInstanceOf[Halt]) => true + case directCall: DirectCall if (!directCall.successor.isInstanceOf[Unreachable]) => true case _ => false protected def isExit(exit: CFGPosition): Boolean = exit match // only looking at functions with statements - case command: Command => IRWalk.lastInProc(IRWalk.procedure(command)) == command + case command: Return => true case _ => false protected def getAfterCalls(exit: IndirectCall): Set[Command] = @@ -264,12 +264,19 @@ abstract class BackwardIDESolver[D, T, L <: Lattice[T]](program: Program) protected def isCall(call: CFGPosition): Boolean = call match - case c : Command => isAfterCall(c) && IRWalk.prevCommandInBlock(c).map(_.isInstanceOf[DirectCall]).getOrElse(false) + case c: Unreachable => false /* don't process non-returning calls */ + case c : Command => { + val call = IRWalk.prevCommandInBlock(c) + call match { + case Some(d: DirectCall) if d.target.returnBlock.isDefined => true + case _ => false + } + } case _ => false protected def isExit(exit: CFGPosition): Boolean = exit match - case procedure: Procedure => true + case procedure: Procedure => procedure.blocks.nonEmpty case _ => false protected def getAfterCalls(exit: Procedure): Set[DirectCall] = exit.incomingCalls().toSet diff --git a/src/main/scala/cfg_visualiser/Output.scala b/src/main/scala/cfg_visualiser/Output.scala deleted file mode 100644 index cc730d62e..000000000 --- a/src/main/scala/cfg_visualiser/Output.scala +++ /dev/null @@ -1,37 +0,0 @@ -package cfg_visualiser - -import java.io.{File, PrintWriter} -import analysis._ - -/** Basic outputting functionality. - */ -object Output { - - /** Helper function for producing string output for a control-flow graph node after an analysis. - * @param res - * map from control-flow graph nodes to strings, as produced by the analysis - */ - def labeler(res: Map[CfgNode, _], stateAfterNode: Boolean)(n: CfgNode): String = { - val r = res.getOrElse(n, "-") - val desc = n match { - case entry: CfgFunctionEntryNode => s"Function ${entry.data.name} entry" - case exit: CfgFunctionExitNode => s"Function ${exit.data.name} exit" - case _ => n.toString + s" ${n.id}" - } - if (stateAfterNode) s"$desc\n$r" - else s"$r\n$desc" - } - - /** Generate an unique ID string for the given AST node. - */ - def dotIder(n: CfgNode, uniqueId: Int): String = - n match { - case real: CfgCommandNode => s"real${real.data}_$uniqueId" - case entry: CfgFunctionEntryNode => s"entry${entry.data}_$uniqueId" - case exit: CfgFunctionExitNode => s"exit${exit.data}_$uniqueId" - case ret: CfgProcedureReturnNode => s"return_$uniqueId" - case noCallRet: CfgCallNoReturnNode => s"callnoreturn_$uniqueId" - case callRet: CfgCallReturnNode => s"callreturn_$uniqueId" - case _ => ??? - } -} \ No newline at end of file diff --git a/src/main/scala/ir/IRCursor.scala b/src/main/scala/ir/IRCursor.scala index 690b57dea..27b558ccf 100644 --- a/src/main/scala/ir/IRCursor.scala +++ b/src/main/scala/ir/IRCursor.scala @@ -96,7 +96,7 @@ trait IntraProcIRCursor extends IRWalk[CFGPosition, CFGPosition] { case proc: Procedure => proc.entryBlock.toSet case b: Block => b.statements.headOption.orElse(Some(b.jump)).toSet case n: GoTo => n.targets.asInstanceOf[Set[CFGPosition]] - case h: Halt => Set() + case h: Unreachable => Set() case h: Return => Set() case c: Statement => IRWalk.nextCommandInBlock(c).toSet } @@ -150,7 +150,6 @@ trait InterProcIRCursor extends IRWalk[CFGPosition, CFGPosition] { IntraProcIRCursor.succ(pos) ++ (pos match case c: DirectCall if c.target.blocks.nonEmpty => Set(c.target) - // case c: IndirectCall if c.parent.isProcReturn => c.parent.parent.incomingCalls().map(_.successor).toSet case c: Return => c.parent.parent.incomingCalls().map(_.successor).toSet case _ => Set.empty ) @@ -159,7 +158,13 @@ trait InterProcIRCursor extends IRWalk[CFGPosition, CFGPosition] { final def pred(pos: CFGPosition): Set[CFGPosition] = { IntraProcIRCursor.pred(pos) ++ (pos match - case d: DirectCall if d.target.blocks.nonEmpty => d.target.returnBlock.toSet + case c: Command => { + IRWalk.prevCommandInBlock(c) match { + case Some(d: DirectCall) if d.target.blocks.nonEmpty => d.target.returnBlock.toSet + case o => o.toSet + } + + } case c: Procedure => c.incomingCalls().toSet.asInstanceOf[Set[CFGPosition]] case _ => Set.empty ) diff --git a/src/main/scala/ir/Interpreter.scala b/src/main/scala/ir/Interpreter.scala index de470d5d9..0430ef66a 100644 --- a/src/main/scala/ir/Interpreter.scala +++ b/src/main/scala/ir/Interpreter.scala @@ -248,8 +248,8 @@ class Interpreter() { case r: Return => { nextCmd = Some(returnCmd.pop()) } - case h: Halt => { - Logger.debug("Halt") + case h: Unreachable => { + Logger.debug("Unreachable") nextCmd = None } } diff --git a/src/main/scala/ir/Program.scala b/src/main/scala/ir/Program.scala index ed61cba4f..54916b27c 100644 --- a/src/main/scala/ir/Program.scala +++ b/src/main/scala/ir/Program.scala @@ -18,40 +18,6 @@ class Program(var procedures: ArrayBuffer[Procedure], serialiseIL(this) } - // This shouldn't be run before indirect calls are resolved - def stripUnreachableFunctions(depth: Int = Int.MaxValue): Unit = { - val procedureCalleeNames = procedures.map(f => f.name -> f.calls.map(_.name)).toMap - - val toVisit: mutable.LinkedHashSet[(Int, String)] = mutable.LinkedHashSet((0, mainProcedure.name)) - var reachableFound = true - val reachableNames = mutable.HashMap[String, Int]() - while (toVisit.nonEmpty) { - val next = toVisit.head - toVisit.remove(next) - - if (next._1 <= depth) { - - def addName(depth: Int, name: String): Unit = { - val oldDepth = reachableNames.getOrElse(name, Integer.MAX_VALUE) - reachableNames.put(next._2, if depth < oldDepth then depth else oldDepth) - } - addName(next._1, next._2) - - val callees = procedureCalleeNames(next._2) - - toVisit.addAll(callees.diff(reachableNames.keySet).map(c => (next._1 + 1, c))) - callees.foreach(c => addName(next._1 + 1, c)) - } - } - procedures = procedures.filter(f => reachableNames.keySet.contains(f.name)) - - for (elem <- procedures.filter(c => c.calls.exists(s => !procedures.contains(s)))) { - // last layer is analysed only as specifications so we remove the body for anything that calls - // a function we have removed - - elem.clearBlocks() - } - } def setModifies(specModifies: Map[String, List[String]]): Unit = { val procToCalls: mutable.Map[Procedure, Set[Procedure]] = mutable.Map() @@ -229,7 +195,6 @@ class Procedure private ( } def addBlocks(block: Block): Block = { - block.parent = this if (!_blocks.contains(block)) { block.parent = this _blocks.add(block) @@ -318,7 +283,8 @@ class Procedure private ( def clearBlocks(): Unit = { // O(n) because we are careful to unlink the parents etc. - removeBlocks(_blocks) + // .toList to avoid modifying our own iterator + removeBlocksDisconnect(_blocks.toList) } def callers(): Iterable[Procedure] = _callers.map(_.parent.parent).toSet[Procedure] @@ -369,6 +335,7 @@ class Block private ( this(label, address, IntrusiveList().addAll(statements), jump, mutable.HashSet.empty) } + def isReturn: Boolean = parent.returnBlock.contains(this) def isEntry: Boolean = parent.entryBlock.contains(this) def jump: Jump = _jump diff --git a/src/main/scala/ir/Statement.scala b/src/main/scala/ir/Statement.scala index 2dea68f46..ce49bc82e 100644 --- a/src/main/scala/ir/Statement.scala +++ b/src/main/scala/ir/Statement.scala @@ -82,7 +82,7 @@ sealed trait Jump extends Command { def acceptVisit(visitor: Visitor): Jump = throw new Exception("visitor " + visitor + " unimplemented for: " + this) } -class Halt(override val label: Option[String] = None) extends Jump { +class Unreachable(override val label: Option[String] = None) extends Jump { /* Terminate / No successors / assume false */ override def acceptVisit(visitor: Visitor): Jump = this } @@ -139,7 +139,7 @@ object GoTo: sealed trait Call extends Statement { def returnTarget: Option[Command] = successor match { - case h: Halt => None + case h: Unreachable => None case o => Some(o) } } diff --git a/src/main/scala/ir/dsl/DSL.scala b/src/main/scala/ir/dsl/DSL.scala index 6a1b96742..3ebeefbc4 100644 --- a/src/main/scala/ir/dsl/DSL.scala +++ b/src/main/scala/ir/dsl/DSL.scala @@ -73,8 +73,8 @@ case class EventuallyGoto(targets: List[DelayNameResolve]) extends EventuallyJum case class EventuallyReturn() extends EventuallyJump { override def resolve(p: Program) = Return() } -case class EventuallyHalt() extends EventuallyJump { - override def resolve(p: Program) = Halt() +case class EventuallyUnreachable() extends EventuallyJump { + override def resolve(p: Program) = Unreachable() } def goto(): EventuallyGoto = EventuallyGoto(List.empty) @@ -84,7 +84,7 @@ def goto(targets: String*): EventuallyGoto = { } def ret: EventuallyReturn = EventuallyReturn() -def halt: EventuallyHalt= EventuallyHalt() +def unreachable: EventuallyUnreachable= EventuallyUnreachable() def goto(targets: List[String]): EventuallyGoto = { EventuallyGoto(targets.map(p => DelayNameResolve(p))) @@ -111,8 +111,6 @@ def block(label: String, sl: (Statement | EventuallyStatement | EventuallyJump)* val statements : Seq[EventuallyStatement] = sl.flatMap { case s: Statement => Some(ResolvableStatement(s)) case o: EventuallyStatement => Some(o) - case o: EventuallyCall => Some(o) - case o: EventuallyIndirectCall => Some(o) case g: EventuallyJump => None } val jump = sl.collectFirst { diff --git a/src/main/scala/ir/transforms/IndirectCallResolution.scala b/src/main/scala/ir/transforms/IndirectCallResolution.scala index e9c9fddf7..4345dcaa1 100644 --- a/src/main/scala/ir/transforms/IndirectCallResolution.scala +++ b/src/main/scala/ir/transforms/IndirectCallResolution.scala @@ -1,7 +1,5 @@ package ir.transforms - - import scala.collection.mutable.ListBuffer import scala.collection.mutable.ArrayBuffer import analysis.solvers.* @@ -11,165 +9,27 @@ import ir.* import translating.* import util.Logger import util.intrusive_list.IntrusiveList -import analysis.CfgCommandNode import scala.collection.mutable import cilvisitor._ - -/** Resolve indirect calls to an address-conditional choice between direct calls using the Value Set Analysis results. - * Dead code, and currently broken by statement calls - * -def resolveIndirectCalls( - cfg: ProgramCfg, - valueSets: Map[CfgNode, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]], +def resolveIndirectCallsUsingPointsTo( + pointsTos: Map[RegisterVariableWrapper, Set[RegisterVariableWrapper | MemoryRegion]], + regionContents: Map[MemoryRegion, Set[BitVecLiteral | MemoryRegion]], + reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], IRProgram: Program ): Boolean = { var modified: Boolean = false - val worklist = ListBuffer[CfgNode]() - cfg.startNode.succIntra.union(cfg.startNode.succInter).foreach(node => worklist.addOne(node)) - - val visited = mutable.Set[CfgNode]() - while (worklist.nonEmpty) { - val node = worklist.remove(0) - if (!visited.contains(node)) { - process(node) - node.succIntra.union(node.succInter).foreach(node => worklist.addOne(node)) - visited.add(node) - } - } - - def process(n: CfgNode): Unit = n match { - /* - case c: CfgStatementNode => - c.data match + val worklist = ListBuffer[CFGPosition]() - //We do not want to insert the VSA results into the IR like this - case localAssign: Assign => - localAssign.rhs match - case _: MemoryLoad => - if (valueSets(n).contains(localAssign.lhs) && valueSets(n).get(localAssign.lhs).head.size == 1) { - val extractedValue = extractExprFromValue(valueSets(n).get(localAssign.lhs).head.head) - localAssign.rhs = extractedValue - Logger.info(s"RESOLVED: Memory load ${localAssign.lhs} resolved to ${extractedValue}") - } else if (valueSets(n).contains(localAssign.lhs) && valueSets(n).get(localAssign.lhs).head.size > 1) { - Logger.info(s"RESOLVED: WARN Memory load ${localAssign.lhs} resolved to multiple values, cannot replace") - - /* - // must merge into a single memory variable to represent the possible values - // Make a binary OR of all the possible values takes two at a time (incorrect to do BVOR) - val values = valueSets(n).get(localAssign.lhs).head - val exprValues = values.map(extractExprFromValue) - val result = exprValues.reduce((a, b) => BinaryExpr(BVOR, a, b)) // need to express nondeterministic - // choice between these specific options - localAssign.rhs = result - */ - } - case _ => - */ - case c: CfgJumpNode => - val block = c.block - c.data match - case indirectCall: IndirectCall => - if (block.jump != indirectCall) { - // We only replace the calls with DirectCalls in the IR, and don't replace the CommandNode.data - // Hence if we have already processed this CFG node there will be no corresponding IndirectCall in the IR - // to replace. - // We want to replace all possible indirect calls based on this CFG, before regenerating it from the IR - return - } - valueSets(n) match { - case Lift(valueSet) => - val targetNames = resolveAddresses(valueSet(indirectCall.target)).map(_.name).toList.sorted - val targets = targetNames.map(name => IRProgram.procedures.filter(_.name.equals(name)).head) - - if (targets.size == 1) { - modified = true - - // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) - val newCall = DirectCall(targets.head, indirectCall.label) - block.statements.replace(indirectCall, newCall) - } else if (targets.size > 1) { - modified = true - val procedure = c.parent.data - val newBlocks = ArrayBuffer[Block]() - for (t <- targets) { - val assume = Assume(BinaryExpr(BVEQ, indirectCall.target, BitVecLiteral(t.address.get, 64))) - val newLabel: String = block.label + t.name - val directCall = DirectCall(t) - directCall.parent = indirectCall.parent - - // assume indircall is the last statement in block - assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) - val fallthrough = indirectCall.parent.jump - - newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) - } - procedure.addBlocks(newBlocks) - val newCall = GoTo(newBlocks, indirectCall.label) - block.replaceJump(newCall) - } - case LiftedBottom => - } - case _ => - case _ => - } - - def nameExists(name: String): Boolean = { - IRProgram.procedures.exists(_.name.equals(name)) - } - - def addFakeProcedure(name: String): Unit = { - IRProgram.procedures += Procedure(name) - } + worklist.addAll(IRProgram) - def resolveAddresses(valueSet: Set[Value]): Set[AddressValue] = { - var functionNames: Set[AddressValue] = Set() - valueSet.foreach { - case globalAddress: GlobalAddress => - if (nameExists(globalAddress.name)) { - functionNames += globalAddress - Logger.info(s"RESOLVED: Call to Global address ${globalAddress.name} rt statuesolved.") - } else { - addFakeProcedure(globalAddress.name) - functionNames += globalAddress - Logger.info(s"Global address ${globalAddress.name} does not exist in the program. Added a fake function.") - } - case localAddress: LocalAddress => - if (nameExists(localAddress.name)) { - functionNames += localAddress - Logger.info(s"RESOLVED: Call to Local address ${localAddress.name}") - } else { - addFakeProcedure(localAddress.name) - functionNames += localAddress - Logger.info(s"Local address ${localAddress.name} does not exist in the program. Added a fake function.") - } - case _ => - } - functionNames - } - - modified -} - - */ - -def resolveIndirectCallsUsingPointsTo( - cfg: ProgramCfg, - pointsTos: Map[RegisterVariableWrapper, Set[RegisterVariableWrapper | MemoryRegion]], - regionContents: Map[MemoryRegion, Set[BitVecLiteral | MemoryRegion]], - reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], - IRProgram: Program - ): Boolean = { - var modified: Boolean = false - val worklist = ListBuffer[CfgNode]() - cfg.startNode.succIntra.union(cfg.startNode.succInter).foreach(node => worklist.addOne(node)) - - val visited = mutable.Set[CfgNode]() + val visited = mutable.Set[CFGPosition]() while (worklist.nonEmpty) { val node = worklist.remove(0) if (!visited.contains(node)) { + // add to worklist before we delete the node and can no longer find its successors + InterProcIRCursor.succ(node).foreach(node => worklist.addOne(node)) process(node) - node.succIntra.union(node.succInter).foreach(node => worklist.addOne(node)) visited.add(node) } } @@ -181,7 +41,7 @@ def resolveIndirectCallsUsingPointsTo( if (regionContents.contains(stackRegion)) { for (c <- regionContents(stackRegion)) { c match { - case bitVecLiteral: BitVecLiteral => Logger.debug("hi: " + bitVecLiteral)//??? + case bitVecLiteral: BitVecLiteral => Logger.debug("hi: " + bitVecLiteral) //??? case memoryRegion: MemoryRegion => result.addAll(searchRegion(memoryRegion)) } @@ -195,7 +55,7 @@ def resolveIndirectCallsUsingPointsTo( result.add(dataRegion.regionIdentifier) // TODO: may need to investigate if we should add the parent region for (c <- regionContents(dataRegion)) { c match { - case bitVecLiteral: BitVecLiteral => Logger.debug("hi: " + bitVecLiteral)//??? + case bitVecLiteral: BitVecLiteral => Logger.debug("hi: " + bitVecLiteral) //??? case memoryRegion: MemoryRegion => result.addAll(searchRegion(memoryRegion)) } @@ -218,76 +78,70 @@ def resolveIndirectCallsUsingPointsTo( case Some(value) => value.map { case v: RegisterVariableWrapper => names.addAll(resolveAddresses(v.variable, i)) - case m: MemoryRegion => names.addAll(searchRegion(m)) + case m: MemoryRegion => names.addAll(searchRegion(m)) } names case None => names } } - def process(n: CfgNode): Unit = n match { - case c: CfgJumpNode => - val block = c.block - c.data match - // don't try to resolve returns - case indirectCall: IndirectCall if indirectCall.target != Register("R30", 64) => - if (!indirectCall.hasParent) { - // We only replace the calls with DirectCalls in the IR, and don't replace the CommandNode.data - // Hence if we have already processed this CFG node there will be no corresponding IndirectCall in the IR - // to replace. - // We want to replace all possible indirect calls based on this CFG, before regenerating it from the IR - return - } - assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) - - val targetNames = resolveAddresses(indirectCall.target, indirectCall) - Logger.debug(s"Points-To approximated call ${indirectCall.target} with $targetNames") - Logger.debug(IRProgram.procedures) - val targets: mutable.Set[Procedure] = targetNames.map(name => IRProgram.procedures.find(_.name == name).getOrElse(addFakeProcedure(name))) - - if (targets.size > 1) { - Logger.info(s"Resolved indirect call $indirectCall") + def process(n: CFGPosition): Unit = n match { + case indirectCall: IndirectCall if indirectCall.target != Register("R30", 64) => + if (!indirectCall.hasParent) { + // skip if we have already processesd this call + return + } + // we need the single-call-at-end-of-block invariant + assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) + + val block = indirectCall.parent + val procedure = block.parent + + val targetNames = resolveAddresses(indirectCall.target, indirectCall) + Logger.debug(s"Points-To approximated call ${indirectCall.target} with $targetNames") + Logger.debug(IRProgram.procedures) + val targets: mutable.Set[Procedure] = + targetNames.map(name => IRProgram.procedures.find(_.name == name).getOrElse(addFakeProcedure(name))) + + if (targets.size > 1) { + Logger.info(s"Resolved indirect call $indirectCall") + } + + if (targets.size == 1) { + modified = true + + val newCall = DirectCall(targets.head, indirectCall.label) + block.statements.replace(indirectCall, newCall) + } else if (targets.size > 1) { + + val oft = indirectCall.parent.jump + + modified = true + val newBlocks = ArrayBuffer[Block]() + for (t <- targets) { + Logger.debug(targets) + val address = t.address.match { + case Some(a) => a + case None => + throw Exception(s"resolved indirect call $indirectCall to procedure which does not have address: $t") } - - - if (targets.size == 1) { - modified = true - - // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) - val newCall = DirectCall(targets.head, indirectCall.label) - block.statements.replace(indirectCall, newCall) - } else if (targets.size > 1) { - - val oft = indirectCall.parent.jump - - modified = true - val procedure = c.parent.data - val newBlocks = ArrayBuffer[Block]() - // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) - for (t <- targets) { - Logger.debug(targets) - val address = t.address.match { - case Some(a) => a - case None => throw Exception(s"resolved indirect call $indirectCall to procedure which does not have address: $t") - } - val assume = Assume(BinaryExpr(BVEQ, indirectCall.target, BitVecLiteral(address, 64))) - val newLabel: String = block.label + t.name - val directCall = DirectCall(t) - - /* copy the goto node resulting */ - val fallthrough = oft match { - case g: GoTo => GoTo(g.targets, g.label) - case h: Halt => Halt() - case r: Return => Return() - } - newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) - } - block.statements.remove(indirectCall) - procedure.addBlocks(newBlocks) - val newCall = GoTo(newBlocks, indirectCall.label) - block.replaceJump(newCall) + val assume = Assume(BinaryExpr(BVEQ, indirectCall.target, BitVecLiteral(address, 64))) + val newLabel: String = block.label + t.name + val directCall = DirectCall(t) + + /* copy the goto node resulting */ + val fallthrough = oft match { + case g: GoTo => GoTo(g.targets, g.label) + case h: Unreachable => Unreachable() + case r: Return => Return() } - case _ => + newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) + } + block.statements.remove(indirectCall) + procedure.addBlocks(newBlocks) + val newCall = GoTo(newBlocks, indirectCall.label) + block.replaceJump(newCall) + } case _ => } diff --git a/src/main/scala/ir/transforms/ReplaceReturn.scala b/src/main/scala/ir/transforms/ReplaceReturn.scala index 91681ab8e..f41c25e3d 100644 --- a/src/main/scala/ir/transforms/ReplaceReturn.scala +++ b/src/main/scala/ir/transforms/ReplaceReturn.scala @@ -13,7 +13,7 @@ class ReplaceReturns extends CILVisitor { j match { case IndirectCall(Register("R30", _), _) => { assert(j.parent.statements.lastOption.contains(j)) - if (j.parent.jump.isInstanceOf[Halt | Return]) { + if (j.parent.jump.isInstanceOf[Unreachable | Return]) { j.parent.replaceJump(Return()) ChangeTo(List()) } else { @@ -32,7 +32,6 @@ def addReturnBlocks(p: Program, toAll: Boolean = false) = { p.procedures.foreach(p => { val containsReturn = p.blocks.map(_.jump).find(_.isInstanceOf[Return]).isDefined if (toAll && p.blocks.isEmpty && p.entryBlock.isEmpty && p.returnBlock.isEmpty) { - Logger.info(s"proc ${p.name} ${p.entryBlock}, ${p.returnBlock}") p.returnBlock = (Block(label=p.name + "_basil_return",jump=Return())) p.entryBlock = (Block(label=p.name + "_basil_entry",jump=GoTo(p.returnBlock.get))) } else if (p.returnBlock.isEmpty && (toAll || containsReturn)) { diff --git a/src/main/scala/ir/transforms/SplitThreads.scala b/src/main/scala/ir/transforms/SplitThreads.scala index 7f36fdb3c..8678720e7 100644 --- a/src/main/scala/ir/transforms/SplitThreads.scala +++ b/src/main/scala/ir/transforms/SplitThreads.scala @@ -11,7 +11,6 @@ import util.Logger import java.util.Base64 import spray.json.DefaultJsonProtocol.* import util.intrusive_list.IntrusiveList -import analysis.CfgCommandNode import scala.collection.mutable import cilvisitor._ diff --git a/src/main/scala/ir/transforms/StripUnreachableFunctions.scala b/src/main/scala/ir/transforms/StripUnreachableFunctions.scala new file mode 100644 index 000000000..96784cc28 --- /dev/null +++ b/src/main/scala/ir/transforms/StripUnreachableFunctions.scala @@ -0,0 +1,38 @@ +package ir.transforms +import ir._ +import collection.mutable + +// This shouldn't be run before indirect calls are resolved +def stripUnreachableFunctions(p: Program, depth: Int = Int.MaxValue): Unit = { + val procedureCalleeNames = p.procedures.map(f => f.name -> f.calls.map(_.name)).toMap + + val toVisit: mutable.LinkedHashSet[(Int, String)] = mutable.LinkedHashSet((0, p.mainProcedure.name)) + var reachableFound = true + val reachableNames = mutable.HashMap[String, Int]() + while (toVisit.nonEmpty) { + val next = toVisit.head + toVisit.remove(next) + + if (next._1 <= depth) { + + def addName(depth: Int, name: String): Unit = { + val oldDepth = reachableNames.getOrElse(name, Integer.MAX_VALUE) + reachableNames.put(next._2, if depth < oldDepth then depth else oldDepth) + } + addName(next._1, next._2) + + val callees = procedureCalleeNames(next._2) + + toVisit.addAll(callees.diff(reachableNames.keySet).map(c => (next._1 + 1, c))) + callees.foreach(c => addName(next._1 + 1, c)) + } + } + p.procedures = p.procedures.filter(f => reachableNames.keySet.contains(f.name)) + + for (elem <- p.procedures.filter(c => c.calls.exists(s => !p.procedures.contains(s)))) { + // last layer is analysed only as specifications so we remove the body for anything that calls + // a function we have removed + + elem.clearBlocks() + } +} diff --git a/src/main/scala/translating/BAPToIR.scala b/src/main/scala/translating/BAPToIR.scala index 90ba61335..93a5010cf 100644 --- a/src/main/scala/translating/BAPToIR.scala +++ b/src/main/scala/translating/BAPToIR.scala @@ -136,11 +136,11 @@ class BAPToIR(var program: BAPProgram, mainAddress: Int) { jumps.head match { case b: BAPDirectCall => val call = Some(DirectCall(nameToProcedure(b.target),Some(b.line))) - val ft = (b.returnTarget.map(t => labelToBlock(t))).map(x => GoTo(Set(x))).getOrElse(Halt()) + val ft = (b.returnTarget.map(t => labelToBlock(t))).map(x => GoTo(Set(x))).getOrElse(Unreachable()) (call, ft, ArrayBuffer()) case b: BAPIndirectCall => val call = IndirectCall(b.target.toIR, Some(b.line)) - val ft = (b.returnTarget.map(t => labelToBlock(t))).map(x => GoTo(Set(x))).getOrElse(Halt()) + val ft = (b.returnTarget.map(t => labelToBlock(t))).map(x => GoTo(Set(x))).getOrElse(Unreachable()) (Some(call), ft, ArrayBuffer()) case b: BAPGoTo => val target = labelToBlock(b.target) diff --git a/src/main/scala/translating/GTIRBToIR.scala b/src/main/scala/translating/GTIRBToIR.scala index ad090eb90..b1483343d 100644 --- a/src/main/scala/translating/GTIRBToIR.scala +++ b/src/main/scala/translating/GTIRBToIR.scala @@ -364,7 +364,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ // need to copy jump as it can't have multiple parents val jumpCopy = currentBlock.jump match { case GoTo(targets, label) => GoTo(targets, label) - case h: Halt => Halt() + case h: Unreachable => Unreachable() case r: Return => Return() case _ => throw Exception("this shouldn't be reachable") } @@ -392,7 +392,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ case _ => throw Exception(s"no assignment to program counter found before indirect call in block ${block.label}") } block.statements.remove(block.statements.last) // remove _PC assignment - (Some(IndirectCall(target)), Halt()) + (Some(IndirectCall(target)), Unreachable()) } else if (proxySymbols.size > 1) { // TODO requires further consideration once encountered throw Exception(s"multiple uuidToSymbol ${proxySymbols.map(_.name).mkString(", ")} associated with proxy block ${byteStringToString(edge.targetUuid)}, target of indirect call from block ${block.label}") @@ -408,7 +408,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ proc } removePCAssign(block) - (Some(DirectCall(target)), Halt()) + (Some(DirectCall(target)), Unreachable()) } } else if (uuidToBlock.contains(edge.targetUuid)) { // resolved indirect jump @@ -428,7 +428,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ val jump = if (procedure == targetProc) { (None, GoTo(mutable.Set(uuidToBlock(edge.targetUuid)))) } else { - (Some(DirectCall(targetProc)), Halt()) + (Some(DirectCall(targetProc)), Unreachable()) } removePCAssign(block) jump @@ -450,7 +450,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ // probably doesn't actually happen in practice since it seems to be after brk instructions? val targetProc = entranceUUIDtoProcedure(edge.targetUuid) // assuming fallthrough won't fall through to start of own procedure - (Some(DirectCall(targetProc)), Halt()) + (Some(DirectCall(targetProc)), Unreachable()) } else if (uuidToBlock.contains(edge.targetUuid)) { val target = uuidToBlock(edge.targetUuid) (None, GoTo(mutable.Set(target))) @@ -463,7 +463,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ if (entranceUUIDtoProcedure.contains(edge.targetUuid)) { val target = entranceUUIDtoProcedure(edge.targetUuid) removePCAssign(block) - (Some(DirectCall(target)), Halt()) + (Some(DirectCall(target)), Unreachable()) } else { throw Exception(s"edge from ${block.label} to ${byteStringToString(edge.targetUuid)} does not point to a known procedure entrance") } diff --git a/src/main/scala/translating/ILtoIL.scala b/src/main/scala/translating/ILtoIL.scala index 99b64bc6e..856b18934 100644 --- a/src/main/scala/translating/ILtoIL.scala +++ b/src/main/scala/translating/ILtoIL.scala @@ -61,7 +61,12 @@ private class ILSerialiser extends ReadOnlyVisitor { } override def visitJump(node: Jump): Jump = { - node.acceptVisit(this) + node match { + case j: GoTo => program ++= s"goTo(${j.targets.map(_.label).mkString(", ")})" + case h: Unreachable => program ++= "halt" + case h: Return => program ++= "return" + } + node } @@ -77,7 +82,6 @@ private class ILSerialiser extends ReadOnlyVisitor { override def visitDirectCall(node: DirectCall): Statement = { program ++= "DirectCall(" program ++= procedureIdentifier(node.target) - program ++= ", " program ++= ")" // DirectCall node } @@ -95,7 +99,10 @@ private class ILSerialiser extends ReadOnlyVisitor { program ++= "Block(" + blockIdentifier(node) + ",\n" indentLevel += 1 program ++= getIndent() - program ++= "statements(\n" + program ++= "statements(" + if (node.statements.size > 0) { + program ++= "\n" + } indentLevel += 1 for (s <- node.statements) { @@ -105,8 +112,7 @@ private class ILSerialiser extends ReadOnlyVisitor { } indentLevel -= 1 program ++= getIndent() + "),\n" - program ++= getIndent() + "jumps(\n" - program ++= getIndent() + program ++= getIndent() + "jump(" visitJump(node.jump) program ++= ")\n" indentLevel -= 1 diff --git a/src/main/scala/translating/IRToBoogie.scala b/src/main/scala/translating/IRToBoogie.scala index e37f9ce92..6cf271de5 100644 --- a/src/main/scala/translating/IRToBoogie.scala +++ b/src/main/scala/translating/IRToBoogie.scala @@ -650,7 +650,7 @@ class IRToBoogie(var program: Program, var spec: Specification, var thread: Opti val jump = GoToCmd(g.targets.map(_.label).toSeq) conditionAssert :+ jump case r: Return => List(ReturnCmd) - case r: Halt => List(BAssume(FalseBLiteral)) + case r: Unreachable => List(BAssume(FalseBLiteral)) } def translate(j: Call): List[BCmd] = j match { diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index 059701ab2..aeaaa6916 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -13,7 +13,6 @@ import java.io.{BufferedWriter, FileWriter, IOException} import scala.jdk.CollectionConverters.* import analysis.solvers.* import analysis.* -import cfg_visualiser.Output import bap.* import ir.* import boogie.* @@ -28,7 +27,6 @@ import util.Logger import java.util.Base64 import spray.json.DefaultJsonProtocol.* import util.intrusive_list.IntrusiveList -import analysis.CfgCommandNode import cilvisitor._ import scala.annotation.tailrec @@ -51,7 +49,6 @@ case class IRContext( /** Stores the results of the static analyses. */ case class StaticAnalysisContext( - cfg: ProgramCfg, constPropResult: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], IRconstPropResult: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], memoryRegionResult: Map[CFGPosition, LiftedElement[Set[MemoryRegion]]], @@ -216,7 +213,7 @@ object IRTransform { Logger.info("[!] Stripping unreachable") val before = ctx.program.procedures.size - ctx.program.stripUnreachableFunctions(config.procedureTrimDepth) + transforms.stripUnreachableFunctions(ctx.program, config.procedureTrimDepth) Logger.info( s"[!] Removed ${before - ctx.program.procedures.size} functions (${ctx.program.procedures.size} remaining)" ) @@ -273,15 +270,12 @@ object StaticAnalysis { newLoops.foreach(l => Logger.info(s"Loop found: ${l.name}")) config.analysisDotPath.foreach { s => - val newCFG = ProgramCfgFactory().fromIR(IRProgram) - writeToFile(newCFG.toDot(x => x.toString, Output.dotIder), s"${s}_resolvedCFG-reducible.dot") + writeToFile(dotBlockGraph(IRProgram, IRProgram.map(b => b -> b.toString).toMap), s"${s}_graph-after-reduce-$iteration.dot") writeToFile(dotBlockGraph(IRProgram, IRProgram.filter(_.isInstanceOf[Block]).map(b => b -> b.toString).toMap), s"${s}_blockgraph-after-reduce-$iteration.dot") } val mergedSubroutines = subroutines ++ externalAddresses - val cfg = ProgramCfgFactory().fromIR(IRProgram) - val domain = computeDomain(IntraProcIRCursor, IRProgram.procedures) Logger.info("[!] Running ANR") @@ -320,11 +314,17 @@ object StaticAnalysis { Logger.info("[!] Running RegToMemAnalysisSolver") - val regionAccessesAnalysisSolver = RegionAccessesAnalysisSolver(cfg, constPropResult, reachingDefinitionsAnalysisResults) + val regionAccessesAnalysisSolver = RegionAccessesAnalysisSolver(IRProgram, constPropResult, reachingDefinitionsAnalysisResults) val regionAccessesAnalysisResults = regionAccessesAnalysisSolver.analyze() - config.analysisDotPath.foreach(s => writeToFile(cfg.toDot(Output.labeler(regionAccessesAnalysisResults, true), Output.dotIder), s"${s}_RegTo$iteration.dot")) - config.analysisResultsPath.foreach(s => writeToFile(printAnalysisResults(cfg, regionAccessesAnalysisResults, iteration), s"${s}_RegTo$iteration.txt")) +// config.analysisDotPath.foreach(s => writeToFile(cfg.toDot(Output.labeler(regionAccessesAnalysisResults, true), Output.dotIder), s"${s}_RegTo$iteration.dot")) + config.analysisResultsPath.foreach(s => writeToFile(printAnalysisResults(IRProgram, regionAccessesAnalysisResults), s"${s}_RegTo$iteration.txt")) + config.analysisDotPath.foreach(s => { + writeToFile( + toDot(IRProgram, IRProgram.filter(_.isInstanceOf[Command]).map(b => b -> regionAccessesAnalysisResults(b).toString).toMap), + s"${s}_RegTo$iteration.dot" + ) + }) Logger.info("[!] Running Constant Propagation with SSA") val constPropSolverWithSSA = ConstantPropagationSolverWithSSA(IRProgram, reachingDefinitionsAnalysisResults) @@ -365,10 +365,8 @@ object StaticAnalysis { mmm.logRegions(memoryRegionContents) // turn fake procedures into diamonds - transforms.addReturnBlocks(ctx.program, true) // add return to all blocks because IDE solver expects it Logger.info("[!] Running VSA") - val vsaSolver = - ValueSetAnalysisSolver(IRProgram, globalAddresses, externalAddresses, globalOffsets, subroutines, mmm, constPropResult) + val vsaSolver = ValueSetAnalysisSolver(IRProgram, globalAddresses, externalAddresses, globalOffsets, subroutines, mmm, constPropResult) val vsaResult: Map[CFGPosition, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]] = vsaSolver.analyze() Logger.info("[!] Running Interprocedural Live Variables Analysis") @@ -380,7 +378,6 @@ object StaticAnalysis { // val paramResults = Map[Procedure, Set[Variable]]() StaticAnalysisContext( - cfg = cfg, constPropResult = constPropResult, IRconstPropResult = newCPResult, memoryRegionResult = mraResult, @@ -394,34 +391,6 @@ object StaticAnalysis { ) } - /** Converts MapLattice of CfgNodes to a MapLattice from IRPosition. - * @param cfg - * The CFG - * @param result - * The analysis result MapLattice - * @tparam T - * The analysis result type. - * @return - * The new map analysis result. - */ - def convertAnalysisResults[T](cfg: ProgramCfg, result: Map[CfgNode, T]): Map[CFGPosition, T] = { - val results = mutable.HashMap[CFGPosition, T]() - result.foreach((node, res) => - node match { - case s: CfgStatementNode => results.addOne(s.data -> res) - case s: CfgFunctionEntryNode => results.addOne(s.data -> res) - case s: CfgJumpNode => results.addOne(s.data -> res) - case s: CfgCommandNode => results.addOne(s.data -> res) - case _ => () - } - ) - - results.toMap - } - - def printAnalysisResults[T](program: Program, cfg: ProgramCfg, result: Map[CfgNode, T]): String = { - printAnalysisResults(program, convertAnalysisResults(cfg, result)) - } def printAnalysisResults(prog: Program, result: Map[CFGPosition, _]): String = { val results = mutable.ArrayBuffer[String]() @@ -465,86 +434,6 @@ object StaticAnalysis { results.mkString(System.lineSeparator()) } - def printAnalysisResults(cfg: ProgramCfg, result: Map[CfgNode, _], iteration: Int): String = { - val functionEntries = cfg.nodes.collect { case n: CfgFunctionEntryNode => n }.toSeq.sortBy(_.data.name) - val s = StringBuilder() - s.append(System.lineSeparator()) - for (f <- functionEntries) { - val stack: mutable.Stack[CfgNode] = mutable.Stack() - val visited: mutable.Set[CfgNode] = mutable.Set() - stack.push(f) - var previousBlock: String = "" - var isEntryNode = false - while (stack.nonEmpty) { - val next = stack.pop() - if (!visited.contains(next)) { - visited.add(next) - next.match { - case c: CfgCommandNode => - if (c.block.label != previousBlock) { - printBlock(c) - } - c match { - case _: CfgStatementNode => s.append(" ") - case _ => () - } - printNode(c) - previousBlock = c.block.label - isEntryNode = false - case c: CfgFunctionEntryNode => - printNode(c) - isEntryNode = true - case c: CfgCallNoReturnNode => - s.append(System.lineSeparator()) - isEntryNode = false - case _ => isEntryNode = false - } - val successors = next.succIntra - if (successors.size > 1) { - val successorsCmd = successors.collect { case c: CfgCommandNode => c }.toSeq.sortBy(_.data.toString) - printGoTo(successorsCmd) - for (s <- successorsCmd) { - if (!visited.contains(s)) { - stack.push(s) - } - } - } else if (successors.size == 1) { - val successor = successors.head - if (!visited.contains(successor)) { - stack.push(successor) - } - successor.match { - case c: CfgCommandNode if (c.block.label != previousBlock) && (!isEntryNode) => printGoTo(Seq(c)) - case _ => - } - } - } - } - s.append(System.lineSeparator()) - } - - def printNode(node: CfgNode): Unit = { - s.append(node) - s.append(" :: ") - s.append(result(node)) - s.append(System.lineSeparator()) - } - - def printGoTo(nodes: Seq[CfgCommandNode]): Unit = { - s.append("[GoTo] ") - s.append(nodes.map(_.block.label).mkString(", ")) - s.append(System.lineSeparator()) - s.append(System.lineSeparator()) - } - - def printBlock(node: CfgCommandNode): Unit = { - s.append("[Block] ") - s.append(node.block.label) - s.append(System.lineSeparator()) - } - - s.toString - } } @@ -608,7 +497,7 @@ object RunUtils { val result = StaticAnalysis.analyse(ctx, config, iteration) analysisResult.append(result) Logger.info("[!] Replacing Indirect Calls") - modified = transforms.resolveIndirectCallsUsingPointsTo(result.cfg, + modified = transforms.resolveIndirectCallsUsingPointsTo( result.steensgaardResults, result.memoryRegionContents, result.reachingDefs, @@ -626,11 +515,6 @@ object RunUtils { transforms.splitThreads(ctx.program, analysisResult.last.steensgaardResults, analysisResult.last.memoryRegionContents, analysisResult.last.reachingDefs) } - config.analysisDotPath.foreach { s => - val newCFG = analysisResult.last.cfg - writeToFile(newCFG.toDot(x => x.toString, Output.dotIder), s"${s}_resolvedCFG.dot") - } - assert(invariant.singleCallBlockEnd(ctx.program)) Logger.info(s"[!] Finished indirect call resolution after $iteration iterations") analysisResult.last diff --git a/src/test/scala/IndirectCallsTests.scala b/src/test/scala/IndirectCallsTests.scala index 15c635fc9..fbc5a4362 100644 --- a/src/test/scala/IndirectCallsTests.scala +++ b/src/test/scala/IndirectCallsTests.scala @@ -84,7 +84,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -122,7 +121,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -160,7 +158,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -204,7 +201,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -248,7 +244,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -291,7 +286,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -335,7 +329,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -363,7 +356,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before "l000004f3set_seven" -> ("set_seven", "R0") ) - println("prev " + result.ir.program) // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => @@ -375,7 +367,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -414,7 +405,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -453,7 +443,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before } // Traverse the statements in the main function - println(result.ir.program) assert(expectedCallTransform.isEmpty) } diff --git a/src/test/scala/LiveVarsAnalysisTests.scala b/src/test/scala/LiveVarsAnalysisTests.scala index e1b142001..881ad61fc 100644 --- a/src/test/scala/LiveVarsAnalysisTests.scala +++ b/src/test/scala/LiveVarsAnalysisTests.scala @@ -115,10 +115,12 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { var program = prog( proc("main", block("main_first_call", - directCall("wrapper1"), goto("main_second_call") + directCall("wrapper1"), + goto("main_second_call") ), block("main_second_call", - directCall("wrapper2"), goto("main_return") + directCall("wrapper2"), + goto("main_return") ), block("main_return", ret) ), @@ -128,10 +130,12 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { proc("wrapper1", block("wrapper1_first_call", Assign(R1, constant1), - directCall("callee"), goto("wrapper1_second_call") + directCall("callee"), + goto("wrapper1_second_call") ), block("wrapper1_second_call", - directCall("callee2"), goto("wrapper1_return")), + directCall("callee2"), + goto("wrapper1_return")), block("wrapper1_return", ret) ), proc("wrapper2", @@ -349,12 +353,11 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { val blocks = result.ir.program.blocks // main has no parameters, get_two has three and a return - assert(analysisResults(blocks("lmain").jump) == Map(R29 -> TwoElementTop, R31 -> TwoElementTop)) - assert(analysisResults(blocks("l000003ec").jump) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) // get_two aftercall - assert(analysisResults(blocks("l00000430").jump) == Map(R31 -> TwoElementTop)) // printf aftercall - assert(analysisResults(blocks("main_basil_return").jump) == Map(R30 -> TwoElementTop)) - assert(analysisResults(blocks("lget_two").jump) == Map(R0 -> TwoElementTop, R1 -> TwoElementTop, R2 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) - assert(analysisResults(blocks("get_two_basil_return").jump) == Map(R0 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("lmain")) == Map(R29 -> TwoElementTop, R31 -> TwoElementTop, R30 -> TwoElementTop)) + assert(analysisResults(blocks("l000003ec")) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) // get_two aftercall + assert(analysisResults(blocks("l00000430")) == Map(R31 -> TwoElementTop)) // printf aftercall + assert(analysisResults(blocks("lget_two")) == Map(R0 -> TwoElementTop, R1 -> TwoElementTop, R2 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("get_two_basil_return")) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) } test("ifbranches") { diff --git a/src/test/scala/MemoryRegionAnalysisMiscTest.scala b/src/test/scala/MemoryRegionAnalysisMiscTest.scala index 2844548d1..a7d7e4e73 100644 --- a/src/test/scala/MemoryRegionAnalysisMiscTest.scala +++ b/src/test/scala/MemoryRegionAnalysisMiscTest.scala @@ -1,4 +1,4 @@ -import analysis.{CfgNode, LiftedElement, MemoryRegion} +import analysis.{LiftedElement, MemoryRegion} import org.scalatest.Inside.inside import org.scalatest.* import org.scalatest.funsuite.* diff --git a/src/test/scala/ir/IRTest.scala b/src/test/scala/ir/IRTest.scala index 7421c2a28..ac6df38cf 100644 --- a/src/test/scala/ir/IRTest.scala +++ b/src/test/scala/ir/IRTest.scala @@ -223,7 +223,7 @@ class IRTest extends AnyFunSuite { Assign(R0, bv64(22)), Assign(R0, bv64(22)), directCall("main"), - halt + unreachable ).resolve(p) val b2 = block("newblock1", Assign(R0, bv64(22)), @@ -249,7 +249,7 @@ class IRTest extends AnyFunSuite { val b3 = block("newblock3", Assign(R0, bv64(22)), directCall("called"), - halt + unreachable ).resolve(p) assert(b3.calls.toSet == Set(p.procs("called"))) diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 47cf65e3c..57f9739fe 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -30,7 +30,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { var IRProgram = IRTranslator.translate IRProgram = ExternalRemover(externalFunctions.map(e => e.name)).visitProgram(IRProgram) IRProgram = Renamer(Set("free")).visitProgram(IRProgram) - IRProgram.stripUnreachableFunctions() + transforms.stripUnreachableFunctions(IRProgram) val stackIdentification = StackSubstituter() stackIdentification.visitProgram(IRProgram) IRProgram.setModifies(Map()) From bd6b2adb61b9752a9157355ddef98e84b2ed8e31 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Mon, 19 Aug 2024 12:15:10 +1000 Subject: [PATCH 04/62] update docs --- docs/basil-ir.md | 45 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/docs/basil-ir.md b/docs/basil-ir.md index fe30af78d..4c152db8d 100644 --- a/docs/basil-ir.md +++ b/docs/basil-ir.md @@ -3,6 +3,7 @@ BASIL IR is the intermediate representation used during static analysis. This is on contrast to Boogie IR which is used for specification annotation, and output to textual boogie syntax that can be run through the Boogie verifier. + The grammar is described below, note that the IR is a data-structure, without a concrete textual representation so the below grammar only represents the structure. We omit the full description of the expression language because it is relatively standard. @@ -13,13 +14,16 @@ The IR has a completely standard simple type system that is enforced at construc Program ::=&~ Procedure* \\ Procedure ::=&~ (name: ProcID) (entryBlock: Block) (returnBlock: Block) (blocks: Block*) \\ &~ \text{Where }entryBlock, returnBlock \in blocks \\ -Block ::=&~ BlockID \; (Statement*)\; Jump \; (fallthrough: (GoTo | None))\\ - &~ \text{Where $fallthough$ may be $GoTo$ IF $Jump$ is $Call$} \\ +Block_1 ::=&~ BlockID \; Statement*\; Call? \; Jump \; \\ +Block_2 ::=&~ BlockID \; (Statement | Call)*\; Jump \; \\ +\\ +&~ Block = Block_1 \text{ is a structural invariant that holds during all the early analysis/transform stages} +\\ Statement ::=&~ MemoryAssign ~|~ LocalAssign ~|~ Assume ~|~ Assert ~|~ NOP \\ ProcID ::=&~ String \\ BlockID ::=&~ String \\ \\ -Jump ::=&~ Call ~|~ GoTo \\ +Jump ::=&~ GoTo ~|~ Unreachable ~|~ Return \\ GoTo ::=&~ \text{goto } BlockID* \\ Call ::=&~ DirectCall ~|~ IndirectCall \\ DirectCall ::=&~ \text{call } ProcID \\ @@ -46,6 +50,33 @@ Endian ::=&~ BigEndian ~|~ LittleEndian \\ \end{align*} ``` +- The `GoTo` jump is a multi-target jump reprsenting non-deterministic choice between its targets. + Conditional structures are represented by these with a guard (an assume statement) beginning each target. +- The `Unreachable` jump is used to signify the absence of successors, it has the semantics of `assume false`. +- The `Return` jump passes control to the calling function, often this is over-approximated to all functions which call the statement's parent procedure. + +## Translation Phases + +#### IR With Returns + +- Immediately after loading the IR return statements may appear in any block, or may be represented by indirect calls. + The transform pass below replaces all calls to the link register (R30) with return statements. + In the future, more proof is required to implement this soundly. + +``` +cilvisitor.visit_prog(transforms.ReplaceReturns(), ctx.program) +transforms.addReturnBlocks(ctx.program, true) // add return to all blocks because IDE solver expects it +cilvisitor.visit_prog(transforms.ConvertSingleReturn(), ctx.program) +``` + +This ensures that all returning, non-stub procedures have exactly one return statement residing in their `returnBlock`. + +#### Calls appear only as the last statement in a block + +- The structure of the IR allows a call may appear anywhere in the block but for all the analysis passes we hold the invariant that it + only appears as the last statement. This is checked with the function `singleCallBlockEnd(p: Program)`. + And it means for any call statement `c` we may `assert(c.parent.statements.lastOption.contains(c))`. + ## Interaction with BASIL IR ### Constructing Programs in Code @@ -62,10 +93,12 @@ var program: Program = prog( block("first_call", Assign(R0, bv64(1), None) Assign(R1, bv64(1), None) - directCall("callee1", Some("second_call")) + directCall("callee1"), + goto("second_call")) ), block("second_call", - directCall("callee2", Some("returnBlock")) + directCall("callee2"), + goto("returnBlock") ), block("returnBlock", ret @@ -82,7 +115,7 @@ program ::= prog ( procedure+ ) procedure ::= proc (procname, block+) block ::= block(blocklabel, statement+, jump) statement ::= -jump ::= call_s | goto_s | ret +jump ::= goto_s | ret | unreachable call_s ::= directCall (procedurename, None | Some(blocklabel)) // target, fallthrough goto_s ::= goto(blocklabel+) // targets procname ::= String From 96b902837677228b6b2c8eb31bc80424b33bbf2a Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Mon, 19 Aug 2024 16:22:51 +1000 Subject: [PATCH 05/62] pe based expr evaluator --- src/main/scala/analysis/Lattice.scala | 4 +- src/main/scala/analysis/UtilMethods.scala | 66 +---- src/main/scala/ir/Expr.scala | 9 +- src/main/scala/ir/Program.scala | 2 +- .../eval/Bitvector.scala} | 36 +-- .../scala/ir/{ => eval}/Interpreter.scala | 196 ++++----------- src/main/scala/ir/eval/expr.scala | 228 ++++++++++++++++++ 7 files changed, 306 insertions(+), 235 deletions(-) rename src/main/scala/{analysis/BitVectorEval.scala => ir/eval/Bitvector.scala} (90%) rename src/main/scala/ir/{ => eval}/Interpreter.scala (54%) create mode 100644 src/main/scala/ir/eval/expr.scala diff --git a/src/main/scala/analysis/Lattice.scala b/src/main/scala/analysis/Lattice.scala index 0ef98020f..cb499f8c3 100644 --- a/src/main/scala/analysis/Lattice.scala +++ b/src/main/scala/analysis/Lattice.scala @@ -1,7 +1,7 @@ package analysis import ir._ -import analysis.BitVectorEval +import ir.eval.BitVectorEval import util.Logger /** Basic lattice @@ -244,4 +244,4 @@ class ConstantPropagationLatticeWithSSA extends PowersetLattice[BitVecLiteral] { apply(BitVectorEval.boogie_extract(high, low, _: BitVecLiteral), a) def concat(a: Set[BitVecLiteral], b: Set[BitVecLiteral]): Set[BitVecLiteral] = apply(BitVectorEval.smt_concat, a, b) -} \ No newline at end of file +} diff --git a/src/main/scala/analysis/UtilMethods.scala b/src/main/scala/analysis/UtilMethods.scala index 2d6c54090..0b7925894 100644 --- a/src/main/scala/analysis/UtilMethods.scala +++ b/src/main/scala/analysis/UtilMethods.scala @@ -1,6 +1,7 @@ package analysis import ir.* import util.Logger +import ir.eval.BitVectorEval /** Evaluate an expression in a hope of finding a global variable. * @@ -12,63 +13,14 @@ import util.Logger * The evaluated expression (e.g. 0x69632) */ def evaluateExpression(exp: Expr, constantPropResult: Map[Variable, FlatElement[BitVecLiteral]]): Option[BitVecLiteral] = { - Logger.debug(s"evaluateExpression: $exp") - exp match { - case binOp: BinaryExpr => - val lhs = evaluateExpression(binOp.arg1, constantPropResult) - val rhs = evaluateExpression(binOp.arg2, constantPropResult) - - (lhs, rhs) match { - case (Some(l: BitVecLiteral), Some(r: BitVecLiteral)) => - val result = binOp.op match { - case BVADD => BitVectorEval.smt_bvadd(l, r) - case BVSUB => BitVectorEval.smt_bvsub(l, r) - case BVMUL => BitVectorEval.smt_bvmul(l, r) - case BVUDIV => BitVectorEval.smt_bvudiv(l, r) - case BVSDIV => BitVectorEval.smt_bvsdiv(l, r) - case BVSREM => BitVectorEval.smt_bvsrem(l, r) - case BVUREM => BitVectorEval.smt_bvurem(l, r) - case BVSMOD => BitVectorEval.smt_bvsmod(l, r) - case BVAND => BitVectorEval.smt_bvand(l, r) - case BVOR => BitVectorEval.smt_bvxor(l, r) - case BVXOR => BitVectorEval.smt_bvxor(l, r) - case BVNAND => BitVectorEval.smt_bvnand(l, r) - case BVNOR => BitVectorEval.smt_bvnor(l, r) - case BVXNOR => BitVectorEval.smt_bvxnor(l, r) - case BVSHL => BitVectorEval.smt_bvshl(l, r) - case BVLSHR => BitVectorEval.smt_bvlshr(l, r) - case BVASHR => BitVectorEval.smt_bvashr(l, r) - case BVCOMP => BitVectorEval.smt_bvcomp(l, r) - case BVCONCAT => BitVectorEval.smt_concat(l, r) - case x => throw RuntimeException("Binary operation support not implemented: " + binOp.op) - } - Some(result) - case _ => None - } - case extend: ZeroExtend => - evaluateExpression(extend.body, constantPropResult) match { - case Some(b: BitVecLiteral) => Some(BitVectorEval.smt_zero_extend(extend.extension, b)) - case None => None - } - case extend: SignExtend => - evaluateExpression(extend.body, constantPropResult) match { - case Some(b: BitVecLiteral) => Some(BitVectorEval.smt_sign_extend(extend.extension, b)) - case None => None - } - case e: Extract => - evaluateExpression(e.body, constantPropResult) match { - case Some(b: BitVecLiteral) => Some(BitVectorEval.boogie_extract(e.end, e.start, b)) - case None => None - } - case variable: Variable => - constantPropResult(variable) match { + def value(v: Variable) = constantPropResult(v) match { case FlatEl(value) => Some(value) - case Top => None - case Bottom => None - } - case b: BitVecLiteral => Some(b) - case _ => //throw new RuntimeException("ERROR: CASE NOT HANDLED: " + exp + "\n") - None + case _ => None + } + + ir.eval.evalBVExpr(exp, value) match { + case Right(v) => Some(v) + case Left(_) => None } } @@ -170,4 +122,4 @@ def unwrapExpr(expr: Expr): Set[Expr] = { case _ => } buffers -} \ No newline at end of file +} diff --git a/src/main/scala/ir/Expr.scala b/src/main/scala/ir/Expr.scala index 6a7c862cd..4b10fc092 100644 --- a/src/main/scala/ir/Expr.scala +++ b/src/main/scala/ir/Expr.scala @@ -1,5 +1,6 @@ package ir + import boogie._ import scala.collection.mutable @@ -28,18 +29,22 @@ sealed trait Literal extends Expr { override def acceptVisit(visitor: Visitor): Literal = visitor.visitLiteral(this) } -sealed trait BoolLit extends Literal +sealed trait BoolLit extends Literal { + def value: Boolean +} case object TrueLiteral extends BoolLit { override def toBoogie: BoolBLiteral = TrueBLiteral override def getType: IRType = BoolType override def toString: String = "true" + override def value = true } case object FalseLiteral extends BoolLit { override def toBoogie: BoolBLiteral = FalseBLiteral override def getType: IRType = BoolType override def toString: String = "false" + override def value = false } case class BitVecLiteral(value: BigInt, size: Int) extends Literal { @@ -391,4 +396,4 @@ case class StackMemory(override val name: String, override val addressSize: Int, // A non-stack region of memory, which is shared between threads case class SharedMemory(override val name: String, override val addressSize: Int, override val valueSize: Int) extends Memory { override def acceptVisit(visitor: Visitor): Memory = visitor.visitSharedMemory(this) -} \ No newline at end of file +} diff --git a/src/main/scala/ir/Program.scala b/src/main/scala/ir/Program.scala index 54916b27c..0a19add9c 100644 --- a/src/main/scala/ir/Program.scala +++ b/src/main/scala/ir/Program.scala @@ -3,9 +3,9 @@ package ir import scala.collection.mutable.ArrayBuffer import scala.collection.{IterableOnceExtensionMethods, View, immutable, mutable} import boogie.* -import analysis.BitVectorEval import util.intrusive_list.* import translating.serialiseIL +import eval.BitVectorEval class Program(var procedures: ArrayBuffer[Procedure], var mainProcedure: Procedure, diff --git a/src/main/scala/analysis/BitVectorEval.scala b/src/main/scala/ir/eval/Bitvector.scala similarity index 90% rename from src/main/scala/analysis/BitVectorEval.scala rename to src/main/scala/ir/eval/Bitvector.scala index 0b5847506..b3e426010 100644 --- a/src/main/scala/analysis/BitVectorEval.scala +++ b/src/main/scala/ir/eval/Bitvector.scala @@ -1,6 +1,6 @@ -package analysis +package ir.eval + import ir._ -import analysis.BitVectorEval.* import scala.math.pow @@ -157,15 +157,15 @@ object BitVectorEval { /** (bvneq (_ BitVec m) (_ BitVec m)) * - not equal too */ - def smt_bveq(s: BitVecLiteral, t: BitVecLiteral): BoolLit = { - bool2BoolLit(s == t) + def smt_bveq(s: BitVecLiteral, t: BitVecLiteral): Boolean = { + s == t } /** (bvneq (_ BitVec m) (_ BitVec m)) * - not equal too */ - def smt_bvneq(s: BitVecLiteral, t: BitVecLiteral): BoolLit = { - bool2BoolLit(s != t) + def smt_bvneq(s: BitVecLiteral, t: BitVecLiteral): Boolean = { + s != t } /** (bvshl (_ BitVec m) (_ BitVec m) (_ BitVec m)) @@ -259,55 +259,55 @@ object BitVectorEval { /** (bvult (_ BitVec m) (_ BitVec m) Bool) * - binary predicate for unsigned less-than */ - def smt_bvult(s: BitVecLiteral, t: BitVecLiteral): BoolLit = { - bool2BoolLit(bv2nat(s) < bv2nat(t)) + def smt_bvult(s: BitVecLiteral, t: BitVecLiteral): Boolean = { + bv2nat(s) < bv2nat(t) } /** (bvule (_ BitVec m) (_ BitVec m) Bool) * - binary predicate for unsigned less than or equal */ - def smt_bvule(s: BitVecLiteral, t: BitVecLiteral): BoolLit = { - bool2BoolLit(bv2nat(s) <= bv2nat(t)) + def smt_bvule(s: BitVecLiteral, t: BitVecLiteral): Boolean = { + bv2nat(s) <= bv2nat(t) } /** (bvugt (_ BitVec m) (_ BitVec m) Bool) * - binary predicate for unsigned greater than */ - def smt_bvugt(s: BitVecLiteral, t: BitVecLiteral): BoolLit = { + def smt_bvugt(s: BitVecLiteral, t: BitVecLiteral): Boolean = { smt_bvult(t, s) } /** (bvuge (_ BitVec m) (_ BitVec m) Bool) * - binary predicate for unsigned greater than or equal */ - def smt_bvuge(s: BitVecLiteral, t: BitVecLiteral): BoolLit = smt_bvule(t, s) + def smt_bvuge(s: BitVecLiteral, t: BitVecLiteral): Boolean = smt_bvule(t, s) /** (bvslt (_ BitVec m) (_ BitVec m) Bool) * - binary predicate for signed less than */ - def smt_bvslt(s: BitVecLiteral, t: BitVecLiteral): BoolLit = { + def smt_bvslt(s: BitVecLiteral, t: BitVecLiteral): Boolean = { val sNeg = isNegative(s) val tNeg = isNegative(t) - bool2BoolLit((sNeg && !tNeg) || ((sNeg == tNeg) && (smt_bvult(s, t) == TrueLiteral))) + (sNeg && !tNeg) || ((sNeg == tNeg) && (smt_bvult(s, t))) } /** (bvsle (_ BitVec m) (_ BitVec m) Bool) * - binary predicate for signed less than or equal */ - def smt_bvsle(s: BitVecLiteral, t: BitVecLiteral): BoolLit = + def smt_bvsle(s: BitVecLiteral, t: BitVecLiteral): Boolean = val sNeg = isNegative(s) val tNeg = isNegative(t) - bool2BoolLit((sNeg && !tNeg) || ((sNeg == tNeg) && (smt_bvule(s, t) == TrueLiteral))) + (sNeg && !tNeg) || ((sNeg == tNeg) && (smt_bvule(s, t))) /** (bvsgt (_ BitVec m) (_ BitVec m) Bool) * - binary predicate for signed greater than */ - def smt_bvsgt(s: BitVecLiteral, t: BitVecLiteral): BoolLit = smt_bvslt(t, s) + def smt_bvsgt(s: BitVecLiteral, t: BitVecLiteral): Boolean = smt_bvslt(t, s) /** (bvsge (_ BitVec m) (_ BitVec m) Bool) * - binary predicate for signed greater than or equal */ - def smt_bvsge(s: BitVecLiteral, t: BitVecLiteral): BoolLit = smt_bvsle(t, s) + def smt_bvsge(s: BitVecLiteral, t: BitVecLiteral): Boolean = smt_bvsle(t, s) def smt_bvashr(s: BitVecLiteral, t: BitVecLiteral): BitVecLiteral = if (!isNegative(s)) { diff --git a/src/main/scala/ir/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala similarity index 54% rename from src/main/scala/ir/Interpreter.scala rename to src/main/scala/ir/eval/Interpreter.scala index 0430ef66a..b2f45ed1a 100644 --- a/src/main/scala/ir/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -1,10 +1,29 @@ package ir -import analysis.BitVectorEval.* +import ir.eval.BitVectorEval.* import util.Logger import scala.collection.mutable import scala.util.control.Breaks.{break, breakable} + +// enum Asssumption: +// case Assume(x: Expr) +// case Jump(choice: Block) +// case Not(a: Assumption) +// +// +// class State() { +// // stack of assumptions +// var assumptions: mutable.Stack[Assumption] = mutable.Stack() +// var memory: mutable.Map[Memory, Map[BigInt, BigInt]] = Map() +// var bvValues: mutable.Map[Variable, BigInt] = Map() +// var intValues: mutable.Map[Variable, BigInt] = Map() +// +// var nextCmd: Option[Command] = None +// var returnCmd: mutable.Stack[Command] = mutable.Stack() +// } + + class Interpreter() { val regs: mutable.Map[Variable, BitVecLiteral] = mutable.Map() val mems: mutable.Map[Int, BitVecLiteral] = mutable.Map() @@ -15,155 +34,28 @@ class Interpreter() { private val returnCmd: mutable.Stack[Command] = mutable.Stack() def eval(exp: Expr, env: mutable.Map[Variable, BitVecLiteral]): BitVecLiteral = { - exp match { - case id: Variable => - env.get(id) match { - case Some(value) => - Logger.debug(s"\t${id.name} == 0x${value.value.toString(16)}[u${value.size}]") - value - case _ => throw new Exception(s"$id not found in env") - } - - case n: Literal => - n match { - case bv: BitVecLiteral => - Logger.debug(s"\tBitVecLiteral(0x${bv.value.toString(16)}[u${bv.size}])") - bv - case _ => ??? - } - - case ze: ZeroExtend => - Logger.debug(s"\t$ze") - smt_zero_extend(ze.extension, eval(ze.body, env)) - - case se: SignExtend => - Logger.debug(s"\t$se") - smt_sign_extend(se.extension, eval(se.body, env)) - - case e: Extract => - Logger.debug(s"\tExtract($e, ${e.start}, ${e.end})") - boogie_extract(e.end, e.start, eval(e.body, env)) - - case r: Repeat => - Logger.debug(s"\t$r") - ??? // TODO - - case bin: BinaryExpr => - val left = eval(bin.arg1, env) - val right = eval(bin.arg2, env) - Logger.debug( - s"\tBinaryExpr(0x${left.value.toString(16)}[u${left.size}] ${bin.op} 0x${right.value.toString(16)}[u${right.size}])" - ) - bin.op match { - case BVAND => smt_bvand(left, right) - case BVOR => smt_bvor(left, right) - case BVADD => smt_bvadd(left, right) - case BVMUL => smt_bvmul(left, right) - case BVUDIV => smt_bvudiv(left, right) - case BVUREM => smt_bvurem(left, right) - case BVSHL => smt_bvshl(left, right) - case BVLSHR => smt_bvlshr(left, right) - case BVNAND => smt_bvnand(left, right) - case BVNOR => smt_bvnor(left, right) - case BVXOR => smt_bvxor(left, right) - case BVXNOR => smt_bvxnor(left, right) - case BVCOMP => smt_bvcomp(left, right) - case BVSUB => smt_bvsub(left, right) - case BVSDIV => smt_bvsdiv(left, right) - case BVSREM => smt_bvsrem(left, right) - case BVSMOD => smt_bvsmod(left, right) - case BVASHR => smt_bvashr(left, right) - case BVCONCAT => smt_concat(left, right) - case _ => ??? - } - - case un: UnaryExpr => - val arg = eval(un.arg, env) - Logger.debug(s"\tUnaryExpr($un)") - un.op match { - case BVNEG => smt_bvneg(arg) - case BVNOT => smt_bvnot(arg) - } - - case ml: MemoryLoad => - Logger.debug(s"\t$ml") - val index: Int = eval(ml.index, env).value.toInt - getMemory(index, ml.size, ml.endian, mems) + def load(m: Memory, index: Expr, endian: Endian, size: Int) = { + val idx = evalInt(index, env).toInt + Some(getMemory(idx, size, endian, mems)) + } - case u: UninterpretedFunction => - Logger.debug(s"\t$u") - ??? + ir.eval.evalBVExpr(exp, x => env.get(x), load) match { + case Right(b) => b + case Left(e) => throw Exception(s"Failed to evaluate expr: residual $exp") } } - def evalBool(exp: Expr, env: mutable.Map[Variable, BitVecLiteral]): BoolLit = { - exp match { - case n: BoolLit => n - case bin: BinaryExpr => - bin.op match { - case b: BoolBinOp => - val arg1 = evalBool(bin.arg1, env) - val arg2 = evalBool(bin.arg2, env) - b match { - case BoolEQ => - if (arg1 == arg2) { - TrueLiteral - } else { - FalseLiteral - } - case BoolNEQ => - if (arg1 != arg2) { - TrueLiteral - } else { - FalseLiteral - } - case BoolAND => - (arg1, arg2) match { - case (TrueLiteral, TrueLiteral) => TrueLiteral - case _ => FalseLiteral - } - case BoolOR => - (arg1, arg2) match { - case (FalseLiteral, FalseLiteral) => FalseLiteral - case _ => TrueLiteral - } - case BoolIMPLIES => - (arg1, arg2) match { - case (TrueLiteral, FalseLiteral) => FalseLiteral - case _ => TrueLiteral - } - case BoolEQUIV => - if (arg1 == arg2) { - TrueLiteral - } else { - FalseLiteral - } - } - case b: BVBinOp => - val left = eval(bin.arg1, env) - val right = eval(bin.arg2, env) - b match { - case BVULT => smt_bvult(left, right) - case BVULE => smt_bvule(left, right) - case BVUGT => smt_bvugt(left, right) - case BVUGE => smt_bvuge(left, right) - case BVSLT => smt_bvslt(left, right) - case BVSLE => smt_bvsle(left, right) - case BVSGT => smt_bvsgt(left, right) - case BVSGE => smt_bvsge(left, right) - case BVEQ => smt_bveq(left, right) - case BVNEQ => smt_bvneq(left, right) - case _ => ??? - } - case _ => ??? - } + def evalBool(exp: Expr, env: mutable.Map[Variable, BitVecLiteral]): Boolean = { + ir.eval.evalLogExpr(exp, x => env.get(x)) match { + case Right(b) => b + case Left(e) => throw Exception(s"Failed to evaluate expr: residual $e") + } + } - case un: UnaryExpr => - un.op match { - case BoolNOT => if evalBool(un.arg, env) == TrueLiteral then FalseLiteral else TrueLiteral - case _ => ??? - } - case _ => ??? + def evalInt(exp: Expr, env: mutable.Map[Variable, BitVecLiteral]): BigInt = { + ir.eval.evalIntExpr(exp, x => env.get(x)) match { + case Right(b) => b + case Left(e) => throw Exception(s"Failed to evaluate expr: residual $e") } } @@ -234,11 +126,9 @@ class Interpreter() { for (g <- gt.targets) { val condition: Option[Expr] = g.statements.headOption.collect { case a: Assume => a.body } condition match { - case Some(e) => evalBool(e, regs) match { - case TrueLiteral => + case Some(e) => if (evalBool(e, regs)) { nextCmd = Some(g.statements.headOption.getOrElse(g.jump)) break - case _ => } case None => nextCmd = Some(g.statements.headOption.getOrElse(g.jump)) @@ -284,19 +174,15 @@ class Interpreter() { case assert: Assert => // TODO Logger.debug(assert) - evalBool(assert.body, regs) match { - case TrueLiteral => () - case FalseLiteral => throw Exception(s"Assertion failed ${assert}") + if (!evalBool(assert.body, regs)) { + throw Exception(s"Assertion failed ${assert}") } case assume: Assume => // TODO, but already taken into effect if it is a branch condition Logger.debug(assume) - evalBool(assume.body, regs) match { - case TrueLiteral => () - case FalseLiteral => { + if (!evalBool(assume.body, regs)) { nextCmd = None Logger.debug(s"Assumption not satisfied: $assume") - } } case dc: DirectCall => Logger.debug(s"$dc") diff --git a/src/main/scala/ir/eval/expr.scala b/src/main/scala/ir/eval/expr.scala new file mode 100644 index 000000000..a7f748546 --- /dev/null +++ b/src/main/scala/ir/eval/expr.scala @@ -0,0 +1,228 @@ +package ir.eval +import ir.eval.BitVectorEval +import ir._ + +/** + * We generalise the expression evaluator to a partial evaluator to simplify evaluating casts. + * This is not as nice or type-safe as we would like. + * + * - Program state is taken via a function from var -> value and for loads a function from (mem,addr,endian,size) -> value. + * - For conrete evaluators we prefer low-level representations (bool vs BoolLit) and wrap them at the expression eval level + * - Avoid using default cases so we have some idea of complete coverage + * + */ + + +def evalBVBinExpr(b: BVBinOp, l:BitVecLiteral, r:BitVecLiteral): BitVecLiteral = { + b match { + case BVADD => BitVectorEval.smt_bvadd(l, r) + case BVSUB => BitVectorEval.smt_bvsub(l, r) + case BVMUL => BitVectorEval.smt_bvmul(l, r) + case BVUDIV => BitVectorEval.smt_bvudiv(l, r) + case BVSDIV => BitVectorEval.smt_bvsdiv(l, r) + case BVSREM => BitVectorEval.smt_bvsrem(l, r) + case BVUREM => BitVectorEval.smt_bvurem(l, r) + case BVSMOD => BitVectorEval.smt_bvsmod(l, r) + case BVAND => BitVectorEval.smt_bvand(l, r) + case BVOR => BitVectorEval.smt_bvxor(l, r) + case BVXOR => BitVectorEval.smt_bvxor(l, r) + case BVNAND => BitVectorEval.smt_bvnand(l, r) + case BVNOR => BitVectorEval.smt_bvnor(l, r) + case BVXNOR => BitVectorEval.smt_bvxnor(l, r) + case BVSHL => BitVectorEval.smt_bvshl(l, r) + case BVLSHR => BitVectorEval.smt_bvlshr(l, r) + case BVASHR => BitVectorEval.smt_bvashr(l, r) + case BVCOMP => BitVectorEval.smt_bvcomp(l, r) + case BVCONCAT => BitVectorEval.smt_concat(l, r) + case BVULE => throw Exception("Did not expect logical op") + case BVULT => throw Exception("Did not expect logical op") + case BVUGT => throw Exception("Did not expect logical op") + case BVUGE => throw Exception("Did not expect logical op") + case BVSLT => throw Exception("Did not expect logical op") + case BVSLE => throw Exception("Did not expect logical op") + case BVSGT => throw Exception("Did not expect logical op") + case BVSGE => throw Exception("Did not expect logical op") + case BVEQ => throw Exception("Did not expect logical op") + case BVNEQ => throw Exception("Did not expect logical op") + } +} + +def evalBVLogBinExpr(b: BVBinOp, l: BitVecLiteral, r:BitVecLiteral) : Boolean = b match { + case BVULE => BitVectorEval.smt_bvule(l, r) + case BVUGT => BitVectorEval.smt_bvult(l, r) + case BVUGE => BitVectorEval.smt_bvuge(l, r) + case BVULT => BitVectorEval.smt_bvult(l, r) + case BVSLT => BitVectorEval.smt_bvslt(l, r) + case BVSLE => BitVectorEval.smt_bvsle(l, r) + case BVSGT => BitVectorEval.smt_bvsgt(l, r) + case BVSGE => BitVectorEval.smt_bvsge(l, r) + case BVEQ => BitVectorEval.smt_bveq(l, r) + case BVNEQ => BitVectorEval.smt_bvneq(l, r) + case BVADD => throw Exception("Did not expect non-logical op") + case BVSUB => throw Exception("Did not expect non-logical op") + case BVMUL => throw Exception("Did not expect non-logical op") + case BVUDIV => throw Exception("Did not expect non-logical op") + case BVSDIV => throw Exception("Did not expect non-logical op") + case BVSREM => throw Exception("Did not expect non-logical op") + case BVUREM => throw Exception("Did not expect non-logical op") + case BVSMOD => throw Exception("Did not expect non-logical op") + case BVAND => throw Exception("Did not expect non-logical op") + case BVOR => throw Exception("Did not expect non-logical op") + case BVXOR => throw Exception("Did not expect non-logical op") + case BVNAND => throw Exception("Did not expect non-logical op") + case BVNOR => throw Exception("Did not expect non-logical op") + case BVXNOR => throw Exception("Did not expect non-logical op") + case BVSHL => throw Exception("Did not expect non-logical op") + case BVLSHR => throw Exception("Did not expect non-logical op") + case BVASHR => throw Exception("Did not expect non-logical op") + case BVCOMP => throw Exception("Did not expect non-logical op") + case BVCONCAT => throw Exception("Did not expect non-logical op") +} + +def evalIntLogBinExpr(b: IntBinOp, l:BigInt, r:BigInt) : Boolean = b match { + case IntEQ => l == r + case IntNEQ => l != r + case IntLT => l < r + case IntLE => l <= r + case IntGT => l > r + case IntGE => l >= r + case IntADD => throw Exception("Did not expect non-logical op") + case IntSUB => throw Exception("Did not expect non-logical op") + case IntMUL => throw Exception("Did not expect non-logical op") + case IntDIV => throw Exception("Did not expect non-logical op") + case IntMOD => throw Exception("Did not expect non-logical op") +} + +def evalIntBinExpr(b: IntBinOp, l:BigInt, r: BigInt): BigInt = b match { + case IntADD => l + r + case IntSUB => l - r + case IntMUL => l * r + case IntDIV => l / r + case IntMOD => l % r + case IntEQ => throw Exception("Did not expect logical op") + case IntNEQ => throw Exception("Did not expect logical op") + case IntLT => throw Exception("Did not expect logical op") + case IntLE => throw Exception("Did not expect logical op") + case IntGT => throw Exception("Did not expect logical op") + case IntGE => throw Exception("Did not expect logical op") +} + + +def evalBoolLogBinExpr(b: BoolBinOp, l:Boolean, r:Boolean) : Boolean = b match { + case BoolEQ => l == r + case BoolEQUIV => l == r + case BoolNEQ => l != r + case BoolAND => l && r + case BoolOR => l || r + case BoolIMPLIES => l || (!r) +} + + +def evalUnOp(op: UnOp, body: Literal) : Expr = { + (body, op) match { + case (b: BitVecLiteral, BVNOT) => BitVectorEval.smt_bvnot(b) + case (b: BitVecLiteral, BVNEG) => BitVectorEval.smt_bvneg(b) + case (i: IntLiteral, IntNEG) => IntLiteral(-i.value) + case (FalseLiteral, BoolNOT) => TrueLiteral + case (TrueLiteral, BoolNOT) => FalseLiteral + } +} + +def partialEvalExpr(exp: Expr, variableAssignment: Variable => Option[Expr], memory: (Memory, Expr, Endian, Int) => Option[BitVecLiteral] = ((a,b,c,d) => None)): Expr = { + exp match { + case f: UninterpretedFunction => f + case unOp: UnaryExpr => { + val body = partialEvalExpr(unOp.arg, variableAssignment, memory) + body match { + case l: Literal => evalUnOp(unOp.op, l) + case o => UnaryExpr(unOp.op, body) + } + } + case binOp: BinaryExpr => + val lhs = partialEvalExpr(binOp.arg1, variableAssignment, memory) + val rhs = partialEvalExpr(binOp.arg2, variableAssignment, memory) + binOp.getType match { + case m: MapType => binOp + case b: BitVecType => { + (binOp.op, lhs, rhs) match { + case (o: BVBinOp, l: BitVecLiteral, r: BitVecLiteral) => evalBVBinExpr(o, l, r) + case _ => BinaryExpr(binOp.op, lhs, rhs) + } + } + case BoolType => { + def bool2lit(b: Boolean) = if b then TrueLiteral else FalseLiteral + (binOp.op, lhs, rhs) match { + case (o: BVBinOp, l: BitVecLiteral, r: BitVecLiteral) => bool2lit(evalBVLogBinExpr(o, l, r)) + case (o: IntBinOp, l: IntLiteral , r: IntLiteral) => bool2lit(evalIntLogBinExpr(o, l.value, r.value)) + case (o: BoolBinOp, l: BoolLit, r: BoolLit) => bool2lit(evalBoolLogBinExpr(o, l.value, r.value)) + case _ => BinaryExpr(binOp.op, lhs, rhs) + } + } + case IntType => { + (binOp.op, lhs, rhs) match { + case (o: IntBinOp, l: IntLiteral , r: IntLiteral) => IntLiteral(evalIntBinExpr(o, l.value, r.value)) + case _ => BinaryExpr(binOp.op, lhs, rhs) + } + } + } + case extend: ZeroExtend => partialEvalExpr(extend.body, variableAssignment, memory) match { + case b : BitVecLiteral => BitVectorEval.smt_zero_extend(extend.extension, b) + case o => extend.copy(body=o) + } + case extend: SignExtend => partialEvalExpr(extend.body, variableAssignment, memory) match { + case b: BitVecLiteral => BitVectorEval.smt_sign_extend(extend.extension, b) + case o => extend.copy(body=o) + } + case e: Extract => partialEvalExpr(e.body, variableAssignment, memory) match { + case b: BitVecLiteral => BitVectorEval.boogie_extract(e.end, e.start, b) + case o => e.copy(body=o) + } + case r: Repeat => { + partialEvalExpr(r.body, variableAssignment, memory) match { + case b: BitVecLiteral => { + assert(r.repeats > 0) + if (r.repeats == 1) b + else { + (2 to r.repeats).foldLeft(b)((acc, r) => BitVectorEval.smt_concat(acc, b)) + } + } + case o => r.copy(body=o) + } + + } + case variable: Variable => variableAssignment(variable).getOrElse(variable) + case l: MemoryLoad => memory(l.mem, partialEvalExpr(l.index, variableAssignment, memory), l.endian, l.size).getOrElse(l) + case b: BitVecLiteral => b + case b: IntLiteral => b + case b: BoolLit => b + } +} + +def evalIntExpr(exp: Expr, variableAssignment: Variable => Option[BitVecLiteral], memory: (Memory, Expr, Endian, Int) => Option[BitVecLiteral] = ((a,b,c,d) => None)): Either[Expr, BigInt] = { + partialEvalExpr(exp, variableAssignment, memory) match { + case i: IntLiteral => Right(i.value) + case o => Left(o) + } +} + +def evalBVExpr(exp: Expr, variableAssignment: Variable => Option[BitVecLiteral], memory: (Memory, Expr, Endian, Int) => Option[BitVecLiteral] = ((a,b,c,d) => None)): Either[Expr, BitVecLiteral] = { + partialEvalExpr(exp, variableAssignment, memory) match { + case b: BitVecLiteral => Right(b) + case o => Left(o) + } +} + +def evalLogExpr(exp: Expr, variableAssignment: Variable => Option[BitVecLiteral], memory: (Memory, Expr, Endian, Int) => Option[BitVecLiteral] = ((a,b,c, d) => None)): Either[Expr, Boolean] = { + partialEvalExpr(exp, variableAssignment, memory) match { + case TrueLiteral => Right(true) + case FalseLiteral => Right(false) + case o => Left(o) + } +} + +def evalExpr(exp: Expr, variableAssignment: Variable => Option[BitVecLiteral], memory: (Memory, Expr, Endian, Int) => Option[BitVecLiteral] = ((d, a,b,c) => None)): Option[Literal] = { + partialEvalExpr match { + case l: Literal => Some(l) + case _ => None + } +} From d264b97cfd07a52873bd9cec60f612adb60a9681 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Mon, 19 Aug 2024 18:01:12 +1000 Subject: [PATCH 06/62] interpreter test --- src/main/scala/ir/dsl/DSL.scala | 1 + src/main/scala/ir/eval/Interpreter.scala | 149 ++++++++++++-------- src/test/scala/BitVectorAnalysisTests.scala | 42 +++--- src/test/scala/ir/InterpreterTests.scala | 58 +++++++- 4 files changed, 168 insertions(+), 82 deletions(-) diff --git a/src/main/scala/ir/dsl/DSL.scala b/src/main/scala/ir/dsl/DSL.scala index 3ebeefbc4..2d892a444 100644 --- a/src/main/scala/ir/dsl/DSL.scala +++ b/src/main/scala/ir/dsl/DSL.scala @@ -11,6 +11,7 @@ val R4: Register = Register("R4", 64) val R5: Register = Register("R5", 64) val R6: Register = Register("R6", 64) val R7: Register = Register("R7", 64) +val R8: Register = Register("R8", 64) val R29: Register = Register("R29", 64) val R30: Register = Register("R30", 64) val R31: Register = Register("R31", 64) diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index b2f45ed1a..f505425d3 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -5,57 +5,72 @@ import util.Logger import scala.collection.mutable import scala.util.control.Breaks.{break, breakable} - // enum Asssumption: // case Assume(x: Expr) // case Jump(choice: Block) // case Not(a: Assumption) -// -// -// class State() { + +sealed trait ExecutionState +case class FailedAssertion(a: Assert) extends ExecutionState +case class Stopped() extends ExecutionState +case class Run(val next: Command) extends ExecutionState +case class EscapedControlFlow(val call: IndirectCall) extends ExecutionState +case class Errored(val message: String = "") extends ExecutionState + +// case class Execution() extends State { // // stack of assumptions -// var assumptions: mutable.Stack[Assumption] = mutable.Stack() -// var memory: mutable.Map[Memory, Map[BigInt, BigInt]] = Map() -// var bvValues: mutable.Map[Variable, BigInt] = Map() -// var intValues: mutable.Map[Variable, BigInt] = Map() +// // var assumptions: mutable.Stack[] = mutable.Stack() +// var memory: mutable.Map[Memory, Map[BigInt, BigInt]] = mutable.Map() +// var bvValues: mutable.Map[Variable, BigInt] = mutable.Map() +// var intValues: mutable.Map[Variable, BigInt] = mutable.Map() // -// var nextCmd: Option[Command] = None -// var returnCmd: mutable.Stack[Command] = mutable.Stack() +// var nextCmd: ExecutionState = Stopped() +// var callStack: mutable.Stack[Command] = mutable.Stack() // } +case class InterpreterError(condinue: ExecutionState) extends Exception() + class Interpreter() { val regs: mutable.Map[Variable, BitVecLiteral] = mutable.Map() val mems: mutable.Map[Int, BitVecLiteral] = mutable.Map() private val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) private val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) private val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) - private var nextCmd: Option[Command] = None - private val returnCmd: mutable.Stack[Command] = mutable.Stack() + var nextCmd: ExecutionState = Stopped() + private val callStack: mutable.Stack[Command] = mutable.Stack() def eval(exp: Expr, env: mutable.Map[Variable, BitVecLiteral]): BitVecLiteral = { def load(m: Memory, index: Expr, endian: Endian, size: Int) = { - val idx = evalInt(index, env).toInt + val idx = eval(index, env).value.toInt Some(getMemory(idx, size, endian, mems)) } ir.eval.evalBVExpr(exp, x => env.get(x), load) match { case Right(b) => b - case Left(e) => throw Exception(s"Failed to evaluate expr: residual $exp") + case Left(e) => throw InterpreterError(Errored(s"Failed to evaluate bv expr: residual $exp")) + } + } + + def doReturn() = { + if (callStack.nonEmpty) { + nextCmd = Run(callStack.pop()) + } else { + nextCmd = Stopped() } } def evalBool(exp: Expr, env: mutable.Map[Variable, BitVecLiteral]): Boolean = { ir.eval.evalLogExpr(exp, x => env.get(x)) match { case Right(b) => b - case Left(e) => throw Exception(s"Failed to evaluate expr: residual $e") + case Left(e) => throw InterpreterError(Errored(s"Failed to evaluate logical expr: residual $e")) } } def evalInt(exp: Expr, env: mutable.Map[Variable, BitVecLiteral]): BigInt = { ir.eval.evalIntExpr(exp, x => env.get(x)) match { case Right(b) => b - case Left(e) => throw Exception(s"Failed to evaluate expr: residual $e") + case Left(e) => throw InterpreterError(Errored(s"Failed to evaluate int expr: residual $e")) } } @@ -99,6 +114,7 @@ class Interpreter() { private def interpretProcedure(p: Procedure): Unit = { Logger.debug(s"Procedure(${p.name}, ${p.address.getOrElse("None")})") + Logger.debug(s"Regs $regs") // Procedure.in for ((in, index) <- p.in.zipWithIndex) { @@ -112,96 +128,89 @@ class Interpreter() { // Procedure.Block p.entryBlock match { - case Some(block) => nextCmd = Some(block.statements.headOption.getOrElse(block.jump)) - case None => nextCmd = Some(returnCmd.pop()) + case Some(block) => nextCmd = Run(block.statements.headOption.getOrElse(block.jump)) + case None => doReturn() } } - private def interpretJump(j: Jump) : Unit = { - Logger.debug(s"jump:") + private def interpretJump(j: Jump): Unit = { + Logger.debug(s"jump: $j") breakable { j match { case gt: GoTo => - Logger.debug(s"$gt") for (g <- gt.targets) { val condition: Option[Expr] = g.statements.headOption.collect { case a: Assume => a.body } condition match { - case Some(e) => if (evalBool(e, regs)) { - nextCmd = Some(g.statements.headOption.getOrElse(g.jump)) + case Some(e) => + if (evalBool(e, regs)) { + Logger.debug(s"chosen ${g.label}") + nextCmd = Run(g.statements.headOption.getOrElse(g.jump)) break - } + } case None => - nextCmd = Some(g.statements.headOption.getOrElse(g.jump)) + nextCmd = Run(g.statements.headOption.getOrElse(g.jump)) break } } - case r: Return => { - nextCmd = Some(returnCmd.pop()) - } + case r: Return => doReturn() case h: Unreachable => { Logger.debug("Unreachable") - nextCmd = None + nextCmd = Stopped() } } } } private def interpretStatement(s: Statement): Unit = { + Logger.debug(s"Regs $regs") Logger.debug(s"statement[$s]:") + + nextCmd = Run(s.successor) + s match { case assign: Assign => - Logger.debug(s"LocalAssign ${assign.lhs} = ${assign.rhs}") + //Logger.debug(s"LocalAssign ${assign.lhs} = ${assign.rhs}") val evalRight = eval(assign.rhs, regs) - Logger.debug(s"LocalAssign ${assign.lhs} := 0x${evalRight.value.toString(16)}[u${evalRight.size}]\n") + //Logger.debug(s"LocalAssign ${assign.lhs} := 0x${evalRight.value.toString(16)}[u${evalRight.size}]\n") regs += (assign.lhs -> evalRight) case assign: MemoryAssign => - Logger.debug(s"MemoryAssign ${assign.mem}[${assign.index}] = ${assign.value}") + //Logger.debug(s"MemoryAssign ${assign.mem}[${assign.index}] = ${assign.value}") val index: Int = eval(assign.index, regs).value.toInt val value: BitVecLiteral = eval(assign.value, regs) - Logger.debug(s"\tMemoryStore(mem:${assign.mem}, index:0x${index.toHexString}, value:0x${ - value.value - .toString(16) - }[u${value.size}], size:${assign.size})") + //Logger.debug(s"\tMemoryStore(mem:${assign.mem}, index:0x${index.toHexString}, value:0x${value.value + // .toString(16)}[u${value.size}], size:${assign.size})") val evalStore = setMemory(index, assign.size, assign.endian, value, mems) evalStore match { case BitVecLiteral(value, size) => - Logger.debug(s"MemoryAssign ${assign.mem} := 0x${value.toString(16)}[u$size]\n") + //Logger.debug(s"MemoryAssign ${assign.mem} := 0x${value.toString(16)}[u$size]\n") } - case _ : NOP => () case assert: Assert => - // TODO - Logger.debug(assert) + // Logger.debug(assert) if (!evalBool(assert.body, regs)) { - throw Exception(s"Assertion failed ${assert}") + nextCmd = FailedAssertion(assert) } case assume: Assume => // TODO, but already taken into effect if it is a branch condition - Logger.debug(assume) + // Logger.debug(assume) if (!evalBool(assume.body, regs)) { - nextCmd = None - Logger.debug(s"Assumption not satisfied: $assume") + nextCmd = Errored(s"Assumption not satisfied: $assume") } case dc: DirectCall => - Logger.debug(s"$dc") - returnCmd.push(dc.successor) + // Logger.debug(s"$dc") + callStack.push(dc.successor) interpretProcedure(dc.target) - break case ic: IndirectCall => - Logger.debug(s"$ic") + // Logger.debug(s"$ic") if (ic.target == Register("R30", 64)) { - if (returnCmd.nonEmpty) { - nextCmd = Some(returnCmd.pop()) - } else { - //Exit Interpreter - nextCmd = None - } - break + doReturn() + //Exit Interpreter } else { - ??? + nextCmd = EscapedControlFlow(ic) } + case _: NOP => () } } @@ -228,13 +237,33 @@ class Interpreter() { // Program.Procedure interpretProcedure(IRProgram.mainProcedure) - while (nextCmd.isDefined) { - nextCmd.get match { - case c: Statement => interpretStatement(c) - case c: Jump => interpretJump(c) + while ( + try {nextCmd match { + case Run(c: Statement) => { + interpretStatement(c) + true + } + case Run(c: Jump) => { + interpretJump(c) + true + } + case Stopped() => { + false + } + case errorstop => { + Logger.error(s"Interpreter $errorstop") + false + } + } + } catch { + case InterpreterError(e) => { + nextCmd = e + true } } + ) {} + regs } } diff --git a/src/test/scala/BitVectorAnalysisTests.scala b/src/test/scala/BitVectorAnalysisTests.scala index 485ba6a83..08710a940 100644 --- a/src/test/scala/BitVectorAnalysisTests.scala +++ b/src/test/scala/BitVectorAnalysisTests.scala @@ -1,4 +1,4 @@ -import analysis.BitVectorEval.* +import ir.eval.BitVectorEval._ import ir.* import org.scalatest.funsuite.AnyFunSuite import util.Logger @@ -181,20 +181,20 @@ class BitVectorAnalysisTests extends AnyFunSuite { // smt_bveq test("BitVector Equal - should return true if two BitVectors are equal") { val result = smt_bveq(BitVecLiteral(255, 8), BitVecLiteral(255, 8)) - assert(result == TrueLiteral) + assert(result) } test("BitVector Equal - should return false if two BitVectors are not equal") { val result = smt_bveq(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvneq test("BitVector Not Equal - should return false if two BitVectors are equal") { val result = smt_bvneq(BitVecLiteral(255, 8), BitVecLiteral(255, 8)) - assert(result == FalseLiteral) + assert(!result) } test("BitVector Not Equal - should return true if two BitVectors are not equal") { val result = smt_bvneq(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == TrueLiteral) + assert(result) } // smt_bvshl test("BitVector Shift Left - should shift bits left") { @@ -239,14 +239,14 @@ class BitVectorAnalysisTests extends AnyFunSuite { test("BitVector unsigned less then - should return true if first argument is less than second argument") { val result = smt_bvult(BitVecLiteral(254, 8), BitVecLiteral(255, 8)) - assert(result == TrueLiteral) + assert(result) } test( "BitVector unsigned less then - should return false if first argument is greater than or equal to second argument" ) { val result = smt_bvult(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvule @@ -255,14 +255,14 @@ class BitVectorAnalysisTests extends AnyFunSuite { "BitVector unsigned less then or equal to - should return true if first argument is less equal to second argument" ) { val result = smt_bvule(BitVecLiteral(254, 8), BitVecLiteral(255, 8)) - assert(result == TrueLiteral) + assert(result) } test( "BitVector unsigned less then or equal to - should return false if first argument is greater than second argument" ) { val result = smt_bvule(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvugt @@ -270,14 +270,14 @@ class BitVectorAnalysisTests extends AnyFunSuite { "BitVector unsinged greater than - should return true if first argument is greater equal to than second argument" ) { val result = smt_bvugt(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == TrueLiteral) + assert(result) } test( "BitVector unsinged greater than - should return false if first argument is less than or equal to second argument" ) { val result = smt_bvugt(BitVecLiteral(254, 8), BitVecLiteral(255, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvuge @@ -285,27 +285,27 @@ class BitVectorAnalysisTests extends AnyFunSuite { "BitVector unsinged greater than or equal to - should return true if first argument is greater equal or equal to second argument" ) { val result = smt_bvuge(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == TrueLiteral) + assert(result) } test( "BitVector unsinged greater than or equal to - should return false if first argument is less than second argument" ) { val result = smt_bvuge(BitVecLiteral(254, 8), BitVecLiteral(255, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvslt test("BitVector signed less than - should return true if first argument is less than second argument") { val result = smt_bvslt(BitVecLiteral(254, 8), BitVecLiteral(255, 8)) - assert(result == TrueLiteral) + assert(result) } test( "BitVector signed less than - should return false if first argument is greater than or equal to second argument" ) { val result = smt_bvslt(BitVecLiteral(254, 8), BitVecLiteral(254, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvsle @@ -313,25 +313,25 @@ class BitVectorAnalysisTests extends AnyFunSuite { "BitVector signed less than or equal to - should return true if first argument is less than or equal to second argument" ) { val result = smt_bvsle(BitVecLiteral(254, 8), BitVecLiteral(255, 8)) - assert(result == TrueLiteral) + assert(result) } test( "BitVector signed less than or equal to - should return false if first argument is greater than second argument" ) { val result = smt_bvsle(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvsgt test("BitVector signed greater than - should return true if first argument is greater than second argument") { val result = smt_bvsgt(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == TrueLiteral) + assert(result) } test("BitVector signed greater than - should return false if first argument is less than second argument") { val result = smt_bvsgt(BitVecLiteral(254, 8), BitVecLiteral(255, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvsge @@ -339,14 +339,14 @@ class BitVectorAnalysisTests extends AnyFunSuite { "BitVector signed greater than or equal to - should return true if first argument is greater than or equal to second argument" ) { val result = smt_bvsge(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == TrueLiteral) + assert(result) } test( "BitVector signed greater than or equal to - should return false if first argument is less than second argument" ) { val result = smt_bvsge(BitVecLiteral(254, 8), BitVecLiteral(255, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvashr test("BitVector Arithmetic shift right - should return shift right a positive number") { diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 57f9739fe..56bbbf7d3 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -1,6 +1,7 @@ package ir -import analysis.BitVectorEval.* +import ir.eval.* +import ir.dsl._ import org.scalatest.funsuite.AnyFunSuite import org.scalatest.BeforeAndAfter import specification.SpecGlobal @@ -201,4 +202,59 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { ) testInterpret("no_interference_update_y", expected) } + + test("fibonacci") { + val fib = prog( + proc("begin", + block("entry", + Assign(R8, Register("R31", 64)), + Assign(R0, bv64(8)), + directCall("fib"), + goto("done") + ), + block("done", + Assert(BinaryExpr(BVEQ, R0, bv64(21))), + ret + )), + proc("fib", + block("base", goto("base1", "base2", "dofib")), + block("base1", + Assume(BinaryExpr(BVEQ, R0, bv64(0))), + ret), + block("base2", + Assume(BinaryExpr(BVEQ, R0, bv64(1))), + ret), + block("dofib", + Assume(BinaryExpr(BoolAND, BinaryExpr(BVNEQ, R0, bv64(0)), BinaryExpr(BVNEQ, R0, bv64(1)))), + // R8 stack pointer preserved across calls + Assign(R7, BinaryExpr(BVADD, R8, bv64(8))), + MemoryAssign(stack, R7, R8, Endian.LittleEndian, 64), // sp + Assign(R8, R7), + Assign(R8, BinaryExpr(BVADD, R8, bv64(8))), // sp + 8 + MemoryAssign(stack, R8, R0, Endian.LittleEndian, 64), // [sp + 8] = arg0 + Assign(R0, BinaryExpr(BVSUB, R0, bv64(1))), + directCall("fib"), + Assign(R2, R8), // sp + 8 + Assign(R8, BinaryExpr(BVADD, R8, bv64(8))), // sp + 16 + MemoryAssign(stack, R8, R0, Endian.LittleEndian, 64), // [sp + 16] = r1 + Assign(R0, MemoryLoad(stack, R2, Endian.LittleEndian, 64)), // [sp + 8] + Assign(R0, BinaryExpr(BVSUB, R0, bv64(2))), + directCall("fib"), + Assign(R2, MemoryLoad(stack, R8, Endian.LittleEndian, 64)), // [sp + 16] (r1) + Assign(R0, BinaryExpr(BVADD, R0, R2)), + Assign(R8, MemoryLoad(stack, BinaryExpr(BVSUB, R8, bv64(16)), Endian.LittleEndian, 64)), + ret + ) + ) + ) + + val regs = i.interpret(fib) + + // Show interpreted result + Logger.info("Registers:") + regs.foreach { (key, value) => + Logger.info(s"$key := $value") + } + + } } From a3bf83b9897b99c22863bfc23b9f0dd6784a9c57 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Thu, 22 Aug 2024 16:01:41 +1000 Subject: [PATCH 07/62] rewrite interpreter in functional style --- .../ir/eval/{expr.scala => ExprEval.scala} | 13 +- src/main/scala/ir/eval/Interpreter.scala | 734 +++++++++++++----- src/main/scala/util/RunUtils.scala | 4 +- src/test/scala/ir/InterpreterTests.scala | 254 +++--- 4 files changed, 708 insertions(+), 297 deletions(-) rename src/main/scala/ir/eval/{expr.scala => ExprEval.scala} (91%) diff --git a/src/main/scala/ir/eval/expr.scala b/src/main/scala/ir/eval/ExprEval.scala similarity index 91% rename from src/main/scala/ir/eval/expr.scala rename to src/main/scala/ir/eval/ExprEval.scala index a7f748546..7efcc8f7d 100644 --- a/src/main/scala/ir/eval/expr.scala +++ b/src/main/scala/ir/eval/ExprEval.scala @@ -4,11 +4,10 @@ import ir._ /** * We generalise the expression evaluator to a partial evaluator to simplify evaluating casts. - * This is not as nice or type-safe as we would like. * * - Program state is taken via a function from var -> value and for loads a function from (mem,addr,endian,size) -> value. * - For conrete evaluators we prefer low-level representations (bool vs BoolLit) and wrap them at the expression eval level - * - Avoid using default cases so we have some idea of complete coverage + * - Avoid using any default cases so we have some idea of complete coverage * */ @@ -128,7 +127,7 @@ def evalUnOp(op: UnOp, body: Literal) : Expr = { } } -def partialEvalExpr(exp: Expr, variableAssignment: Variable => Option[Expr], memory: (Memory, Expr, Endian, Int) => Option[BitVecLiteral] = ((a,b,c,d) => None)): Expr = { +def partialEvalExpr(exp: Expr, variableAssignment: Variable => Option[Expr], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)): Expr = { exp match { case f: UninterpretedFunction => f case unOp: UnaryExpr => { @@ -198,21 +197,21 @@ def partialEvalExpr(exp: Expr, variableAssignment: Variable => Option[Expr], mem } } -def evalIntExpr(exp: Expr, variableAssignment: Variable => Option[BitVecLiteral], memory: (Memory, Expr, Endian, Int) => Option[BitVecLiteral] = ((a,b,c,d) => None)): Either[Expr, BigInt] = { +def evalIntExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)): Either[Expr, BigInt] = { partialEvalExpr(exp, variableAssignment, memory) match { case i: IntLiteral => Right(i.value) case o => Left(o) } } -def evalBVExpr(exp: Expr, variableAssignment: Variable => Option[BitVecLiteral], memory: (Memory, Expr, Endian, Int) => Option[BitVecLiteral] = ((a,b,c,d) => None)): Either[Expr, BitVecLiteral] = { +def evalBVExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)): Either[Expr, BitVecLiteral] = { partialEvalExpr(exp, variableAssignment, memory) match { case b: BitVecLiteral => Right(b) case o => Left(o) } } -def evalLogExpr(exp: Expr, variableAssignment: Variable => Option[BitVecLiteral], memory: (Memory, Expr, Endian, Int) => Option[BitVecLiteral] = ((a,b,c, d) => None)): Either[Expr, Boolean] = { +def evalLogExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c, d) => None)): Either[Expr, Boolean] = { partialEvalExpr(exp, variableAssignment, memory) match { case TrueLiteral => Right(true) case FalseLiteral => Right(false) @@ -220,7 +219,7 @@ def evalLogExpr(exp: Expr, variableAssignment: Variable => Option[BitVecLiteral] } } -def evalExpr(exp: Expr, variableAssignment: Variable => Option[BitVecLiteral], memory: (Memory, Expr, Endian, Int) => Option[BitVecLiteral] = ((d, a,b,c) => None)): Option[Literal] = { +def evalExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((d, a,b,c) => None)): Option[Literal] = { partialEvalExpr match { case l: Literal => Some(l) case _ => None diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index f505425d3..e2fda9e6b 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -1,269 +1,605 @@ -package ir +package ir.eval import ir.eval.BitVectorEval.* +import ir._ import util.Logger +import boogie.Scope +import scala.annotation.tailrec import scala.collection.mutable +import scala.collection.immutable import scala.util.control.Breaks.{break, breakable} -// enum Asssumption: -// case Assume(x: Expr) -// case Jump(choice: Block) -// case Not(a: Assumption) -sealed trait ExecutionState -case class FailedAssertion(a: Assert) extends ExecutionState -case class Stopped() extends ExecutionState -case class Run(val next: Command) extends ExecutionState -case class EscapedControlFlow(val call: IndirectCall) extends ExecutionState -case class Errored(val message: String = "") extends ExecutionState +sealed trait ExecutionContinuation +case class FailedAssertion(a: Assert) extends ExecutionContinuation +case class Stopped() extends ExecutionContinuation -// case class Execution() extends State { -// // stack of assumptions -// // var assumptions: mutable.Stack[] = mutable.Stack() -// var memory: mutable.Map[Memory, Map[BigInt, BigInt]] = mutable.Map() -// var bvValues: mutable.Map[Variable, BigInt] = mutable.Map() -// var intValues: mutable.Map[Variable, BigInt] = mutable.Map() -// -// var nextCmd: ExecutionState = Stopped() -// var callStack: mutable.Stack[Command] = mutable.Stack() -// } +/** Normal stop * */ +case class Run(val next: Command) extends ExecutionContinuation +case class EscapedControlFlow(val call: Jump | Call) extends ExecutionContinuation +/** controlflow has reached somewhere eunrecoverable */ +case class Errored(val message: String = "") extends ExecutionContinuation +case class TypeError(val message: String = "") extends ExecutionContinuation /* type mismatch appeared */ +case class EvalError(val message: String = "") + extends ExecutionContinuation /* failed to evaluate an expression to a concrete value */ +case class MemoryError(val message: String = "") + extends ExecutionContinuation /* failed to evaluate an expression to a concrete value */ -case class InterpreterError(condinue: ExecutionState) extends Exception() +case class InterpreterError(continue: ExecutionContinuation) extends Exception() -class Interpreter() { - val regs: mutable.Map[Variable, BitVecLiteral] = mutable.Map() - val mems: mutable.Map[Int, BitVecLiteral] = mutable.Map() - private val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) - private val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) - private val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) - var nextCmd: ExecutionState = Stopped() - private val callStack: mutable.Stack[Command] = mutable.Stack() +case class InterpreterSummary( + val exitState: ExecutionContinuation, + val regs: Map[Variable, BitVecLiteral], + val memory: Map[Int, BitVecLiteral] +) - def eval(exp: Expr, env: mutable.Map[Variable, BitVecLiteral]): BitVecLiteral = { - def load(m: Memory, index: Expr, endian: Endian, size: Int) = { - val idx = eval(index, env).value.toInt - Some(getMemory(idx, size, endian, mems)) +enum BasilValue(val irType: IRType): + case Scalar(val value: Literal) extends BasilValue(value.getType) + // Erase the type of basil values and enforce the invariant that + // \exists i . \forall v \in value.keys , v.irType = i and + // \exists j . \forall v \in value.values, v.irType = j + case MapValue(val value: Map[BasilValue, BasilValue], override val irType: MapType) extends BasilValue(irType) + +given Conversion[BitVecLiteral, Scalar] with + def apply(b: BitVecLiteral): Scalar = new Scalar(b) + +given Conversion[IntLiteral, Scalar] with + def apply(b: IntLiteral): Scalar = new Scalar(b) + +given Conversion[BoolLit, Scalar] with + def apply(b: BoolLit): Scalar = new Scalar(b) + +case object BasilValue: + + def size(v: IRType): Int = { + v match { + case BitVecType(sz) => sz + case _ => 1 } + } + + def size(v: BasilValue): Int = size(v.irType) + + def unsafeAdd(l: BasilValue, vr: Int): BasilValue = { + l match { + case Scalar(IntLiteral(vl)) => Scalar(IntLiteral(vl + vr)) + case Scalar(b1: BitVecLiteral) => Scalar(eval.evalBVBinExpr(BVADD, b1, BitVecLiteral(vr, b1.size))) + case _ => throw InterpreterError(TypeError(s"Operation add $vr undefined on $l")) + } + } - ir.eval.evalBVExpr(exp, x => env.get(x), load) match { - case Right(b) => b - case Left(e) => throw InterpreterError(Errored(s"Failed to evaluate bv expr: residual $exp")) + def add(l: BasilValue, r: BasilValue): BasilValue = { + (l, r) match { + case (Scalar(IntLiteral(vl)), Scalar(IntLiteral(vr))) => Scalar(IntLiteral(vl + vr)) + case (Scalar(b1: BitVecLiteral), Scalar(b2: BitVecLiteral)) => Scalar(eval.evalBVBinExpr(BVADD, b1, b2)) + case (Scalar(b1: BoolLit), Scalar(b2: BoolLit)) => + Scalar(if (b2.value || b2.value) then TrueLiteral else FalseLiteral) + case _ => throw InterpreterError(TypeError(s"Operation add undefined on $l $r")) } } - def doReturn() = { - if (callStack.nonEmpty) { - nextCmd = Run(callStack.pop()) - } else { - nextCmd = Stopped() + def concat(l: BasilValue, r: BasilValue): BasilValue = { + (l, r) match { + case (Scalar(b1: BitVecLiteral), Scalar(b2: BitVecLiteral)) => Scalar(eval.evalBVBinExpr(BVCONCAT, b1, b2)) + case _ => throw InterpreterError(TypeError(s"Operation concat undefined on $l $r")) } } - def evalBool(exp: Expr, env: mutable.Map[Variable, BitVecLiteral]): Boolean = { - ir.eval.evalLogExpr(exp, x => env.get(x)) match { - case Right(b) => b - case Left(e) => throw InterpreterError(Errored(s"Failed to evaluate logical expr: residual $e")) + def extract(l: BasilValue, high: Int, low: Int): BasilValue = { + (l) match { + case Scalar(b: BitVecLiteral) => Scalar(eval.BitVectorEval.boogie_extract(high, low, b)) + case _ => throw InterpreterError(TypeError(s"Operation extract($high, $low) undefined on $l")) } } - def evalInt(exp: Expr, env: mutable.Map[Variable, BitVecLiteral]): BigInt = { - ir.eval.evalIntExpr(exp, x => env.get(x)) match { - case Right(b) => b - case Left(e) => throw InterpreterError(Errored(s"Failed to evaluate int expr: residual $e")) + def fromIR(e: Expr) = { + e match { + case t: IntLiteral => Scalar(t) + case v: BitVecLiteral => Scalar(v) + case b: BoolLit => Scalar(b) + case _ => throw InterpreterError(EvalError(s"Failed to get value from non-literal expr $e")) + } } - def getMemory(index: Int, size: Int, endian: Endian, env: mutable.Map[Int, BitVecLiteral]): BitVecLiteral = { - val end = index + size / 8 - 1 - val memoryChunks = (index to end).map(i => env.getOrElse(i, BitVecLiteral(0, 8))) +export BasilValue._ + +def evalToConst( + to: IRType, + exp: Expr, + variable: Variable => Option[Expr], + load: (Memory, Expr, Endian, Int) => Option[Literal] +): BasilValue = { + + val res: Expr = ir.eval.partialEvalExpr(exp, variable, load) + res match { + case e: Literal if e.getType == to => Scalar(e) + case res => throw InterpreterError(EvalError(s"Failed to evaluate expr to constant ${to} literal: residual $res")) + } +} + +// case class BasilConstant(val basilType: BasilValue, val value: basilType.ReprType) + +type StackFrameID = String +val globalFrame: StackFrameID = "GLOBAL" + +case class MemoryState( + val stackFrames: Map[StackFrameID, Map[String, BasilValue]] = Map((globalFrame -> Map.empty)), + val activations: List[StackFrameID] = List.empty, + val activationCount: Map[String, Int] = Map.empty.withDefault(_ => 0) +) { + + /** Debug return useful values * */ + + def getGlobalVals: Map[String, BitVecLiteral] = { + stackFrames(globalFrame).collect { case (k, Scalar(b: BitVecLiteral)) => + k -> b + } + } - val (newValue, newSize) = memoryChunks.foldLeft(("", 0)) { (acc, current) => - val currentString: String = current.value.toString(2).reverse.padTo(8, '0').reverse - endian match { - case Endian.LittleEndian => (currentString + acc._1, acc._2 + current.size) - case Endian.BigEndian => (acc._1 + currentString, acc._2 + current.size) + def getMem(name: String): Map[BitVecLiteral, BitVecLiteral] = { + stackFrames(globalFrame)(name) match { + case MapValue(innerMap, MapType(BitVecType(ks), BitVecType(vs))) => { + def unwrap(v: BasilValue): BitVecLiteral = v match { + case Scalar(b: BitVecLiteral) => b + case v => throw Exception(s"Failed to convert map value to bitvector: $v (interpreter type error somewhere)") + } + innerMap.map((k, v) => unwrap(k) -> unwrap(v)) } + case v => throw Exception(s"$name not a bitvec map variable: ${v.irType}") } + } + + /** Local Variable Stack * */ - BitVecLiteral(BigInt(newValue, 2), newSize) + def pushStackFrame(function: String): MemoryState = { + val counts = activationCount + (function -> (activationCount(function) + 1)) + val frameName: StackFrameID = s"AR_${function}_${activationCount(function)}" + val frames = stackFrames + (frameName -> Map.empty) + MemoryState(frames, frameName :: activations, counts) } - def setMemory( - index: Int, - size: Int, - endian: Endian, - value: BitVecLiteral, - env: mutable.Map[Int, BitVecLiteral] - ): BitVecLiteral = { - val binaryString: String = value.value.toString(2).reverse.padTo(size, '0').reverse + def popStackFrame(): MemoryState = { + val (frame, remactivs) = activations match { + case Nil => throw InterpreterError(Errored("No stack frame to pop")) + case h :: Nil if h == globalFrame => throw InterpreterError(Errored("tried to pop global scope")) + case h :: tl => (h, tl) + } + val frames = stackFrames.removed(frame) + MemoryState(frames, remactivs, activationCount) + } + + /* Variable retrieval and setting */ + + def setVar(frame: StackFrameID, varname: String, value: BasilValue): MemoryState = { + val nv = stackFrames + (frame -> (stackFrames(frame) + (varname -> value))) + MemoryState(nv, activations, activationCount) + } + + def setVar(v: String, value: BasilValue): MemoryState = { + val frame = findVarOpt(v).map(_._1).getOrElse(activations.head) + setVar(frame, v, value) + } + + def setVar(v: String, value: Literal): MemoryState = { + setVar(v, Scalar(value)) + } + + def defVar(v: Variable, value: Literal): MemoryState = { + val frame = v.toBoogie.scope match { + case Scope.Global => globalFrame + case _ => activations.head + } + setVar(frame, v.name, Scalar(value)) + } + + def findVarOpt(name: String): Option[(StackFrameID, BasilValue)] = { + val searchScopes = globalFrame :: activations.headOption.toList + searchScopes.foldRight(None: Option[(StackFrameID, BasilValue)])((r, acc) => + acc match { + case None => stackFrames(r).get(name).map(v => (r, v)) + case s => s + } + ) + } + + def findVar(name: String): (StackFrameID, BasilValue) = { + findVarOpt(name: String).getOrElse(throw InterpreterError(Errored(s"Access to undefined variable $name"))) + } + + def getVarOpt(name: String): Option[BasilValue] = findVarOpt(name).map(_._2) + + def getVar(name: String): BasilValue = { + getVarOpt(name).getOrElse(throw InterpreterError(Errored(s"Access undefined variable $name"))) + } + + def getVar(v: Variable): BasilValue = { + val value = getVar(v.name) + value match { + case dv: BasilValue if v.getType != dv.irType => + throw InterpreterError( + Errored(s"Type mismatch on variable definition and load: defined ${dv.irType}, variable ${v.getType}") + ) + case o => o + } + } + + def getVarLiteralOpt(v: Variable): Option[Literal] = { + getVar(v) match { + case Scalar(v) => Some(v) + case _ => None + } + } + + /* Map variable accessing ; load and store operations */ + + /* canonical load operation */ + def load(vname: String, addr: Scalar, endian: Endian, count: Int): List[BasilValue] = { + val (frame, mem) = findVar(vname) + + val mapv: MapValue = mem match { + case m @ MapValue(innerMap, ty) => m + case _ => throw InterpreterError(Errored("Load from nonmap")) + } + + if (count == 0) { + throw InterpreterError(Errored(s"Attempted fractional load")) + } + + val keys = (0 until count).map(i => BasilValue.unsafeAdd(addr, i)) - val data: List[BitVecLiteral] = endian match { - case Endian.LittleEndian => - binaryString.grouped(8).toList.map(chunk => BitVecLiteral(BigInt(chunk, 2), 8)).reverse - case Endian.BigEndian => - binaryString.grouped(8).toList.map(chunk => BitVecLiteral(BigInt(chunk, 2), 8)) + val values = keys.map(k => + mapv.value.get(k).getOrElse(throw InterpreterError(MemoryError(s"Read from uninitialised $vname[$k]"))) + ) + + val vals = endian match { + case Endian.LittleEndian => values.reverse + case Endian.BigEndian => values + } + + vals.toList + } + + /** Load and concat bitvectors */ + def loadBV(vname: String, addr: Scalar, endian: Endian, size: Int): BitVecLiteral = { + val (frame, mem) = findVar(vname) + + val (valsize, mapv) = mem match { + case mapv @ MapValue(_, MapType(_, BitVecType(sz))) => (sz, mapv) + case _ => throw InterpreterError(Errored("Trued to load-concat non bv")) } - data.zipWithIndex.foreach { case (bv, i) => - env(index + i) = bv + val cells = size / BasilValue.size(mapv.irType.result) + + val bvs: List[BitVecLiteral] = { + val res = load(vname, addr, endian, cells) + val rr = res.map { + case Scalar(bv @ BitVecLiteral(v, sz)) if sz == valsize => bv + case _ => throw InterpreterError(Errored(s"Loaded value that did not match expected type bv$valsize")) + } + rr } - value + val bvres = bvs.foldLeft(BitVecLiteral(0, 0))((acc, r) => eval.evalBVBinExpr(BVCONCAT, acc, r)) + assert(bvres.size == size) + bvres + } + + def loadSingle(vname: String, addr: Scalar): BasilValue = { + load(vname, addr, Endian.LittleEndian, 1).head } - private def interpretProcedure(p: Procedure): Unit = { - Logger.debug(s"Procedure(${p.name}, ${p.address.getOrElse("None")})") - Logger.debug(s"Regs $regs") + /** Canonical store operation */ + def store(vname: String, addr: BasilValue, values: List[BasilValue], endian: Endian) = { + val (frame, mem) = findVar(vname) - // Procedure.in - for ((in, index) <- p.in.zipWithIndex) { - Logger.debug(s"\tin[$index]:${in.name} ${in.size} ${in.value}") + val (mapval, keytype, valtype) = mem match { + case m @ MapValue(_, MapType(kt, vt)) if kt == addr.irType && values.forall(v => v.irType == vt) => (m, kt, vt) + case _ => throw InterpreterError(Errored("Invalid map store operation.")) } + val keys = (0 until values.size).map(i => BasilValue.unsafeAdd(addr, i)) + val vals = endian match { + case Endian.LittleEndian => values.reverse + case Endian.BigEndian => values + } + + val nmap = MapValue(mapval.value ++ keys.zip(vals), mapval.irType) + setVar(frame, vname, nmap) + } - // Procedure.out - for ((out, index) <- p.out.zipWithIndex) { - Logger.debug(s"\tout[$index]:${out.name} ${out.size} ${out.value}") + /** Store extract bitvec to bytes and store bytes */ + def storeBV(vname: String, addr: BasilValue, value: BitVecLiteral, endian: Endian) = { + val (frame, mem) = findVar(vname) + val (mapval, vsize) = mem match { + case m @ MapValue(_, MapType(kt, BitVecType(size))) if kt == addr.irType => (m, size) + case _ => throw InterpreterError(Errored("Tried to extract-store non to bv map")) } + val cells = value.size / vsize + if (cells < 1) { + throw InterpreterError(Errored("Tried to execute fractional store")) + } + + val extractVals = (0 until cells).map(i => BitVectorEval.boogie_extract((i + 1) * vsize, i * vsize, value)).toList - // Procedure.Block - p.entryBlock match { - case Some(block) => nextCmd = Run(block.statements.headOption.getOrElse(block.jump)) - case None => doReturn() + val vs = endian match { + case Endian.LittleEndian => extractVals.reverse.map(Scalar(_)) + case Endian.BigEndian => extractVals.map(Scalar(_)) } + + store(vname, addr, vs, endian) } - private def interpretJump(j: Jump): Unit = { - Logger.debug(s"jump: $j") - breakable { - j match { - case gt: GoTo => - for (g <- gt.targets) { - val condition: Option[Expr] = g.statements.headOption.collect { case a: Assume => a.body } - condition match { - case Some(e) => - if (evalBool(e, regs)) { - Logger.debug(s"chosen ${g.label}") - nextCmd = Run(g.statements.headOption.getOrElse(g.jump)) - break - } - case None => - nextCmd = Run(g.statements.headOption.getOrElse(g.jump)) - break - } - } - case r: Return => doReturn() - case h: Unreachable => { - Logger.debug("Unreachable") - nextCmd = Stopped() + def storeSingle(vname: String, addr: BasilValue, value: BasilValue) = { + store(vname, addr, List(value), Endian.LittleEndian) + } + +} + +case object Eval { + + def getVar(s: MemoryState)(v: Variable) = s.getVarLiteralOpt(v) + + def doLoad(s: MemoryState)(m: Memory, addr: Expr, endian: Endian, sz: Int): Option[Literal] = { + addr match { + case l: Literal if sz == 1 => ( + s.loadSingle(m.name, Scalar(l)) match { + case Scalar(v) => Some(v) + case _ => None } + ) + case l: Literal => Some(s.loadBV(m.name, Scalar(l), endian, sz)) + case _ => None + } + } + + def evalBV(s: MemoryState, e: Expr): BitVecLiteral = { + ir.eval.evalBVExpr(e, Eval.getVar(s), Eval.doLoad(s)) match { + case Right(e) => e + case Left(e) => throw InterpreterError(Errored(s"Eval BV residual $e")) + } + } + + def evalInt(s: MemoryState, e: Expr): BigInt = { + ir.eval.evalIntExpr(e, Eval.getVar(s), Eval.doLoad(s)) match { + case Right(e) => e + case Left(e) => throw InterpreterError(Errored(s"Eval int residual $e")) + } + } + + def evalBool(s: MemoryState, e: Expr): Boolean = { + ir.eval.evalLogExpr(e, Eval.getVar(s), Eval.doLoad(s)) match { + case Right(e) => e + case Left(e) => throw InterpreterError(Errored(s"Eval bool residual $e")) + } + } + +} + +def initialState(): MemoryState = { + val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) + val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) + val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) + + MemoryState() + .setVar(globalFrame, "mem", MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + .setVar(globalFrame, "stack", MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + .setVar(globalFrame, "R31", Scalar(SP)) + .setVar(globalFrame, "R29", Scalar(FP)) + .setVar(globalFrame, "R30", Scalar(LR)) +} + +sealed trait Effects[T <: Effects[T]] { + /* evaluation (may side-effect via InterpreterException on evaluation failure) */ + def evalBV(e: Expr): BitVecLiteral + + def evalInt(e: Expr): BigInt + + def evalBool(e: Expr): Boolean + + /** effects * */ + def setNext(c: ExecutionContinuation): T + + def call(c: DirectCall): T + + def doReturn(): T + + def storeVar(v: Variable, value: Literal): T + + def storeMem(vname: String, addr: BitVecLiteral, value: BitVecLiteral, endian: Endian, size: Int): T + + def initialiseProgram(p: Program): T +} + +case class InterpreterCFState( + val nextCmd: ExecutionContinuation = Stopped(), + val callStack: List[DirectCall] = List.empty, + val memoryState: MemoryState = MemoryState() +) extends Effects[InterpreterCFState] { + + /** eval * */ + def evalBV(e: Expr): BitVecLiteral = Eval.evalBV(memoryState, e) + + def evalInt(e: Expr): BigInt = Eval.evalInt(memoryState, e) + + def evalBool(e: Expr): Boolean = Eval.evalBool(memoryState, e) + + /** effects * */ + def setNext(c: ExecutionContinuation): InterpreterCFState = { + InterpreterCFState(c, callStack, memoryState) + } + + def call(c: DirectCall): InterpreterCFState = { + Logger.debug(s" eff : CALL $c") + c.target.entryBlock match { + case Some(block) => + InterpreterCFState( + Run(block.statements.headOption.getOrElse(block.jump)), + c :: callStack, + memoryState.pushStackFrame(c.target.name) + ) + case None => setNext(Run(c.successor)) + } + } + + def doReturn(): InterpreterCFState = { + callStack match { + case Nil => InterpreterCFState(Stopped(), Nil, memoryState) + case h :: tl => { + Logger.debug(s" eff : RETURN $h") + InterpreterCFState(Run(h.successor), tl, memoryState.popStackFrame()) } } } - private def interpretStatement(s: Statement): Unit = { - Logger.debug(s"Regs $regs") - Logger.debug(s"statement[$s]:") + def storeVar(v: Variable, value: Literal) = { + Logger.debug(s" eff : SET $v := $value") + InterpreterCFState(nextCmd, callStack, memoryState.defVar(v, value)) + } + + def storeMem( + vname: String, + addr: BitVecLiteral, + value: BitVecLiteral, + endian: Endian, + size: Int + ): InterpreterCFState = { + Logger.debug(s" eff : STORE $vname[$addr..$addr + $size] := $value ($endian)") + InterpreterCFState(nextCmd, callStack, memoryState.storeBV(vname, Scalar(addr), value, endian)) + } + + def initialiseProgram(p: Program): InterpreterCFState = { + val mem = p.initialMemory.foldLeft(initialState())((s, memory) => { + s.store("mem", Scalar(BitVecLiteral(memory.address, 64)), memory.bytes.map(Scalar(_)).toList, Endian.LittleEndian) + s.store( + "stack", + Scalar(BitVecLiteral(memory.address, 64)), + memory.bytes.map(Scalar(_)).toList, + Endian.LittleEndian + ) + }) + + InterpreterCFState( + Run(IRWalk.firstInBlock(p.mainProcedure.entryBlock.get)), + callStack, + mem.pushStackFrame(p.mainProcedure.name) + ) + } +} - nextCmd = Run(s.successor) +case object InterpFuns { - s match { - case assign: Assign => - //Logger.debug(s"LocalAssign ${assign.lhs} = ${assign.rhs}") - val evalRight = eval(assign.rhs, regs) - //Logger.debug(s"LocalAssign ${assign.lhs} := 0x${evalRight.value.toString(16)}[u${evalRight.size}]\n") - regs += (assign.lhs -> evalRight) - - case assign: MemoryAssign => - //Logger.debug(s"MemoryAssign ${assign.mem}[${assign.index}] = ${assign.value}") - - val index: Int = eval(assign.index, regs).value.toInt - val value: BitVecLiteral = eval(assign.value, regs) - //Logger.debug(s"\tMemoryStore(mem:${assign.mem}, index:0x${index.toHexString}, value:0x${value.value - // .toString(16)}[u${value.size}], size:${assign.size})") - - val evalStore = setMemory(index, assign.size, assign.endian, value, mems) - evalStore match { - case BitVecLiteral(value, size) => - //Logger.debug(s"MemoryAssign ${assign.mem} := 0x${value.toString(16)}[u$size]\n") + def interpretJump[T <: Effects[T]](s: T, j: Jump): T = { + j match { + case gt: GoTo if gt.targets.size == 1 => { + s.setNext(Run(IRWalk.firstInBlock(gt.targets.head))) + } + case gt: GoTo => + val condition = gt.targets.flatMap(_.statements.headOption).collect { case a: Assume => + (a, s.evalBool(a.body)) + } + + if (condition.size != gt.targets.size) { + throw InterpreterError(Errored(s"Some goto target missing guard $gt")) } - case assert: Assert => - // Logger.debug(assert) - if (!evalBool(assert.body, regs)) { - nextCmd = FailedAssertion(assert) + + val chosen = condition.filter(_._2).toList match { + case Nil => throw InterpreterError(Errored(s"No jump target satisfied $gt")) + case h :: Nil => h + case h :: tl => throw InterpreterError(Errored(s"More than one jump guard satisfied $gt")) + } + Logger.debug(s"Goto ${chosen._1.parent.label}") + + s.setNext(Run(chosen._1.successor)) + case r: Return => s.doReturn() + case h: Unreachable => s.setNext(EscapedControlFlow(h)) + } + } + + def interpretStatement[T <: Effects[T]](st: T, s: Statement): T = { + s match { + case assign: Assign => { + val rhs = st.evalBV(assign.rhs) + st.storeVar(assign.lhs, rhs).setNext(Run(s.successor)) + + } + case assign: MemoryAssign => { + val index: BitVecLiteral = st.evalBV(assign.index) + val value: BitVecLiteral = st.evalBV(assign.value) + st.storeMem(assign.mem.name, index, value, assign.endian, assign.size).setNext(Run(s.successor)) + + } + case assert: Assert => { + if (!st.evalBool(assert.body)) then { + st.setNext(FailedAssertion(assert)) + } else { + st.setNext(Run(s.successor)) + } + } case assume: Assume => - // TODO, but already taken into effect if it is a branch condition - // Logger.debug(assume) - if (!evalBool(assume.body, regs)) { - nextCmd = Errored(s"Assumption not satisfied: $assume") + if (!st.evalBool(assume.body)) { + st.setNext(Errored(s"Assumption not satisfied: $assume")) + } else { + st.setNext(Run(s.successor)) + } case dc: DirectCall => - // Logger.debug(s"$dc") - callStack.push(dc.successor) - interpretProcedure(dc.target) + st.call(dc) case ic: IndirectCall => - // Logger.debug(s"$ic") if (ic.target == Register("R30", 64)) { - doReturn() - //Exit Interpreter + st.doReturn() } else { - nextCmd = EscapedControlFlow(ic) - } - case _: NOP => () - } - } - - def interpret(IRProgram: Program): mutable.Map[Variable, BitVecLiteral] = { - // initialize memory array from IRProgram - var currentAddress = 0 - IRProgram.initialMemory - .sortBy(_.address) - .foreach { im => - if (im.address + im.size > currentAddress) { - val start = im.address.max(currentAddress) - val data = if (im.address < currentAddress) im.bytes.slice(currentAddress - im.address, im.size) else im.bytes - data.zipWithIndex.foreach { (byte, index) => - mems(start + index) = byte - } - currentAddress = im.address + im.size + st.setNext(EscapedControlFlow(ic)) } - } + case _: NOP => st.setNext(Run(s.successor)) + } + } - // Initial SP, FP and LR to regs - regs += (Register("R31", 64) -> SP) - regs += (Register("R29", 64) -> FP) - regs += (Register("R30", 64) -> LR) - - // Program.Procedure - interpretProcedure(IRProgram.mainProcedure) - while ( - try {nextCmd match { - case Run(c: Statement) => { - interpretStatement(c) - true - } - case Run(c: Jump) => { - interpretJump(c) - true - } - case Stopped() => { - false - } - case errorstop => { - Logger.error(s"Interpreter $errorstop") - false - } - } + def protect[T](x: () => T, fnly: PartialFunction[Exception, T]): T = { + try { + x() } catch { - case InterpreterError(e) => { - nextCmd = e - true - } + case e: Exception if fnly.isDefinedAt(e) => fnly(e) } + } - ) {} + @tailrec + def interpret(s: InterpreterCFState): InterpreterCFState = { + Logger.debug(s"interpret ${s.nextCmd}") + s.nextCmd match { + case Run(c: Statement) => + interpret( + protect[InterpreterCFState]( + () => interpretStatement(s, c), + { + case InterpreterError(e) => s.setNext(e) + case e: IllegalArgumentException => s.setNext(Errored(s"Evaluation error $e")) + } + ) + ) + case Run(c: Jump) => + interpret( + protect[InterpreterCFState]( + () => interpretJump(s, c), + { + case InterpreterError(e) => s.setNext(e) + case e: IllegalArgumentException => s.setNext(Errored(s"Evaluation error $e")) + } + ) + ) + case Stopped() => s + case errorstop => s + } + } - regs + def interpretProg(p: Program): InterpreterCFState = { + var s = InterpreterCFState().initialiseProgram(p) + interpret(s) } + +} + +def interpret(IRProgram: Program) = { + InterpFuns.interpretProg(IRProgram) } diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index aeaaa6916..73768e61b 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -461,8 +461,8 @@ object RunUtils { q.loading.dumpIL.foreach(s => writeToFile(serialiseIL(ctx.program), s"$s-after-analysis.il")) if (q.runInterpret) { - val interpreter = Interpreter() - interpreter.interpret(ctx.program) + // val interpreter = eval.Interpreter() + eval.interpret(ctx.program) } IRTransform.prepareForTranslation(q.loading, ctx) diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 56bbbf7d3..85e8c5c70 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -1,6 +1,6 @@ package ir -import ir.eval.* +import ir.eval._ import ir.dsl._ import org.scalatest.funsuite.AnyFunSuite import org.scalatest.BeforeAndAfter @@ -10,13 +10,30 @@ import util.{LogLevel, Logger} import util.IRLoading.{loadBAP, loadReadELF} import util.ILLoadingConfig + +def load[T <: Effects[T]](s: T, global: SpecGlobal) : Option[BitVecLiteral] = { + // i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems) + // m.evalBV("mem", BitVecLiteral(64, global.address), Endian.LittleEndian, global.size) // i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems) + + try { + Some(s.evalBV(MemoryLoad(SharedMemory("mem", 64, 8), BitVecLiteral(global.address, 64), Endian.LittleEndian, global.size))) + } catch { + case e : InterpreterError => None + } +} + + +def mems[T <: Effects[T]](m: MemoryState) : Map[BigInt, BitVecLiteral] = { + m.getMem("mem").map((k,v) => k.value -> v) +} + class InterpreterTests extends AnyFunSuite with BeforeAndAfter { - var i: Interpreter = Interpreter() + // var i: Interpreter = Interpreter() Logger.setLevel(LogLevel.DEBUG) - def getProgram(name: String): (Program, Set[SpecGlobal]) = { + def getProgram(name: String): (Program, Set[SpecGlobal]) = { val loading = ILLoadingConfig( inputFile = s"examples/$name/$name.adt", @@ -41,7 +58,8 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { def testInterpret(name: String, expected: Map[String, Int]): Unit = { val (program, globals) = getProgram(name) - val regs = i.interpret(program) + val fstate = interpret(program) + val regs = fstate.memoryState.getGlobalVals // Show interpreted result Logger.info("Registers:") @@ -50,68 +68,114 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { } Logger.info("Globals:") + // def loadBV(vname: String, addr: BasilValue, valueSize: Int, endian: Endian, size: Int): List[BitVecLiteral] = { globals.foreach { global => - val mem = i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems) - Logger.info(s"$global := $mem") + val mem = load(fstate, global) + mem.foreach(mem => Logger.info(s"$global := $mem")) } // Test expected value - expected.foreach { (name, expected) => - globals.find(_.name == name) match { - case Some(global) => - val actual = i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems).value.toInt - assert(actual == expected) - case None => assert("None" == name) - } - } + val actual : Map[String, Int] = expected.flatMap ( (name, expected) => + globals.find(_.name == name).flatMap(global => + load(fstate, global).map(gv => name -> gv.value.toInt) + ) + ) + assert(expected == actual) } - before { - i = Interpreter() - } - test("getMemory in LittleEndian") { - i.mems(0) = BitVecLiteral(BigInt("0D", 16), 8) - i.mems(1) = BitVecLiteral(BigInt("0C", 16), 8) - i.mems(2) = BitVecLiteral(BigInt("0B", 16), 8) - i.mems(3) = BitVecLiteral(BigInt("0A", 16), 8) - val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) - val actual: BitVecLiteral = i.getMemory(0, 32, Endian.LittleEndian, i.mems) + test("Store = Load LittleEndian") { + val ts = List( + BitVecLiteral(BigInt("0D", 16), 8), + BitVecLiteral(BigInt("0C", 16), 8), + BitVecLiteral(BigInt("0B", 16), 8), + BitVecLiteral(BigInt("0A", 16), 8)) + + val s = initialState().store("mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) + val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) + val actual: BitVecLiteral = s.loadBV("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) assert(actual == expected) + + val s2 = initialState().storeBV("mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) + val actual2: BitVecLiteral = s.loadBV("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) + assert(actual2 == actual) + } - test("getMemory in BigEndian") { - i.mems(0) = BitVecLiteral(BigInt("0A", 16), 8) - i.mems(1) = BitVecLiteral(BigInt("0B", 16), 8) - i.mems(2) = BitVecLiteral(BigInt("0C", 16), 8) - i.mems(3) = BitVecLiteral(BigInt("0D", 16), 8) + + test("Store = Load BigEndian") { + val ts = List( + BitVecLiteral(BigInt("0D", 16), 8), + BitVecLiteral(BigInt("0C", 16), 8), + BitVecLiteral(BigInt("0B", 16), 8), + BitVecLiteral(BigInt("0A", 16), 8)) + + val s = initialState().store("mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) - val actual: BitVecLiteral = i.getMemory(0, 32, Endian.BigEndian, i.mems) + val actual: BitVecLiteral = s.loadBV("mem", Scalar(BitVecLiteral(0, 64)), Endian.BigEndian , 32) assert(actual == expected) + + } - test("setMemory in LittleEndian") { - i.mems(0) = BitVecLiteral(BigInt("FF", 16), 8) - i.mems(1) = BitVecLiteral(BigInt("FF", 16), 8) - i.mems(2) = BitVecLiteral(BigInt("FF", 16), 8) - i.mems(3) = BitVecLiteral(BigInt("FF", 16), 8) + test("getMemory in LittleEndian") { + val ts = List((BitVecLiteral(0, 64), BitVecLiteral(BigInt("0D", 16), 8)), + (BitVecLiteral(1, 64) , BitVecLiteral(BigInt("0C", 16), 8)), + (BitVecLiteral(2, 64) , BitVecLiteral(BigInt("0B", 16), 8)), + (BitVecLiteral(3, 64) , BitVecLiteral(BigInt("0A", 16), 8))) + val s = ts.foldLeft(initialState())((m, v) => m.storeSingle("mem", Scalar(v._1), Scalar(v._2))) + // val s = initialState().store("mem") + // val r = s.loadBV("mem", BitVecLiteral(0, 64)) + val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) - i.setMemory(0, 32, Endian.LittleEndian, expected, i.mems) - val actual: BitVecLiteral = i.getMemory(0, 32, Endian.LittleEndian, i.mems) + + // def loadBV(vname: String, addr: Scalar, endian: Endian, size: Int): BitVecLiteral = { + val actual: BitVecLiteral = s.loadBV("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) assert(actual == expected) } - test("setMemory in BigEndian") { - i.mems(0) = BitVecLiteral(BigInt("FF", 16), 8) - i.mems(1) = BitVecLiteral(BigInt("FF", 16), 8) - i.mems(2) = BitVecLiteral(BigInt("FF", 16), 8) - i.mems(3) = BitVecLiteral(BigInt("FF", 16), 8) + + test("StoreBV = LoadBV LE ") { val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) - i.setMemory(0, 32, Endian.BigEndian, expected, i.mems) - val actual: BitVecLiteral = i.getMemory(0, 32, Endian.BigEndian, i.mems) + + val s = initialState().storeBV("mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) + val actual: BitVecLiteral = s.loadBV("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) + println(s"${actual.value.toInt.toHexString} == ${expected.value.toInt.toHexString}") assert(actual == expected) } + // test("getMemory in BigEndian") { + // i.mems(0) = BitVecLiteral(BigInt("0A", 16), 8) + // i.mems(1) = BitVecLiteral(BigInt("0B", 16), 8) + // i.mems(2) = BitVecLiteral(BigInt("0C", 16), 8) + // i.mems(3) = BitVecLiteral(BigInt("0D", 16), 8) + // val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) + // val actual: BitVecLiteral = i.getMemory(0, 32, Endian.BigEndian, i.mems) + // assert(actual == expected) + // } + + // test("setMemory in LittleEndian") { + // i.mems(0) = BitVecLiteral(BigInt("FF", 16), 8) + // i.mems(1) = BitVecLiteral(BigInt("FF", 16), 8) + // i.mems(2) = BitVecLiteral(BigInt("FF", 16), 8) + // i.mems(3) = BitVecLiteral(BigInt("FF", 16), 8) + // val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) + // i.setMemory(0, 32, Endian.LittleEndian, expected, i.mems) + // val actual: BitVecLiteral = i.getMemory(0, 32, Endian.LittleEndian, i.mems) + // assert(actual == expected) + // } + + // test("setMemory in BigEndian") { + // i.mems(0) = BitVecLiteral(BigInt("FF", 16), 8) + // i.mems(1) = BitVecLiteral(BigInt("FF", 16), 8) + // i.mems(2) = BitVecLiteral(BigInt("FF", 16), 8) + // i.mems(3) = BitVecLiteral(BigInt("FF", 16), 8) + // val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) + // i.setMemory(0, 32, Endian.BigEndian, expected, i.mems) + // val actual: BitVecLiteral = i.getMemory(0, 32, Endian.BigEndian, i.mems) + // assert(actual == expected) + // } + test("basic_arrays_read") { val expected = Map( "arr" -> 0 @@ -204,57 +268,69 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { } test("fibonacci") { - val fib = prog( - proc("begin", - block("entry", - Assign(R8, Register("R31", 64)), - Assign(R0, bv64(8)), - directCall("fib"), - goto("done") - ), - block("done", - Assert(BinaryExpr(BVEQ, R0, bv64(21))), - ret - )), - proc("fib", - block("base", goto("base1", "base2", "dofib")), - block("base1", - Assume(BinaryExpr(BVEQ, R0, bv64(0))), - ret), - block("base2", - Assume(BinaryExpr(BVEQ, R0, bv64(1))), - ret), - block("dofib", - Assume(BinaryExpr(BoolAND, BinaryExpr(BVNEQ, R0, bv64(0)), BinaryExpr(BVNEQ, R0, bv64(1)))), - // R8 stack pointer preserved across calls - Assign(R7, BinaryExpr(BVADD, R8, bv64(8))), - MemoryAssign(stack, R7, R8, Endian.LittleEndian, 64), // sp - Assign(R8, R7), - Assign(R8, BinaryExpr(BVADD, R8, bv64(8))), // sp + 8 - MemoryAssign(stack, R8, R0, Endian.LittleEndian, 64), // [sp + 8] = arg0 - Assign(R0, BinaryExpr(BVSUB, R0, bv64(1))), - directCall("fib"), - Assign(R2, R8), // sp + 8 - Assign(R8, BinaryExpr(BVADD, R8, bv64(8))), // sp + 16 - MemoryAssign(stack, R8, R0, Endian.LittleEndian, 64), // [sp + 16] = r1 - Assign(R0, MemoryLoad(stack, R2, Endian.LittleEndian, 64)), // [sp + 8] - Assign(R0, BinaryExpr(BVSUB, R0, bv64(2))), - directCall("fib"), - Assign(R2, MemoryLoad(stack, R8, Endian.LittleEndian, 64)), // [sp + 16] (r1) - Assign(R0, BinaryExpr(BVADD, R0, R2)), - Assign(R8, MemoryLoad(stack, BinaryExpr(BVSUB, R8, bv64(16)), Endian.LittleEndian, 64)), - ret + + def fibonacciProg(n: Int) = { + def expected(n: Int) : Int = { + n match { + case 0 => 0 + case 1 => 1 + case n => expected(n - 1) + expected(n - 2) + } + } + prog( + proc("begin", + block("entry", + Assign(R8, Register("R31", 64)), + Assign(R0, bv64(n)), + directCall("fib"), + goto("done") + ), + block("done", + Assert(BinaryExpr(BVEQ, R0, bv64(expected(n)))), + ret + )), + proc("fib", + block("base", goto("base1", "base2", "dofib")), + block("base1", + Assume(BinaryExpr(BVEQ, R0, bv64(0))), + ret), + block("base2", + Assume(BinaryExpr(BVEQ, R0, bv64(1))), + ret), + block("dofib", + Assume(BinaryExpr(BoolAND, BinaryExpr(BVNEQ, R0, bv64(0)), BinaryExpr(BVNEQ, R0, bv64(1)))), + // R8 stack pointer preserved across calls + Assign(R7, BinaryExpr(BVADD, R8, bv64(8))), + MemoryAssign(stack, R7, R8, Endian.LittleEndian, 64), // sp + Assign(R8, R7), + Assign(R8, BinaryExpr(BVADD, R8, bv64(8))), // sp + 8 + MemoryAssign(stack, R8, R0, Endian.LittleEndian, 64), // [sp + 8] = arg0 + Assign(R0, BinaryExpr(BVSUB, R0, bv64(1))), + directCall("fib"), + Assign(R2, R8), // sp + 8 + Assign(R8, BinaryExpr(BVADD, R8, bv64(8))), // sp + 16 + MemoryAssign(stack, R8, R0, Endian.LittleEndian, 64), // [sp + 16] = r1 + Assign(R0, MemoryLoad(stack, R2, Endian.LittleEndian, 64)), // [sp + 8] + Assign(R0, BinaryExpr(BVSUB, R0, bv64(2))), + directCall("fib"), + Assign(R2, MemoryLoad(stack, R8, Endian.LittleEndian, 64)), // [sp + 16] (r1) + Assign(R0, BinaryExpr(BVADD, R0, R2)), + Assign(R8, MemoryLoad(stack, BinaryExpr(BVSUB, R8, bv64(16)), Endian.LittleEndian, 64)), + ret + ) ) ) - ) + } - val regs = i.interpret(fib) + val fib = fibonacciProg(8) + val r = interpret(fib) + assert(r.nextCmd == Stopped()) // Show interpreted result Logger.info("Registers:") - regs.foreach { (key, value) => - Logger.info(s"$key := $value") - } + // r.regs.foreach { (key, value) => + // Logger.info(s"$key := $value") + // } } } From 74c4bc275d9efb89ebb68a40c585f48a64890f39 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Thu, 22 Aug 2024 17:29:23 +1000 Subject: [PATCH 08/62] cleanup call/return --- src/main/scala/ir/eval/Interpreter.scala | 231 ++++++++++++++--------- src/test/scala/ir/InterpreterTests.scala | 24 ++- 2 files changed, 159 insertions(+), 96 deletions(-) diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index e2fda9e6b..40d3e4bc5 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -9,22 +9,20 @@ import scala.collection.mutable import scala.collection.immutable import scala.util.control.Breaks.{break, breakable} - sealed trait ExecutionContinuation case class FailedAssertion(a: Assert) extends ExecutionContinuation -case class Stopped() extends ExecutionContinuation -/** Normal stop * */ -case class Run(val next: Command) extends ExecutionContinuation -case class EscapedControlFlow(val call: Jump | Call) extends ExecutionContinuation +case class Stopped() extends ExecutionContinuation /* normal program stop */ +case class Run(val next: Command) extends ExecutionContinuation /* continue by executing next command */ + +case class EscapedControlFlow(val call: Jump | Call) + extends ExecutionContinuation /* controlflow has reached somewhere eunrecoverable */ -/** controlflow has reached somewhere eunrecoverable */ case class Errored(val message: String = "") extends ExecutionContinuation case class TypeError(val message: String = "") extends ExecutionContinuation /* type mismatch appeared */ case class EvalError(val message: String = "") extends ExecutionContinuation /* failed to evaluate an expression to a concrete value */ -case class MemoryError(val message: String = "") - extends ExecutionContinuation /* failed to evaluate an expression to a concrete value */ +case class MemoryError(val message: String = "") extends ExecutionContinuation /* An error to do with memory */ case class InterpreterError(continue: ExecutionContinuation) extends Exception() @@ -41,15 +39,6 @@ enum BasilValue(val irType: IRType): // \exists j . \forall v \in value.values, v.irType = j case MapValue(val value: Map[BasilValue, BasilValue], override val irType: MapType) extends BasilValue(irType) -given Conversion[BitVecLiteral, Scalar] with - def apply(b: BitVecLiteral): Scalar = new Scalar(b) - -given Conversion[IntLiteral, Scalar] with - def apply(b: IntLiteral): Scalar = new Scalar(b) - -given Conversion[BoolLit, Scalar] with - def apply(b: BoolLit): Scalar = new Scalar(b) - case object BasilValue: def size(v: IRType): Int = { @@ -63,6 +52,7 @@ case object BasilValue: def unsafeAdd(l: BasilValue, vr: Int): BasilValue = { l match { + case _ if vr == 0 => l case Scalar(IntLiteral(vl)) => Scalar(IntLiteral(vl + vr)) case Scalar(b1: BitVecLiteral) => Scalar(eval.evalBVBinExpr(BVADD, b1, BitVecLiteral(vr, b1.size))) case _ => throw InterpreterError(TypeError(s"Operation add $vr undefined on $l")) @@ -172,28 +162,30 @@ case class MemoryState( /* Variable retrieval and setting */ + /* Set variable in a given frame */ def setVar(frame: StackFrameID, varname: String, value: BasilValue): MemoryState = { val nv = stackFrames + (frame -> (stackFrames(frame) + (varname -> value))) MemoryState(nv, activations, activationCount) } + /* Find variable definition scope and set it in the correct frame */ def setVar(v: String, value: BasilValue): MemoryState = { val frame = findVarOpt(v).map(_._1).getOrElse(activations.head) setVar(frame, v, value) } - def setVar(v: String, value: Literal): MemoryState = { - setVar(v, Scalar(value)) - } - - def defVar(v: Variable, value: Literal): MemoryState = { - val frame = v.toBoogie.scope match { + /* Define a variable in the scope specified + * ignoring whether it may already be defined + */ + def defVar(name: String, s: Scope , value: BasilValue): MemoryState = { + val frame = s match { case Scope.Global => globalFrame case _ => activations.head } - setVar(frame, v.name, Scalar(value)) + setVar(frame, name, value) } + /* Lookup the value of a variable */ def findVarOpt(name: String): Option[(StackFrameID, BasilValue)] = { val searchScopes = globalFrame :: activations.headOption.toList searchScopes.foldRight(None: Option[(StackFrameID, BasilValue)])((r, acc) => @@ -296,7 +288,7 @@ case class MemoryState( val (mapval, keytype, valtype) = mem match { case m @ MapValue(_, MapType(kt, vt)) if kt == addr.irType && values.forall(v => v.irType == vt) => (m, kt, vt) - case _ => throw InterpreterError(Errored("Invalid map store operation.")) + case v => throw InterpreterError(TypeError(s"Invalid map store operation to $vname : ${v.irType}")) } val keys = (0 until values.size).map(i => BasilValue.unsafeAdd(addr, i)) val vals = endian match { @@ -308,16 +300,21 @@ case class MemoryState( setVar(frame, vname, nmap) } - /** Store extract bitvec to bytes and store bytes */ + /** Extract bitvec to bytes and store bytes */ def storeBV(vname: String, addr: BasilValue, value: BitVecLiteral, endian: Endian) = { val (frame, mem) = findVar(vname) val (mapval, vsize) = mem match { case m @ MapValue(_, MapType(kt, BitVecType(size))) if kt == addr.irType => (m, size) - case _ => throw InterpreterError(Errored("Tried to extract-store non to bv map")) + case v => + throw InterpreterError( + TypeError( + s"Invalid map store operation to $vname : ${v.irType} (expect [${addr.irType}] <- ${value.getType})" + ) + ) } val cells = value.size / vsize if (cells < 1) { - throw InterpreterError(Errored("Tried to execute fractional store")) + throw InterpreterError(MemoryError("Tried to execute fractional store")) } val extractVals = (0 until cells).map(i => BitVectorEval.boogie_extract((i + 1) * vsize, i * vsize, value)).toList @@ -376,18 +373,6 @@ case object Eval { } -def initialState(): MemoryState = { - val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) - val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) - val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) - - MemoryState() - .setVar(globalFrame, "mem", MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) - .setVar(globalFrame, "stack", MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) - .setVar(globalFrame, "R31", Scalar(SP)) - .setVar(globalFrame, "R29", Scalar(FP)) - .setVar(globalFrame, "R30", Scalar(LR)) -} sealed trait Effects[T <: Effects[T]] { /* evaluation (may side-effect via InterpreterException on evaluation failure) */ @@ -400,22 +385,69 @@ sealed trait Effects[T <: Effects[T]] { /** effects * */ def setNext(c: ExecutionContinuation): T - def call(c: DirectCall): T + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): T def doReturn(): T - def storeVar(v: Variable, value: Literal): T + def storeVar(v: String, scope: Scope, value: BasilValue): T def storeMem(vname: String, addr: BitVecLiteral, value: BitVecLiteral, endian: Endian, size: Int): T - def initialiseProgram(p: Program): T + def storeMultiMem(vname: String, addr: Literal, value: List[Literal], endian: Endian): T } -case class InterpreterCFState( + + +// enum Effect: +// case Call(c: DirectCall) +// case SetNext(c: ExecutionContinuation) +// case Return +// case StoreVar(v: Variable, value: Literal) +// case StoreMem(vname: String, addr: BitVecLiteral, value: BitVecLiteral, endian: Endian, size: Int) +// case StoreMultiMem(vname: String, addr: Literal, value: List[Literal], endian: Endian) +// +// case class TracingInterpreter( +// val s: InterpreterState, +// val trace: List[Effect] +// ) extends Effects[TracingInterpreter] { +// +// +// def evalBV(e: Expr): BitVecLiteral = s.evalBV(e) +// def evalInt(e: Expr): BigInt = s.evalInt(e) +// def evalBool(e: Expr): Boolean = s.evalBool(e) +// +// /** effects * */ +// def setNext(c: ExecutionContinuation) = { +// TracingInterpreter(s.setNext(c), Effect.SetNext(c)::trace) +// } +// +// def call(c: DirectCall) = { +// TracingInterpreter(s.call(c), Effect.Call(c)::trace) +// } +// +// def doReturn() = { +// TracingInterpreter(s.doReturn(), Effect.Return::trace) +// } +// +// def storeVar(v: Variable, value: Literal) = { +// TracingInterpreter(s.storeVar(v, value), Effect.StoreVar(v,value)::trace) +// } +// +// def storeMem(vname: String, addr: BitVecLiteral, value: BitVecLiteral, endian: Endian, size: Int) = { +// TracingInterpreter(s.storeMem(vname, addr, value, endian, size), Effect.StoreMem(vname,addr,value,endian,size)::trace) +// } +// +// def storeMultiMem(vname: String, addr: Literal, value: List[Literal], endian: Endian) = { +// TracingInterpreter(s.storeMultiMem(vname, addr, value, endian), Effect.StoreMultiMem(vname,addr,value,endian)::trace) +// } +// +// } + +case class InterpreterState( val nextCmd: ExecutionContinuation = Stopped(), - val callStack: List[DirectCall] = List.empty, + val callStack: List[ExecutionContinuation] = List.empty, val memoryState: MemoryState = MemoryState() -) extends Effects[InterpreterCFState] { +) extends Effects[InterpreterState] { /** eval * */ def evalBV(e: Expr): BitVecLiteral = Eval.evalBV(memoryState, e) @@ -425,36 +457,32 @@ case class InterpreterCFState( def evalBool(e: Expr): Boolean = Eval.evalBool(memoryState, e) /** effects * */ - def setNext(c: ExecutionContinuation): InterpreterCFState = { - InterpreterCFState(c, callStack, memoryState) - } - - def call(c: DirectCall): InterpreterCFState = { - Logger.debug(s" eff : CALL $c") - c.target.entryBlock match { - case Some(block) => - InterpreterCFState( - Run(block.statements.headOption.getOrElse(block.jump)), - c :: callStack, - memoryState.pushStackFrame(c.target.name) - ) - case None => setNext(Run(c.successor)) - } + def setNext(c: ExecutionContinuation): InterpreterState = { + InterpreterState(c, callStack, memoryState) + } + + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): InterpreterState = { + Logger.debug(s" eff : CALL $target") + InterpreterState( + beginFrom, + returnTo :: callStack, + memoryState.pushStackFrame(target) + ) } - def doReturn(): InterpreterCFState = { + def doReturn(): InterpreterState = { callStack match { - case Nil => InterpreterCFState(Stopped(), Nil, memoryState) + case Nil => InterpreterState(Stopped(), Nil, memoryState) case h :: tl => { Logger.debug(s" eff : RETURN $h") - InterpreterCFState(Run(h.successor), tl, memoryState.popStackFrame()) + InterpreterState(h, tl, memoryState.popStackFrame()) } } } - def storeVar(v: Variable, value: Literal) = { + def storeVar(v: String, scope: Scope, value: BasilValue) = { Logger.debug(s" eff : SET $v := $value") - InterpreterCFState(nextCmd, callStack, memoryState.defVar(v, value)) + InterpreterState(nextCmd, callStack, memoryState.defVar(v, scope, value)) } def storeMem( @@ -463,31 +491,50 @@ case class InterpreterCFState( value: BitVecLiteral, endian: Endian, size: Int - ): InterpreterCFState = { + ): InterpreterState = { Logger.debug(s" eff : STORE $vname[$addr..$addr + $size] := $value ($endian)") - InterpreterCFState(nextCmd, callStack, memoryState.storeBV(vname, Scalar(addr), value, endian)) + InterpreterState(nextCmd, callStack, memoryState.storeBV(vname, Scalar(addr), value, endian)) + } + + def storeMultiMem( + vname: String, + addr: Literal, + value: List[Literal], + endian: Endian + ): InterpreterState = { + Logger.debug(s" eff : STOREMULTI $vname[$addr] := $value ($endian)") + InterpreterState(nextCmd, callStack, memoryState.store(vname, Scalar(addr), value.map(Scalar(_)), endian)) + } + +} + +case object InterpFuns { + + def initialState[T <: Effects[T]](s: T): T = { + val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) + val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) + val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) + + s.storeVar("mem", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + .storeVar("stack", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + .storeVar("R31", Scope.Global, Scalar(SP)) + .storeVar("R29", Scope.Global, Scalar(FP)) + .storeVar("R30", Scope.Global, Scalar(LR)) } - def initialiseProgram(p: Program): InterpreterCFState = { - val mem = p.initialMemory.foldLeft(initialState())((s, memory) => { - s.store("mem", Scalar(BitVecLiteral(memory.address, 64)), memory.bytes.map(Scalar(_)).toList, Endian.LittleEndian) - s.store( + def initialiseProgram[T <: Effects[T]](p: Program, s: T): T = { + val mem = p.initialMemory.foldLeft(initialState(s))((s, memory) => { + s.storeMultiMem("mem", BitVecLiteral(memory.address, 64), memory.bytes.toList, Endian.LittleEndian) + s.storeMultiMem( "stack", - Scalar(BitVecLiteral(memory.address, 64)), - memory.bytes.map(Scalar(_)).toList, + BitVecLiteral(memory.address, 64), + memory.bytes.toList, Endian.LittleEndian ) }) - InterpreterCFState( - Run(IRWalk.firstInBlock(p.mainProcedure.entryBlock.get)), - callStack, - mem.pushStackFrame(p.mainProcedure.name) - ) + mem.call(p.mainProcedure.name, Run(IRWalk.firstInBlock(p.mainProcedure.entryBlock.get)), Stopped()) } -} - -case object InterpFuns { def interpretJump[T <: Effects[T]](s: T, j: Jump): T = { j match { @@ -520,8 +567,7 @@ case object InterpFuns { s match { case assign: Assign => { val rhs = st.evalBV(assign.rhs) - st.storeVar(assign.lhs, rhs).setNext(Run(s.successor)) - + st.storeVar(assign.lhs.name, assign.lhs.toBoogie.scope, Scalar(rhs)).setNext(Run(s.successor)) } case assign: MemoryAssign => { val index: BitVecLiteral = st.evalBV(assign.index) @@ -545,7 +591,12 @@ case object InterpFuns { } case dc: DirectCall => - st.call(dc) + if (dc.target.entryBlock.isDefined) { + val block = dc.target.entryBlock.get + st.call(dc.target.name, Run(block.statements.headOption.getOrElse(block.jump)), Run(dc.successor)) + } else { + st.setNext(Run(dc.successor)) + } case ic: IndirectCall => if (ic.target == Register("R30", 64)) { st.doReturn() @@ -565,12 +616,12 @@ case object InterpFuns { } @tailrec - def interpret(s: InterpreterCFState): InterpreterCFState = { + def interpret(s: InterpreterState): InterpreterState = { Logger.debug(s"interpret ${s.nextCmd}") s.nextCmd match { case Run(c: Statement) => interpret( - protect[InterpreterCFState]( + protect[InterpreterState]( () => interpretStatement(s, c), { case InterpreterError(e) => s.setNext(e) @@ -580,7 +631,7 @@ case object InterpFuns { ) case Run(c: Jump) => interpret( - protect[InterpreterCFState]( + protect[InterpreterState]( () => interpretJump(s, c), { case InterpreterError(e) => s.setNext(e) @@ -593,8 +644,8 @@ case object InterpFuns { } } - def interpretProg(p: Program): InterpreterCFState = { - var s = InterpreterCFState().initialiseProgram(p) + def interpretProg(p: Program): InterpreterState = { + var s = initialiseProgram(p, InterpreterState()) interpret(s) } diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 85e8c5c70..30a36af91 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -10,6 +10,18 @@ import util.{LogLevel, Logger} import util.IRLoading.{loadBAP, loadReadELF} import util.ILLoadingConfig +def initialMem(): MemoryState = { + val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) + val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) + val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) + + MemoryState() + .setVar(globalFrame, "mem", MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + .setVar(globalFrame, "stack", MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + .setVar(globalFrame, "R31", Scalar(SP)) + .setVar(globalFrame, "R29", Scalar(FP)) + .setVar(globalFrame, "R30", Scalar(LR)) +} def load[T <: Effects[T]](s: T, global: SpecGlobal) : Option[BitVecLiteral] = { // i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems) @@ -91,12 +103,12 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { BitVecLiteral(BigInt("0B", 16), 8), BitVecLiteral(BigInt("0A", 16), 8)) - val s = initialState().store("mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) + val s = initialMem().store("mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) val actual: BitVecLiteral = s.loadBV("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) assert(actual == expected) - val s2 = initialState().storeBV("mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) + val s2 = initialMem().storeBV("mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) val actual2: BitVecLiteral = s.loadBV("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) assert(actual2 == actual) @@ -110,7 +122,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { BitVecLiteral(BigInt("0B", 16), 8), BitVecLiteral(BigInt("0A", 16), 8)) - val s = initialState().store("mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) + val s = initialMem().store("mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) val actual: BitVecLiteral = s.loadBV("mem", Scalar(BitVecLiteral(0, 64)), Endian.BigEndian , 32) assert(actual == expected) @@ -123,8 +135,8 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { (BitVecLiteral(1, 64) , BitVecLiteral(BigInt("0C", 16), 8)), (BitVecLiteral(2, 64) , BitVecLiteral(BigInt("0B", 16), 8)), (BitVecLiteral(3, 64) , BitVecLiteral(BigInt("0A", 16), 8))) - val s = ts.foldLeft(initialState())((m, v) => m.storeSingle("mem", Scalar(v._1), Scalar(v._2))) - // val s = initialState().store("mem") + val s = ts.foldLeft(initialMem())((m, v) => m.storeSingle("mem", Scalar(v._1), Scalar(v._2))) + // val s = initialMem().store("mem") // val r = s.loadBV("mem", BitVecLiteral(0, 64)) val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) @@ -138,7 +150,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { test("StoreBV = LoadBV LE ") { val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) - val s = initialState().storeBV("mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) + val s = initialMem().storeBV("mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) val actual: BitVecLiteral = s.loadBV("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) println(s"${actual.value.toInt.toHexString} == ${expected.value.toInt.toHexString}") assert(actual == expected) From 9e235589220d2fc2a8d2f844a73dad38c244cad7 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Thu, 22 Aug 2024 19:29:21 +1000 Subject: [PATCH 09/62] cleanup memory ops to enter effects --- src/main/scala/ir/eval/Interpreter.scala | 352 ++++++++++++----------- src/test/scala/ir/InterpreterTests.scala | 53 ++-- 2 files changed, 219 insertions(+), 186 deletions(-) diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 40d3e4bc5..3d52be5cf 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -32,12 +32,19 @@ case class InterpreterSummary( val memory: Map[Int, BitVecLiteral] ) -enum BasilValue(val irType: IRType): - case Scalar(val value: Literal) extends BasilValue(value.getType) - // Erase the type of basil values and enforce the invariant that - // \exists i . \forall v \in value.keys , v.irType = i and - // \exists j . \forall v \in value.values, v.irType = j - case MapValue(val value: Map[BasilValue, BasilValue], override val irType: MapType) extends BasilValue(irType) +sealed trait BasilValue(val irType: IRType) +case class Scalar(val value: Literal) extends BasilValue(value.getType) { + override def toString = value match { + case b: BitVecLiteral => "0x%x:bv%d".format(b.value, b.size) + case c => c.toString + } +} +// Erase the type of basil values and enforce the invariant that +// \exists i . \forall v \in value.keys , v.irType = i and +// \exists j . \forall v \in value.values, v.irType = j +case class MapValue(val value: Map[BasilValue, BasilValue], override val irType: MapType) extends BasilValue(irType) { + override def toString = s"MapValue : $irType" +} case object BasilValue: @@ -177,7 +184,7 @@ case class MemoryState( /* Define a variable in the scope specified * ignoring whether it may already be defined */ - def defVar(name: String, s: Scope , value: BasilValue): MemoryState = { + def defVar(name: String, s: Scope, value: BasilValue): MemoryState = { val frame = s match { case Scope.Global => globalFrame case _ => activations.head @@ -225,26 +232,90 @@ case class MemoryState( } /* Map variable accessing ; load and store operations */ - - /* canonical load operation */ - def load(vname: String, addr: Scalar, endian: Endian, count: Int): List[BasilValue] = { + def doLoad(vname: String, addr: List[BasilValue]): List[BasilValue] = { val (frame, mem) = findVar(vname) - val mapv: MapValue = mem match { case m @ MapValue(innerMap, ty) => m - case _ => throw InterpreterError(Errored("Load from nonmap")) + case m => throw InterpreterError(TypeError(s"Load from nonmap ${m.irType}")) + } + + addr.map(k => + mapv.value.get(k).getOrElse(throw InterpreterError(MemoryError(s"Read from uninitialised $vname[$k]"))) + ) + } + + /** typecheck and some fields of a map variable */ + def doStore(vname: String, values: Map[BasilValue, BasilValue]) = { + val (frame, mem) = findVar(vname) + + val (mapval, keytype, valtype) = mem match { + case m @ MapValue(_, MapType(kt, vt)) => (m, kt, vt) + case v => throw InterpreterError(TypeError(s"Invalid map store operation to $vname : ${v.irType}")) + } + + (values.find((k, v) => k.irType != keytype || v.irType != valtype)) match { + case Some(v) => + throw InterpreterError( + TypeError( + s"Invalid addr or value type (${v._1.irType}, ${v._2.irType}) does not match map type $vname : ($keytype, $valtype)" + ) + ) + case None => () + } + + val nmap = MapValue(mapval.value ++ values, mapval.irType) + setVar(frame, vname, nmap) + } +} + +case object Eval { + def getVar[T <: Effects[T]](s: T)(v: Variable): Option[Literal] = s.loadVar(v.name) match { + case Scalar(l) => Some(l) + case _ => None + } + + def doLoad[T <: Effects[T]](s: T)(m: Memory, addr: Expr, endian: Endian, sz: Int): Option[Literal] = { + addr match { + case l: Literal if sz == 1 => ( + loadSingle(s, m.name, Scalar(l)) match { + case Scalar(v) => Some(v) + case _ => None + } + ) + case l: Literal => Some(loadBV(s, m.name, Scalar(l), endian, sz)) + case _ => None + } + } + + def evalBV[T <: Effects[T]](s: T, e: Expr): BitVecLiteral = { + ir.eval.evalBVExpr(e, Eval.getVar(s), Eval.doLoad(s)) match { + case Right(e) => e + case Left(e) => throw InterpreterError(Errored(s"Eval BV residual $e")) + } + } + + def evalInt[T <: Effects[T]](s: T, e: Expr): BigInt = { + ir.eval.evalIntExpr(e, Eval.getVar(s), Eval.doLoad(s)) match { + case Right(e) => e + case Left(e) => throw InterpreterError(Errored(s"Eval int residual $e")) + } + } + + def evalBool[T <: Effects[T]](s: T, e: Expr): Boolean = { + ir.eval.evalLogExpr(e, Eval.getVar(s), Eval.doLoad(s)) match { + case Right(e) => e + case Left(e) => throw InterpreterError(Errored(s"Eval bool residual $e")) } + } + /** Load helpers * */ + def load[T <: Effects[T]](s: T, vname: String, addr: Scalar, endian: Endian, count: Int): List[BasilValue] = { if (count == 0) { throw InterpreterError(Errored(s"Attempted fractional load")) } val keys = (0 until count).map(i => BasilValue.unsafeAdd(addr, i)) - - val values = keys.map(k => - mapv.value.get(k).getOrElse(throw InterpreterError(MemoryError(s"Read from uninitialised $vname[$k]"))) - ) - + val values = s.loadMem(vname, keys.toList) val vals = endian match { case Endian.LittleEndian => values.reverse case Endian.BigEndian => values @@ -254,21 +325,22 @@ case class MemoryState( } /** Load and concat bitvectors */ - def loadBV(vname: String, addr: Scalar, endian: Endian, size: Int): BitVecLiteral = { - val (frame, mem) = findVar(vname) + def loadBV[T <: Effects[T]](s: T, vname: String, addr: Scalar, endian: Endian, size: Int): BitVecLiteral = { + val mem = s.loadVar(vname) val (valsize, mapv) = mem match { case mapv @ MapValue(_, MapType(_, BitVecType(sz))) => (sz, mapv) case _ => throw InterpreterError(Errored("Trued to load-concat non bv")) } - val cells = size / BasilValue.size(mapv.irType.result) + val cells = size / valsize val bvs: List[BitVecLiteral] = { - val res = load(vname, addr, endian, cells) + val res = load(s, vname, addr, endian, cells) val rr = res.map { case Scalar(bv @ BitVecLiteral(v, sz)) if sz == valsize => bv - case _ => throw InterpreterError(Errored(s"Loaded value that did not match expected type bv$valsize")) + case c => + throw InterpreterError(TypeError(s"Loaded value of type ${c.irType} did not match expected type bv$valsize")) } rr } @@ -278,13 +350,16 @@ case class MemoryState( bvres } - def loadSingle(vname: String, addr: Scalar): BasilValue = { - load(vname, addr, Endian.LittleEndian, 1).head + def loadSingle[T <: Effects[T]](s: T, vname: String, addr: Scalar): BasilValue = { + load(s, vname, addr, Endian.LittleEndian, 1).head } - /** Canonical store operation */ - def store(vname: String, addr: BasilValue, values: List[BasilValue], endian: Endian) = { - val (frame, mem) = findVar(vname) + /** State modifying helpers, e.g. store + */ + + /* Expand addr for number of values to store */ + def store[T <: Effects[T]](st: T, vname: String, addr: BasilValue, values: List[BasilValue], endian: Endian): T = { + val mem = st.loadVar(vname) val (mapval, keytype, valtype) = mem match { case m @ MapValue(_, MapType(kt, vt)) if kt == addr.irType && values.forall(v => v.irType == vt) => (m, kt, vt) @@ -296,13 +371,12 @@ case class MemoryState( case Endian.BigEndian => values } - val nmap = MapValue(mapval.value ++ keys.zip(vals), mapval.irType) - setVar(frame, vname, nmap) + st.storeMem(vname, keys.zip(vals).toMap) } /** Extract bitvec to bytes and store bytes */ - def storeBV(vname: String, addr: BasilValue, value: BitVecLiteral, endian: Endian) = { - val (frame, mem) = findVar(vname) + def storeBV[T <: Effects[T]](st: T, vname: String, addr: BasilValue, value: BitVecLiteral, endian: Endian): T = { + val mem = st.loadVar(vname) val (mapval, vsize) = mem match { case m @ MapValue(_, MapType(kt, BitVecType(size))) if kt == addr.irType => (m, size) case v => @@ -320,60 +394,20 @@ case class MemoryState( val extractVals = (0 until cells).map(i => BitVectorEval.boogie_extract((i + 1) * vsize, i * vsize, value)).toList val vs = endian match { - case Endian.LittleEndian => extractVals.reverse.map(Scalar(_)) - case Endian.BigEndian => extractVals.map(Scalar(_)) + case Endian.LittleEndian => extractVals.map(Scalar(_)) + case Endian.BigEndian => extractVals.reverse.map(Scalar(_)) } - store(vname, addr, vs, endian) + val keys = (0 until cells).map(i => BasilValue.unsafeAdd(addr, i)) + st.storeMem(vname, keys.zip(vs).toMap) } - def storeSingle(vname: String, addr: BasilValue, value: BasilValue) = { - store(vname, addr, List(value), Endian.LittleEndian) + def storeSingle[T <: Effects[T]](st: T, vname: String, addr: BasilValue, value: BasilValue): T = { + st.storeMem(vname, Map((addr -> value))) } } -case object Eval { - - def getVar(s: MemoryState)(v: Variable) = s.getVarLiteralOpt(v) - - def doLoad(s: MemoryState)(m: Memory, addr: Expr, endian: Endian, sz: Int): Option[Literal] = { - addr match { - case l: Literal if sz == 1 => ( - s.loadSingle(m.name, Scalar(l)) match { - case Scalar(v) => Some(v) - case _ => None - } - ) - case l: Literal => Some(s.loadBV(m.name, Scalar(l), endian, sz)) - case _ => None - } - } - - def evalBV(s: MemoryState, e: Expr): BitVecLiteral = { - ir.eval.evalBVExpr(e, Eval.getVar(s), Eval.doLoad(s)) match { - case Right(e) => e - case Left(e) => throw InterpreterError(Errored(s"Eval BV residual $e")) - } - } - - def evalInt(s: MemoryState, e: Expr): BigInt = { - ir.eval.evalIntExpr(e, Eval.getVar(s), Eval.doLoad(s)) match { - case Right(e) => e - case Left(e) => throw InterpreterError(Errored(s"Eval int residual $e")) - } - } - - def evalBool(s: MemoryState, e: Expr): Boolean = { - ir.eval.evalLogExpr(e, Eval.getVar(s), Eval.doLoad(s)) match { - case Right(e) => e - case Left(e) => throw InterpreterError(Errored(s"Eval bool residual $e")) - } - } - -} - - sealed trait Effects[T <: Effects[T]] { /* evaluation (may side-effect via InterpreterException on evaluation failure) */ def evalBV(e: Expr): BitVecLiteral @@ -382,6 +416,10 @@ sealed trait Effects[T <: Effects[T]] { def evalBool(e: Expr): Boolean + def loadVar(v: String): BasilValue + + def loadMem(v: String, addrs: List[BasilValue]): List[BasilValue] + /** effects * */ def setNext(c: ExecutionContinuation): T @@ -391,57 +429,56 @@ sealed trait Effects[T <: Effects[T]] { def storeVar(v: String, scope: Scope, value: BasilValue): T - def storeMem(vname: String, addr: BitVecLiteral, value: BitVecLiteral, endian: Endian, size: Int): T + def storeMem(vname: String, update: Map[BasilValue, BasilValue]): T - def storeMultiMem(vname: String, addr: Literal, value: List[Literal], endian: Endian): T } +enum Effect: + case Call(target: String, begin: ExecutionContinuation, returnTo: ExecutionContinuation) + case SetNext(c: ExecutionContinuation) + case Return + case StoreVar(v: String, s: Scope, value: BasilValue) + case StoreMem(vname: String, update: Map[BasilValue, BasilValue]) + +case class TracingInterpreter( + val s: InterpreterState, + val trace: List[Effect] +) extends Effects[TracingInterpreter] { + + def evalBV(e: Expr): BitVecLiteral = s.evalBV(e) + def evalInt(e: Expr): BigInt = s.evalInt(e) + def evalBool(e: Expr): Boolean = s.evalBool(e) + + def loadVar(v: String): BasilValue = s.loadVar(v) + def loadMem(v: String, addrs: List[BasilValue]) = s.loadMem(v, addrs) + /** effects * */ + def setNext(c: ExecutionContinuation) = { + Logger.debug(s" eff : DONEXT $c") + TracingInterpreter(s.setNext(c), Effect.SetNext(c)::trace) + } + + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = { + Logger.debug(s" eff : CALL $target") + TracingInterpreter(s.call(target, beginFrom, returnTo), Effect.Call(target, beginFrom, returnTo)::trace) + } -// enum Effect: -// case Call(c: DirectCall) -// case SetNext(c: ExecutionContinuation) -// case Return -// case StoreVar(v: Variable, value: Literal) -// case StoreMem(vname: String, addr: BitVecLiteral, value: BitVecLiteral, endian: Endian, size: Int) -// case StoreMultiMem(vname: String, addr: Literal, value: List[Literal], endian: Endian) -// -// case class TracingInterpreter( -// val s: InterpreterState, -// val trace: List[Effect] -// ) extends Effects[TracingInterpreter] { -// -// -// def evalBV(e: Expr): BitVecLiteral = s.evalBV(e) -// def evalInt(e: Expr): BigInt = s.evalInt(e) -// def evalBool(e: Expr): Boolean = s.evalBool(e) -// -// /** effects * */ -// def setNext(c: ExecutionContinuation) = { -// TracingInterpreter(s.setNext(c), Effect.SetNext(c)::trace) -// } -// -// def call(c: DirectCall) = { -// TracingInterpreter(s.call(c), Effect.Call(c)::trace) -// } -// -// def doReturn() = { -// TracingInterpreter(s.doReturn(), Effect.Return::trace) -// } -// -// def storeVar(v: Variable, value: Literal) = { -// TracingInterpreter(s.storeVar(v, value), Effect.StoreVar(v,value)::trace) -// } -// -// def storeMem(vname: String, addr: BitVecLiteral, value: BitVecLiteral, endian: Endian, size: Int) = { -// TracingInterpreter(s.storeMem(vname, addr, value, endian, size), Effect.StoreMem(vname,addr,value,endian,size)::trace) -// } -// -// def storeMultiMem(vname: String, addr: Literal, value: List[Literal], endian: Endian) = { -// TracingInterpreter(s.storeMultiMem(vname, addr, value, endian), Effect.StoreMultiMem(vname,addr,value,endian)::trace) -// } -// -// } + def doReturn() = { + Logger.debug(s" eff : RETURN") + TracingInterpreter(s.doReturn(), Effect.Return::trace) + } + + def storeVar(v: String, c: Scope, value: BasilValue) = { + Logger.debug(s" eff : SET $v := $value") + TracingInterpreter(s.storeVar(v, c, value), Effect.StoreVar(v, c, value)::trace) + } + + def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = { + Logger.debug(s" eff : STORE $vname <- $update") + TracingInterpreter(s.storeMem(vname, update), Effect.StoreMem(vname, update)::trace) + } + +} case class InterpreterState( val nextCmd: ExecutionContinuation = Stopped(), @@ -449,12 +486,16 @@ case class InterpreterState( val memoryState: MemoryState = MemoryState() ) extends Effects[InterpreterState] { - /** eval * */ - def evalBV(e: Expr): BitVecLiteral = Eval.evalBV(memoryState, e) + /* eval */ + def evalBV(e: Expr): BitVecLiteral = Eval.evalBV(this, e) + + def evalInt(e: Expr): BigInt = Eval.evalInt(this, e) - def evalInt(e: Expr): BigInt = Eval.evalInt(memoryState, e) + def evalBool(e: Expr): Boolean = Eval.evalBool(this, e) - def evalBool(e: Expr): Boolean = Eval.evalBool(memoryState, e) + def loadVar(v: String): BasilValue = memoryState.getVar(v) + + def loadMem(v: String, addrs: List[BasilValue]): List[BasilValue] = memoryState.doLoad(v, addrs) /** effects * */ def setNext(c: ExecutionContinuation): InterpreterState = { @@ -462,7 +503,6 @@ case class InterpreterState( } def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): InterpreterState = { - Logger.debug(s" eff : CALL $target") InterpreterState( beginFrom, returnTo :: callStack, @@ -474,36 +514,17 @@ case class InterpreterState( callStack match { case Nil => InterpreterState(Stopped(), Nil, memoryState) case h :: tl => { - Logger.debug(s" eff : RETURN $h") InterpreterState(h, tl, memoryState.popStackFrame()) } } } - def storeVar(v: String, scope: Scope, value: BasilValue) = { - Logger.debug(s" eff : SET $v := $value") + def storeVar(v: String, scope: Scope, value: BasilValue) = { InterpreterState(nextCmd, callStack, memoryState.defVar(v, scope, value)) } - def storeMem( - vname: String, - addr: BitVecLiteral, - value: BitVecLiteral, - endian: Endian, - size: Int - ): InterpreterState = { - Logger.debug(s" eff : STORE $vname[$addr..$addr + $size] := $value ($endian)") - InterpreterState(nextCmd, callStack, memoryState.storeBV(vname, Scalar(addr), value, endian)) - } - - def storeMultiMem( - vname: String, - addr: Literal, - value: List[Literal], - endian: Endian - ): InterpreterState = { - Logger.debug(s" eff : STOREMULTI $vname[$addr] := $value ($endian)") - InterpreterState(nextCmd, callStack, memoryState.store(vname, Scalar(addr), value.map(Scalar(_)), endian)) + def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = { + InterpreterState(nextCmd, callStack, memoryState.doStore(vname, update)) } } @@ -515,20 +536,27 @@ case object InterpFuns { val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) - s.storeVar("mem", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) - .storeVar("stack", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) - .storeVar("R31", Scope.Global, Scalar(SP)) - .storeVar("R29", Scope.Global, Scalar(FP)) - .storeVar("R30", Scope.Global, Scalar(LR)) + s.storeVar("mem", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + .storeVar("stack", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + .storeVar("R31", Scope.Global, Scalar(SP)) + .storeVar("R29", Scope.Global, Scalar(FP)) + .storeVar("R30", Scope.Global, Scalar(LR)) } def initialiseProgram[T <: Effects[T]](p: Program, s: T): T = { val mem = p.initialMemory.foldLeft(initialState(s))((s, memory) => { - s.storeMultiMem("mem", BitVecLiteral(memory.address, 64), memory.bytes.toList, Endian.LittleEndian) - s.storeMultiMem( + val s1 = Eval.store( + s, + "mem", + Scalar(BitVecLiteral(memory.address, 64)), + memory.bytes.toList.map(Scalar(_)), + Endian.LittleEndian + ) + Eval.store( + s1, "stack", - BitVecLiteral(memory.address, 64), - memory.bytes.toList, + Scalar(BitVecLiteral(memory.address, 64)), + memory.bytes.toList.map(Scalar(_)), Endian.LittleEndian ) }) @@ -555,7 +583,6 @@ case object InterpFuns { case h :: Nil => h case h :: tl => throw InterpreterError(Errored(s"More than one jump guard satisfied $gt")) } - Logger.debug(s"Goto ${chosen._1.parent.label}") s.setNext(Run(chosen._1.successor)) case r: Return => s.doReturn() @@ -572,15 +599,14 @@ case object InterpFuns { case assign: MemoryAssign => { val index: BitVecLiteral = st.evalBV(assign.index) val value: BitVecLiteral = st.evalBV(assign.value) - st.storeMem(assign.mem.name, index, value, assign.endian, assign.size).setNext(Run(s.successor)) - + // st.storeMem(assign.mem.name, index, value, assign.endian, assign.size).setNext(Run(s.successor)) + Eval.storeBV(st, assign.mem.name, Scalar(index), value, assign.endian).setNext(Run(s.successor)) } case assert: Assert => { if (!st.evalBool(assert.body)) then { st.setNext(FailedAssertion(assert)) } else { st.setNext(Run(s.successor)) - } } case assume: Assume => diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 30a36af91..05bf954d5 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -10,18 +10,21 @@ import util.{LogLevel, Logger} import util.IRLoading.{loadBAP, loadReadELF} import util.ILLoadingConfig -def initialMem(): MemoryState = { - val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) - val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) - val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) - - MemoryState() - .setVar(globalFrame, "mem", MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) - .setVar(globalFrame, "stack", MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) - .setVar(globalFrame, "R31", Scalar(SP)) - .setVar(globalFrame, "R29", Scalar(FP)) - .setVar(globalFrame, "R30", Scalar(LR)) -} +// def initialMem(): MemoryState = { +// val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) +// val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) +// val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) +// +// MemoryState() +// .setVar(globalFrame, "mem", MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) +// .setVar(globalFrame, "stack", MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) +// .setVar(globalFrame, "R31", Scalar(SP)) +// .setVar(globalFrame, "R29", Scalar(FP)) +// .setVar(globalFrame, "R30", Scalar(LR)) +// } + + +def initialMem() = InterpFuns.initialState(TracingInterpreter(InterpreterState(), List())) def load[T <: Effects[T]](s: T, global: SpecGlobal) : Option[BitVecLiteral] = { // i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems) @@ -103,17 +106,21 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { BitVecLiteral(BigInt("0B", 16), 8), BitVecLiteral(BigInt("0A", 16), 8)) - val s = initialMem().store("mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) + val s = Eval.store(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) - val actual: BitVecLiteral = s.loadBV("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) + val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) assert(actual == expected) - val s2 = initialMem().storeBV("mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) - val actual2: BitVecLiteral = s.loadBV("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) - assert(actual2 == actual) } + test("store bv = loadbv le") { + val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) + val s2 = Eval.storeBV(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) + val actual2: BitVecLiteral = Eval.loadBV(s2, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) + assert(actual2 == expected) + } + test("Store = Load BigEndian") { val ts = List( @@ -122,9 +129,9 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { BitVecLiteral(BigInt("0B", 16), 8), BitVecLiteral(BigInt("0A", 16), 8)) - val s = initialMem().store("mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) + val s = Eval.store(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) - val actual: BitVecLiteral = s.loadBV("mem", Scalar(BitVecLiteral(0, 64)), Endian.BigEndian , 32) + val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.BigEndian , 32) assert(actual == expected) @@ -135,14 +142,14 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { (BitVecLiteral(1, 64) , BitVecLiteral(BigInt("0C", 16), 8)), (BitVecLiteral(2, 64) , BitVecLiteral(BigInt("0B", 16), 8)), (BitVecLiteral(3, 64) , BitVecLiteral(BigInt("0A", 16), 8))) - val s = ts.foldLeft(initialMem())((m, v) => m.storeSingle("mem", Scalar(v._1), Scalar(v._2))) + val s = ts.foldLeft(initialMem())((m, v) => Eval.storeSingle(m, "mem", Scalar(v._1), Scalar(v._2))) // val s = initialMem().store("mem") // val r = s.loadBV("mem", BitVecLiteral(0, 64)) val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) // def loadBV(vname: String, addr: Scalar, endian: Endian, size: Int): BitVecLiteral = { - val actual: BitVecLiteral = s.loadBV("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) + val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) assert(actual == expected) } @@ -150,8 +157,8 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { test("StoreBV = LoadBV LE ") { val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) - val s = initialMem().storeBV("mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) - val actual: BitVecLiteral = s.loadBV("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) + val s = Eval.storeBV(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) + val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) println(s"${actual.value.toInt.toHexString} == ${expected.value.toInt.toHexString}") assert(actual == expected) } From 7571aa3a82d5454a3513e2fcf980139b68524b67 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Mon, 26 Aug 2024 14:12:52 +1000 Subject: [PATCH 10/62] tracing interpreter --- src/main/scala/ir/eval/Interpreter.scala | 42 ++++++--- src/test/scala/ir/InterpreterTests.scala | 113 +++++++++++++---------- 2 files changed, 90 insertions(+), 65 deletions(-) diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 3d52be5cf..6e4a60555 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -431,6 +431,8 @@ sealed trait Effects[T <: Effects[T]] { def storeMem(vname: String, update: Map[BasilValue, BasilValue]): T + def getNext : ExecutionContinuation + } enum Effect: @@ -454,30 +456,32 @@ case class TracingInterpreter( /** effects * */ def setNext(c: ExecutionContinuation) = { - Logger.debug(s" eff : DONEXT $c") - TracingInterpreter(s.setNext(c), Effect.SetNext(c)::trace) + // Logger.debug(s" eff : DONEXT $c") + TracingInterpreter(s.setNext(c), trace) } def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = { - Logger.debug(s" eff : CALL $target") + //Logger.debug(s" eff : CALL $target") TracingInterpreter(s.call(target, beginFrom, returnTo), Effect.Call(target, beginFrom, returnTo)::trace) } def doReturn() = { - Logger.debug(s" eff : RETURN") + //Logger.debug(s" eff : RETURN") TracingInterpreter(s.doReturn(), Effect.Return::trace) } def storeVar(v: String, c: Scope, value: BasilValue) = { - Logger.debug(s" eff : SET $v := $value") + //Logger.debug(s" eff : SET $v := $value") TracingInterpreter(s.storeVar(v, c, value), Effect.StoreVar(v, c, value)::trace) } def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = { - Logger.debug(s" eff : STORE $vname <- $update") + //Logger.debug(s" eff : STORE $vname <- $update") TracingInterpreter(s.storeMem(vname, update), Effect.StoreMem(vname, update)::trace) } + def getNext = s.getNext + } case class InterpreterState( @@ -497,6 +501,8 @@ case class InterpreterState( def loadMem(v: String, addrs: List[BasilValue]): List[BasilValue] = memoryState.doLoad(v, addrs) + def getNext = nextCmd + /** effects * */ def setNext(c: ExecutionContinuation): InterpreterState = { InterpreterState(c, callStack, memoryState) @@ -642,12 +648,12 @@ case object InterpFuns { } @tailrec - def interpret(s: InterpreterState): InterpreterState = { - Logger.debug(s"interpret ${s.nextCmd}") - s.nextCmd match { + def interpret[T <: Effects[T]](s: T): T = { + Logger.debug(s"interpret ${s.getNext}") + s.getNext match { case Run(c: Statement) => interpret( - protect[InterpreterState]( + protect[T]( () => interpretStatement(s, c), { case InterpreterError(e) => s.setNext(e) @@ -657,7 +663,7 @@ case object InterpFuns { ) case Run(c: Jump) => interpret( - protect[InterpreterState]( + protect[T]( () => interpretJump(s, c), { case InterpreterError(e) => s.setNext(e) @@ -670,13 +676,19 @@ case object InterpFuns { } } - def interpretProg(p: Program): InterpreterState = { - var s = initialiseProgram(p, InterpreterState()) + def interpretProg[T <: Effects[T]](p: Program, i: T): T = { + var s = initialiseProgram(p, i) interpret(s) } } -def interpret(IRProgram: Program) = { - InterpFuns.interpretProg(IRProgram) +def interpret(IRProgram: Program) : InterpreterState = { + InterpFuns.interpretProg(IRProgram, InterpreterState()) } + +def interpretTrace(IRProgram: Program) : TracingInterpreter = { + val s : TracingInterpreter = InterpFuns.interpretProg(IRProgram, TracingInterpreter(InterpreterState(), List())) + s +} + diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 05bf954d5..ecf521a51 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -286,61 +286,61 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { testInterpret("no_interference_update_y", expected) } - test("fibonacci") { - def fibonacciProg(n: Int) = { - def expected(n: Int) : Int = { - n match { - case 0 => 0 - case 1 => 1 - case n => expected(n - 1) + expected(n - 2) - } + def fibonacciProg(n: Int) = { + def expected(n: Int) : Int = { + n match { + case 0 => 0 + case 1 => 1 + case n => expected(n - 1) + expected(n - 2) } - prog( - proc("begin", - block("entry", - Assign(R8, Register("R31", 64)), - Assign(R0, bv64(n)), - directCall("fib"), - goto("done") - ), - block("done", - Assert(BinaryExpr(BVEQ, R0, bv64(expected(n)))), - ret - )), - proc("fib", - block("base", goto("base1", "base2", "dofib")), - block("base1", - Assume(BinaryExpr(BVEQ, R0, bv64(0))), - ret), - block("base2", - Assume(BinaryExpr(BVEQ, R0, bv64(1))), - ret), - block("dofib", - Assume(BinaryExpr(BoolAND, BinaryExpr(BVNEQ, R0, bv64(0)), BinaryExpr(BVNEQ, R0, bv64(1)))), - // R8 stack pointer preserved across calls - Assign(R7, BinaryExpr(BVADD, R8, bv64(8))), - MemoryAssign(stack, R7, R8, Endian.LittleEndian, 64), // sp - Assign(R8, R7), - Assign(R8, BinaryExpr(BVADD, R8, bv64(8))), // sp + 8 - MemoryAssign(stack, R8, R0, Endian.LittleEndian, 64), // [sp + 8] = arg0 - Assign(R0, BinaryExpr(BVSUB, R0, bv64(1))), - directCall("fib"), - Assign(R2, R8), // sp + 8 - Assign(R8, BinaryExpr(BVADD, R8, bv64(8))), // sp + 16 - MemoryAssign(stack, R8, R0, Endian.LittleEndian, 64), // [sp + 16] = r1 - Assign(R0, MemoryLoad(stack, R2, Endian.LittleEndian, 64)), // [sp + 8] - Assign(R0, BinaryExpr(BVSUB, R0, bv64(2))), - directCall("fib"), - Assign(R2, MemoryLoad(stack, R8, Endian.LittleEndian, 64)), // [sp + 16] (r1) - Assign(R0, BinaryExpr(BVADD, R0, R2)), - Assign(R8, MemoryLoad(stack, BinaryExpr(BVSUB, R8, bv64(16)), Endian.LittleEndian, 64)), - ret - ) + } + prog( + proc("begin", + block("entry", + Assign(R8, Register("R31", 64)), + Assign(R0, bv64(n)), + directCall("fib"), + goto("done") + ), + block("done", + Assert(BinaryExpr(BVEQ, R0, bv64(expected(n)))), + ret + )), + proc("fib", + block("base", goto("base1", "base2", "dofib")), + block("base1", + Assume(BinaryExpr(BVEQ, R0, bv64(0))), + ret), + block("base2", + Assume(BinaryExpr(BVEQ, R0, bv64(1))), + ret), + block("dofib", + Assume(BinaryExpr(BoolAND, BinaryExpr(BVNEQ, R0, bv64(0)), BinaryExpr(BVNEQ, R0, bv64(1)))), + // R8 stack pointer preserved across calls + Assign(R7, BinaryExpr(BVADD, R8, bv64(8))), + MemoryAssign(stack, R7, R8, Endian.LittleEndian, 64), // sp + Assign(R8, R7), + Assign(R8, BinaryExpr(BVADD, R8, bv64(8))), // sp + 8 + MemoryAssign(stack, R8, R0, Endian.LittleEndian, 64), // [sp + 8] = arg0 + Assign(R0, BinaryExpr(BVSUB, R0, bv64(1))), + directCall("fib"), + Assign(R2, R8), // sp + 8 + Assign(R8, BinaryExpr(BVADD, R8, bv64(8))), // sp + 16 + MemoryAssign(stack, R8, R0, Endian.LittleEndian, 64), // [sp + 16] = r1 + Assign(R0, MemoryLoad(stack, R2, Endian.LittleEndian, 64)), // [sp + 8] + Assign(R0, BinaryExpr(BVSUB, R0, bv64(2))), + directCall("fib"), + Assign(R2, MemoryLoad(stack, R8, Endian.LittleEndian, 64)), // [sp + 16] (r1) + Assign(R0, BinaryExpr(BVADD, R0, R2)), + Assign(R8, MemoryLoad(stack, BinaryExpr(BVSUB, R8, bv64(16)), Endian.LittleEndian, 64)), + ret ) ) - } + ) + } + test("fibonacci") { val fib = fibonacciProg(8) val r = interpret(fib) @@ -352,4 +352,17 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { // } } + + + test("fibonacci Trace") { + + val fib = fibonacciProg(8) + val r = interpretTrace(fib) + assert(r.getNext == Stopped()) + // Show interpreted result + // + info(r.trace.reverse.mkString("\n")) + + } + } From f36c13e0d0bd1e50d55df281ad6cf731b1037461 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Tue, 27 Aug 2024 16:02:38 +1000 Subject: [PATCH 11/62] compile with state monad --- src/main/scala/ir/eval/ExprEval.scala | 141 +++-- src/main/scala/ir/eval/Interpreter.scala | 700 +++++++++++++---------- src/test/scala/ir/InterpreterTests.scala | 219 +++---- 3 files changed, 607 insertions(+), 453 deletions(-) diff --git a/src/main/scala/ir/eval/ExprEval.scala b/src/main/scala/ir/eval/ExprEval.scala index 7efcc8f7d..5cb5979f7 100644 --- a/src/main/scala/ir/eval/ExprEval.scala +++ b/src/main/scala/ir/eval/ExprEval.scala @@ -127,19 +127,68 @@ def evalUnOp(op: UnOp, body: Literal) : Expr = { } } -def partialEvalExpr(exp: Expr, variableAssignment: Variable => Option[Expr], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)): Expr = { + + +def evalIntExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)): Either[Expr, BigInt] = { + partialEvalExpr(exp, variableAssignment, memory) match { + case i: IntLiteral => Right(i.value) + case o => Left(o) + } +} + +def evalBVExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)): Either[Expr, BitVecLiteral] = { + partialEvalExpr(exp, variableAssignment, memory) match { + case b: BitVecLiteral => Right(b) + case o => Left(o) + } +} + +def evalLogExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c, d) => None)): Either[Expr, Boolean] = { + partialEvalExpr(exp, variableAssignment, memory) match { + case TrueLiteral => Right(true) + case FalseLiteral => Right(false) + case o => Left(o) + } +} + +def evalExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((d, a,b,c) => None)): Option[Literal] = { + partialEvalExpr match { + case l: Literal => Some(l) + case _ => None + } +} + + +def mkSt[S, F](f: () => F): State[S, F] = for { + n <- get((s:S) => f()) + } yield (n) + +/** + * typeclass defining variable and memory laoding from state S + */ +trait Loader[S] { + def getVariable(v: Variable) : State[S, Option[Expr]] + def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int) : State[S, Option[Literal]] = { + mkSt(() => None) + } +} + + +def statePartialEvalExpr[S](l: Loader[S])(exp: Expr): State[S, Expr] = { + val eval = statePartialEvalExpr(l) exp match { - case f: UninterpretedFunction => f - case unOp: UnaryExpr => { - val body = partialEvalExpr(unOp.arg, variableAssignment, memory) + case f: UninterpretedFunction => mkSt(() => f) + case unOp: UnaryExpr => for { + body <- eval(unOp.arg) + } yield ( body match { case l: Literal => evalUnOp(unOp.op, l) case o => UnaryExpr(unOp.op, body) - } - } - case binOp: BinaryExpr => - val lhs = partialEvalExpr(binOp.arg1, variableAssignment, memory) - val rhs = partialEvalExpr(binOp.arg2, variableAssignment, memory) + }) + case binOp: BinaryExpr => for { + lhs <- eval(binOp.arg1) + rhs <- eval(binOp.arg2) + } yield ( binOp.getType match { case m: MapType => binOp case b: BitVecType => { @@ -163,21 +212,28 @@ def partialEvalExpr(exp: Expr, variableAssignment: Variable => Option[Expr], mem case _ => BinaryExpr(binOp.op, lhs, rhs) } } - } - case extend: ZeroExtend => partialEvalExpr(extend.body, variableAssignment, memory) match { + }) + case extend: ZeroExtend => for { + body <- eval(extend.body) + } yield (body match { case b : BitVecLiteral => BitVectorEval.smt_zero_extend(extend.extension, b) case o => extend.copy(body=o) - } - case extend: SignExtend => partialEvalExpr(extend.body, variableAssignment, memory) match { + }) + case extend: SignExtend => for { + body <- eval(extend.body) + } yield (body match { case b: BitVecLiteral => BitVectorEval.smt_sign_extend(extend.extension, b) case o => extend.copy(body=o) - } - case e: Extract => partialEvalExpr(e.body, variableAssignment, memory) match { + }) + case e: Extract => for { + body <- eval(e.body) + } yield (body match { case b: BitVecLiteral => BitVectorEval.boogie_extract(e.end, e.start, b) case o => e.copy(body=o) - } - case r: Repeat => { - partialEvalExpr(r.body, variableAssignment, memory) match { + }) + case r: Repeat => for { + body <- eval(r.body) + } yield (body match { case b: BitVecLiteral => { assert(r.repeats > 0) if (r.repeats == 1) b @@ -186,42 +242,29 @@ def partialEvalExpr(exp: Expr, variableAssignment: Variable => Option[Expr], mem } } case o => r.copy(body=o) - } - - } - case variable: Variable => variableAssignment(variable).getOrElse(variable) - case l: MemoryLoad => memory(l.mem, partialEvalExpr(l.index, variableAssignment, memory), l.endian, l.size).getOrElse(l) - case b: BitVecLiteral => b - case b: IntLiteral => b - case b: BoolLit => b + }) + case variable: Variable => for { + v <- l.getVariable(variable) + } yield (v.getOrElse(variable)) + case ml: MemoryLoad => for { + addr <- eval(ml.index) + mem <- l.loadMemory(ml.mem, addr, ml.endian, ml.size) + } yield (mem.getOrElse(ml)) + case b: BitVecLiteral => mkSt(() => b) + case b: IntLiteral => mkSt(() => b) + case b: BoolLit => mkSt(() => b) } } -def evalIntExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)): Either[Expr, BigInt] = { - partialEvalExpr(exp, variableAssignment, memory) match { - case i: IntLiteral => Right(i.value) - case o => Left(o) - } -} -def evalBVExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)): Either[Expr, BitVecLiteral] = { - partialEvalExpr(exp, variableAssignment, memory) match { - case b: BitVecLiteral => Right(b) - case o => Left(o) - } +class StatelessLoader(getVar: Variable => Option[Expr], loadMem: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)) extends Loader[Unit] { + def getVariable(v: Variable) : State[Unit, Option[Expr]] = mkSt(() => getVar(v)) + override def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int) : State[Unit, Option[Literal]] = mkSt(() => loadMem(m, addr, endian, size)) } -def evalLogExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c, d) => None)): Either[Expr, Boolean] = { - partialEvalExpr(exp, variableAssignment, memory) match { - case TrueLiteral => Right(true) - case FalseLiteral => Right(false) - case o => Left(o) - } -} -def evalExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((d, a,b,c) => None)): Option[Literal] = { - partialEvalExpr match { - case l: Literal => Some(l) - case _ => None - } +def partialEvalExpr(exp: Expr, variableAssignment: Variable => Option[Expr], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)): Expr = { + val l = StatelessLoader(variableAssignment, memory) + statePartialEvalExpr(l)(exp).f(())._2 } + diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 6e4a60555..15165c199 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -3,12 +3,67 @@ import ir.eval.BitVectorEval.* import ir._ import util.Logger import boogie.Scope +import scala.collection.WithFilter import scala.annotation.tailrec import scala.collection.mutable import scala.collection.immutable import scala.util.control.Breaks.{break, breakable} + +case class State[S, +A](f: S => (S, A)) {// extends WithFilter[A, ({ type l[x] = State[S, x] })#l] { + + def unit[A](a: A): State[S, A] = State(s => (s, a)) + + def foreach[U](f: A => U): Unit = () + + // def flatMap[B](f: A => IterableOnce[B]): ir.eval.State[S, B] = ??? + def flatMap[B](f: A => State[S, B]): State[S, B] = State(s => { + val (s2, a) = this.f(s) + f(a).f(s2) + }) + + def map[B](f: A => B): State[S, B] = { + State(s => { + val (s2, a) = this.f(s) + (s2, f(a)) + }) + } + + def withFilter(q : A => Boolean) = { + this + } +} + + +def stateCompose[S, A, B](s1: State[S, A], s2: State[S, B]) : State[S, B] = { + State((s: S) => s1.f(s) match { + case (s, _) => s2.f(s) + }) +} + +def pure[S, A](a: A) : State[S, A] = State((s:S) => (s, a)) + +def sequence[S, V](xs: Iterable[State[S,V]]) : State[S,V] = { + xs.reduceRight(stateCompose) +} + +def sequence[V](xs: Iterable[Option[V]]) : Option[V] = { + xs.reduceRight((a, b) => a match { + case Some(x) => Some(x) + case None => b + }) +} + +def filterM[A, S](m : (A => State[S, Boolean]), xs: Iterable[A]): State[S, List[A]] = { + xs.foldRight(pure(List[A]()))((b,acc) => acc.flatMap(c => m(b).map(v => if v then b::c else c))) +} + + +def get[S,A](f: S => A) : State[S, A] = State(s => (s, f(s))) +def modify[S](f: S => S) : State[S, Unit] = State(s => (f(s), ())) +def execute[S, A](s: S, c: State[S,A]) = c.f(s) + sealed trait ExecutionContinuation case class FailedAssertion(a: Assert) extends ExecutionContinuation @@ -102,6 +157,35 @@ case object BasilValue: export BasilValue._ + +sealed trait Effects[T] { + /* evaluation (may side-effect via InterpreterException on evaluation failure) */ + def evalBV(e: Expr): State[T, BitVecLiteral] + + def evalInt(e: Expr): State[T, BigInt] + + def evalBool(e: Expr): State[T, Boolean] + + def loadVar(v: String): State[T, BasilValue] + + def loadMem(v: String, addrs: List[BasilValue]): State[T, List[BasilValue]] + + def getNext: State[T, ExecutionContinuation] + + /** effects * */ + def setNext(c: ExecutionContinuation): State[T, Unit] + + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): State[T, Unit] + + def doReturn(): State[T, Unit] + + def storeVar(v: String, scope: Scope, value: BasilValue): State[T, Unit] + + def storeMem(vname: String, update: Map[BasilValue, BasilValue]): State[T, Unit] +} + + + def evalToConst( to: IRType, exp: Expr, @@ -268,116 +352,169 @@ case class MemoryState( } } -case object Eval { - def getVar[T <: Effects[T]](s: T)(v: Variable): Option[Literal] = s.loadVar(v.name) match { - case Scalar(l) => Some(l) - case _ => None - } +case class StVarLoader[S, F <: Effects[S]](f : F) extends Loader[S] { - def doLoad[T <: Effects[T]](s: T)(m: Memory, addr: Expr, endian: Endian, sz: Int): Option[Literal] = { - addr match { - case l: Literal if sz == 1 => ( - loadSingle(s, m.name, Scalar(l)) match { - case Scalar(v) => Some(v) - case _ => None - } - ) - case l: Literal => Some(loadBV(s, m.name, Scalar(l), endian, sz)) - case _ => None + /** Load helpers * */ + def load(vname: String, addr: Scalar, endian: Endian, count: Int): State[S, List[BasilValue]] = { + if (count == 0) { + throw InterpreterError(Errored(s"Attempted fractional load")) } - } - - def evalBV[T <: Effects[T]](s: T, e: Expr): BitVecLiteral = { - ir.eval.evalBVExpr(e, Eval.getVar(s), Eval.doLoad(s)) match { - case Right(e) => e - case Left(e) => throw InterpreterError(Errored(s"Eval BV residual $e")) + val keys = (0 until count).map(i => BasilValue.unsafeAdd(addr, i)) + for { + values <- f.loadMem(vname, keys.toList) + vals = endian match { + case Endian.LittleEndian => values.reverse + case Endian.BigEndian => values + } } + yield (vals.toList) } - def evalInt[T <: Effects[T]](s: T, e: Expr): BigInt = { - ir.eval.evalIntExpr(e, Eval.getVar(s), Eval.doLoad(s)) match { - case Right(e) => e - case Left(e) => throw InterpreterError(Errored(s"Eval int residual $e")) - } - } - def evalBool[T <: Effects[T]](s: T, e: Expr): Boolean = { - ir.eval.evalLogExpr(e, Eval.getVar(s), Eval.doLoad(s)) match { - case Right(e) => e - case Left(e) => throw InterpreterError(Errored(s"Eval bool residual $e")) - } - } + /** Load and concat bitvectors */ + def loadBV( vname: String, addr: Scalar, endian: Endian, size: Int): State[S, BitVecLiteral] = for { + mem <- f.loadVar(vname) + (valsize, mapv) = mem match { + case mapv @ MapValue(_, MapType(_, BitVecType(sz))) => (sz, mapv) + case _ => throw InterpreterError(Errored("Trued to load-concat non bv")) + } - /** Load helpers * */ - def load[T <: Effects[T]](s: T, vname: String, addr: Scalar, endian: Endian, count: Int): List[BasilValue] = { - if (count == 0) { - throw InterpreterError(Errored(s"Attempted fractional load")) - } + cells = size / valsize - val keys = (0 until count).map(i => BasilValue.unsafeAdd(addr, i)) - val values = s.loadMem(vname, keys.toList) - val vals = endian match { - case Endian.LittleEndian => values.reverse - case Endian.BigEndian => values - } + res <- load(vname, addr, endian, cells) + bvs: List[BitVecLiteral] = { + val rr = res.map { + case Scalar(bv @ BitVecLiteral(v, sz)) if sz == valsize => bv + case c => + throw InterpreterError(TypeError(s"Loaded value of type ${c.irType} did not match expected type bv$valsize")) + } + rr + } - vals.toList + bvres = bvs.foldLeft(BitVecLiteral(0, 0))((acc, r) => eval.evalBVBinExpr(BVCONCAT, acc, r)) + _ = {assert(bvres.size == size)} + } yield(bvres) + + def loadSingle(vname: String, addr: Scalar): State[S, BasilValue] = { + for { + m <- load(vname, addr, Endian.LittleEndian, 1) + } yield (m.head) } - /** Load and concat bitvectors */ - def loadBV[T <: Effects[T]](s: T, vname: String, addr: Scalar, endian: Endian, size: Int): BitVecLiteral = { - val mem = s.loadVar(vname) + def getVariable(v: Variable) : State[S, Option[Expr]] = { + for { + v <- f.loadVar(v.name) + } yield ( + v match { + case Scalar(l) => Some(l) + case _ => None + }) + } - val (valsize, mapv) = mem match { - case mapv @ MapValue(_, MapType(_, BitVecType(sz))) => (sz, mapv) - case _ => throw InterpreterError(Errored("Trued to load-concat non bv")) + override def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int) : State[S, Option[Literal]] = { + for { + r <- addr match { + case l: Literal if size== 1 => loadSingle(m.name, Scalar(l)).map((v : BasilValue) => v match { + case Scalar(l) => Some(l) + case _ => None + }) + case l: Literal => loadBV(m.name, Scalar(l), endian, size).map(Some(_)) + case _ => get((s:S) => None) } + } yield (r) + } - val cells = size / valsize +} - val bvs: List[BitVecLiteral] = { - val res = load(s, vname, addr, endian, cells) - val rr = res.map { - case Scalar(bv @ BitVecLiteral(v, sz)) if sz == valsize => bv - case c => - throw InterpreterError(TypeError(s"Loaded value of type ${c.irType} did not match expected type bv$valsize")) - } - rr - } - val bvres = bvs.foldLeft(BitVecLiteral(0, 0))((acc, r) => eval.evalBVBinExpr(BVCONCAT, acc, r)) - assert(bvres.size == size) - bvres +case object Eval { + //def getVar[S, F <: Effects[S]](f: F)(s: S)(v: Variable): Option[Literal] = + // f.loadVar(v.name).f(s) match { + // case Scalar(l) => Some(l) + // case _ => None + //} + + //def doLoad[S, T <: Effects[S]](f: T)(s: S)(m: Memory, addr: Expr, endian: Endian, sz: Int): Option[Literal] = { + // addr match { + // case l: Literal if sz == 1 => ( + // loadSingle(f)(s)(m.name, Scalar(l)) match { + // case Scalar(v) => Some(v) + // case _ => None + // } + // ) + // case l: Literal => Some(loadBV(f)(s)(m.name, Scalar(l), endian, sz)) + // case _ => None + // } + //} + + def evalBV[S, T <: Effects[S]](f: T)(e: Expr): State[S, BitVecLiteral] = { + val ldr = StVarLoader[S, T](f) + for { + res <- ir.eval.statePartialEvalExpr[S](ldr)(e) + } yield ( + e match { + case l: BitVecLiteral => l + case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) + }) + } + + def evalInt[S, T <: Effects[S]](f: T)(e: Expr): State[S, BigInt] = { + val ldr = StVarLoader[S, T](f) + for { + res <- ir.eval.statePartialEvalExpr[S](ldr)(e) + } yield ( + e match { + case l: IntLiteral => l.value + case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) + }) + } + + def evalBool[S, T <: Effects[S]](f: T)(e: Expr): State[S, Boolean] = { + val ldr = StVarLoader[S, T](f) + for { + res <- ir.eval.statePartialEvalExpr[S](ldr)(e) + } yield ( + e match { + case l: BoolLit => l == TrueLiteral + case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) + }) } - def loadSingle[T <: Effects[T]](s: T, vname: String, addr: Scalar): BasilValue = { - load(s, vname, addr, Endian.LittleEndian, 1).head - } + /** State modifying helpers, e.g. store */ /* Expand addr for number of values to store */ - def store[T <: Effects[T]](st: T, vname: String, addr: BasilValue, values: List[BasilValue], endian: Endian): T = { - val mem = st.loadVar(vname) + def store[S, T <: Effects[S]](f: T)( + vname: String, + addr: BasilValue, + values: List[BasilValue], + endian: Endian + ): State[S, Unit] = for { + mem <- f.loadVar(vname) + (mapval, keytype, valtype) = mem match { + case m @ MapValue(_, MapType(kt, vt)) if kt == addr.irType && values.forall(v => v.irType == vt) => (m, kt, vt) + case v => throw InterpreterError(TypeError(s"Invalid map store operation to $vname : $v")) + } + keys = (0 until values.size).map(i => BasilValue.unsafeAdd(addr, i)) + vals = endian match { + case Endian.LittleEndian => values.reverse + case Endian.BigEndian => values + } + x <- f.storeMem(vname, keys.zip(vals).toMap) + } yield (x) - val (mapval, keytype, valtype) = mem match { - case m @ MapValue(_, MapType(kt, vt)) if kt == addr.irType && values.forall(v => v.irType == vt) => (m, kt, vt) - case v => throw InterpreterError(TypeError(s"Invalid map store operation to $vname : ${v.irType}")) - } - val keys = (0 until values.size).map(i => BasilValue.unsafeAdd(addr, i)) - val vals = endian match { - case Endian.LittleEndian => values.reverse - case Endian.BigEndian => values - } - - st.storeMem(vname, keys.zip(vals).toMap) - } /** Extract bitvec to bytes and store bytes */ - def storeBV[T <: Effects[T]](st: T, vname: String, addr: BasilValue, value: BitVecLiteral, endian: Endian): T = { - val mem = st.loadVar(vname) - val (mapval, vsize) = mem match { + def storeBV[S, T <: Effects[S]](f: T)( + vname: String, + addr: BasilValue, + value: BitVecLiteral, + endian: Endian + ): State[S, Unit] = for { + mem <- f.loadVar(vname) + (mapval, vsize) = mem match { case m @ MapValue(_, MapType(kt, BitVecType(size))) if kt == addr.irType => (m, size) case v => throw InterpreterError( @@ -386,54 +523,26 @@ case object Eval { ) ) } - val cells = value.size / vsize + cells = value.size / vsize + _ = { if (cells < 1) { throw InterpreterError(MemoryError("Tried to execute fractional store")) - } - - val extractVals = (0 until cells).map(i => BitVectorEval.boogie_extract((i + 1) * vsize, i * vsize, value)).toList + }} - val vs = endian match { + extractVals = (0 until cells).map(i => BitVectorEval.boogie_extract((i + 1) * vsize, i * vsize, value)).toList + vs = endian match { case Endian.LittleEndian => extractVals.map(Scalar(_)) case Endian.BigEndian => extractVals.reverse.map(Scalar(_)) } - val keys = (0 until cells).map(i => BasilValue.unsafeAdd(addr, i)) - st.storeMem(vname, keys.zip(vs).toMap) - } + keys = (0 until cells).map(i => BasilValue.unsafeAdd(addr, i)) + } yield (f.storeMem(vname, keys.zip(vs).toMap)) - def storeSingle[T <: Effects[T]](st: T, vname: String, addr: BasilValue, value: BasilValue): T = { - st.storeMem(vname, Map((addr -> value))) + def storeSingle[S, T <: Effects[S]](f: T)(vname: String, addr: BasilValue, value: BasilValue): State[S, Unit] = { + f.storeMem(vname, Map((addr -> value))) } - } -sealed trait Effects[T <: Effects[T]] { - /* evaluation (may side-effect via InterpreterException on evaluation failure) */ - def evalBV(e: Expr): BitVecLiteral - - def evalInt(e: Expr): BigInt - - def evalBool(e: Expr): Boolean - - def loadVar(v: String): BasilValue - - def loadMem(v: String, addrs: List[BasilValue]): List[BasilValue] - - /** effects * */ - def setNext(c: ExecutionContinuation): T - - def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): T - - def doReturn(): T - - def storeVar(v: String, scope: Scope, value: BasilValue): T - - def storeMem(vname: String, update: Map[BasilValue, BasilValue]): T - - def getNext : ExecutionContinuation - -} enum Effect: case Call(target: String, begin: ExecutionContinuation, returnTo: ExecutionContinuation) @@ -442,200 +551,216 @@ enum Effect: case StoreVar(v: String, s: Scope, value: BasilValue) case StoreMem(vname: String, update: Map[BasilValue, BasilValue]) -case class TracingInterpreter( - val s: InterpreterState, - val trace: List[Effect] -) extends Effects[TracingInterpreter] { - def evalBV(e: Expr): BitVecLiteral = s.evalBV(e) - def evalInt(e: Expr): BigInt = s.evalInt(e) - def evalBool(e: Expr): Boolean = s.evalBool(e) - - def loadVar(v: String): BasilValue = s.loadVar(v) - def loadMem(v: String, addrs: List[BasilValue]) = s.loadMem(v, addrs) - - /** effects * */ - def setNext(c: ExecutionContinuation) = { - // Logger.debug(s" eff : DONEXT $c") - TracingInterpreter(s.setNext(c), trace) - } - - def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = { - //Logger.debug(s" eff : CALL $target") - TracingInterpreter(s.call(target, beginFrom, returnTo), Effect.Call(target, beginFrom, returnTo)::trace) - } - - def doReturn() = { - //Logger.debug(s" eff : RETURN") - TracingInterpreter(s.doReturn(), Effect.Return::trace) - } - - def storeVar(v: String, c: Scope, value: BasilValue) = { - //Logger.debug(s" eff : SET $v := $value") - TracingInterpreter(s.storeVar(v, c, value), Effect.StoreVar(v, c, value)::trace) - } - - def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = { - //Logger.debug(s" eff : STORE $vname <- $update") - TracingInterpreter(s.storeMem(vname, update), Effect.StoreMem(vname, update)::trace) - } - - def getNext = s.getNext - -} +// case class TracingInterpreter( +// val s: InterpreterState, +// val trace: List[Effect] +// ) extends Effects[TracingInterpreter] { +// +// def evalBV(e: Expr) = Eval.evalBV(this)(e) +// def evalInt(e: Expr) = Eval.evalInt(this)(e) +// def evalBool(e: Expr) = Eval.evalBool(this)(e) +// +// def loadVar(v: String) = +// def loadMem(v: String, addrs: List[BasilValue]) = s.loadMem(v, addrs) +// +// /** effects * */ +// def setNext(c: ExecutionContinuation) = { +// // Logger.debug(s" eff : DONEXT $c") +// TracingInterpreter(s.setNext(c), trace) +// } +// +// def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = { +// //Logger.debug(s" eff : CALL $target") +// TracingInterpreter(s.call(target, beginFrom, returnTo), Effect.Call(target, beginFrom, returnTo) :: trace) +// } +// +// def doReturn() = { +// //Logger.debug(s" eff : RETURN") +// TracingInterpreter(s.doReturn(), Effect.Return :: trace) +// } +// +// def storeVar(v: String, c: Scope, value: BasilValue) = { +// //Logger.debug(s" eff : SET $v := $value") +// TracingInterpreter(s.storeVar(v, c, value), Effect.StoreVar(v, c, value) :: trace) +// } +// +// def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = { +// //Logger.debug(s" eff : STORE $vname <- $update") +// TracingInterpreter(s.storeMem(vname, update), Effect.StoreMem(vname, update) :: trace) +// } +// +// def getNext = s.getNext +// +// } case class InterpreterState( val nextCmd: ExecutionContinuation = Stopped(), val callStack: List[ExecutionContinuation] = List.empty, val memoryState: MemoryState = MemoryState() -) extends Effects[InterpreterState] { +) + +case class NormalInterpreter() extends Effects[InterpreterState] { /* eval */ - def evalBV(e: Expr): BitVecLiteral = Eval.evalBV(this, e) + def evalBV(e: Expr) = Eval.evalBV(this)(e) - def evalInt(e: Expr): BigInt = Eval.evalInt(this, e) + def evalInt(e: Expr) = Eval.evalInt(this)(e) - def evalBool(e: Expr): Boolean = Eval.evalBool(this, e) + def evalBool(e: Expr) = Eval.evalBool(this)(e) - def loadVar(v: String): BasilValue = memoryState.getVar(v) + def loadVar(v: String) = get((s: InterpreterState) => s.memoryState.getVar(v)) - def loadMem(v: String, addrs: List[BasilValue]): List[BasilValue] = memoryState.doLoad(v, addrs) + def loadMem(v: String, addrs: List[BasilValue]) = get((s: InterpreterState) => s.memoryState.doLoad(v, addrs)) - def getNext = nextCmd + def getNext = get ((s: InterpreterState) => s.nextCmd) /** effects * */ - def setNext(c: ExecutionContinuation): InterpreterState = { - InterpreterState(c, callStack, memoryState) - } + def setNext(c: ExecutionContinuation) = modify ((s: InterpreterState) => { + InterpreterState(c, s.callStack, s.memoryState) + }) - def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): InterpreterState = { + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = modify ((s:InterpreterState) => { + Logger.info(s" eff : CALL $target") InterpreterState( beginFrom, - returnTo :: callStack, - memoryState.pushStackFrame(target) + returnTo :: s.callStack, + s.memoryState.pushStackFrame(target) ) - } + }) - def doReturn(): InterpreterState = { - callStack match { - case Nil => InterpreterState(Stopped(), Nil, memoryState) + def doReturn() = { + modify ((s: InterpreterState) => {s.callStack match { + case Nil => InterpreterState(Stopped(), Nil, s.memoryState) case h :: tl => { - InterpreterState(h, tl, memoryState.popStackFrame()) + InterpreterState(h, tl, s.memoryState.popStackFrame()) } } + }) } def storeVar(v: String, scope: Scope, value: BasilValue) = { - InterpreterState(nextCmd, callStack, memoryState.defVar(v, scope, value)) + Logger.debug(s" eff : SET $v := $value") + modify ((s: InterpreterState) => InterpreterState(s.nextCmd, s.callStack, s.memoryState.defVar(v, scope, value))) } - def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = { - InterpreterState(nextCmd, callStack, memoryState.doStore(vname, update)) - } + def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = modify ((s:InterpreterState) => { + Logger.debug(s" eff : STORE $vname <- $update") + InterpreterState(s.nextCmd, s.callStack, s.memoryState.doStore(vname, update)) + }) } case object InterpFuns { - def initialState[T <: Effects[T]](s: T): T = { + def initialState[S, T <: Effects[S]](s: T): State[S, Unit] = { val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) - s.storeVar("mem", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) - .storeVar("stack", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) - .storeVar("R31", Scope.Global, Scalar(SP)) - .storeVar("R29", Scope.Global, Scalar(FP)) - .storeVar("R30", Scope.Global, Scalar(LR)) - } - - def initialiseProgram[T <: Effects[T]](p: Program, s: T): T = { - val mem = p.initialMemory.foldLeft(initialState(s))((s, memory) => { - val s1 = Eval.store( - s, - "mem", - Scalar(BitVecLiteral(memory.address, 64)), - memory.bytes.toList.map(Scalar(_)), - Endian.LittleEndian - ) - Eval.store( - s1, - "stack", - Scalar(BitVecLiteral(memory.address, 64)), - memory.bytes.toList.map(Scalar(_)), - Endian.LittleEndian - ) - }) - - mem.call(p.mainProcedure.name, Run(IRWalk.firstInBlock(p.mainProcedure.entryBlock.get)), Stopped()) - } - - def interpretJump[T <: Effects[T]](s: T, j: Jump): T = { + for { + l <- s.storeVar("R30", Scope.Global, Scalar(LR)) + h <- s.storeVar("mem", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + i <- s.storeVar("stack", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + j <- s.storeVar("R31", Scope.Global, Scalar(SP)) + k <- s.storeVar("R29", Scope.Global, Scalar(FP)) + } yield (l) + // s.storeVar("mem", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + } + + def initialiseProgram[S, T <: Effects[S]](f:T)(p: Program): State[S, Unit] = { + for { + d <- initialState(f) + mem <- sequence(p.initialMemory.map(memory => + Eval.store(f)( + "mem", + Scalar(BitVecLiteral(memory.address, 64)), + memory.bytes.toList.map(Scalar(_)), + Endian.LittleEndian))) + mem <- sequence(p.initialMemory.map(memory => + Eval.store(f)( + "stack", + Scalar(BitVecLiteral(memory.address, 64)), + memory.bytes.toList.map(Scalar(_)), + Endian.LittleEndian))) + r <- f.call(p.mainProcedure.name, Run(IRWalk.firstInBlock(p.mainProcedure.entryBlock.get)), Stopped()) + } yield (r) + } + + def interpretJump[S, T <: Effects[S]](f: T)(j: Jump): State[S, Unit] = { j match { case gt: GoTo if gt.targets.size == 1 => { - s.setNext(Run(IRWalk.firstInBlock(gt.targets.head))) + f.setNext(Run(IRWalk.firstInBlock(gt.targets.head))) } case gt: GoTo => - val condition = gt.targets.flatMap(_.statements.headOption).collect { case a: Assume => - (a, s.evalBool(a.body)) - } - - if (condition.size != gt.targets.size) { + val assumes = gt.targets.flatMap(_.statements.headOption).collect { + case a: Assume => a + } + if (assumes.size != gt.targets.size) { throw InterpreterError(Errored(s"Some goto target missing guard $gt")) } + for { + chosen : List[Assume] <- filterM((a:Assume) => f.evalBool(a.body), assumes) - val chosen = condition.filter(_._2).toList match { - case Nil => throw InterpreterError(Errored(s"No jump target satisfied $gt")) - case h :: Nil => h - case h :: tl => throw InterpreterError(Errored(s"More than one jump guard satisfied $gt")) - } - - s.setNext(Run(chosen._1.successor)) - case r: Return => s.doReturn() - case h: Unreachable => s.setNext(EscapedControlFlow(h)) + res <- chosen match { + case Nil => f.setNext(Errored(s"No jump target satisfied $gt")) + case h :: Nil => f.setNext(Run(h)) + case h :: tl => f.setNext(Errored(s"More than one jump guard satisfied $gt")) + } + } yield (res) + case r: Return => f.doReturn() + case h: Unreachable => f.setNext(EscapedControlFlow(h)) } } - def interpretStatement[T <: Effects[T]](st: T, s: Statement): T = { + def interpretStatement[S, T <: Effects[S]](f: T)(s: Statement): State[S, Unit] = { s match { case assign: Assign => { - val rhs = st.evalBV(assign.rhs) - st.storeVar(assign.lhs.name, assign.lhs.toBoogie.scope, Scalar(rhs)).setNext(Run(s.successor)) + for { + rhs <- f.evalBV(assign.rhs) + st <- f.storeVar(assign.lhs.name, assign.lhs.toBoogie.scope, Scalar(rhs)) + n <- f.setNext(Run(s.successor)) + } yield (st) } - case assign: MemoryAssign => { - val index: BitVecLiteral = st.evalBV(assign.index) - val value: BitVecLiteral = st.evalBV(assign.value) + case assign: MemoryAssign => for { + index : BitVecLiteral <- f.evalBV(assign.index) + value : BitVecLiteral <- f.evalBV(assign.value) // st.storeMem(assign.mem.name, index, value, assign.endian, assign.size).setNext(Run(s.successor)) - Eval.storeBV(st, assign.mem.name, Scalar(index), value, assign.endian).setNext(Run(s.successor)) - } - case assert: Assert => { - if (!st.evalBool(assert.body)) then { - st.setNext(FailedAssertion(assert)) + _ <- Eval.storeBV(f)(assign.mem.name, Scalar(index), value, assign.endian) + n <- f.setNext(Run(s.successor)) + } yield (n) + case assert: Assert => for { + b <- f.evalBool(assert.body) + n <- (if (!b) then { + f.setNext(FailedAssertion(assert)) } else { - st.setNext(Run(s.successor)) - } - } - case assume: Assume => - if (!st.evalBool(assume.body)) { - st.setNext(Errored(s"Assumption not satisfied: $assume")) + f.setNext(Run(s.successor)) + }) + } yield (n) + case assume: Assume => for { + b <- f.evalBool(assume.body) + n <- (if (!b) { + f.setNext(Errored(s"Assumption not satisfied: $assume")) } else { - st.setNext(Run(s.successor)) - - } - case dc: DirectCall => - if (dc.target.entryBlock.isDefined) { + f.setNext(Run(s.successor)) + }) + } yield (n) + case dc: DirectCall => for { + n <- if (dc.target.entryBlock.isDefined) { val block = dc.target.entryBlock.get - st.call(dc.target.name, Run(block.statements.headOption.getOrElse(block.jump)), Run(dc.successor)) + f.call(dc.target.name, Run(block.statements.headOption.getOrElse(block.jump)), Run(dc.successor)) } else { - st.setNext(Run(dc.successor)) + f.setNext(Run(dc.successor)) } - case ic: IndirectCall => - if (ic.target == Register("R30", 64)) { - st.doReturn() + } yield (n) + case ic: IndirectCall => for { + n <- (if (ic.target == Register("R30", 64)) { + f.doReturn() } else { - st.setNext(EscapedControlFlow(ic)) - } - case _: NOP => st.setNext(Run(s.successor)) + f.setNext(EscapedControlFlow(ic)) + }) + } yield (n) + case _: NOP => f.setNext(Run(s.successor)) } } @@ -647,48 +772,33 @@ case object InterpFuns { } } - @tailrec - def interpret[T <: Effects[T]](s: T): T = { - Logger.debug(s"interpret ${s.getNext}") - s.getNext match { - case Run(c: Statement) => - interpret( - protect[T]( - () => interpretStatement(s, c), - { - case InterpreterError(e) => s.setNext(e) - case e: IllegalArgumentException => s.setNext(Errored(s"Evaluation error $e")) - } - ) - ) - case Run(c: Jump) => - interpret( - protect[T]( - () => interpretJump(s, c), - { - case InterpreterError(e) => s.setNext(e) - case e: IllegalArgumentException => s.setNext(Errored(s"Evaluation error $e")) - } - ) - ) - case Stopped() => s - case errorstop => s - } + def interpret[S, T <: Effects[S]](f: T, m: State[S, Unit]): State[S, Unit] = { + for { + n <- f.getNext + _ = { + Logger.debug(s"interpret ${n}") + } + c <- n match { + case Run(c: Statement) => interpret(f, interpretStatement(f)(c)) + case Run(c: Jump) => interpret(f, interpretJump(f)(c)) + case Stopped() => State((s:S) => (s,())) + case errorstop => State((s:S) => (s,())) + } + } yield(c) } - def interpretProg[T <: Effects[T]](p: Program, i: T): T = { - var s = initialiseProgram(p, i) - interpret(s) + def interpretProg[S, T <: Effects[S]](f: T)(p: Program, is: S): S = { + val (fs: S,_) = execute[S, Unit](is, interpret(f, initialiseProgram(f)(p))) + fs } } -def interpret(IRProgram: Program) : InterpreterState = { - InterpFuns.interpretProg(IRProgram, InterpreterState()) -} - -def interpretTrace(IRProgram: Program) : TracingInterpreter = { - val s : TracingInterpreter = InterpFuns.interpretProg(IRProgram, TracingInterpreter(InterpreterState(), List())) - s +def interpret(IRProgram: Program): InterpreterState = { + InterpFuns.interpretProg(NormalInterpreter())(IRProgram, InterpreterState()) } +// def interpretTrace(IRProgram: Program): TracingInterpreter = { +// val s: TracingInterpreter = InterpFuns.interpretProg(IRProgram, TracingInterpreter(InterpreterState(), List())) +// s +// } diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index ecf521a51..1efabf0c1 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -24,14 +24,15 @@ import util.ILLoadingConfig // } -def initialMem() = InterpFuns.initialState(TracingInterpreter(InterpreterState(), List())) +// def initialMem() = InterpFuns.initialState(InterpreterState(), List()) -def load[T <: Effects[T]](s: T, global: SpecGlobal) : Option[BitVecLiteral] = { +def load(s: InterpreterState, global: SpecGlobal) : Option[BitVecLiteral] = { + val f = NormalInterpreter() // i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems) // m.evalBV("mem", BitVecLiteral(64, global.address), Endian.LittleEndian, global.size) // i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems) try { - Some(s.evalBV(MemoryLoad(SharedMemory("mem", 64, 8), BitVecLiteral(global.address, 64), Endian.LittleEndian, global.size))) + Some(f.evalBV(MemoryLoad(SharedMemory("mem", 64, 8), BitVecLiteral(global.address, 64), Endian.LittleEndian, global.size)).f(s)._2) } catch { case e : InterpreterError => None } @@ -99,102 +100,102 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { } - test("Store = Load LittleEndian") { - val ts = List( - BitVecLiteral(BigInt("0D", 16), 8), - BitVecLiteral(BigInt("0C", 16), 8), - BitVecLiteral(BigInt("0B", 16), 8), - BitVecLiteral(BigInt("0A", 16), 8)) - - val s = Eval.store(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) - val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) - val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) - assert(actual == expected) - - - } - - test("store bv = loadbv le") { - val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) - val s2 = Eval.storeBV(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) - val actual2: BitVecLiteral = Eval.loadBV(s2, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) - assert(actual2 == expected) - } - - - test("Store = Load BigEndian") { - val ts = List( - BitVecLiteral(BigInt("0D", 16), 8), - BitVecLiteral(BigInt("0C", 16), 8), - BitVecLiteral(BigInt("0B", 16), 8), - BitVecLiteral(BigInt("0A", 16), 8)) - - val s = Eval.store(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) - val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) - val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.BigEndian , 32) - assert(actual == expected) - - - } - - test("getMemory in LittleEndian") { - val ts = List((BitVecLiteral(0, 64), BitVecLiteral(BigInt("0D", 16), 8)), - (BitVecLiteral(1, 64) , BitVecLiteral(BigInt("0C", 16), 8)), - (BitVecLiteral(2, 64) , BitVecLiteral(BigInt("0B", 16), 8)), - (BitVecLiteral(3, 64) , BitVecLiteral(BigInt("0A", 16), 8))) - val s = ts.foldLeft(initialMem())((m, v) => Eval.storeSingle(m, "mem", Scalar(v._1), Scalar(v._2))) - // val s = initialMem().store("mem") - // val r = s.loadBV("mem", BitVecLiteral(0, 64)) - - val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) - - // def loadBV(vname: String, addr: Scalar, endian: Endian, size: Int): BitVecLiteral = { - val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) - assert(actual == expected) - } - - - test("StoreBV = LoadBV LE ") { - val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) - - val s = Eval.storeBV(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) - val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) - println(s"${actual.value.toInt.toHexString} == ${expected.value.toInt.toHexString}") - assert(actual == expected) - } - - // test("getMemory in BigEndian") { - // i.mems(0) = BitVecLiteral(BigInt("0A", 16), 8) - // i.mems(1) = BitVecLiteral(BigInt("0B", 16), 8) - // i.mems(2) = BitVecLiteral(BigInt("0C", 16), 8) - // i.mems(3) = BitVecLiteral(BigInt("0D", 16), 8) - // val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) - // val actual: BitVecLiteral = i.getMemory(0, 32, Endian.BigEndian, i.mems) - // assert(actual == expected) - // } - - // test("setMemory in LittleEndian") { - // i.mems(0) = BitVecLiteral(BigInt("FF", 16), 8) - // i.mems(1) = BitVecLiteral(BigInt("FF", 16), 8) - // i.mems(2) = BitVecLiteral(BigInt("FF", 16), 8) - // i.mems(3) = BitVecLiteral(BigInt("FF", 16), 8) - // val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) - // i.setMemory(0, 32, Endian.LittleEndian, expected, i.mems) - // val actual: BitVecLiteral = i.getMemory(0, 32, Endian.LittleEndian, i.mems) - // assert(actual == expected) - // } - - // test("setMemory in BigEndian") { - // i.mems(0) = BitVecLiteral(BigInt("FF", 16), 8) - // i.mems(1) = BitVecLiteral(BigInt("FF", 16), 8) - // i.mems(2) = BitVecLiteral(BigInt("FF", 16), 8) - // i.mems(3) = BitVecLiteral(BigInt("FF", 16), 8) - // val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) - // i.setMemory(0, 32, Endian.BigEndian, expected, i.mems) - // val actual: BitVecLiteral = i.getMemory(0, 32, Endian.BigEndian, i.mems) - // assert(actual == expected) - // } - +// test("Store = Load LittleEndian") { +// val ts = List( +// BitVecLiteral(BigInt("0D", 16), 8), +// BitVecLiteral(BigInt("0C", 16), 8), +// BitVecLiteral(BigInt("0B", 16), 8), +// BitVecLiteral(BigInt("0A", 16), 8)) +// +// val s = Eval.store(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) +// val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) +// val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) +// assert(actual == expected) +// +// +// } +// +// test("store bv = loadbv le") { +// val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) +// val s2 = Eval.storeBV(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) +// val actual2: BitVecLiteral = Eval.loadBV(s2, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) +// assert(actual2 == expected) +// } +// +// +// test("Store = Load BigEndian") { +// val ts = List( +// BitVecLiteral(BigInt("0D", 16), 8), +// BitVecLiteral(BigInt("0C", 16), 8), +// BitVecLiteral(BigInt("0B", 16), 8), +// BitVecLiteral(BigInt("0A", 16), 8)) +// +// val s = Eval.store(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) +// val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) +// val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.BigEndian , 32) +// assert(actual == expected) +// +// +// } +// +// test("getMemory in LittleEndian") { +// val ts = List((BitVecLiteral(0, 64), BitVecLiteral(BigInt("0D", 16), 8)), +// (BitVecLiteral(1, 64) , BitVecLiteral(BigInt("0C", 16), 8)), +// (BitVecLiteral(2, 64) , BitVecLiteral(BigInt("0B", 16), 8)), +// (BitVecLiteral(3, 64) , BitVecLiteral(BigInt("0A", 16), 8))) +// val s = ts.foldLeft(initialMem())((m, v) => Eval.storeSingle(m, "mem", Scalar(v._1), Scalar(v._2))) +// // val s = initialMem().store("mem") +// // val r = s.loadBV("mem", BitVecLiteral(0, 64)) +// +// val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) +// +// // def loadBV(vname: String, addr: Scalar, endian: Endian, size: Int): BitVecLiteral = { +// val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) +// assert(actual == expected) +// } +// +// +// test("StoreBV = LoadBV LE ") { +// val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) +// +// val s = Eval.storeBV(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) +// val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) +// println(s"${actual.value.toInt.toHexString} == ${expected.value.toInt.toHexString}") +// assert(actual == expected) +// } +// +// // test("getMemory in BigEndian") { +// // i.mems(0) = BitVecLiteral(BigInt("0A", 16), 8) +// // i.mems(1) = BitVecLiteral(BigInt("0B", 16), 8) +// // i.mems(2) = BitVecLiteral(BigInt("0C", 16), 8) +// // i.mems(3) = BitVecLiteral(BigInt("0D", 16), 8) +// // val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) +// // val actual: BitVecLiteral = i.getMemory(0, 32, Endian.BigEndian, i.mems) +// // assert(actual == expected) +// // } +// +// // test("setMemory in LittleEndian") { +// // i.mems(0) = BitVecLiteral(BigInt("FF", 16), 8) +// // i.mems(1) = BitVecLiteral(BigInt("FF", 16), 8) +// // i.mems(2) = BitVecLiteral(BigInt("FF", 16), 8) +// // i.mems(3) = BitVecLiteral(BigInt("FF", 16), 8) +// // val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) +// // i.setMemory(0, 32, Endian.LittleEndian, expected, i.mems) +// // val actual: BitVecLiteral = i.getMemory(0, 32, Endian.LittleEndian, i.mems) +// // assert(actual == expected) +// // } +// +// // test("setMemory in BigEndian") { +// // i.mems(0) = BitVecLiteral(BigInt("FF", 16), 8) +// // i.mems(1) = BitVecLiteral(BigInt("FF", 16), 8) +// // i.mems(2) = BitVecLiteral(BigInt("FF", 16), 8) +// // i.mems(3) = BitVecLiteral(BigInt("FF", 16), 8) +// // val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) +// // i.setMemory(0, 32, Endian.BigEndian, expected, i.mems) +// // val actual: BitVecLiteral = i.getMemory(0, 32, Endian.BigEndian, i.mems) +// // assert(actual == expected) +// // } +// test("basic_arrays_read") { val expected = Map( "arr" -> 0 @@ -354,15 +355,15 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { } - test("fibonacci Trace") { - - val fib = fibonacciProg(8) - val r = interpretTrace(fib) - assert(r.getNext == Stopped()) - // Show interpreted result - // - info(r.trace.reverse.mkString("\n")) - - } +// test("fibonacci Trace") { +// +// val fib = fibonacciProg(8) +// val r = interpretTrace(fib) +// assert(r.getNext == Stopped()) +// // Show interpreted result +// // +// info(r.trace.reverse.mkString("\n")) +// +// } } From 2c16c838ec1f911e532f472d3734d16a36c2c9e3 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 28 Aug 2024 10:47:36 +1000 Subject: [PATCH 12/62] fix state monad interp --- src/main/scala/ir/eval/ExprEval.scala | 27 ++- src/main/scala/ir/eval/Interpreter.scala | 234 +++++++++++------------ src/main/scala/util/functional.scala | 59 ++++++ src/test/scala/ir/InterpreterTests.scala | 63 ++++-- src/test/scala/util/StateMonad.scala | 31 +++ 5 files changed, 265 insertions(+), 149 deletions(-) create mode 100644 src/main/scala/util/functional.scala create mode 100644 src/test/scala/util/StateMonad.scala diff --git a/src/main/scala/ir/eval/ExprEval.scala b/src/main/scala/ir/eval/ExprEval.scala index 5cb5979f7..85d71d7cb 100644 --- a/src/main/scala/ir/eval/ExprEval.scala +++ b/src/main/scala/ir/eval/ExprEval.scala @@ -1,5 +1,6 @@ package ir.eval import ir.eval.BitVectorEval +import util.functional.State import ir._ /** @@ -159,17 +160,13 @@ def evalExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: } -def mkSt[S, F](f: () => F): State[S, F] = for { - n <- get((s:S) => f()) - } yield (n) - /** * typeclass defining variable and memory laoding from state S */ trait Loader[S] { - def getVariable(v: Variable) : State[S, Option[Expr]] + def getVariable(v: Variable) : State[S, Option[Literal]] def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int) : State[S, Option[Literal]] = { - mkSt(() => None) + State.pure(None) } } @@ -177,7 +174,7 @@ trait Loader[S] { def statePartialEvalExpr[S](l: Loader[S])(exp: Expr): State[S, Expr] = { val eval = statePartialEvalExpr(l) exp match { - case f: UninterpretedFunction => mkSt(() => f) + case f: UninterpretedFunction => State.pure(f) case unOp: UnaryExpr => for { body <- eval(unOp.arg) } yield ( @@ -244,26 +241,26 @@ def statePartialEvalExpr[S](l: Loader[S])(exp: Expr): State[S, Expr] = { case o => r.copy(body=o) }) case variable: Variable => for { - v <- l.getVariable(variable) + v : Option[Literal] <- l.getVariable(variable) } yield (v.getOrElse(variable)) case ml: MemoryLoad => for { addr <- eval(ml.index) mem <- l.loadMemory(ml.mem, addr, ml.endian, ml.size) } yield (mem.getOrElse(ml)) - case b: BitVecLiteral => mkSt(() => b) - case b: IntLiteral => mkSt(() => b) - case b: BoolLit => mkSt(() => b) + case b: BitVecLiteral => State.pure(b) + case b: IntLiteral => State.pure(b) + case b: BoolLit => State.pure(b) } } -class StatelessLoader(getVar: Variable => Option[Expr], loadMem: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)) extends Loader[Unit] { - def getVariable(v: Variable) : State[Unit, Option[Expr]] = mkSt(() => getVar(v)) - override def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int) : State[Unit, Option[Literal]] = mkSt(() => loadMem(m, addr, endian, size)) +class StatelessLoader(getVar: Variable => Option[Literal], loadMem: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)) extends Loader[Unit] { + def getVariable(v: Variable) : State[Unit, Option[Literal]] = State.pure(getVar(v)) + override def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int) : State[Unit, Option[Literal]] = State.pure(loadMem(m, addr, endian, size)) } -def partialEvalExpr(exp: Expr, variableAssignment: Variable => Option[Expr], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)): Expr = { +def partialEvalExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)): Expr = { val l = StatelessLoader(variableAssignment, memory) statePartialEvalExpr(l)(exp).f(())._2 } diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 15165c199..d68ec467e 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -1,7 +1,9 @@ package ir.eval import ir.eval.BitVectorEval.* -import ir._ +import ir.* import util.Logger +import util.functional.* +import util.functional.State.* import boogie.Scope import scala.collection.WithFilter @@ -11,58 +13,6 @@ import scala.collection.immutable import scala.util.control.Breaks.{break, breakable} -case class State[S, +A](f: S => (S, A)) {// extends WithFilter[A, ({ type l[x] = State[S, x] })#l] { - - def unit[A](a: A): State[S, A] = State(s => (s, a)) - - def foreach[U](f: A => U): Unit = () - - // def flatMap[B](f: A => IterableOnce[B]): ir.eval.State[S, B] = ??? - def flatMap[B](f: A => State[S, B]): State[S, B] = State(s => { - val (s2, a) = this.f(s) - f(a).f(s2) - }) - - def map[B](f: A => B): State[S, B] = { - State(s => { - val (s2, a) = this.f(s) - (s2, f(a)) - }) - } - - def withFilter(q : A => Boolean) = { - this - } -} - - -def stateCompose[S, A, B](s1: State[S, A], s2: State[S, B]) : State[S, B] = { - State((s: S) => s1.f(s) match { - case (s, _) => s2.f(s) - }) -} - -def pure[S, A](a: A) : State[S, A] = State((s:S) => (s, a)) - -def sequence[S, V](xs: Iterable[State[S,V]]) : State[S,V] = { - xs.reduceRight(stateCompose) -} - -def sequence[V](xs: Iterable[Option[V]]) : Option[V] = { - xs.reduceRight((a, b) => a match { - case Some(x) => Some(x) - case None => b - }) -} - -def filterM[A, S](m : (A => State[S, Boolean]), xs: Iterable[A]): State[S, List[A]] = { - xs.foldRight(pure(List[A]()))((b,acc) => acc.flatMap(c => m(b).map(v => if v then b::c else c))) -} - - -def get[S,A](f: S => A) : State[S, A] = State(s => (s, f(s))) -def modify[S](f: S => S) : State[S, Unit] = State(s => (f(s), ())) -def execute[S, A](s: S, c: State[S,A]) = c.f(s) sealed trait ExecutionContinuation case class FailedAssertion(a: Assert) extends ExecutionContinuation @@ -186,19 +136,6 @@ sealed trait Effects[T] { -def evalToConst( - to: IRType, - exp: Expr, - variable: Variable => Option[Expr], - load: (Memory, Expr, Endian, Int) => Option[Literal] -): BasilValue = { - - val res: Expr = ir.eval.partialEvalExpr(exp, variable, load) - res match { - case e: Literal if e.getType == to => Scalar(e) - case res => throw InterpreterError(EvalError(s"Failed to evaluate expr to constant ${to} literal: residual $res")) - } -} // case class BasilConstant(val basilType: BasilValue, val value: basilType.ReprType) @@ -381,7 +318,7 @@ case class StVarLoader[S, F <: Effects[S]](f : F) extends Loader[S] { cells = size / valsize - res <- load(vname, addr, endian, cells) + res <- load(vname, addr, endian, cells) // actual load bvs: List[BitVecLiteral] = { val rr = res.map { case Scalar(bv @ BitVecLiteral(v, sz)) if sz == valsize => bv @@ -390,10 +327,7 @@ case class StVarLoader[S, F <: Effects[S]](f : F) extends Loader[S] { } rr } - - bvres = bvs.foldLeft(BitVecLiteral(0, 0))((acc, r) => eval.evalBVBinExpr(BVCONCAT, acc, r)) - _ = {assert(bvres.size == size)} - } yield(bvres) + } yield(bvs.foldLeft(BitVecLiteral(0, 0))((acc, r) => eval.evalBVBinExpr(BVCONCAT, acc, r))) def loadSingle(vname: String, addr: Scalar): State[S, BasilValue] = { for { @@ -401,14 +335,14 @@ case class StVarLoader[S, F <: Effects[S]](f : F) extends Loader[S] { } yield (m.head) } - def getVariable(v: Variable) : State[S, Option[Expr]] = { + def getVariable(v: Variable) : State[S, Option[Literal]] = { for { v <- f.loadVar(v.name) } yield ( - v match { + (v match { case Scalar(l) => Some(l) case _ => None - }) + })) } override def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int) : State[S, Option[Literal]] = { @@ -452,7 +386,7 @@ case object Eval { for { res <- ir.eval.statePartialEvalExpr[S](ldr)(e) } yield ( - e match { + res match { case l: BitVecLiteral => l case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) }) @@ -463,7 +397,7 @@ case object Eval { for { res <- ir.eval.statePartialEvalExpr[S](ldr)(e) } yield ( - e match { + res match { case l: IntLiteral => l.value case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) }) @@ -474,7 +408,7 @@ case object Eval { for { res <- ir.eval.statePartialEvalExpr[S](ldr)(e) } yield ( - e match { + res match { case l: BoolLit => l == TrueLiteral case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) }) @@ -536,7 +470,8 @@ case object Eval { } keys = (0 until cells).map(i => BasilValue.unsafeAdd(addr, i)) - } yield (f.storeMem(vname, keys.zip(vs).toMap)) + s <- f.storeMem(vname, keys.zip(vs).toMap) + } yield (s) def storeSingle[S, T <: Effects[S]](f: T)(vname: String, addr: BasilValue, value: BasilValue): State[S, Unit] = { f.storeMem(vname, Map((addr -> value))) @@ -600,28 +535,83 @@ case class InterpreterState( val memoryState: MemoryState = MemoryState() ) -case class NormalInterpreter() extends Effects[InterpreterState] { +object NormalInterpreter extends Effects[InterpreterState] { /* eval */ - def evalBV(e: Expr) = Eval.evalBV(this)(e) + def evalBV(e: Expr) = { + Eval.evalBV(this)(e) + } - def evalInt(e: Expr) = Eval.evalInt(this)(e) + def evalInt(e: Expr) = { + Eval.evalInt(this)(e) + } - def evalBool(e: Expr) = Eval.evalBool(this)(e) + def evalBool(e: Expr) = { + Eval.evalBool(this)(e) + } - def loadVar(v: String) = get((s: InterpreterState) => s.memoryState.getVar(v)) + def loadVar(v: String) = { + State.get((s: InterpreterState) => { + s.memoryState.getVar(v) + }) + } - def loadMem(v: String, addrs: List[BasilValue]) = get((s: InterpreterState) => s.memoryState.doLoad(v, addrs)) + def formatStore(varname: String, update: Map[BasilValue, BasilValue]) = { + val ks = update.toList.sortWith((x,y) => { + def conv(v:BasilValue): BigInt = v match { + case (Scalar(b: BitVecLiteral)) => b.value + case (Scalar(b: IntLiteral)) => b.value + case _ => BigInt(0) + } + conv(x._1) <= conv(y._1) + }) - def getNext = get ((s: InterpreterState) => s.nextCmd) + val rs = ks.foldLeft(Some((None,List[BitVecLiteral]())): Option[(Option[BigInt], List[BitVecLiteral])])((acc, v) => + v match { + case (Scalar(bv : BitVecLiteral), Scalar(bv2 : BitVecLiteral)) => { + acc match { + case None => None + case Some(None, l) => Some(Some(bv.value), bv2::l) + case Some(Some(v), l) if bv.value == v + 1 => Some(Some(bv.value), bv2::l) + case Some(Some(v), l) if bv.value != v + 1 => { + println(s"$v != ${bv.value} + 1") + None + } + } + } + case (bv, bv2) => None + } + ) + + rs match { + case Some(_, l) => { + val vs = Scalar(l.foldLeft(BitVecLiteral(0, 0))((acc, r) => eval.evalBVBinExpr(BVCONCAT, acc, r))).toString + s"$varname[${ks.head._1}] := $vs" + } + case None if ks.length < 8 => s"$varname[${ks.map(_._1).mkString(",")}] := ${ks.map(_._2).mkString(",")}" + case None => s"$varname[${ks.map(_._1).take(8).mkString(",")}...] := ${ks.map(_._2).take(8).mkString(", ")}... " + } + + } + + def loadMem(v: String, addrs: List[BasilValue]) = { + State.get((s: InterpreterState) => { + val r = s.memoryState.doLoad(v, addrs) + Logger.debug(s" eff : LOAD ${addrs.head} x ${addrs.size}") + r + }) + } + + def getNext = State.get ((s: InterpreterState) => s.nextCmd) /** effects * */ - def setNext(c: ExecutionContinuation) = modify ((s: InterpreterState) => { + def setNext(c: ExecutionContinuation) = State.modify ((s: InterpreterState) => { + // Logger.debug(s" eff : setNext $c") InterpreterState(c, s.callStack, s.memoryState) }) def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = modify ((s:InterpreterState) => { - Logger.info(s" eff : CALL $target") + Logger.debug(s" eff : CALL $target") InterpreterState( beginFrom, returnTo :: s.callStack, @@ -630,6 +620,7 @@ case class NormalInterpreter() extends Effects[InterpreterState] { }) def doReturn() = { + Logger.debug(s" eff : RETURN") modify ((s: InterpreterState) => {s.callStack match { case Nil => InterpreterState(Stopped(), Nil, s.memoryState) case h :: tl => { @@ -639,13 +630,13 @@ case class NormalInterpreter() extends Effects[InterpreterState] { }) } - def storeVar(v: String, scope: Scope, value: BasilValue) = { + def storeVar(v: String, scope: Scope, value: BasilValue) : State[InterpreterState, Unit] = { Logger.debug(s" eff : SET $v := $value") - modify ((s: InterpreterState) => InterpreterState(s.nextCmd, s.callStack, s.memoryState.defVar(v, scope, value))) + State.modify ((s: InterpreterState) => InterpreterState(s.nextCmd, s.callStack, s.memoryState.defVar(v, scope, value))) } - def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = modify ((s:InterpreterState) => { - Logger.debug(s" eff : STORE $vname <- $update") + def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = State.modify ((s:InterpreterState) => { + Logger.debug(s" eff : STORE ${formatStore(vname, update)}") InterpreterState(s.nextCmd, s.callStack, s.memoryState.doStore(vname, update)) }) @@ -659,11 +650,11 @@ case object InterpFuns { val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) for { - l <- s.storeVar("R30", Scope.Global, Scalar(LR)) h <- s.storeVar("mem", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) i <- s.storeVar("stack", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) j <- s.storeVar("R31", Scope.Global, Scalar(SP)) k <- s.storeVar("R29", Scope.Global, Scalar(FP)) + l <- s.storeVar("R30", Scope.Global, Scalar(LR)) } yield (l) // s.storeVar("mem", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) } @@ -671,13 +662,13 @@ case object InterpFuns { def initialiseProgram[S, T <: Effects[S]](f:T)(p: Program): State[S, Unit] = { for { d <- initialState(f) - mem <- sequence(p.initialMemory.map(memory => + mem <- State.sequence(State.pure(()), p.initialMemory.map(memory => Eval.store(f)( "mem", Scalar(BitVecLiteral(memory.address, 64)), memory.bytes.toList.map(Scalar(_)), Endian.LittleEndian))) - mem <- sequence(p.initialMemory.map(memory => + mem <- State.sequence(State.pure(()), p.initialMemory.map(memory => Eval.store(f)( "stack", Scalar(BitVecLiteral(memory.address, 64)), @@ -764,41 +755,48 @@ case object InterpFuns { } } - def protect[T](x: () => T, fnly: PartialFunction[Exception, T]): T = { - try { - x() - } catch { - case e: Exception if fnly.isDefinedAt(e) => fnly(e) - } - } - def interpret[S, T <: Effects[S]](f: T, m: State[S, Unit]): State[S, Unit] = { - for { - n <- f.getNext - _ = { - Logger.debug(s"interpret ${n}") - } - c <- n match { - case Run(c: Statement) => interpret(f, interpretStatement(f)(c)) - case Run(c: Jump) => interpret(f, interpretJump(f)(c)) - case Stopped() => State((s:S) => (s,())) - case errorstop => State((s:S) => (s,())) - } - } yield(c) + def interpret[S, T <: Effects[S]](f: T, m: S): S = { + val next = State.evaluate(m, f.getNext) + next match { + case Run(c: Statement) => interpret(f, + protect((() => execute(m, interpretStatement(f)(c))), + { + case x @ InterpreterError(e) => { + Logger.error(s"${x.getStackTrace.mkString("\n")}") + execute(m, f.setNext(e)) + } + case e: IllegalArgumentException => execute(m, f.setNext(Errored(e.toString))) + } + )) + case Run(c: Jump) => interpret(f, + protect((() => execute(m, interpretJump(f)(c))), + { + case x @ InterpreterError(e) => { + Logger.error(s"${x.getStackTrace.mkString("\n")}") + execute(m, f.setNext(e)) + } + case e: IllegalArgumentException => execute(m, f.setNext(Errored(e.toString))) + } + )) + case Stopped() => m + case errorstop => m + } } def interpretProg[S, T <: Effects[S]](f: T)(p: Program, is: S): S = { - val (fs: S,_) = execute[S, Unit](is, interpret(f, initialiseProgram(f)(p))) - fs + val begin = State.execute(is, initialiseProgram(f)(p)) + // State.execute[S,Unit](is, ) + interpret(f, begin) } } def interpret(IRProgram: Program): InterpreterState = { - InterpFuns.interpretProg(NormalInterpreter())(IRProgram, InterpreterState()) + InterpFuns.interpretProg(NormalInterpreter)(IRProgram, InterpreterState()) } // def interpretTrace(IRProgram: Program): TracingInterpreter = { // val s: TracingInterpreter = InterpFuns.interpretProg(IRProgram, TracingInterpreter(InterpreterState(), List())) // s -// } +//e diff --git a/src/main/scala/util/functional.scala b/src/main/scala/util/functional.scala new file mode 100644 index 000000000..410fa5f71 --- /dev/null +++ b/src/main/scala/util/functional.scala @@ -0,0 +1,59 @@ +package util.functional + +case class State[S, +A](f: S => (S, A)) { + + def unit[A](a: A): State[S, A] = State(s => (s, a)) + + + def flatMap[B](f: A => State[S, B]): State[S, B] = State(s => { + // println(s"flatmap ${this.f} $f") + val (s2, a) = this.f(s) + f(a).f(s2) + }) + + + def map[B](f: A => B): State[S, B] = { + State(s => { + val (s2, a) = this.f(s) + (s2, f(a)) + }) + } +} + + +object State { + def get[S,A](f: S => A) : State[S, A] = State(s => (s, f(s))) + def getS[S] : State[S,S] = State((s:S) => (s,s)) + def putS[S](s: S) : State[S,_] = State((_) => (s,())) + def modify[S](f: S => S) : State[S, Unit] = State(s => (f(s), ())) + def execute[S, A](s: S, c: State[S,A]) : S = c.f(s)._1 + def evaluate[S, A](s: S, c: State[S,A]) : A = c.f(s)._2 + + def pure[S, A](a: A) : State[S, A] = State((s:S) => (s, a)) + + def sequence[S, V](ident: State[S,V], xs: Iterable[State[S,V]]) : State[S,V] = { + xs.foldRight(ident)((l,r) => for { + x <- l + y <- r + } yield(y)) + } + + def sequence[V](xs: Iterable[Option[V]]) : Option[V] = { + xs.reduceRight((a, b) => a match { + case Some(x) => Some(x) + case None => b + }) + } + + def filterM[A, S](m : (A => State[S, Boolean]), xs: Iterable[A]): State[S, List[A]] = { + xs.foldRight(pure(List[A]()))((b,acc) => acc.flatMap(c => m(b).map(v => if v then b::c else c))) + } +} + +def protect[T](x: () => T, fnly: PartialFunction[Exception, T]): T = { + try { + x() + } catch { + case e: Exception if fnly.isDefinedAt(e) => fnly(e) + } +} diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 1efabf0c1..15da62074 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -1,5 +1,6 @@ package ir +import util.functional._ import ir.eval._ import ir.dsl._ import org.scalatest.funsuite.AnyFunSuite @@ -27,7 +28,8 @@ import util.ILLoadingConfig // def initialMem() = InterpFuns.initialState(InterpreterState(), List()) def load(s: InterpreterState, global: SpecGlobal) : Option[BitVecLiteral] = { - val f = NormalInterpreter() + println(s) + val f = NormalInterpreter // i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems) // m.evalBV("mem", BitVecLiteral(64, global.address), Endian.LittleEndian, global.size) // i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems) @@ -96,25 +98,54 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { load(fstate, global).map(gv => name -> gv.value.toInt) ) ) + assert(fstate.nextCmd == Stopped()) assert(expected == actual) } + test("initialise") { -// test("Store = Load LittleEndian") { -// val ts = List( -// BitVecLiteral(BigInt("0D", 16), 8), -// BitVecLiteral(BigInt("0C", 16), 8), -// BitVecLiteral(BigInt("0B", 16), 8), -// BitVecLiteral(BigInt("0A", 16), 8)) -// -// val s = Eval.store(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) -// val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) -// val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) -// assert(actual == expected) -// -// -// } -// + val init = InterpFuns.initialState(NormalInterpreter) + + val s = State.execute(InterpreterState(), init) + assert(s.memoryState.getVarOpt("mem").isDefined) + assert(s.memoryState.getVarOpt("stack").isDefined) + assert(s.memoryState.getVarOpt("R31").isDefined) + assert(s.memoryState.getVarOpt("R29").isDefined) + + + } + test("var load store") { + val s = for { + s <- InterpFuns.initialState(NormalInterpreter) + v <- NormalInterpreter.loadVar("R31") + } yield (v) + val l = State.evaluate(InterpreterState(), s) + + assert(l == Scalar(BitVecLiteral(4096 - 16, 64))) + + } + + test("Store = Load LittleEndian") { + val ts = List( + BitVecLiteral(BigInt("0D", 16), 8), + BitVecLiteral(BigInt("0C", 16), 8), + BitVecLiteral(BigInt("0B", 16), 8), + BitVecLiteral(BigInt("0A", 16), 8)) + + val loader = StVarLoader(NormalInterpreter) + + val s = for { + _ <- InterpFuns.initialState(NormalInterpreter) + _ <- Eval.store(NormalInterpreter)("mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) + r <- loader.loadBV("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) + } yield(r) + val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) + val actual: BitVecLiteral = State.evaluate(InterpreterState(), s) + assert(actual == expected) + + + } + // test("store bv = loadbv le") { // val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) // val s2 = Eval.storeBV(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) diff --git a/src/test/scala/util/StateMonad.scala b/src/test/scala/util/StateMonad.scala new file mode 100644 index 000000000..29721174a --- /dev/null +++ b/src/test/scala/util/StateMonad.scala @@ -0,0 +1,31 @@ +import ir._ +import util.functional._ + + +import ir.eval._ +import ir.dsl._ +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.BeforeAndAfter +import specification.SpecGlobal +import translating.BAPToIR +import util.{LogLevel, Logger} +import util.IRLoading.{loadBAP, loadReadELF} +import util.ILLoadingConfig + + +def add: State[Int, Unit] = State(s => (s+1, ())) + +class StateMonadTest extends AnyFunSuite { + + test("forcompre") { + val s = for { + _ <- add + _ <- add + _ <- add + } yield () + + + val res = State.execute(0, s) + assert(res == 3) + } +} From dd11175d2c935d79364a3f18ebb410546cc0631d Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 28 Aug 2024 13:38:36 +1000 Subject: [PATCH 13/62] indirect calls --- src/main/scala/ir/eval/Interpreter.scala | 110 ++++++++++++++++------- src/test/scala/ir/InterpreterTests.scala | 10 ++- 2 files changed, 84 insertions(+), 36 deletions(-) diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index d68ec467e..07cec32d3 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -44,6 +44,9 @@ case class Scalar(val value: Literal) extends BasilValue(value.getType) { case c => c.toString } } + +case class FunPointer(val addr: BitVecLiteral, val name: String, val call: ExecutionContinuation) extends BasilValue(addr.getType) + // Erase the type of basil values and enforce the invariant that // \exists i . \forall v \in value.keys , v.irType = i and // \exists j . \forall v \in value.values, v.irType = j @@ -120,6 +123,8 @@ sealed trait Effects[T] { def loadMem(v: String, addrs: List[BasilValue]): State[T, List[BasilValue]] + def evalAddrToProc(addr: Int): State[T, Option[FunPointer]] + def getNext: State[T, ExecutionContinuation] /** effects * */ @@ -253,16 +258,27 @@ case class MemoryState( } /* Map variable accessing ; load and store operations */ - def doLoad(vname: String, addr: List[BasilValue]): List[BasilValue] = { + def doLoadOpt(vname: String, addr: List[BasilValue]): Option[List[BasilValue]] = { val (frame, mem) = findVar(vname) val mapv: MapValue = mem match { case m @ MapValue(innerMap, ty) => m case m => throw InterpreterError(TypeError(s"Load from nonmap ${m.irType}")) } - addr.map(k => - mapv.value.get(k).getOrElse(throw InterpreterError(MemoryError(s"Read from uninitialised $vname[$k]"))) - ) + val rs = addr.map(k => mapv.value.get(k)) + if (rs.forall(_.isDefined)) { + Some(rs.map(_.get)) + } else { + None + } + } + def doLoad(vname: String, addr: List[BasilValue]): List[BasilValue] = { + doLoadOpt(vname, addr) match { + case Some(vs) => vs + case None => { + throw InterpreterError(MemoryError(s"Read from uninitialised $vname[${addr.head} .. ${addr.last}]")) + } + } } /** typecheck and some fields of a map variable */ @@ -556,6 +572,18 @@ object NormalInterpreter extends Effects[InterpreterState] { }) } + def evalAddrToProc(addr: Int): State[InterpreterState, Option[FunPointer]] = + Logger.debug(s" eff : FIND PROC $addr") + val load = StVarLoader(this) + for { + res <- get ((s: InterpreterState) => s.memoryState.doLoadOpt("funtable", List(Scalar(BitVecLiteral(addr, 64))))) + } yield { + res match { + case Some((f: FunPointer)::Nil) => Some(f) + case _ => None + } + } + def formatStore(varname: String, update: Map[BasilValue, BasilValue]) = { val ks = update.toList.sortWith((x,y) => { def conv(v:BasilValue): BigInt = v match { @@ -573,7 +601,7 @@ object NormalInterpreter extends Effects[InterpreterState] { case None => None case Some(None, l) => Some(Some(bv.value), bv2::l) case Some(Some(v), l) if bv.value == v + 1 => Some(Some(bv.value), bv2::l) - case Some(Some(v), l) if bv.value != v + 1 => { + case Some(Some(v), l) => { println(s"$v != ${bv.value} + 1") None } @@ -607,37 +635,35 @@ object NormalInterpreter extends Effects[InterpreterState] { /** effects * */ def setNext(c: ExecutionContinuation) = State.modify ((s: InterpreterState) => { // Logger.debug(s" eff : setNext $c") - InterpreterState(c, s.callStack, s.memoryState) + s.copy(nextCmd = c) }) def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = modify ((s:InterpreterState) => { Logger.debug(s" eff : CALL $target") - InterpreterState( - beginFrom, - returnTo :: s.callStack, - s.memoryState.pushStackFrame(target) + s.copy( + nextCmd=beginFrom, + callStack=returnTo :: s.callStack, + memoryState=s.memoryState.pushStackFrame(target) ) }) def doReturn() = { Logger.debug(s" eff : RETURN") modify ((s: InterpreterState) => {s.callStack match { - case Nil => InterpreterState(Stopped(), Nil, s.memoryState) - case h :: tl => { - InterpreterState(h, tl, s.memoryState.popStackFrame()) - } + case Nil => s.copy(nextCmd=Stopped()) + case h :: tl => s.copy(nextCmd=h,callStack=tl,memoryState=s.memoryState.popStackFrame()) } }) } def storeVar(v: String, scope: Scope, value: BasilValue) : State[InterpreterState, Unit] = { Logger.debug(s" eff : SET $v := $value") - State.modify ((s: InterpreterState) => InterpreterState(s.nextCmd, s.callStack, s.memoryState.defVar(v, scope, value))) + State.modify ((s: InterpreterState) => s.copy(memoryState=s.memoryState.defVar(v, scope, value))) } def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = State.modify ((s:InterpreterState) => { Logger.debug(s" eff : STORE ${formatStore(vname, update)}") - InterpreterState(s.nextCmd, s.callStack, s.memoryState.doStore(vname, update)) + s.copy(memoryState=s.memoryState.doStore(vname, update)) }) } @@ -650,30 +676,41 @@ case object InterpFuns { val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) for { + h <- s.storeVar("funtable", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(64)))) h <- s.storeVar("mem", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) i <- s.storeVar("stack", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) j <- s.storeVar("R31", Scope.Global, Scalar(SP)) k <- s.storeVar("R29", Scope.Global, Scalar(FP)) l <- s.storeVar("R30", Scope.Global, Scalar(LR)) } yield (l) - // s.storeVar("mem", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) } def initialiseProgram[S, T <: Effects[S]](f:T)(p: Program): State[S, Unit] = { - for { - d <- initialState(f) - mem <- State.sequence(State.pure(()), p.initialMemory.map(memory => - Eval.store(f)( - "mem", - Scalar(BitVecLiteral(memory.address, 64)), - memory.bytes.toList.map(Scalar(_)), - Endian.LittleEndian))) - mem <- State.sequence(State.pure(()), p.initialMemory.map(memory => + def initMemory(mem: String, mems: Iterable[MemorySection]) ={ + for { + m <- State.sequence(State.pure(()), mems.filter(m => m.address != 0).map(memory => Eval.store(f)( - "stack", + mem, Scalar(BitVecLiteral(memory.address, 64)), memory.bytes.toList.map(Scalar(_)), Endian.LittleEndian))) + } yield () + } + + println(p.initialMemory) + + for { + d <- initialState(f) + funs <- State.sequence(State.pure(()), p.procedures.filter(p => p.blocks.nonEmpty && p.address.isDefined).map((proc: Procedure) => + Eval.storeSingle(f)( + "funtable", + Scalar(BitVecLiteral(proc.address.get, 64)), + FunPointer(BitVecLiteral(proc.address.get, 64), proc.name, Run(IRWalk.firstInBlock(proc.entryBlock.get))) + ))) + mem <- initMemory("mem", p.initialMemory) + mem <- initMemory("stack", p.initialMemory) + mem <- initMemory("mem", p.readOnlyMemory) + mem <- initMemory("stack", p.readOnlyMemory) r <- f.call(p.mainProcedure.name, Run(IRWalk.firstInBlock(p.mainProcedure.entryBlock.get)), Stopped()) } yield (r) } @@ -716,7 +753,6 @@ case object InterpFuns { case assign: MemoryAssign => for { index : BitVecLiteral <- f.evalBV(assign.index) value : BitVecLiteral <- f.evalBV(assign.value) - // st.storeMem(assign.mem.name, index, value, assign.endian, assign.size).setNext(Run(s.successor)) _ <- Eval.storeBV(f)(assign.mem.name, Scalar(index), value, assign.endian) n <- f.setNext(Run(s.successor)) } yield (n) @@ -744,13 +780,20 @@ case object InterpFuns { f.setNext(Run(dc.successor)) } } yield (n) - case ic: IndirectCall => for { - n <- (if (ic.target == Register("R30", 64)) { + case ic: IndirectCall => { + if (ic.target == Register("R30", 64)) { f.doReturn() } else { - f.setNext(EscapedControlFlow(ic)) - }) - } yield (n) + for { + addr <- f.evalBV(ic.target) + fp <- f.evalAddrToProc(addr.value.toInt) + _ <- fp match { + case Some(fp) => f.call(fp.name, fp.call, Run(ic.successor)) + case none => f.setNext(EscapedControlFlow(ic)) + } + } yield () + } + } case _: NOP => f.setNext(Run(s.successor)) } } @@ -758,6 +801,7 @@ case object InterpFuns { def interpret[S, T <: Effects[S]](f: T, m: S): S = { val next = State.evaluate(m, f.getNext) + Logger.debug(s"eval $next") next match { case Run(c: Statement) => interpret(f, protect((() => execute(m, interpretStatement(f)(c))), diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 15da62074..f6bdd0198 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -66,9 +66,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { var IRProgram = IRTranslator.translate IRProgram = ExternalRemover(externalFunctions.map(e => e.name)).visitProgram(IRProgram) IRProgram = Renamer(Set("free")).visitProgram(IRProgram) - transforms.stripUnreachableFunctions(IRProgram) - val stackIdentification = StackSubstituter() - stackIdentification.visitProgram(IRProgram) + // transforms.stripUnreachableFunctions(IRProgram) IRProgram.setModifies(Map()) (IRProgram, globals) @@ -289,6 +287,12 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { testInterpret("secret_write", expected) } + test("indirect_call") { + val expected = Map[String, Int]() + testInterpret("indirect_call_outparam", expected) + } + + test("ifglobal") { val expected = Map( "x" -> 1 From c7b4a56e95b79e49cbb79a1d7a1f529ca1650bb2 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 28 Aug 2024 15:54:58 +1000 Subject: [PATCH 14/62] tracing interpreter --- src/main/scala/ir/eval/InterpretBasilIR.scala | 402 ++++++++++++ .../scala/ir/eval/InterpretBasilTrace.scala | 82 +++ src/main/scala/ir/eval/Interpreter.scala | 570 +++--------------- .../scala/ir/eval/InterpreterProduct.scala | 98 +++ src/main/scala/util/PerformanceTimer.scala | 2 +- src/test/scala/ir/InterpreterTests.scala | 71 ++- 6 files changed, 709 insertions(+), 516 deletions(-) create mode 100644 src/main/scala/ir/eval/InterpretBasilIR.scala create mode 100644 src/main/scala/ir/eval/InterpretBasilTrace.scala create mode 100644 src/main/scala/ir/eval/InterpreterProduct.scala diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala new file mode 100644 index 000000000..8cf8f384b --- /dev/null +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -0,0 +1,402 @@ +package ir.eval +import ir._ +import ir.eval.BitVectorEval.* +import ir.* +import util.Logger +import util.functional.* +import util.functional.State.* +import boogie.Scope +import scala.collection.WithFilter + +import scala.annotation.tailrec +import scala.collection.mutable +import scala.collection.immutable +import scala.util.control.Breaks.{break, breakable} + + +/** Abstraction for memload and variable lookup used by the expression evaluator. + */ +case class StVarLoader[S, F <: Effects[S]](f: F) extends Loader[S] { + + def getVariable(v: Variable): State[S, Option[Literal]] = { + for { + v <- f.loadVar(v.name) + } yield ((v match { + case Scalar(l) => Some(l) + case _ => None + })) + } + + override def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int): State[S, Option[Literal]] = { + for { + r <- addr match { + case l: Literal if size == 1 => + Eval + .loadSingle(f)(m.name, Scalar(l)) + .map((v: BasilValue) => + v match { + case Scalar(l) => Some(l) + case _ => None + } + ) + case l: Literal => Eval.loadBV(f)(m.name, Scalar(l), endian, size).map(Some(_)) + case _ => get((s: S) => None) + } + } yield (r) + } + +} + +/* + * Helper functions for compiling high level structures to the interpreter effects. + * All are parametric in concrete state S and Effects[S] + */ +case object Eval { + + /*--------------------------------------------------------------------------------*/ + /* Eval functions */ + /*--------------------------------------------------------------------------------*/ + + def evalBV[S, T <: Effects[S]](f: T)(e: Expr): State[S, BitVecLiteral] = { + val ldr = StVarLoader[S, T](f) + for { + res <- ir.eval.statePartialEvalExpr[S](ldr)(e) + } yield (res match { + case l: BitVecLiteral => l + case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) + }) + } + + def evalInt[S, T <: Effects[S]](f: T)(e: Expr): State[S, BigInt] = { + val ldr = StVarLoader[S, T](f) + for { + res <- ir.eval.statePartialEvalExpr[S](ldr)(e) + } yield (res match { + case l: IntLiteral => l.value + case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) + }) + } + + def evalBool[S, T <: Effects[S]](f: T)(e: Expr): State[S, Boolean] = { + val ldr = StVarLoader[S, T](f) + for { + res <- ir.eval.statePartialEvalExpr[S](ldr)(e) + } yield (res match { + case l: BoolLit => l == TrueLiteral + case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) + }) + } + + /*--------------------------------------------------------------------------------*/ + /* Load functions */ + /*--------------------------------------------------------------------------------*/ + + def load[S, T <: Effects[S]]( + f: T + )(vname: String, addr: Scalar, endian: Endian, count: Int): State[S, List[BasilValue]] = { + if (count == 0) { + throw InterpreterError(Errored(s"Attempted fractional load")) + } + val keys = (0 until count).map(i => BasilValue.unsafeAdd(addr, i)) + for { + values <- f.loadMem(vname, keys.toList) + vals = endian match { + case Endian.LittleEndian => values.reverse + case Endian.BigEndian => values + } + } yield (vals.toList) + } + + /** Load and concat bitvectors */ + def loadBV[S, T <: Effects[S]]( + f: T + )(vname: String, addr: Scalar, endian: Endian, size: Int): State[S, BitVecLiteral] = for { + mem <- f.loadVar(vname) + (valsize, mapv) = mem match { + case mapv @ MapValue(_, MapType(_, BitVecType(sz))) => (sz, mapv) + case _ => throw InterpreterError(Errored("Trued to load-concat non bv")) + } + + cells = size / valsize + + res <- load(f)(vname, addr, endian, cells) // actual load + bvs: List[BitVecLiteral] = { + val rr = res.map { + case Scalar(bv @ BitVecLiteral(v, sz)) if sz == valsize => bv + case c => + throw InterpreterError(TypeError(s"Loaded value of type ${c.irType} did not match expected type bv$valsize")) + } + rr + } + } yield (bvs.foldLeft(BitVecLiteral(0, 0))((acc, r) => eval.evalBVBinExpr(BVCONCAT, acc, r))) + + def loadSingle[S, T <: Effects[S]](f: T)(vname: String, addr: Scalar): State[S, BasilValue] = { + for { + m <- load(f)(vname, addr, Endian.LittleEndian, 1) + } yield (m.head) + } + + /*--------------------------------------------------------------------------------*/ + /* Store functions */ + /*--------------------------------------------------------------------------------*/ + + /* Expand addr for number of values to store */ + def store[S, T <: Effects[S]](f: T)( + vname: String, + addr: BasilValue, + values: List[BasilValue], + endian: Endian + ): State[S, Unit] = for { + mem <- f.loadVar(vname) + (mapval, keytype, valtype) = mem match { + case m @ MapValue(_, MapType(kt, vt)) if kt == addr.irType && values.forall(v => v.irType == vt) => (m, kt, vt) + case v => throw InterpreterError(TypeError(s"Invalid map store operation to $vname : $v")) + } + keys = (0 until values.size).map(i => BasilValue.unsafeAdd(addr, i)) + vals = endian match { + case Endian.LittleEndian => values.reverse + case Endian.BigEndian => values + } + x <- f.storeMem(vname, keys.zip(vals).toMap) + } yield (x) + + /** Extract bitvec to bytes and store bytes */ + def storeBV[S, T <: Effects[S]](f: T)( + vname: String, + addr: BasilValue, + value: BitVecLiteral, + endian: Endian + ): State[S, Unit] = for { + mem <- f.loadVar(vname) + (mapval, vsize) = mem match { + case m @ MapValue(_, MapType(kt, BitVecType(size))) if kt == addr.irType => (m, size) + case v => + throw InterpreterError( + TypeError( + s"Invalid map store operation to $vname : ${v.irType} (expect [${addr.irType}] <- ${value.getType})" + ) + ) + } + cells = value.size / vsize + _ = { + if (cells < 1) { + throw InterpreterError(MemoryError("Tried to execute fractional store")) + } + } + + extractVals = (0 until cells).map(i => BitVectorEval.boogie_extract((i + 1) * vsize, i * vsize, value)).toList + vs = endian match { + case Endian.LittleEndian => extractVals.map(Scalar(_)) + case Endian.BigEndian => extractVals.reverse.map(Scalar(_)) + } + + keys = (0 until cells).map(i => BasilValue.unsafeAdd(addr, i)) + s <- f.storeMem(vname, keys.zip(vs).toMap) + } yield (s) + + def storeSingle[S, T <: Effects[S]](f: T)(vname: String, addr: BasilValue, value: BasilValue): State[S, Unit] = { + f.storeMem(vname, Map((addr -> value))) + } +} + +case object InterpFuns { + + /** Functions which compile BASIL IR down to the minimal interpreter effects. + * + * Each function takes as parameter an implementation of Effects[S] + */ + + def initialState[S, T <: Effects[S]](s: T): State[S, Unit] = { + val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) + val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) + val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) + + for { + h <- s.storeVar("funtable", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(64)))) + h <- s.storeVar("mem", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + i <- s.storeVar("stack", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + j <- s.storeVar("R31", Scope.Global, Scalar(SP)) + k <- s.storeVar("R29", Scope.Global, Scalar(FP)) + l <- s.storeVar("R30", Scope.Global, Scalar(LR)) + } yield (l) + } + + def initialiseProgram[S, T <: Effects[S]](f: T)(p: Program): State[S, Unit] = { + def initMemory(mem: String, mems: Iterable[MemorySection]) = { + for { + m <- State.sequence( + State.pure(()), + mems + .filter(m => m.address != 0) + .map(memory => + Eval.store(f)( + mem, + Scalar(BitVecLiteral(memory.address, 64)), + memory.bytes.toList.map(Scalar(_)), + Endian.LittleEndian + ) + ) + ) + } yield () + } + + for { + d <- initialState(f) + funs <- State.sequence( + State.pure(()), + p.procedures + .filter(p => p.blocks.nonEmpty && p.address.isDefined) + .map((proc: Procedure) => + Eval.storeSingle(f)( + "funtable", + Scalar(BitVecLiteral(proc.address.get, 64)), + FunPointer(BitVecLiteral(proc.address.get, 64), proc.name, Run(IRWalk.firstInBlock(proc.entryBlock.get))) + ) + ) + ) + mem <- initMemory("mem", p.initialMemory) + mem <- initMemory("stack", p.initialMemory) + mem <- initMemory("mem", p.readOnlyMemory) + mem <- initMemory("stack", p.readOnlyMemory) + r <- f.call(p.mainProcedure.name, Run(IRWalk.firstInBlock(p.mainProcedure.entryBlock.get)), Stopped()) + } yield (r) + } + + def interpretJump[S, T <: Effects[S]](f: T)(j: Jump): State[S, Unit] = { + j match { + case gt: GoTo if gt.targets.size == 1 => { + f.setNext(Run(IRWalk.firstInBlock(gt.targets.head))) + } + case gt: GoTo => + val assumes = gt.targets.flatMap(_.statements.headOption).collect { case a: Assume => + a + } + if (assumes.size != gt.targets.size) { + throw InterpreterError(Errored(s"Some goto target missing guard $gt")) + } + for { + chosen: List[Assume] <- filterM((a: Assume) => f.evalBool(a.body), assumes) + + res <- chosen match { + case Nil => f.setNext(Errored(s"No jump target satisfied $gt")) + case h :: Nil => f.setNext(Run(h)) + case h :: tl => f.setNext(Errored(s"More than one jump guard satisfied $gt")) + } + } yield (res) + case r: Return => f.doReturn() + case h: Unreachable => f.setNext(EscapedControlFlow(h)) + } + } + + def interpretStatement[S, T <: Effects[S]](f: T)(s: Statement): State[S, Unit] = { + s match { + case assign: Assign => { + for { + rhs <- f.evalBV(assign.rhs) + st <- f.storeVar(assign.lhs.name, assign.lhs.toBoogie.scope, Scalar(rhs)) + n <- f.setNext(Run(s.successor)) + } yield (st) + } + case assign: MemoryAssign => + for { + index: BitVecLiteral <- f.evalBV(assign.index) + value: BitVecLiteral <- f.evalBV(assign.value) + _ <- Eval.storeBV(f)(assign.mem.name, Scalar(index), value, assign.endian) + n <- f.setNext(Run(s.successor)) + } yield (n) + case assert: Assert => + for { + b <- f.evalBool(assert.body) + n <- + (if (!b) then { + f.setNext(FailedAssertion(assert)) + } else { + f.setNext(Run(s.successor)) + }) + } yield (n) + case assume: Assume => + for { + b <- f.evalBool(assume.body) + n <- + (if (!b) { + f.setNext(Errored(s"Assumption not satisfied: $assume")) + } else { + f.setNext(Run(s.successor)) + }) + } yield (n) + case dc: DirectCall => + for { + n <- + if (dc.target.entryBlock.isDefined) { + val block = dc.target.entryBlock.get + f.call(dc.target.name, Run(block.statements.headOption.getOrElse(block.jump)), Run(dc.successor)) + } else { + f.setNext(Run(dc.successor)) + } + } yield (n) + case ic: IndirectCall => { + if (ic.target == Register("R30", 64)) { + f.doReturn() + } else { + for { + addr <- f.evalBV(ic.target) + fp <- f.evalAddrToProc(addr.value.toInt) + _ <- fp match { + case Some(fp) => f.call(fp.name, fp.call, Run(ic.successor)) + case none => f.setNext(EscapedControlFlow(ic)) + } + } yield () + } + } + case _: NOP => f.setNext(Run(s.successor)) + } + } + + def interpret[S, T <: Effects[S]](f: T, m: S): S = { + val next = State.evaluate(m, f.getNext) + Logger.debug(s"eval $next") + next match { + case Run(c: Statement) => + interpret( + f, + protect( + (() => execute(m, interpretStatement(f)(c))), + { + case x @ InterpreterError(e) => { + Logger.error(s"${x.getStackTrace.mkString("\n")}") + execute(m, f.setNext(e)) + } + case e: IllegalArgumentException => execute(m, f.setNext(Errored(e.toString))) + } + ) + ) + case Run(c: Jump) => + interpret( + f, + protect( + (() => execute(m, interpretJump(f)(c))), + { + case x @ InterpreterError(e) => { + Logger.error(s"${x.getStackTrace.mkString("\n")}") + execute(m, f.setNext(e)) + } + case e: IllegalArgumentException => execute(m, f.setNext(Errored(e.toString))) + } + ) + ) + case Stopped() => m + case errorstop => m + } + } + + def interpretProg[S, T <: Effects[S]](f: T)(p: Program, is: S): S = { + val begin = State.execute(is, initialiseProgram(f)(p)) + // State.execute[S,Unit](is, ) + interpret(f, begin) + } +} + +def interpret(IRProgram: Program): InterpreterState = { + InterpFuns.interpretProg(NormalInterpreter)(IRProgram, InterpreterState()) +} + diff --git a/src/main/scala/ir/eval/InterpretBasilTrace.scala b/src/main/scala/ir/eval/InterpretBasilTrace.scala new file mode 100644 index 000000000..4ba8c6836 --- /dev/null +++ b/src/main/scala/ir/eval/InterpretBasilTrace.scala @@ -0,0 +1,82 @@ +package ir.eval +import ir._ +import ir.eval.BitVectorEval.* +import ir.* +import util.Logger +import util.functional.* +import util.functional.State.* +import boogie.Scope +import scala.collection.WithFilter + +import scala.annotation.tailrec +import scala.collection.mutable +import scala.collection.immutable +import scala.util.control.Breaks.{break, breakable} + + +enum ExecEffect: + case Call(target: String, begin: ExecutionContinuation, returnTo: ExecutionContinuation) + case SetNext(c: ExecutionContinuation) + case Return + case StoreVar(v: String, s: Scope, value: BasilValue) + case LoadVar(v: String) + case StoreMem(vname: String, update: Map[BasilValue, BasilValue]) + case LoadMem(vname: String, addrs: List[BasilValue]) + case FindProc(addr: Int) + +case class Trace(val t: List[ExecEffect]) + +case object Trace { + def add(e: ExecEffect) : State[Trace, Unit] = { + modify ((t: Trace) => Trace(t.t.appended(e))) + } +} + + +object TraceGen extends Effects[Trace] { + /** Values are discarded by ProductInterpreter so do not matter */ + def evalBV(e: Expr) = State.pure(BitVecLiteral(0,0)) + + def evalInt(e: Expr) = State.pure(BigInt(0)) + + def evalBool(e: Expr) = State.pure(false) + + def loadVar(v: String) = for { + s <- Trace.add(ExecEffect.LoadVar(v)) + } yield (Scalar(FalseLiteral)) + + def loadMem(v: String, addrs: List[BasilValue]) = for { + s <- Trace.add(ExecEffect.LoadMem(v, addrs)) + } yield (List()) + + def evalAddrToProc(addr: Int) = for { + s <- Trace.add(ExecEffect.FindProc(addr)) + } yield (None) + + def getNext = State.pure(Stopped()) + + def setNext(c: ExecutionContinuation) = State.pure(()) + + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = for { + s <- Trace.add(ExecEffect.Call(target, beginFrom, returnTo)) + } yield (()) + + def doReturn() = for { + s <- Trace.add(ExecEffect.Return) + } yield (()) + + def storeVar(v: String, scope: Scope, value: BasilValue) = for { + s <- Trace.add(ExecEffect.StoreVar(v, scope, value)) + } yield (()) + + def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = for { + s <- Trace.add(ExecEffect.StoreMem(vname, update)) + } yield (()) +} + +def tracingInterpreter = ProductInterpreter(NormalInterpreter, TraceGen) + +def interpretTrace(p: Program) : (InterpreterState, Trace) = { + InterpFuns.interpretProg(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) +} + diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 07cec32d3..12ed90010 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -12,8 +12,8 @@ import scala.collection.mutable import scala.collection.immutable import scala.util.control.Breaks.{break, breakable} - - +/** Interpreter status type, either stopped, run next command or error + */ sealed trait ExecutionContinuation case class FailedAssertion(a: Assert) extends ExecutionContinuation @@ -29,14 +29,11 @@ case class EvalError(val message: String = "") extends ExecutionContinuation /* failed to evaluate an expression to a concrete value */ case class MemoryError(val message: String = "") extends ExecutionContinuation /* An error to do with memory */ +/** TODO: errors should be encapsualted in error monad, rather than mapping exceptions back into state transitions at + * State.execute() */ case class InterpreterError(continue: ExecutionContinuation) extends Exception() -case class InterpreterSummary( - val exitState: ExecutionContinuation, - val regs: Map[Variable, BitVecLiteral], - val memory: Map[Int, BitVecLiteral] -) - +/* Concrete value type of the interpreter. */ sealed trait BasilValue(val irType: IRType) case class Scalar(val value: Literal) extends BasilValue(value.getType) { override def toString = value match { @@ -45,7 +42,9 @@ case class Scalar(val value: Literal) extends BasilValue(value.getType) { } } -case class FunPointer(val addr: BitVecLiteral, val name: String, val call: ExecutionContinuation) extends BasilValue(addr.getType) +/** Slightly hacky way of mapping addresses to function calls within the interpreter dynamic state */ +case class FunPointer(val addr: BitVecLiteral, val name: String, val call: ExecutionContinuation) + extends BasilValue(addr.getType) // Erase the type of basil values and enforce the invariant that // \exists i . \forall v \in value.keys , v.irType = i and @@ -54,7 +53,7 @@ case class MapValue(val value: Map[BasilValue, BasilValue], override val irType: override def toString = s"MapValue : $irType" } -case object BasilValue: +case object BasilValue { def size(v: IRType): Int = { v match { @@ -83,36 +82,15 @@ case object BasilValue: case _ => throw InterpreterError(TypeError(s"Operation add undefined on $l $r")) } } +} - def concat(l: BasilValue, r: BasilValue): BasilValue = { - (l, r) match { - case (Scalar(b1: BitVecLiteral), Scalar(b2: BitVecLiteral)) => Scalar(eval.evalBVBinExpr(BVCONCAT, b1, b2)) - case _ => throw InterpreterError(TypeError(s"Operation concat undefined on $l $r")) - } - } - - def extract(l: BasilValue, high: Int, low: Int): BasilValue = { - (l) match { - case Scalar(b: BitVecLiteral) => Scalar(eval.BitVectorEval.boogie_extract(high, low, b)) - case _ => throw InterpreterError(TypeError(s"Operation extract($high, $low) undefined on $l")) - } - } - - def fromIR(e: Expr) = { - e match { - case t: IntLiteral => Scalar(t) - case v: BitVecLiteral => Scalar(v) - case b: BoolLit => Scalar(b) - case _ => throw InterpreterError(EvalError(s"Failed to get value from non-literal expr $e")) - - } - } - -export BasilValue._ - +/** + * Minimal language defining all state transitions in the interpreter, + * defined for the interpreter's concrete state T. + */ +trait Effects[T] { + /* evaluation (may side-effect on error) */ -sealed trait Effects[T] { - /* evaluation (may side-effect via InterpreterException on evaluation failure) */ def evalBV(e: Expr): State[T, BitVecLiteral] def evalInt(e: Expr): State[T, BigInt] @@ -123,13 +101,20 @@ sealed trait Effects[T] { def loadMem(v: String, addrs: List[BasilValue]): State[T, List[BasilValue]] - def evalAddrToProc(addr: Int): State[T, Option[FunPointer]] + def evalAddrToProc(addr: Int): State[T, Option[FunPointer]] def getNext: State[T, ExecutionContinuation] - /** effects * */ + /** state effects */ + + /* High-level implementation of a program counter that leverages the intrusive CFG. */ def setNext(c: ExecutionContinuation): State[T, Unit] + /* Perform a call: + * target: arbitrary target name + * beginFrom: ExecutionContinuation which begins executing the procedure + * returnTo: ExecutionContinuation which begins executing after procedure return + */ def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): State[T, Unit] def doReturn(): State[T, Unit] @@ -139,15 +124,20 @@ sealed trait Effects[T] { def storeMem(vname: String, update: Map[BasilValue, BasilValue]): State[T, Unit] } - +/** -------------------------------------------------------------------------------- + * Definition of concrete state + * -------------------------------------------------------------------------------- */ -// case class BasilConstant(val basilType: BasilValue, val value: basilType.ReprType) type StackFrameID = String val globalFrame: StackFrameID = "GLOBAL" case class MemoryState( + /* We have a very permissive value reprsentation and store all dynamic state in `stackFrames`. + * - activations is the call stack, the top of which indicates the current stackFrame. + * - activationCount: (procedurename -> int) is used to create uniquely-named stackframes. + */ val stackFrames: Map[StackFrameID, Map[String, BasilValue]] = Map((globalFrame -> Map.empty)), val activations: List[StackFrameID] = List.empty, val activationCount: Map[String, Int] = Map.empty.withDefault(_ => 0) @@ -305,245 +295,6 @@ case class MemoryState( } } -case class StVarLoader[S, F <: Effects[S]](f : F) extends Loader[S] { - - /** Load helpers * */ - def load(vname: String, addr: Scalar, endian: Endian, count: Int): State[S, List[BasilValue]] = { - if (count == 0) { - throw InterpreterError(Errored(s"Attempted fractional load")) - } - val keys = (0 until count).map(i => BasilValue.unsafeAdd(addr, i)) - for { - values <- f.loadMem(vname, keys.toList) - vals = endian match { - case Endian.LittleEndian => values.reverse - case Endian.BigEndian => values - } - } - yield (vals.toList) - } - - - /** Load and concat bitvectors */ - def loadBV( vname: String, addr: Scalar, endian: Endian, size: Int): State[S, BitVecLiteral] = for { - mem <- f.loadVar(vname) - (valsize, mapv) = mem match { - case mapv @ MapValue(_, MapType(_, BitVecType(sz))) => (sz, mapv) - case _ => throw InterpreterError(Errored("Trued to load-concat non bv")) - } - - cells = size / valsize - - res <- load(vname, addr, endian, cells) // actual load - bvs: List[BitVecLiteral] = { - val rr = res.map { - case Scalar(bv @ BitVecLiteral(v, sz)) if sz == valsize => bv - case c => - throw InterpreterError(TypeError(s"Loaded value of type ${c.irType} did not match expected type bv$valsize")) - } - rr - } - } yield(bvs.foldLeft(BitVecLiteral(0, 0))((acc, r) => eval.evalBVBinExpr(BVCONCAT, acc, r))) - - def loadSingle(vname: String, addr: Scalar): State[S, BasilValue] = { - for { - m <- load(vname, addr, Endian.LittleEndian, 1) - } yield (m.head) - } - - def getVariable(v: Variable) : State[S, Option[Literal]] = { - for { - v <- f.loadVar(v.name) - } yield ( - (v match { - case Scalar(l) => Some(l) - case _ => None - })) - } - - override def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int) : State[S, Option[Literal]] = { - for { - r <- addr match { - case l: Literal if size== 1 => loadSingle(m.name, Scalar(l)).map((v : BasilValue) => v match { - case Scalar(l) => Some(l) - case _ => None - }) - case l: Literal => loadBV(m.name, Scalar(l), endian, size).map(Some(_)) - case _ => get((s:S) => None) - } - } yield (r) - } - -} - - -case object Eval { - //def getVar[S, F <: Effects[S]](f: F)(s: S)(v: Variable): Option[Literal] = - // f.loadVar(v.name).f(s) match { - // case Scalar(l) => Some(l) - // case _ => None - //} - - //def doLoad[S, T <: Effects[S]](f: T)(s: S)(m: Memory, addr: Expr, endian: Endian, sz: Int): Option[Literal] = { - // addr match { - // case l: Literal if sz == 1 => ( - // loadSingle(f)(s)(m.name, Scalar(l)) match { - // case Scalar(v) => Some(v) - // case _ => None - // } - // ) - // case l: Literal => Some(loadBV(f)(s)(m.name, Scalar(l), endian, sz)) - // case _ => None - // } - //} - - def evalBV[S, T <: Effects[S]](f: T)(e: Expr): State[S, BitVecLiteral] = { - val ldr = StVarLoader[S, T](f) - for { - res <- ir.eval.statePartialEvalExpr[S](ldr)(e) - } yield ( - res match { - case l: BitVecLiteral => l - case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) - }) - } - - def evalInt[S, T <: Effects[S]](f: T)(e: Expr): State[S, BigInt] = { - val ldr = StVarLoader[S, T](f) - for { - res <- ir.eval.statePartialEvalExpr[S](ldr)(e) - } yield ( - res match { - case l: IntLiteral => l.value - case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) - }) - } - - def evalBool[S, T <: Effects[S]](f: T)(e: Expr): State[S, Boolean] = { - val ldr = StVarLoader[S, T](f) - for { - res <- ir.eval.statePartialEvalExpr[S](ldr)(e) - } yield ( - res match { - case l: BoolLit => l == TrueLiteral - case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) - }) - } - - - - /** State modifying helpers, e.g. store - */ - - /* Expand addr for number of values to store */ - def store[S, T <: Effects[S]](f: T)( - vname: String, - addr: BasilValue, - values: List[BasilValue], - endian: Endian - ): State[S, Unit] = for { - mem <- f.loadVar(vname) - (mapval, keytype, valtype) = mem match { - case m @ MapValue(_, MapType(kt, vt)) if kt == addr.irType && values.forall(v => v.irType == vt) => (m, kt, vt) - case v => throw InterpreterError(TypeError(s"Invalid map store operation to $vname : $v")) - } - keys = (0 until values.size).map(i => BasilValue.unsafeAdd(addr, i)) - vals = endian match { - case Endian.LittleEndian => values.reverse - case Endian.BigEndian => values - } - x <- f.storeMem(vname, keys.zip(vals).toMap) - } yield (x) - - - /** Extract bitvec to bytes and store bytes */ - def storeBV[S, T <: Effects[S]](f: T)( - vname: String, - addr: BasilValue, - value: BitVecLiteral, - endian: Endian - ): State[S, Unit] = for { - mem <- f.loadVar(vname) - (mapval, vsize) = mem match { - case m @ MapValue(_, MapType(kt, BitVecType(size))) if kt == addr.irType => (m, size) - case v => - throw InterpreterError( - TypeError( - s"Invalid map store operation to $vname : ${v.irType} (expect [${addr.irType}] <- ${value.getType})" - ) - ) - } - cells = value.size / vsize - _ = { - if (cells < 1) { - throw InterpreterError(MemoryError("Tried to execute fractional store")) - }} - - extractVals = (0 until cells).map(i => BitVectorEval.boogie_extract((i + 1) * vsize, i * vsize, value)).toList - vs = endian match { - case Endian.LittleEndian => extractVals.map(Scalar(_)) - case Endian.BigEndian => extractVals.reverse.map(Scalar(_)) - } - - keys = (0 until cells).map(i => BasilValue.unsafeAdd(addr, i)) - s <- f.storeMem(vname, keys.zip(vs).toMap) - } yield (s) - - def storeSingle[S, T <: Effects[S]](f: T)(vname: String, addr: BasilValue, value: BasilValue): State[S, Unit] = { - f.storeMem(vname, Map((addr -> value))) - } -} - - -enum Effect: - case Call(target: String, begin: ExecutionContinuation, returnTo: ExecutionContinuation) - case SetNext(c: ExecutionContinuation) - case Return - case StoreVar(v: String, s: Scope, value: BasilValue) - case StoreMem(vname: String, update: Map[BasilValue, BasilValue]) - - -// case class TracingInterpreter( -// val s: InterpreterState, -// val trace: List[Effect] -// ) extends Effects[TracingInterpreter] { -// -// def evalBV(e: Expr) = Eval.evalBV(this)(e) -// def evalInt(e: Expr) = Eval.evalInt(this)(e) -// def evalBool(e: Expr) = Eval.evalBool(this)(e) -// -// def loadVar(v: String) = -// def loadMem(v: String, addrs: List[BasilValue]) = s.loadMem(v, addrs) -// -// /** effects * */ -// def setNext(c: ExecutionContinuation) = { -// // Logger.debug(s" eff : DONEXT $c") -// TracingInterpreter(s.setNext(c), trace) -// } -// -// def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = { -// //Logger.debug(s" eff : CALL $target") -// TracingInterpreter(s.call(target, beginFrom, returnTo), Effect.Call(target, beginFrom, returnTo) :: trace) -// } -// -// def doReturn() = { -// //Logger.debug(s" eff : RETURN") -// TracingInterpreter(s.doReturn(), Effect.Return :: trace) -// } -// -// def storeVar(v: String, c: Scope, value: BasilValue) = { -// //Logger.debug(s" eff : SET $v := $value") -// TracingInterpreter(s.storeVar(v, c, value), Effect.StoreVar(v, c, value) :: trace) -// } -// -// def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = { -// //Logger.debug(s" eff : STORE $vname <- $update") -// TracingInterpreter(s.storeMem(vname, update), Effect.StoreMem(vname, update) :: trace) -// } -// -// def getNext = s.getNext -// -// } case class InterpreterState( val nextCmd: ExecutionContinuation = Stopped(), @@ -551,6 +302,8 @@ case class InterpreterState( val memoryState: MemoryState = MemoryState() ) +/** Implementation of Effects for InterpreterState concrete state representation. + */ object NormalInterpreter extends Effects[InterpreterState] { /* eval */ @@ -572,42 +325,40 @@ object NormalInterpreter extends Effects[InterpreterState] { }) } - def evalAddrToProc(addr: Int): State[InterpreterState, Option[FunPointer]] = + def evalAddrToProc(addr: Int): State[InterpreterState, Option[FunPointer]] = Logger.debug(s" eff : FIND PROC $addr") - val load = StVarLoader(this) for { - res <- get ((s: InterpreterState) => s.memoryState.doLoadOpt("funtable", List(Scalar(BitVecLiteral(addr, 64))))) - } yield { - res match { - case Some((f: FunPointer)::Nil) => Some(f) - case _ => None + res <- get((s: InterpreterState) => s.memoryState.doLoadOpt("funtable", List(Scalar(BitVecLiteral(addr, 64))))) + } yield { + res match { + case Some((f: FunPointer) :: Nil) => Some(f) + case _ => None + } } - } - def formatStore(varname: String, update: Map[BasilValue, BasilValue]) = { - val ks = update.toList.sortWith((x,y) => { - def conv(v:BasilValue): BigInt = v match { + def formatStore(varname: String, update: Map[BasilValue, BasilValue]) : String = { + val ks = update.toList.sortWith((x, y) => { + def conv(v: BasilValue): BigInt = v match { case (Scalar(b: BitVecLiteral)) => b.value - case (Scalar(b: IntLiteral)) => b.value - case _ => BigInt(0) + case (Scalar(b: IntLiteral)) => b.value + case _ => BigInt(0) } conv(x._1) <= conv(y._1) }) - val rs = ks.foldLeft(Some((None,List[BitVecLiteral]())): Option[(Option[BigInt], List[BitVecLiteral])])((acc, v) => + val rs = ks.foldLeft(Some((None, List[BitVecLiteral]())): Option[(Option[BigInt], List[BitVecLiteral])])((acc, v) => v match { - case (Scalar(bv : BitVecLiteral), Scalar(bv2 : BitVecLiteral)) => { - acc match { - case None => None - case Some(None, l) => Some(Some(bv.value), bv2::l) - case Some(Some(v), l) if bv.value == v + 1 => Some(Some(bv.value), bv2::l) - case Some(Some(v), l) => { - println(s"$v != ${bv.value} + 1") - None + case (Scalar(bv: BitVecLiteral), Scalar(bv2: BitVecLiteral)) => { + acc match { + case None => None + case Some(None, l) => Some(Some(bv.value), bv2 :: l) + case Some(Some(v), l) if bv.value == v + 1 => Some(Some(bv.value), bv2 :: l) + case Some(Some(v), l) => { + None + } } } - } - case (bv, bv2) => None + case (bv, bv2) => None } ) @@ -630,215 +381,46 @@ object NormalInterpreter extends Effects[InterpreterState] { }) } - def getNext = State.get ((s: InterpreterState) => s.nextCmd) + def getNext = State.get((s: InterpreterState) => s.nextCmd) /** effects * */ - def setNext(c: ExecutionContinuation) = State.modify ((s: InterpreterState) => { + def setNext(c: ExecutionContinuation) = State.modify((s: InterpreterState) => { // Logger.debug(s" eff : setNext $c") s.copy(nextCmd = c) }) - def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = modify ((s:InterpreterState) => { - Logger.debug(s" eff : CALL $target") - s.copy( - nextCmd=beginFrom, - callStack=returnTo :: s.callStack, - memoryState=s.memoryState.pushStackFrame(target) - ) - }) + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = + modify((s: InterpreterState) => { + Logger.debug(s" eff : CALL $target") + s.copy( + nextCmd = beginFrom, + callStack = returnTo :: s.callStack, + memoryState = s.memoryState.pushStackFrame(target) + ) + }) def doReturn() = { Logger.debug(s" eff : RETURN") - modify ((s: InterpreterState) => {s.callStack match { - case Nil => s.copy(nextCmd=Stopped()) - case h :: tl => s.copy(nextCmd=h,callStack=tl,memoryState=s.memoryState.popStackFrame()) - } + modify((s: InterpreterState) => { + s.callStack match { + case Nil => s.copy(nextCmd = Stopped()) + case h :: tl => s.copy(nextCmd = h, callStack = tl, memoryState = s.memoryState.popStackFrame()) + } }) } - def storeVar(v: String, scope: Scope, value: BasilValue) : State[InterpreterState, Unit] = { + def storeVar(v: String, scope: Scope, value: BasilValue): State[InterpreterState, Unit] = { Logger.debug(s" eff : SET $v := $value") - State.modify ((s: InterpreterState) => s.copy(memoryState=s.memoryState.defVar(v, scope, value))) + State.modify((s: InterpreterState) => s.copy(memoryState = s.memoryState.defVar(v, scope, value))) } - def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = State.modify ((s:InterpreterState) => { + def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = State.modify((s: InterpreterState) => { Logger.debug(s" eff : STORE ${formatStore(vname, update)}") - s.copy(memoryState=s.memoryState.doStore(vname, update)) + s.copy(memoryState = s.memoryState.doStore(vname, update)) }) } -case object InterpFuns { - - def initialState[S, T <: Effects[S]](s: T): State[S, Unit] = { - val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) - val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) - val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) - - for { - h <- s.storeVar("funtable", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(64)))) - h <- s.storeVar("mem", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) - i <- s.storeVar("stack", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) - j <- s.storeVar("R31", Scope.Global, Scalar(SP)) - k <- s.storeVar("R29", Scope.Global, Scalar(FP)) - l <- s.storeVar("R30", Scope.Global, Scalar(LR)) - } yield (l) - } - - def initialiseProgram[S, T <: Effects[S]](f:T)(p: Program): State[S, Unit] = { - def initMemory(mem: String, mems: Iterable[MemorySection]) ={ - for { - m <- State.sequence(State.pure(()), mems.filter(m => m.address != 0).map(memory => - Eval.store(f)( - mem, - Scalar(BitVecLiteral(memory.address, 64)), - memory.bytes.toList.map(Scalar(_)), - Endian.LittleEndian))) - } yield () - } - - println(p.initialMemory) - - for { - d <- initialState(f) - funs <- State.sequence(State.pure(()), p.procedures.filter(p => p.blocks.nonEmpty && p.address.isDefined).map((proc: Procedure) => - Eval.storeSingle(f)( - "funtable", - Scalar(BitVecLiteral(proc.address.get, 64)), - FunPointer(BitVecLiteral(proc.address.get, 64), proc.name, Run(IRWalk.firstInBlock(proc.entryBlock.get))) - ))) - mem <- initMemory("mem", p.initialMemory) - mem <- initMemory("stack", p.initialMemory) - mem <- initMemory("mem", p.readOnlyMemory) - mem <- initMemory("stack", p.readOnlyMemory) - r <- f.call(p.mainProcedure.name, Run(IRWalk.firstInBlock(p.mainProcedure.entryBlock.get)), Stopped()) - } yield (r) - } - - def interpretJump[S, T <: Effects[S]](f: T)(j: Jump): State[S, Unit] = { - j match { - case gt: GoTo if gt.targets.size == 1 => { - f.setNext(Run(IRWalk.firstInBlock(gt.targets.head))) - } - case gt: GoTo => - val assumes = gt.targets.flatMap(_.statements.headOption).collect { - case a: Assume => a - } - if (assumes.size != gt.targets.size) { - throw InterpreterError(Errored(s"Some goto target missing guard $gt")) - } - for { - chosen : List[Assume] <- filterM((a:Assume) => f.evalBool(a.body), assumes) - - res <- chosen match { - case Nil => f.setNext(Errored(s"No jump target satisfied $gt")) - case h :: Nil => f.setNext(Run(h)) - case h :: tl => f.setNext(Errored(s"More than one jump guard satisfied $gt")) - } - } yield (res) - case r: Return => f.doReturn() - case h: Unreachable => f.setNext(EscapedControlFlow(h)) - } - } - - def interpretStatement[S, T <: Effects[S]](f: T)(s: Statement): State[S, Unit] = { - s match { - case assign: Assign => { - for { - rhs <- f.evalBV(assign.rhs) - st <- f.storeVar(assign.lhs.name, assign.lhs.toBoogie.scope, Scalar(rhs)) - n <- f.setNext(Run(s.successor)) - } yield (st) - } - case assign: MemoryAssign => for { - index : BitVecLiteral <- f.evalBV(assign.index) - value : BitVecLiteral <- f.evalBV(assign.value) - _ <- Eval.storeBV(f)(assign.mem.name, Scalar(index), value, assign.endian) - n <- f.setNext(Run(s.successor)) - } yield (n) - case assert: Assert => for { - b <- f.evalBool(assert.body) - n <- (if (!b) then { - f.setNext(FailedAssertion(assert)) - } else { - f.setNext(Run(s.successor)) - }) - } yield (n) - case assume: Assume => for { - b <- f.evalBool(assume.body) - n <- (if (!b) { - f.setNext(Errored(s"Assumption not satisfied: $assume")) - } else { - f.setNext(Run(s.successor)) - }) - } yield (n) - case dc: DirectCall => for { - n <- if (dc.target.entryBlock.isDefined) { - val block = dc.target.entryBlock.get - f.call(dc.target.name, Run(block.statements.headOption.getOrElse(block.jump)), Run(dc.successor)) - } else { - f.setNext(Run(dc.successor)) - } - } yield (n) - case ic: IndirectCall => { - if (ic.target == Register("R30", 64)) { - f.doReturn() - } else { - for { - addr <- f.evalBV(ic.target) - fp <- f.evalAddrToProc(addr.value.toInt) - _ <- fp match { - case Some(fp) => f.call(fp.name, fp.call, Run(ic.successor)) - case none => f.setNext(EscapedControlFlow(ic)) - } - } yield () - } - } - case _: NOP => f.setNext(Run(s.successor)) - } - } - - - def interpret[S, T <: Effects[S]](f: T, m: S): S = { - val next = State.evaluate(m, f.getNext) - Logger.debug(s"eval $next") - next match { - case Run(c: Statement) => interpret(f, - protect((() => execute(m, interpretStatement(f)(c))), - { - case x @ InterpreterError(e) => { - Logger.error(s"${x.getStackTrace.mkString("\n")}") - execute(m, f.setNext(e)) - } - case e: IllegalArgumentException => execute(m, f.setNext(Errored(e.toString))) - } - )) - case Run(c: Jump) => interpret(f, - protect((() => execute(m, interpretJump(f)(c))), - { - case x @ InterpreterError(e) => { - Logger.error(s"${x.getStackTrace.mkString("\n")}") - execute(m, f.setNext(e)) - } - case e: IllegalArgumentException => execute(m, f.setNext(Errored(e.toString))) - } - )) - case Stopped() => m - case errorstop => m - } - } - - def interpretProg[S, T <: Effects[S]](f: T)(p: Program, is: S): S = { - val begin = State.execute(is, initialiseProgram(f)(p)) - // State.execute[S,Unit](is, ) - interpret(f, begin) - } - -} - -def interpret(IRProgram: Program): InterpreterState = { - InterpFuns.interpretProg(NormalInterpreter)(IRProgram, InterpreterState()) -} // def interpretTrace(IRProgram: Program): TracingInterpreter = { // val s: TracingInterpreter = InterpFuns.interpretProg(IRProgram, TracingInterpreter(InterpreterState(), List())) diff --git a/src/main/scala/ir/eval/InterpreterProduct.scala b/src/main/scala/ir/eval/InterpreterProduct.scala new file mode 100644 index 000000000..53bc8f779 --- /dev/null +++ b/src/main/scala/ir/eval/InterpreterProduct.scala @@ -0,0 +1,98 @@ + +package ir.eval +import ir._ +import ir.eval.BitVectorEval.* +import ir.* +import util.Logger +import util.functional.* +import util.functional.State.* +import boogie.Scope +import scala.collection.WithFilter + +import scala.annotation.tailrec +import scala.collection.mutable +import scala.collection.immutable +import scala.util.control.Breaks.{break, breakable} + +/** + * Runs two interpreters independently, the returns the value from inner, and ignores before + */ +case class ProductInterpreter[L, T](val inner: Effects[L], val before: Effects[T]) extends Effects[(L, T)] { + def doLeft[V](f: State[L, V]) : State[(L, T), V] = for { + f <- State[(L, T), V]((s: (L, T)) => { + val r = f.f(s._1) + ((r._1, s._2), r._2) + }) + } yield (f) + + def doRight[V](f: State[T, V]) : State[(L, T), V] = for { + f <- State[(L, T), V]((s: (L, T)) => { + val r = f.f(s._2) + ((s._1, r._1), r._2) + }) + } yield (f) + + + def evalBV(e: Expr) = for { + n <- doRight(before.evalBV(e)) + f <- doLeft(inner.evalBV(e)) + } yield (f) + + def evalInt(e: Expr) = for { + n <- doRight(before.evalBV(e)) + f <- doLeft(inner.evalInt(e)) + } yield (f) + + def evalBool(e: Expr) = for { + n <- doRight(before.evalBool(e)) + f <- doLeft(inner.evalBool(e)) + } yield (f) + + def loadVar(v: String) = for { + n <- doRight(before.loadVar(v)) + f <- doLeft(inner.loadVar(v)) + } yield (f) + + def loadMem(v: String, addrs: List[BasilValue]) = for { + n <- doRight(before.loadMem(v, addrs)) + f <- doLeft(inner.loadMem(v, addrs)) + } yield (f) + + def evalAddrToProc(addr: Int) = for { + n <- doRight(before.evalAddrToProc(addr: Int)) + f <- doLeft(inner.evalAddrToProc(addr)) + } yield(f) + + def getNext = for { + n <- doRight(before.getNext) + f <- doLeft(inner.getNext) + } yield(f) + + /** state effects */ + def setNext(c: ExecutionContinuation) = for { + n <- doRight(before.setNext(c)) + f <- doLeft(inner.setNext(c)) + } yield (f) + + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = for { + n <- doRight(before.call(target, beginFrom, returnTo)) + f <- doLeft(inner.call(target, beginFrom, returnTo)) + } yield (f) + + def doReturn() = for { + n <- doRight(before.doReturn()) + f <- doLeft(inner.doReturn()) + } yield (f) + + def storeVar(v: String, scope: Scope, value: BasilValue) = for { + n <- doRight(before.storeVar(v, scope, value)) + f <- doLeft(inner.storeVar(v, scope, value)) + } yield(f) + + def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = for { + n <- doRight(before.storeMem(vname,update)) + f <- doLeft(inner.storeMem(vname, update)) + } yield(f) +} + + diff --git a/src/main/scala/util/PerformanceTimer.scala b/src/main/scala/util/PerformanceTimer.scala index a0be917f3..d233513cb 100644 --- a/src/main/scala/util/PerformanceTimer.scala +++ b/src/main/scala/util/PerformanceTimer.scala @@ -15,7 +15,7 @@ case class PerformanceTimer(timerName: String = "") { Logger.info(s"PerformanceTimer $timerName [$name]: ${delta}ms") delta } - private def elapsed() : Long = { + def elapsed() : Long = { System.currentTimeMillis() - lastCheckpoint } diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index f6bdd0198..51c2e7218 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -1,5 +1,6 @@ package ir +import util.PerformanceTimer import util.functional._ import ir.eval._ import ir.dsl._ @@ -28,7 +29,6 @@ import util.ILLoadingConfig // def initialMem() = InterpFuns.initialState(InterpreterState(), List()) def load(s: InterpreterState, global: SpecGlobal) : Option[BitVecLiteral] = { - println(s) val f = NormalInterpreter // i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems) // m.evalBV("mem", BitVecLiteral(64, global.address), Endian.LittleEndian, global.size) // i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems) @@ -135,7 +135,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { val s = for { _ <- InterpFuns.initialState(NormalInterpreter) _ <- Eval.store(NormalInterpreter)("mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) - r <- loader.loadBV("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) + r <- Eval.loadBV(NormalInterpreter)("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) } yield(r) val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) val actual: BitVecLiteral = State.evaluate(InterpreterState(), s) @@ -323,14 +323,15 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { } - def fibonacciProg(n: Int) = { - def expected(n: Int) : Int = { - n match { - case 0 => 0 - case 1 => 1 - case n => expected(n - 1) + expected(n - 2) - } + def fib(n: Int) : Int = { + n match { + case 0 => 0 + case 1 => 1 + case n => fib(n - 1) + fib(n - 2) } + } + + def fibonacciProg(n: Int) = { prog( proc("begin", block("entry", @@ -340,7 +341,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { goto("done") ), block("done", - Assert(BinaryExpr(BVEQ, R0, bv64(expected(n)))), + Assert(BinaryExpr(BVEQ, R0, bv64(fib(n)))), ret )), proc("fib", @@ -378,27 +379,55 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { test("fibonacci") { + Logger.setLevel(LogLevel.ERROR) val fib = fibonacciProg(8) val r = interpret(fib) assert(r.nextCmd == Stopped()) // Show interpreted result - Logger.info("Registers:") // r.regs.foreach { (key, value) => // Logger.info(s"$key := $value") // } } + test("fibonaccistress") { -// test("fibonacci Trace") { -// -// val fib = fibonacciProg(8) -// val r = interpretTrace(fib) -// assert(r.getNext == Stopped()) -// // Show interpreted result -// // -// info(r.trace.reverse.mkString("\n")) -// -// } + Logger.setLevel(LogLevel.ERROR) + var res = List[(Int, Double, Double)]() + + for (i <- 0 to 25) { + val prog = fibonacciProg(i) + + val t = PerformanceTimer("native") + val r = fib(i) + val native = t.elapsed() + + val intt = PerformanceTimer("interp") + val ir = interpret(prog) + val it = intt.elapsed() + + res = (i,native,it)::res + + println(s"${res.head}") + } + + println(("fib number,native time,interp time"::(res.map(x => s"${x._1},${x._2},${x._3}"))).mkString("\n")) + + } + + + + test("fibonacci Trace") { + + val fib = fibonacciProg(8) + + val r = interpretTrace(fib) + + assert(r._1.nextCmd == Stopped()) + info(r._2.t.mkString("\n")) + // Show interpreted result + // + + } } From f64f904999192fee7e3a52c66ec514aaddc56f6c Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 28 Aug 2024 17:26:57 +1000 Subject: [PATCH 15/62] breakpoints --- src/main/scala/ir/eval/InterpretBasilIR.scala | 29 +----- .../scala/ir/eval/InterpretBasilTrace.scala | 14 ++- src/main/scala/ir/eval/Interpreter.scala | 38 ++++++++ .../scala/ir/eval/InterpreterProduct.scala | 89 ++++++++++++++++++- src/test/scala/ir/InterpreterTests.scala | 12 +++ 5 files changed, 151 insertions(+), 31 deletions(-) diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index 8cf8f384b..7eb6cf671 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -356,34 +356,7 @@ case object InterpFuns { val next = State.evaluate(m, f.getNext) Logger.debug(s"eval $next") next match { - case Run(c: Statement) => - interpret( - f, - protect( - (() => execute(m, interpretStatement(f)(c))), - { - case x @ InterpreterError(e) => { - Logger.error(s"${x.getStackTrace.mkString("\n")}") - execute(m, f.setNext(e)) - } - case e: IllegalArgumentException => execute(m, f.setNext(Errored(e.toString))) - } - ) - ) - case Run(c: Jump) => - interpret( - f, - protect( - (() => execute(m, interpretJump(f)(c))), - { - case x @ InterpreterError(e) => { - Logger.error(s"${x.getStackTrace.mkString("\n")}") - execute(m, f.setNext(e)) - } - case e: IllegalArgumentException => execute(m, f.setNext(Errored(e.toString))) - } - ) - ) + case Run(c) => interpret(f, State.execute(m, f.interpretOne)) case Stopped() => m case errorstop => m } diff --git a/src/main/scala/ir/eval/InterpretBasilTrace.scala b/src/main/scala/ir/eval/InterpretBasilTrace.scala index 4ba8c6836..4ae1af9ed 100644 --- a/src/main/scala/ir/eval/InterpretBasilTrace.scala +++ b/src/main/scala/ir/eval/InterpretBasilTrace.scala @@ -16,7 +16,6 @@ import scala.util.control.Breaks.{break, breakable} enum ExecEffect: case Call(target: String, begin: ExecutionContinuation, returnTo: ExecutionContinuation) - case SetNext(c: ExecutionContinuation) case Return case StoreVar(v: String, s: Scope, value: BasilValue) case LoadVar(v: String) @@ -32,7 +31,6 @@ case object Trace { } } - object TraceGen extends Effects[Trace] { /** Values are discarded by ProductInterpreter so do not matter */ def evalBV(e: Expr) = State.pure(BitVecLiteral(0,0)) @@ -72,6 +70,8 @@ object TraceGen extends Effects[Trace] { def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = for { s <- Trace.add(ExecEffect.StoreMem(vname, update)) } yield (()) + + def interpretOne = State.pure(()) } def tracingInterpreter = ProductInterpreter(NormalInterpreter, TraceGen) @@ -80,3 +80,13 @@ def interpretTrace(p: Program) : (InterpreterState, Trace) = { InterpFuns.interpretProg(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) } + +case class RememberBreakpoints[T, I <: Effects[T]](val f: I, val breaks: Set[Command]) extends NopEffects[(T, List[(Command, T)])] { + override def interpretOne = { + State.modify ((thisState, sl) => State.evaluate(thisState, f.getNext) match { + case Run(s) if breaks.contains(s) => (thisState, (s,thisState)::sl) + case _ => (thisState, sl) + }) + } +} + diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 12ed90010..7834daddd 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -89,6 +89,10 @@ case object BasilValue { * defined for the interpreter's concrete state T. */ trait Effects[T] { + + // perform an execution step + def interpretOne: State[T, Unit] + /* evaluation (may side-effect on error) */ def evalBV(e: Expr): State[T, BitVecLiteral] @@ -125,6 +129,25 @@ trait Effects[T] { } + +trait NopEffects[T] extends Effects[T] { + def interpretOne = State.pure(()) + def evalBV(e: Expr) = State.pure(BitVecLiteral(0,0)) + def evalInt(e: Expr) = State.pure(BigInt(0)) + def evalBool(e: Expr) = State.pure(false) + def loadVar(v: String) = State.pure(Scalar(FalseLiteral)) + def loadMem(v: String, addrs: List[BasilValue]) = State.pure(List()) + def evalAddrToProc(addr: Int) = State.pure(None) + def getNext = State.pure(Stopped()) + def setNext(c: ExecutionContinuation) = State.pure(()) + + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = State.pure(()) + def doReturn() = State.pure(()) + + def storeVar(v: String, scope: Scope, value: BasilValue) = State.pure(()) + def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = State.pure(()) +} + /** -------------------------------------------------------------------------------- * Definition of concrete state * -------------------------------------------------------------------------------- */ @@ -419,6 +442,21 @@ object NormalInterpreter extends Effects[InterpreterState] { s.copy(memoryState = s.memoryState.doStore(vname, update)) }) + def interpretOne: State[InterpreterState, Unit] = for { + next <- getNext + _ <- try { + next match { + case Run(c: Statement) => InterpFuns.interpretStatement(this)(c) + case Run(c: Jump) => InterpFuns.interpretJump(this)(c) + case Stopped() => State.pure (()) + case errorstop => State.pure (()) + } + } catch { + case InterpreterError(e) => setNext(e) + case e: java.lang.IllegalArgumentException => setNext(Errored(e.getStackTrace.take(5).mkString("\n"))) + } + } yield () + } diff --git a/src/main/scala/ir/eval/InterpreterProduct.scala b/src/main/scala/ir/eval/InterpreterProduct.scala index 53bc8f779..e908cdd84 100644 --- a/src/main/scala/ir/eval/InterpreterProduct.scala +++ b/src/main/scala/ir/eval/InterpreterProduct.scala @@ -15,7 +15,7 @@ import scala.collection.immutable import scala.util.control.Breaks.{break, breakable} /** - * Runs two interpreters independently, the returns the value from inner, and ignores before + * Runs two interpreters inner and before simultaneously, returning the value from inner, and ignoring before */ case class ProductInterpreter[L, T](val inner: Effects[L], val before: Effects[T]) extends Effects[(L, T)] { def doLeft[V](f: State[L, V]) : State[(L, T), V] = for { @@ -33,6 +33,11 @@ case class ProductInterpreter[L, T](val inner: Effects[L], val before: Effects[T } yield (f) + def interpretOne = for { + n <- doRight(before.interpretOne) + f <- doLeft(inner.interpretOne) + } yield () + def evalBV(e: Expr) = for { n <- doRight(before.evalBV(e)) f <- doLeft(inner.evalBV(e)) @@ -96,3 +101,85 @@ case class ProductInterpreter[L, T](val inner: Effects[L], val before: Effects[T } +case class LayerInterpreter[L, T](val inner: Effects[L], val before: Effects[(L, T)]) extends Effects[(L, T)] { + def doLeft[V](f: State[L, V]) : State[(L, T), V] = for { + f <- State[(L, T), V]((s: (L, T)) => { + val r = f.f(s._1) + ((r._1, s._2), r._2) + }) + } yield (f) + + def doRight[V](f: State[T, V]) : State[(L, T), V] = for { + f <- State[(L, T), V]((s: (L, T)) => { + val r = f.f(s._2) + ((s._1, r._1), r._2) + }) + } yield (f) + + def interpretOne = for { + n <- (before.interpretOne) + f <- doLeft(inner.interpretOne) + } yield () + + def evalBV(e: Expr) = for { + n <- (before.evalBV(e)) + f <- doLeft(inner.evalBV(e)) + } yield (f) + + def evalInt(e: Expr) = for { + n <- (before.evalBV(e)) + f <- doLeft(inner.evalInt(e)) + } yield (f) + + def evalBool(e: Expr) = for { + n <- (before.evalBool(e)) + f <- doLeft(inner.evalBool(e)) + } yield (f) + + def loadVar(v: String) = for { + n <- (before.loadVar(v)) + f <- doLeft(inner.loadVar(v)) + } yield (f) + + def loadMem(v: String, addrs: List[BasilValue]) = for { + n <- (before.loadMem(v, addrs)) + f <- doLeft(inner.loadMem(v, addrs)) + } yield (f) + + def evalAddrToProc(addr: Int) = for { + n <- (before.evalAddrToProc(addr: Int)) + f <- doLeft(inner.evalAddrToProc(addr)) + } yield(f) + + def getNext = for { + n <- (before.getNext) + f <- doLeft(inner.getNext) + } yield(f) + + /** state effects */ + def setNext(c: ExecutionContinuation) = for { + n <- (before.setNext(c)) + f <- doLeft(inner.setNext(c)) + } yield (f) + + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = for { + n <- (before.call(target, beginFrom, returnTo)) + f <- doLeft(inner.call(target, beginFrom, returnTo)) + } yield (f) + + def doReturn() = for { + n <- (before.doReturn()) + f <- doLeft(inner.doReturn()) + } yield (f) + + def storeVar(v: String, scope: Scope, value: BasilValue) = for { + n <- (before.storeVar(v, scope, value)) + f <- doLeft(inner.storeVar(v, scope, value)) + } yield(f) + + def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = for { + n <- (before.storeMem(vname,update)) + f <- doLeft(inner.storeMem(vname, update)) + } yield(f) +} + diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 51c2e7218..6fec8bf78 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -430,4 +430,16 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { } + test("fib breakpoints") { + + // Logger.setLevel(LogLevel.ERROR) + val fib = fibonacciProg(8) + val watch = IRWalk.firstInProc((fib.procedures.find(_.name == "fib")).get) + val interp = LayerInterpreter(NormalInterpreter, RememberBreakpoints(NormalInterpreter, Set(watch))) + val res = InterpFuns.interpretProg(interp)(fib, (InterpreterState(), List[(Command, InterpreterState)]())) + println(res) + + + } + } From 897c5655743429ab13c6dfb9995aafca3c2a4b79 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Thu, 29 Aug 2024 11:44:02 +1000 Subject: [PATCH 16/62] improve breakpoints --- src/main/scala/ir/eval/InterpretBasilIR.scala | 21 ++++-- .../scala/ir/eval/InterpretBasilTrace.scala | 62 +++++++++++++-- src/main/scala/ir/eval/Interpreter.scala | 23 ------ .../scala/ir/eval/InterpreterProduct.scala | 75 ++++--------------- src/main/scala/util/functional.scala | 6 ++ src/test/scala/ir/InterpreterTests.scala | 9 ++- 6 files changed, 96 insertions(+), 100 deletions(-) diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index 7eb6cf671..35b5000fb 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -57,6 +57,13 @@ case object Eval { /* Eval functions */ /*--------------------------------------------------------------------------------*/ + def evalExpr[S, T <: Effects[S]](f: T)(e: Expr): State[S, Expr] = { + val ldr = StVarLoader[S, T](f) + for { + res <- ir.eval.statePartialEvalExpr[S](ldr)(e) + } yield (res) + } + def evalBV[S, T <: Effects[S]](f: T)(e: Expr): State[S, BitVecLiteral] = { val ldr = StVarLoader[S, T](f) for { @@ -275,7 +282,7 @@ case object InterpFuns { throw InterpreterError(Errored(s"Some goto target missing guard $gt")) } for { - chosen: List[Assume] <- filterM((a: Assume) => f.evalBool(a.body), assumes) + chosen: List[Assume] <- filterM((a: Assume) => Eval.evalBool(f)(a.body), assumes) res <- chosen match { case Nil => f.setNext(Errored(s"No jump target satisfied $gt")) @@ -292,21 +299,21 @@ case object InterpFuns { s match { case assign: Assign => { for { - rhs <- f.evalBV(assign.rhs) + rhs <- Eval.evalBV(f)(assign.rhs) st <- f.storeVar(assign.lhs.name, assign.lhs.toBoogie.scope, Scalar(rhs)) n <- f.setNext(Run(s.successor)) } yield (st) } case assign: MemoryAssign => for { - index: BitVecLiteral <- f.evalBV(assign.index) - value: BitVecLiteral <- f.evalBV(assign.value) + index: BitVecLiteral <- Eval.evalBV(f)(assign.index) + value: BitVecLiteral <- Eval.evalBV(f)(assign.value) _ <- Eval.storeBV(f)(assign.mem.name, Scalar(index), value, assign.endian) n <- f.setNext(Run(s.successor)) } yield (n) case assert: Assert => for { - b <- f.evalBool(assert.body) + b <- Eval.evalBool(f)(assert.body) n <- (if (!b) then { f.setNext(FailedAssertion(assert)) @@ -316,7 +323,7 @@ case object InterpFuns { } yield (n) case assume: Assume => for { - b <- f.evalBool(assume.body) + b <- Eval.evalBool(f)(assume.body) n <- (if (!b) { f.setNext(Errored(s"Assumption not satisfied: $assume")) @@ -339,7 +346,7 @@ case object InterpFuns { f.doReturn() } else { for { - addr <- f.evalBV(ic.target) + addr <- Eval.evalBV(f)(ic.target) fp <- f.evalAddrToProc(addr.value.toInt) _ <- fp match { case Some(fp) => f.call(fp.name, fp.call, Run(ic.successor)) diff --git a/src/main/scala/ir/eval/InterpretBasilTrace.scala b/src/main/scala/ir/eval/InterpretBasilTrace.scala index 4ae1af9ed..fba0adb11 100644 --- a/src/main/scala/ir/eval/InterpretBasilTrace.scala +++ b/src/main/scala/ir/eval/InterpretBasilTrace.scala @@ -27,7 +27,7 @@ case class Trace(val t: List[ExecEffect]) case object Trace { def add(e: ExecEffect) : State[Trace, Unit] = { - modify ((t: Trace) => Trace(t.t.appended(e))) + State.modify ((t: Trace) => Trace(t.t.appended(e))) } } @@ -80,13 +80,61 @@ def interpretTrace(p: Program) : (InterpreterState, Trace) = { InterpFuns.interpretProg(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) } +enum BreakPointLoc: + case CMD(c: Command) + case CMDCond(c: Command, condition: Expr) -case class RememberBreakpoints[T, I <: Effects[T]](val f: I, val breaks: Set[Command]) extends NopEffects[(T, List[(Command, T)])] { - override def interpretOne = { - State.modify ((thisState, sl) => State.evaluate(thisState, f.getNext) match { - case Run(s) if breaks.contains(s) => (thisState, (s,thisState)::sl) - case _ => (thisState, sl) - }) +case class BreakPointAction(saveState: Boolean = true, stop: Boolean = false, evalExprs: List[Expr] = List(), log: Boolean = false) + +case class BreakPoint(name: String = "", location: BreakPointLoc, action: BreakPointAction) + +case class RememberBreakpoints[T, I <: Effects[T]](val f: I, val breaks: List[BreakPoint]) extends NopEffects[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])])] { + + + def findBreaks[R](c: Command) : State[(T,R), List[BreakPoint]] = { + State.filterM[BreakPoint, (T,R)](b => b.location match { + case BreakPointLoc.CMD(bc) if (bc == c) => State.pure(true) + case BreakPointLoc.CMDCond(bc, e) if bc == c => doLeft(Eval.evalBool(f)(e)) + case _ => State.pure(false) + }, breaks) } + + override def interpretOne : State[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])]), Unit] = for { + v : ExecutionContinuation <- doLeft(f.getNext) + n <- v match { + case Run(s) => for { + breaks : List[BreakPoint] <- findBreaks(s) + res <- State.sequence[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])]), Unit](State.pure(()), + breaks.map((breakpoint: BreakPoint) => (breakpoint match { + case breakpoint @ BreakPoint(name, stopcond, action) => (for { + saved <- doLeft(if action.saveState then State.getS[T].map(s => Some(s)) else State.pure(None)) + evals <- (State.mapM((e:Expr) => for { + ev <- doLeft(Eval.evalExpr(f)(e)) + } yield (e, ev) + , action.evalExprs)) + _ <- if action.stop then doLeft(f.setNext(Errored(s"Stopped at breakpoint ${name}"))) else doLeft(State.pure(())) + _ <- State.pure({ + if (action.log) { + val bpn = breakpoint.name + val bpcond = breakpoint.location match { + case BreakPointLoc.CMD(c) => c.toString + case BreakPointLoc.CMDCond(c, e) => s"$c when $e" + } + val saving = if action.saveState then " stashing state, " else "" + val stopping = if action.stop then " stopping. " else "" + val evalstr = evals.map(e => s"\n eval(${e._1}) = ${e._2}").mkString("") + Logger.warn(s"Breakpoint $bpn@$bpcond.$saving$stopping$evalstr") + } + }) + _ <- State.modify ((istate:(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])])) => + (istate._1, ((breakpoint, saved, evals)::istate._2))) + } yield () + ) + }))) + } yield () + case _ => State.pure(()) + } + } yield () + } diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 7834daddd..52d74d3cc 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -93,14 +93,6 @@ trait Effects[T] { // perform an execution step def interpretOne: State[T, Unit] - /* evaluation (may side-effect on error) */ - - def evalBV(e: Expr): State[T, BitVecLiteral] - - def evalInt(e: Expr): State[T, BigInt] - - def evalBool(e: Expr): State[T, Boolean] - def loadVar(v: String): State[T, BasilValue] def loadMem(v: String, addrs: List[BasilValue]): State[T, List[BasilValue]] @@ -132,9 +124,6 @@ trait Effects[T] { trait NopEffects[T] extends Effects[T] { def interpretOne = State.pure(()) - def evalBV(e: Expr) = State.pure(BitVecLiteral(0,0)) - def evalInt(e: Expr) = State.pure(BigInt(0)) - def evalBool(e: Expr) = State.pure(false) def loadVar(v: String) = State.pure(Scalar(FalseLiteral)) def loadMem(v: String, addrs: List[BasilValue]) = State.pure(List()) def evalAddrToProc(addr: Int) = State.pure(None) @@ -329,18 +318,6 @@ case class InterpreterState( */ object NormalInterpreter extends Effects[InterpreterState] { - /* eval */ - def evalBV(e: Expr) = { - Eval.evalBV(this)(e) - } - - def evalInt(e: Expr) = { - Eval.evalInt(this)(e) - } - - def evalBool(e: Expr) = { - Eval.evalBool(this)(e) - } def loadVar(v: String) = { State.get((s: InterpreterState) => { diff --git a/src/main/scala/ir/eval/InterpreterProduct.scala b/src/main/scala/ir/eval/InterpreterProduct.scala index e908cdd84..75e9a50fa 100644 --- a/src/main/scala/ir/eval/InterpreterProduct.scala +++ b/src/main/scala/ir/eval/InterpreterProduct.scala @@ -14,45 +14,30 @@ import scala.collection.mutable import scala.collection.immutable import scala.util.control.Breaks.{break, breakable} -/** - * Runs two interpreters inner and before simultaneously, returning the value from inner, and ignoring before - */ -case class ProductInterpreter[L, T](val inner: Effects[L], val before: Effects[T]) extends Effects[(L, T)] { - def doLeft[V](f: State[L, V]) : State[(L, T), V] = for { - f <- State[(L, T), V]((s: (L, T)) => { - val r = f.f(s._1) - ((r._1, s._2), r._2) - }) - } yield (f) - def doRight[V](f: State[T, V]) : State[(L, T), V] = for { - f <- State[(L, T), V]((s: (L, T)) => { - val r = f.f(s._2) - ((s._1, r._1), r._2) - }) - } yield (f) +def doLeft[L, T, V](f: State[L, V]) : State[(L, T), V] = for { + f <- State[(L, T), V]((s: (L, T)) => { + val r = f.f(s._1) + ((r._1, s._2), r._2) + }) +} yield (f) +def doRight[L, T, V](f: State[T, V]) : State[(L, T), V] = for { + f <- State[(L, T), V]((s: (L, T)) => { + val r = f.f(s._2) + ((s._1, r._1), r._2) + }) +} yield (f) +/** + * Runs two interpreters "inner" and "before" simultaneously, returning the value from inner, and ignoring before + */ +case class ProductInterpreter[L, T](val inner: Effects[L], val before: Effects[T]) extends Effects[(L, T)] { def interpretOne = for { n <- doRight(before.interpretOne) f <- doLeft(inner.interpretOne) } yield () - def evalBV(e: Expr) = for { - n <- doRight(before.evalBV(e)) - f <- doLeft(inner.evalBV(e)) - } yield (f) - - def evalInt(e: Expr) = for { - n <- doRight(before.evalBV(e)) - f <- doLeft(inner.evalInt(e)) - } yield (f) - - def evalBool(e: Expr) = for { - n <- doRight(before.evalBool(e)) - f <- doLeft(inner.evalBool(e)) - } yield (f) - def loadVar(v: String) = for { n <- doRight(before.loadVar(v)) f <- doLeft(inner.loadVar(v)) @@ -102,40 +87,12 @@ case class ProductInterpreter[L, T](val inner: Effects[L], val before: Effects[T case class LayerInterpreter[L, T](val inner: Effects[L], val before: Effects[(L, T)]) extends Effects[(L, T)] { - def doLeft[V](f: State[L, V]) : State[(L, T), V] = for { - f <- State[(L, T), V]((s: (L, T)) => { - val r = f.f(s._1) - ((r._1, s._2), r._2) - }) - } yield (f) - - def doRight[V](f: State[T, V]) : State[(L, T), V] = for { - f <- State[(L, T), V]((s: (L, T)) => { - val r = f.f(s._2) - ((s._1, r._1), r._2) - }) - } yield (f) def interpretOne = for { n <- (before.interpretOne) f <- doLeft(inner.interpretOne) } yield () - def evalBV(e: Expr) = for { - n <- (before.evalBV(e)) - f <- doLeft(inner.evalBV(e)) - } yield (f) - - def evalInt(e: Expr) = for { - n <- (before.evalBV(e)) - f <- doLeft(inner.evalInt(e)) - } yield (f) - - def evalBool(e: Expr) = for { - n <- (before.evalBool(e)) - f <- doLeft(inner.evalBool(e)) - } yield (f) - def loadVar(v: String) = for { n <- (before.loadVar(v)) f <- doLeft(inner.loadVar(v)) diff --git a/src/main/scala/util/functional.scala b/src/main/scala/util/functional.scala index 410fa5f71..e666bee36 100644 --- a/src/main/scala/util/functional.scala +++ b/src/main/scala/util/functional.scala @@ -48,6 +48,12 @@ object State { def filterM[A, S](m : (A => State[S, Boolean]), xs: Iterable[A]): State[S, List[A]] = { xs.foldRight(pure(List[A]()))((b,acc) => acc.flatMap(c => m(b).map(v => if v then b::c else c))) } + + def mapM[A, B, S](m : (A => State[S, B]), xs: Iterable[A]): State[S, List[B]] = { + xs.foldRight(pure(List[B]()))((b,acc) => acc.flatMap(c => m(b).map(v => v::c))) + } + + } def protect[T](x: () => T, fnly: PartialFunction[Exception, T]): T = { diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 6fec8bf78..d92161460 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -34,7 +34,7 @@ def load(s: InterpreterState, global: SpecGlobal) : Option[BitVecLiteral] = { // m.evalBV("mem", BitVecLiteral(64, global.address), Endian.LittleEndian, global.size) // i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems) try { - Some(f.evalBV(MemoryLoad(SharedMemory("mem", 64, 8), BitVecLiteral(global.address, 64), Endian.LittleEndian, global.size)).f(s)._2) + Some(Eval.evalBV(f)(MemoryLoad(SharedMemory("mem", 64, 8), BitVecLiteral(global.address, 64), Endian.LittleEndian, global.size)).f(s)._2) } catch { case e : InterpreterError => None } @@ -432,11 +432,12 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { test("fib breakpoints") { - // Logger.setLevel(LogLevel.ERROR) + Logger.setLevel(LogLevel.WARN) val fib = fibonacciProg(8) val watch = IRWalk.firstInProc((fib.procedures.find(_.name == "fib")).get) - val interp = LayerInterpreter(NormalInterpreter, RememberBreakpoints(NormalInterpreter, Set(watch))) - val res = InterpFuns.interpretProg(interp)(fib, (InterpreterState(), List[(Command, InterpreterState)]())) + val bp = BreakPoint("Fibentry", BreakPointLoc.CMDCond(watch, BinaryExpr(BVEQ, BitVecLiteral(5, 64), Register("R0", 64))), BreakPointAction(true, true, List(Register("R0", 64)), true)) + val interp = LayerInterpreter(NormalInterpreter, RememberBreakpoints(NormalInterpreter, List(bp))) + val res = InterpFuns.interpretProg(interp)(fib, (InterpreterState(), List())) println(res) From cc97052702274fb30767559958e1a07d703b4710 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Thu, 29 Aug 2024 12:02:41 +1000 Subject: [PATCH 17/62] reorg --- ...Trace.scala => InterpretBreakpoints.scala} | 72 ++-------------- src/main/scala/ir/eval/InterpretTrace.scala | 83 +++++++++++++++++++ src/test/scala/ir/InterpreterTests.scala | 5 +- 3 files changed, 92 insertions(+), 68 deletions(-) rename src/main/scala/ir/eval/{InterpretBasilTrace.scala => InterpretBreakpoints.scala} (60%) create mode 100644 src/main/scala/ir/eval/InterpretTrace.scala diff --git a/src/main/scala/ir/eval/InterpretBasilTrace.scala b/src/main/scala/ir/eval/InterpretBreakpoints.scala similarity index 60% rename from src/main/scala/ir/eval/InterpretBasilTrace.scala rename to src/main/scala/ir/eval/InterpretBreakpoints.scala index fba0adb11..a21f2006f 100644 --- a/src/main/scala/ir/eval/InterpretBasilTrace.scala +++ b/src/main/scala/ir/eval/InterpretBreakpoints.scala @@ -14,72 +14,6 @@ import scala.collection.immutable import scala.util.control.Breaks.{break, breakable} -enum ExecEffect: - case Call(target: String, begin: ExecutionContinuation, returnTo: ExecutionContinuation) - case Return - case StoreVar(v: String, s: Scope, value: BasilValue) - case LoadVar(v: String) - case StoreMem(vname: String, update: Map[BasilValue, BasilValue]) - case LoadMem(vname: String, addrs: List[BasilValue]) - case FindProc(addr: Int) - -case class Trace(val t: List[ExecEffect]) - -case object Trace { - def add(e: ExecEffect) : State[Trace, Unit] = { - State.modify ((t: Trace) => Trace(t.t.appended(e))) - } -} - -object TraceGen extends Effects[Trace] { - /** Values are discarded by ProductInterpreter so do not matter */ - def evalBV(e: Expr) = State.pure(BitVecLiteral(0,0)) - - def evalInt(e: Expr) = State.pure(BigInt(0)) - - def evalBool(e: Expr) = State.pure(false) - - def loadVar(v: String) = for { - s <- Trace.add(ExecEffect.LoadVar(v)) - } yield (Scalar(FalseLiteral)) - - def loadMem(v: String, addrs: List[BasilValue]) = for { - s <- Trace.add(ExecEffect.LoadMem(v, addrs)) - } yield (List()) - - def evalAddrToProc(addr: Int) = for { - s <- Trace.add(ExecEffect.FindProc(addr)) - } yield (None) - - def getNext = State.pure(Stopped()) - - def setNext(c: ExecutionContinuation) = State.pure(()) - - def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = for { - s <- Trace.add(ExecEffect.Call(target, beginFrom, returnTo)) - } yield (()) - - def doReturn() = for { - s <- Trace.add(ExecEffect.Return) - } yield (()) - - def storeVar(v: String, scope: Scope, value: BasilValue) = for { - s <- Trace.add(ExecEffect.StoreVar(v, scope, value)) - } yield (()) - - def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = for { - s <- Trace.add(ExecEffect.StoreMem(vname, update)) - } yield (()) - - def interpretOne = State.pure(()) -} - -def tracingInterpreter = ProductInterpreter(NormalInterpreter, TraceGen) - -def interpretTrace(p: Program) : (InterpreterState, Trace) = { - InterpFuns.interpretProg(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) -} - enum BreakPointLoc: case CMD(c: Command) case CMDCond(c: Command, condition: Expr) @@ -138,3 +72,9 @@ case class RememberBreakpoints[T, I <: Effects[T]](val f: I, val breaks: List[Br } + +def interpretWithBreakPoints[I](p: Program, breakpoints: List[BreakPoint], innerInterpreter: Effects[I], innerInitialState: I) : (I, List[(BreakPoint, Option[I], List[(Expr, Expr)])]) = { + val interp = LayerInterpreter(innerInterpreter, RememberBreakpoints(innerInterpreter, breakpoints)) + val res = InterpFuns.interpretProg(interp)(p, (innerInitialState, List())) + res +} diff --git a/src/main/scala/ir/eval/InterpretTrace.scala b/src/main/scala/ir/eval/InterpretTrace.scala new file mode 100644 index 000000000..455e4d43d --- /dev/null +++ b/src/main/scala/ir/eval/InterpretTrace.scala @@ -0,0 +1,83 @@ +package ir.eval +import ir._ +import ir.eval.BitVectorEval.* +import ir.* +import util.Logger +import util.functional.* +import util.functional.State.* +import boogie.Scope +import scala.collection.WithFilter + +import scala.annotation.tailrec +import scala.collection.mutable +import scala.collection.immutable +import scala.util.control.Breaks.{break, breakable} + + +enum ExecEffect: + case Call(target: String, begin: ExecutionContinuation, returnTo: ExecutionContinuation) + case Return + case StoreVar(v: String, s: Scope, value: BasilValue) + case LoadVar(v: String) + case StoreMem(vname: String, update: Map[BasilValue, BasilValue]) + case LoadMem(vname: String, addrs: List[BasilValue]) + case FindProc(addr: Int) + +case class Trace(val t: List[ExecEffect]) + +case object Trace { + def add(e: ExecEffect) : State[Trace, Unit] = { + State.modify ((t: Trace) => Trace(t.t.appended(e))) + } +} + +object TraceGen extends Effects[Trace] { + /** Values are discarded by ProductInterpreter so do not matter */ + def evalBV(e: Expr) = State.pure(BitVecLiteral(0,0)) + + def evalInt(e: Expr) = State.pure(BigInt(0)) + + def evalBool(e: Expr) = State.pure(false) + + def loadVar(v: String) = for { + s <- Trace.add(ExecEffect.LoadVar(v)) + } yield (Scalar(FalseLiteral)) + + def loadMem(v: String, addrs: List[BasilValue]) = for { + s <- Trace.add(ExecEffect.LoadMem(v, addrs)) + } yield (List()) + + def evalAddrToProc(addr: Int) = for { + s <- Trace.add(ExecEffect.FindProc(addr)) + } yield (None) + + def getNext = State.pure(Stopped()) + + def setNext(c: ExecutionContinuation) = State.pure(()) + + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = for { + s <- Trace.add(ExecEffect.Call(target, beginFrom, returnTo)) + } yield (()) + + def doReturn() = for { + s <- Trace.add(ExecEffect.Return) + } yield (()) + + def storeVar(v: String, scope: Scope, value: BasilValue) = for { + s <- Trace.add(ExecEffect.StoreVar(v, scope, value)) + } yield (()) + + def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = for { + s <- Trace.add(ExecEffect.StoreMem(vname, update)) + } yield (()) + + def interpretOne = State.pure(()) +} + +def tracingInterpreter = ProductInterpreter(NormalInterpreter, TraceGen) + +def interpretTrace(p: Program) : (InterpreterState, Trace) = { + InterpFuns.interpretProg(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) +} + + diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index d92161460..a085f6ae8 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -436,8 +436,9 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { val fib = fibonacciProg(8) val watch = IRWalk.firstInProc((fib.procedures.find(_.name == "fib")).get) val bp = BreakPoint("Fibentry", BreakPointLoc.CMDCond(watch, BinaryExpr(BVEQ, BitVecLiteral(5, 64), Register("R0", 64))), BreakPointAction(true, true, List(Register("R0", 64)), true)) - val interp = LayerInterpreter(NormalInterpreter, RememberBreakpoints(NormalInterpreter, List(bp))) - val res = InterpFuns.interpretProg(interp)(fib, (InterpreterState(), List())) + // val interp = LayerInterpreter(NormalInterpreter, RememberBreakpoints(NormalInterpreter, List(bp))) + // val res = InterpFuns.interpretProg(interp)(fib, (InterpreterState(), List())) + val res = interpretWithBreakPoints(fib, List(bp), NormalInterpreter, InterpreterState()) println(res) From 695b2d236ad1001221cfd047ffa9a11ec252a83d Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Thu, 29 Aug 2024 14:33:53 +1000 Subject: [PATCH 18/62] refactor with statemonad[s, either[v]] --- src/main/scala/ir/eval/ExprEval.scala | 16 ++--- src/main/scala/ir/eval/InterpretBasilIR.scala | 62 +++++++++---------- .../scala/ir/eval/InterpretBreakpoints.scala | 14 ++--- src/main/scala/ir/eval/InterpretTrace.scala | 6 +- src/main/scala/ir/eval/Interpreter.scala | 36 ++++++----- .../scala/ir/eval/InterpreterProduct.scala | 12 ++-- src/main/scala/util/functional.scala | 52 ++++++++-------- src/test/scala/ir/InterpreterTests.scala | 4 +- src/test/scala/util/StateMonad.scala | 2 +- 9 files changed, 104 insertions(+), 100 deletions(-) diff --git a/src/main/scala/ir/eval/ExprEval.scala b/src/main/scala/ir/eval/ExprEval.scala index 85d71d7cb..999155ae1 100644 --- a/src/main/scala/ir/eval/ExprEval.scala +++ b/src/main/scala/ir/eval/ExprEval.scala @@ -163,15 +163,15 @@ def evalExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: /** * typeclass defining variable and memory laoding from state S */ -trait Loader[S] { - def getVariable(v: Variable) : State[S, Option[Literal]] - def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int) : State[S, Option[Literal]] = { +trait Loader[S, E] { + def getVariable(v: Variable) : State[S, Option[Literal], E] + def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int) : State[S, Option[Literal], E] = { State.pure(None) } } -def statePartialEvalExpr[S](l: Loader[S])(exp: Expr): State[S, Expr] = { +def statePartialEvalExpr[S, E](l: Loader[S, E])(exp: Expr): State[S, Expr, E] = { val eval = statePartialEvalExpr(l) exp match { case f: UninterpretedFunction => State.pure(f) @@ -254,14 +254,14 @@ def statePartialEvalExpr[S](l: Loader[S])(exp: Expr): State[S, Expr] = { } -class StatelessLoader(getVar: Variable => Option[Literal], loadMem: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)) extends Loader[Unit] { - def getVariable(v: Variable) : State[Unit, Option[Literal]] = State.pure(getVar(v)) - override def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int) : State[Unit, Option[Literal]] = State.pure(loadMem(m, addr, endian, size)) +class StatelessLoader[E](getVar: Variable => Option[Literal], loadMem: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)) extends Loader[Unit, E] { + def getVariable(v: Variable) : State[Unit, Option[Literal], E] = State.pure(getVar(v)) + override def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int) : State[Unit, Option[Literal], E] = State.pure(loadMem(m, addr, endian, size)) } def partialEvalExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)): Expr = { val l = StatelessLoader(variableAssignment, memory) - statePartialEvalExpr(l)(exp).f(())._2 + State.evaluate((), statePartialEvalExpr(l)(exp)) } diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index 35b5000fb..9f47b7c1b 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -16,9 +16,9 @@ import scala.util.control.Breaks.{break, breakable} /** Abstraction for memload and variable lookup used by the expression evaluator. */ -case class StVarLoader[S, F <: Effects[S]](f: F) extends Loader[S] { +case class StVarLoader[S, E, F <: Effects[S, E]](f: F) extends Loader[S, E] { - def getVariable(v: Variable): State[S, Option[Literal]] = { + def getVariable(v: Variable): State[S, Option[Literal], E] = { for { v <- f.loadVar(v.name) } yield ((v match { @@ -27,7 +27,7 @@ case class StVarLoader[S, F <: Effects[S]](f: F) extends Loader[S] { })) } - override def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int): State[S, Option[Literal]] = { + override def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int): State[S, Option[Literal], E] = { for { r <- addr match { case l: Literal if size == 1 => @@ -57,37 +57,37 @@ case object Eval { /* Eval functions */ /*--------------------------------------------------------------------------------*/ - def evalExpr[S, T <: Effects[S]](f: T)(e: Expr): State[S, Expr] = { - val ldr = StVarLoader[S, T](f) + def evalExpr[S, E, T <: Effects[S, E]](f: T)(e: Expr): State[S, Expr, E] = { + val ldr = StVarLoader[S, E, T](f) for { - res <- ir.eval.statePartialEvalExpr[S](ldr)(e) + res <- ir.eval.statePartialEvalExpr[S, E](ldr)(e) } yield (res) } - def evalBV[S, T <: Effects[S]](f: T)(e: Expr): State[S, BitVecLiteral] = { - val ldr = StVarLoader[S, T](f) + def evalBV[S, E, T <: Effects[S, E]](f: T)(e: Expr): State[S, BitVecLiteral, E] = { + val ldr = StVarLoader[S, E, T](f) for { - res <- ir.eval.statePartialEvalExpr[S](ldr)(e) + res <- ir.eval.statePartialEvalExpr[S, E](ldr)(e) } yield (res match { case l: BitVecLiteral => l case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) }) } - def evalInt[S, T <: Effects[S]](f: T)(e: Expr): State[S, BigInt] = { - val ldr = StVarLoader[S, T](f) + def evalInt[S, E, T <: Effects[S, E]](f: T)(e: Expr): State[S, BigInt, E] = { + val ldr = StVarLoader[S, E, T](f) for { - res <- ir.eval.statePartialEvalExpr[S](ldr)(e) + res <- ir.eval.statePartialEvalExpr[S, E](ldr)(e) } yield (res match { case l: IntLiteral => l.value case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) }) } - def evalBool[S, T <: Effects[S]](f: T)(e: Expr): State[S, Boolean] = { - val ldr = StVarLoader[S, T](f) + def evalBool[S, E, T <: Effects[S, E]](f: T)(e: Expr): State[S, Boolean, E] = { + val ldr = StVarLoader[S, E, T](f) for { - res <- ir.eval.statePartialEvalExpr[S](ldr)(e) + res <- ir.eval.statePartialEvalExpr[S, E](ldr)(e) } yield (res match { case l: BoolLit => l == TrueLiteral case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) @@ -98,9 +98,9 @@ case object Eval { /* Load functions */ /*--------------------------------------------------------------------------------*/ - def load[S, T <: Effects[S]]( + def load[S, E, T <: Effects[S, E]]( f: T - )(vname: String, addr: Scalar, endian: Endian, count: Int): State[S, List[BasilValue]] = { + )(vname: String, addr: Scalar, endian: Endian, count: Int): State[S, List[BasilValue], E] = { if (count == 0) { throw InterpreterError(Errored(s"Attempted fractional load")) } @@ -115,9 +115,9 @@ case object Eval { } /** Load and concat bitvectors */ - def loadBV[S, T <: Effects[S]]( + def loadBV[S, E, T <: Effects[S, E]]( f: T - )(vname: String, addr: Scalar, endian: Endian, size: Int): State[S, BitVecLiteral] = for { + )(vname: String, addr: Scalar, endian: Endian, size: Int): State[S, BitVecLiteral, E] = for { mem <- f.loadVar(vname) (valsize, mapv) = mem match { case mapv @ MapValue(_, MapType(_, BitVecType(sz))) => (sz, mapv) @@ -137,7 +137,7 @@ case object Eval { } } yield (bvs.foldLeft(BitVecLiteral(0, 0))((acc, r) => eval.evalBVBinExpr(BVCONCAT, acc, r))) - def loadSingle[S, T <: Effects[S]](f: T)(vname: String, addr: Scalar): State[S, BasilValue] = { + def loadSingle[S, E, T <: Effects[S, E]](f: T)(vname: String, addr: Scalar): State[S, BasilValue, E] = { for { m <- load(f)(vname, addr, Endian.LittleEndian, 1) } yield (m.head) @@ -148,12 +148,12 @@ case object Eval { /*--------------------------------------------------------------------------------*/ /* Expand addr for number of values to store */ - def store[S, T <: Effects[S]](f: T)( + def store[S, E, T <: Effects[S, E]](f: T)( vname: String, addr: BasilValue, values: List[BasilValue], endian: Endian - ): State[S, Unit] = for { + ): State[S, Unit, E] = for { mem <- f.loadVar(vname) (mapval, keytype, valtype) = mem match { case m @ MapValue(_, MapType(kt, vt)) if kt == addr.irType && values.forall(v => v.irType == vt) => (m, kt, vt) @@ -168,12 +168,12 @@ case object Eval { } yield (x) /** Extract bitvec to bytes and store bytes */ - def storeBV[S, T <: Effects[S]](f: T)( + def storeBV[S, E, T <: Effects[S, E]](f: T)( vname: String, addr: BasilValue, value: BitVecLiteral, endian: Endian - ): State[S, Unit] = for { + ): State[S, Unit, E] = for { mem <- f.loadVar(vname) (mapval, vsize) = mem match { case m @ MapValue(_, MapType(kt, BitVecType(size))) if kt == addr.irType => (m, size) @@ -201,7 +201,7 @@ case object Eval { s <- f.storeMem(vname, keys.zip(vs).toMap) } yield (s) - def storeSingle[S, T <: Effects[S]](f: T)(vname: String, addr: BasilValue, value: BasilValue): State[S, Unit] = { + def storeSingle[S, E, T <: Effects[S, E]](f: T)(vname: String, addr: BasilValue, value: BasilValue): State[S, Unit, E] = { f.storeMem(vname, Map((addr -> value))) } } @@ -213,7 +213,7 @@ case object InterpFuns { * Each function takes as parameter an implementation of Effects[S] */ - def initialState[S, T <: Effects[S]](s: T): State[S, Unit] = { + def initialState[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = { val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) @@ -228,7 +228,7 @@ case object InterpFuns { } yield (l) } - def initialiseProgram[S, T <: Effects[S]](f: T)(p: Program): State[S, Unit] = { + def initialiseProgram[S, E, T <: Effects[S, E]](f: T)(p: Program): State[S, Unit, E] = { def initMemory(mem: String, mems: Iterable[MemorySection]) = { for { m <- State.sequence( @@ -269,7 +269,7 @@ case object InterpFuns { } yield (r) } - def interpretJump[S, T <: Effects[S]](f: T)(j: Jump): State[S, Unit] = { + def interpretJump[S, E, T <: Effects[S, E]](f: T)(j: Jump): State[S, Unit, E] = { j match { case gt: GoTo if gt.targets.size == 1 => { f.setNext(Run(IRWalk.firstInBlock(gt.targets.head))) @@ -295,7 +295,7 @@ case object InterpFuns { } } - def interpretStatement[S, T <: Effects[S]](f: T)(s: Statement): State[S, Unit] = { + def interpretStatement[S, E, T <: Effects[S, E]](f: T)(s: Statement): State[S, Unit, E] = { s match { case assign: Assign => { for { @@ -359,7 +359,7 @@ case object InterpFuns { } } - def interpret[S, T <: Effects[S]](f: T, m: S): S = { + def interpret[S, E, T <: Effects[S, E]](f: T, m: S): S = { val next = State.evaluate(m, f.getNext) Logger.debug(s"eval $next") next match { @@ -369,7 +369,7 @@ case object InterpFuns { } } - def interpretProg[S, T <: Effects[S]](f: T)(p: Program, is: S): S = { + def interpretProg[S, E, T <: Effects[S, E]](f: T)(p: Program, is: S): S = { val begin = State.execute(is, initialiseProgram(f)(p)) // State.execute[S,Unit](is, ) interpret(f, begin) diff --git a/src/main/scala/ir/eval/InterpretBreakpoints.scala b/src/main/scala/ir/eval/InterpretBreakpoints.scala index a21f2006f..29ad8d57b 100644 --- a/src/main/scala/ir/eval/InterpretBreakpoints.scala +++ b/src/main/scala/ir/eval/InterpretBreakpoints.scala @@ -22,26 +22,26 @@ case class BreakPointAction(saveState: Boolean = true, stop: Boolean = false, ev case class BreakPoint(name: String = "", location: BreakPointLoc, action: BreakPointAction) -case class RememberBreakpoints[T, I <: Effects[T]](val f: I, val breaks: List[BreakPoint]) extends NopEffects[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])])] { +case class RememberBreakpoints[T, E, I <: Effects[T, E]](val f: I, val breaks: List[BreakPoint]) extends NopEffects[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])]), E] { - def findBreaks[R](c: Command) : State[(T,R), List[BreakPoint]] = { - State.filterM[BreakPoint, (T,R)](b => b.location match { + def findBreaks[R](c: Command) : State[(T,R), List[BreakPoint], E] = { + State.filterM(b => b.location match { case BreakPointLoc.CMD(bc) if (bc == c) => State.pure(true) case BreakPointLoc.CMDCond(bc, e) if bc == c => doLeft(Eval.evalBool(f)(e)) case _ => State.pure(false) }, breaks) } - override def interpretOne : State[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])]), Unit] = for { + override def interpretOne : State[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])]), Unit, E] = for { v : ExecutionContinuation <- doLeft(f.getNext) n <- v match { case Run(s) => for { breaks : List[BreakPoint] <- findBreaks(s) - res <- State.sequence[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])]), Unit](State.pure(()), + res <- State.sequence[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])]), Unit, E](State.pure(()), breaks.map((breakpoint: BreakPoint) => (breakpoint match { case breakpoint @ BreakPoint(name, stopcond, action) => (for { - saved <- doLeft(if action.saveState then State.getS[T].map(s => Some(s)) else State.pure(None)) + saved <- doLeft(if action.saveState then State.getS[T, E].map(s => Some(s)) else State.pure(None)) evals <- (State.mapM((e:Expr) => for { ev <- doLeft(Eval.evalExpr(f)(e)) } yield (e, ev) @@ -73,7 +73,7 @@ case class RememberBreakpoints[T, I <: Effects[T]](val f: I, val breaks: List[Br } -def interpretWithBreakPoints[I](p: Program, breakpoints: List[BreakPoint], innerInterpreter: Effects[I], innerInitialState: I) : (I, List[(BreakPoint, Option[I], List[(Expr, Expr)])]) = { +def interpretWithBreakPoints[I, E](p: Program, breakpoints: List[BreakPoint], innerInterpreter: Effects[I, E], innerInitialState: I) : (I, List[(BreakPoint, Option[I], List[(Expr, Expr)])]) = { val interp = LayerInterpreter(innerInterpreter, RememberBreakpoints(innerInterpreter, breakpoints)) val res = InterpFuns.interpretProg(interp)(p, (innerInitialState, List())) res diff --git a/src/main/scala/ir/eval/InterpretTrace.scala b/src/main/scala/ir/eval/InterpretTrace.scala index 455e4d43d..4a3e78891 100644 --- a/src/main/scala/ir/eval/InterpretTrace.scala +++ b/src/main/scala/ir/eval/InterpretTrace.scala @@ -26,12 +26,12 @@ enum ExecEffect: case class Trace(val t: List[ExecEffect]) case object Trace { - def add(e: ExecEffect) : State[Trace, Unit] = { + def add[E](e: ExecEffect) : State[Trace, Unit, E] = { State.modify ((t: Trace) => Trace(t.t.appended(e))) } } -object TraceGen extends Effects[Trace] { +case class TraceGen[E]() extends Effects[Trace, E] { /** Values are discarded by ProductInterpreter so do not matter */ def evalBV(e: Expr) = State.pure(BitVecLiteral(0,0)) @@ -74,7 +74,7 @@ object TraceGen extends Effects[Trace] { def interpretOne = State.pure(()) } -def tracingInterpreter = ProductInterpreter(NormalInterpreter, TraceGen) +def tracingInterpreter = ProductInterpreter(NormalInterpreter, TraceGen()) def interpretTrace(p: Program) : (InterpreterState, Trace) = { InterpFuns.interpretProg(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 52d74d3cc..da4fd87b0 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -17,6 +17,7 @@ import scala.util.control.Breaks.{break, breakable} sealed trait ExecutionContinuation case class FailedAssertion(a: Assert) extends ExecutionContinuation + case class Stopped() extends ExecutionContinuation /* normal program stop */ case class Run(val next: Command) extends ExecutionContinuation /* continue by executing next command */ @@ -29,6 +30,9 @@ case class EvalError(val message: String = "") extends ExecutionContinuation /* failed to evaluate an expression to a concrete value */ case class MemoryError(val message: String = "") extends ExecutionContinuation /* An error to do with memory */ + +// type InterpreterError = EscapedControlFlow | Errored | TypeError | EvalError | MemoryError + /** TODO: errors should be encapsualted in error monad, rather than mapping exceptions back into state transitions at * State.execute() */ case class InterpreterError(continue: ExecutionContinuation) extends Exception() @@ -88,41 +92,41 @@ case object BasilValue { * Minimal language defining all state transitions in the interpreter, * defined for the interpreter's concrete state T. */ -trait Effects[T] { +trait Effects[T, E] { // perform an execution step - def interpretOne: State[T, Unit] + def interpretOne: State[T, Unit, E] - def loadVar(v: String): State[T, BasilValue] + def loadVar(v: String): State[T, BasilValue, E] - def loadMem(v: String, addrs: List[BasilValue]): State[T, List[BasilValue]] + def loadMem(v: String, addrs: List[BasilValue]): State[T, List[BasilValue], E] - def evalAddrToProc(addr: Int): State[T, Option[FunPointer]] + def evalAddrToProc(addr: Int): State[T, Option[FunPointer], E] - def getNext: State[T, ExecutionContinuation] + def getNext: State[T, ExecutionContinuation, E] /** state effects */ /* High-level implementation of a program counter that leverages the intrusive CFG. */ - def setNext(c: ExecutionContinuation): State[T, Unit] + def setNext(c: ExecutionContinuation): State[T, Unit, E] /* Perform a call: * target: arbitrary target name * beginFrom: ExecutionContinuation which begins executing the procedure * returnTo: ExecutionContinuation which begins executing after procedure return */ - def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): State[T, Unit] + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): State[T, Unit, E] - def doReturn(): State[T, Unit] + def doReturn(): State[T, Unit, E] - def storeVar(v: String, scope: Scope, value: BasilValue): State[T, Unit] + def storeVar(v: String, scope: Scope, value: BasilValue): State[T, Unit, E] - def storeMem(vname: String, update: Map[BasilValue, BasilValue]): State[T, Unit] + def storeMem(vname: String, update: Map[BasilValue, BasilValue]): State[T, Unit, E] } -trait NopEffects[T] extends Effects[T] { +trait NopEffects[T, E] extends Effects[T, E] { def interpretOne = State.pure(()) def loadVar(v: String) = State.pure(Scalar(FalseLiteral)) def loadMem(v: String, addrs: List[BasilValue]) = State.pure(List()) @@ -316,7 +320,7 @@ case class InterpreterState( /** Implementation of Effects for InterpreterState concrete state representation. */ -object NormalInterpreter extends Effects[InterpreterState] { +object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { def loadVar(v: String) = { @@ -325,7 +329,7 @@ object NormalInterpreter extends Effects[InterpreterState] { }) } - def evalAddrToProc(addr: Int): State[InterpreterState, Option[FunPointer]] = + def evalAddrToProc(addr: Int) = Logger.debug(s" eff : FIND PROC $addr") for { res <- get((s: InterpreterState) => s.memoryState.doLoadOpt("funtable", List(Scalar(BitVecLiteral(addr, 64))))) @@ -409,7 +413,7 @@ object NormalInterpreter extends Effects[InterpreterState] { }) } - def storeVar(v: String, scope: Scope, value: BasilValue): State[InterpreterState, Unit] = { + def storeVar(v: String, scope: Scope, value: BasilValue): State[InterpreterState, Unit, InterpreterError] = { Logger.debug(s" eff : SET $v := $value") State.modify((s: InterpreterState) => s.copy(memoryState = s.memoryState.defVar(v, scope, value))) } @@ -419,7 +423,7 @@ object NormalInterpreter extends Effects[InterpreterState] { s.copy(memoryState = s.memoryState.doStore(vname, update)) }) - def interpretOne: State[InterpreterState, Unit] = for { + def interpretOne: State[InterpreterState, Unit, InterpreterError] = for { next <- getNext _ <- try { next match { diff --git a/src/main/scala/ir/eval/InterpreterProduct.scala b/src/main/scala/ir/eval/InterpreterProduct.scala index 75e9a50fa..7383e46d5 100644 --- a/src/main/scala/ir/eval/InterpreterProduct.scala +++ b/src/main/scala/ir/eval/InterpreterProduct.scala @@ -15,15 +15,15 @@ import scala.collection.immutable import scala.util.control.Breaks.{break, breakable} -def doLeft[L, T, V](f: State[L, V]) : State[(L, T), V] = for { - f <- State[(L, T), V]((s: (L, T)) => { +def doLeft[L, T, V, E](f: State[L, V, E]) : State[(L, T), V, E] = for { + f <- State[(L, T), V, E]((s: (L, T)) => { val r = f.f(s._1) ((r._1, s._2), r._2) }) } yield (f) -def doRight[L, T, V](f: State[T, V]) : State[(L, T), V] = for { - f <- State[(L, T), V]((s: (L, T)) => { +def doRight[L, T, V, E](f: State[T, V, E]) : State[(L, T), V, E] = for { + f <- State[(L, T), V, E]((s: (L, T)) => { val r = f.f(s._2) ((s._1, r._1), r._2) }) @@ -32,7 +32,7 @@ def doRight[L, T, V](f: State[T, V]) : State[(L, T), V] = for { /** * Runs two interpreters "inner" and "before" simultaneously, returning the value from inner, and ignoring before */ -case class ProductInterpreter[L, T](val inner: Effects[L], val before: Effects[T]) extends Effects[(L, T)] { +case class ProductInterpreter[L, T, E](val inner: Effects[L, E], val before: Effects[T, E]) extends Effects[(L, T), E] { def interpretOne = for { n <- doRight(before.interpretOne) f <- doLeft(inner.interpretOne) @@ -86,7 +86,7 @@ case class ProductInterpreter[L, T](val inner: Effects[L], val before: Effects[T } -case class LayerInterpreter[L, T](val inner: Effects[L], val before: Effects[(L, T)]) extends Effects[(L, T)] { +case class LayerInterpreter[L, T, E](val inner: Effects[L, E], val before: Effects[(L, T), E]) extends Effects[(L, T), E] { def interpretOne = for { n <- (before.interpretOne) diff --git a/src/main/scala/util/functional.scala b/src/main/scala/util/functional.scala index e666bee36..d81b12e93 100644 --- a/src/main/scala/util/functional.scala +++ b/src/main/scala/util/functional.scala @@ -1,59 +1,59 @@ package util.functional -case class State[S, +A](f: S => (S, A)) { +case class State[S, +A, E](f: S => (S, Either[E, A])) { - def unit[A](a: A): State[S, A] = State(s => (s, a)) + def unit[A](a: A): State[S, A, E] = State(s => (s, Right(a))) - - def flatMap[B](f: A => State[S, B]): State[S, B] = State(s => { - // println(s"flatmap ${this.f} $f") + def flatMap[B](f: A => State[S, B, E]): State[S, B, E] = State(s => { val (s2, a) = this.f(s) - f(a).f(s2) + val r = a match { + case Left(l) => (s2, Left(l)) + case Right(a) => f(a).f(s2) + } + r }) - def map[B](f: A => B): State[S, B] = { + def map[B](f: A => B): State[S, B, E] = { State(s => { val (s2, a) = this.f(s) - (s2, f(a)) + a match { + case Left(l) => (s2, Left(l)) + case Right(a) => (s2, Right(f(a))) + } }) } } object State { - def get[S,A](f: S => A) : State[S, A] = State(s => (s, f(s))) - def getS[S] : State[S,S] = State((s:S) => (s,s)) - def putS[S](s: S) : State[S,_] = State((_) => (s,())) - def modify[S](f: S => S) : State[S, Unit] = State(s => (f(s), ())) - def execute[S, A](s: S, c: State[S,A]) : S = c.f(s)._1 - def evaluate[S, A](s: S, c: State[S,A]) : A = c.f(s)._2 + def get[S, A, E](f: S => A) : State[S, A, E] = State(s => (s, Right(f(s)))) + def getS[S,E] : State[S,S,E] = State((s:S) => (s,Right(s))) + def putS[S,E](s: S) : State[S,Unit,E] = State((_) => (s,Right(()))) + def modify[S, E](f: S => S) : State[S, Unit, E] = State(s => (f(s), Right(()))) + def execute[S, A, E](s: S, c: State[S,A, E]) : S = c.f(s)._1 + def evaluate[S, A, E](s: S, c: State[S,A, E]) : A = c.f(s)._2 match { + case Right(r) => r + case Left(l) => throw Exception(s"Wrong $l") + } - def pure[S, A](a: A) : State[S, A] = State((s:S) => (s, a)) + def pure[S, A, E](a: A) : State[S, A, E] = State((s:S) => (s, Right(a))) - def sequence[S, V](ident: State[S,V], xs: Iterable[State[S,V]]) : State[S,V] = { + def sequence[S, V, E](ident: State[S,V, E], xs: Iterable[State[S,V, E]]) : State[S, V, E] = { xs.foldRight(ident)((l,r) => for { x <- l y <- r } yield(y)) } - def sequence[V](xs: Iterable[Option[V]]) : Option[V] = { - xs.reduceRight((a, b) => a match { - case Some(x) => Some(x) - case None => b - }) - } - - def filterM[A, S](m : (A => State[S, Boolean]), xs: Iterable[A]): State[S, List[A]] = { + def filterM[A, S, E](m : (A => State[S, Boolean, E]), xs: Iterable[A]): State[S, List[A], E] = { xs.foldRight(pure(List[A]()))((b,acc) => acc.flatMap(c => m(b).map(v => if v then b::c else c))) } - def mapM[A, B, S](m : (A => State[S, B]), xs: Iterable[A]): State[S, List[B]] = { + def mapM[A, B, S, E](m : (A => State[S, B, E]), xs: Iterable[A]): State[S, List[B], E] = { xs.foldRight(pure(List[B]()))((b,acc) => acc.flatMap(c => m(b).map(v => v::c))) } - } def protect[T](x: () => T, fnly: PartialFunction[Exception, T]): T = { diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index a085f6ae8..edadf1295 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -34,14 +34,14 @@ def load(s: InterpreterState, global: SpecGlobal) : Option[BitVecLiteral] = { // m.evalBV("mem", BitVecLiteral(64, global.address), Endian.LittleEndian, global.size) // i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems) try { - Some(Eval.evalBV(f)(MemoryLoad(SharedMemory("mem", 64, 8), BitVecLiteral(global.address, 64), Endian.LittleEndian, global.size)).f(s)._2) + Some(State.evaluate(s, Eval.evalBV(f)(MemoryLoad(SharedMemory("mem", 64, 8), BitVecLiteral(global.address, 64), Endian.LittleEndian, global.size)))) } catch { case e : InterpreterError => None } } -def mems[T <: Effects[T]](m: MemoryState) : Map[BigInt, BitVecLiteral] = { +def mems[E, T <: Effects[T, E]](m: MemoryState) : Map[BigInt, BitVecLiteral] = { m.getMem("mem").map((k,v) => k.value -> v) } diff --git a/src/test/scala/util/StateMonad.scala b/src/test/scala/util/StateMonad.scala index 29721174a..efcf4e767 100644 --- a/src/test/scala/util/StateMonad.scala +++ b/src/test/scala/util/StateMonad.scala @@ -13,7 +13,7 @@ import util.IRLoading.{loadBAP, loadReadELF} import util.ILLoadingConfig -def add: State[Int, Unit] = State(s => (s+1, ())) +def add: State[Int, Unit, Unit] = State(s => (s+1, Right(()))) class StateMonadTest extends AnyFunSuite { From 00b482abbffc3d874c698266af115cc559e28c0e Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Thu, 29 Aug 2024 16:57:53 +1000 Subject: [PATCH 19/62] redo error handling --- src/main/scala/ir/eval/InterpretBasilIR.scala | 110 ++++----- .../scala/ir/eval/InterpretBreakpoints.scala | 12 +- src/main/scala/ir/eval/Interpreter.scala | 228 +++++++++--------- src/main/scala/util/functional.scala | 25 +- 4 files changed, 199 insertions(+), 176 deletions(-) diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index 9f47b7c1b..3dc394a61 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -16,9 +16,9 @@ import scala.util.control.Breaks.{break, breakable} /** Abstraction for memload and variable lookup used by the expression evaluator. */ -case class StVarLoader[S, E, F <: Effects[S, E]](f: F) extends Loader[S, E] { +case class StVarLoader[S, F <: Effects[S, InterpreterError]](f: F) extends Loader[S, InterpreterError] { - def getVariable(v: Variable): State[S, Option[Literal], E] = { + def getVariable(v: Variable): State[S, Option[Literal], InterpreterError] = { for { v <- f.loadVar(v.name) } yield ((v match { @@ -27,7 +27,7 @@ case class StVarLoader[S, E, F <: Effects[S, E]](f: F) extends Loader[S, E] { })) } - override def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int): State[S, Option[Literal], E] = { + override def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int): State[S, Option[Literal], InterpreterError] = { for { r <- addr match { case l: Literal if size == 1 => @@ -57,55 +57,53 @@ case object Eval { /* Eval functions */ /*--------------------------------------------------------------------------------*/ - def evalExpr[S, E, T <: Effects[S, E]](f: T)(e: Expr): State[S, Expr, E] = { - val ldr = StVarLoader[S, E, T](f) + def evalExpr[S, T <: Effects[S, InterpreterError]](f: T)(e: Expr): State[S, Expr, InterpreterError] = { + val ldr = StVarLoader[S, T](f) for { - res <- ir.eval.statePartialEvalExpr[S, E](ldr)(e) + res <- ir.eval.statePartialEvalExpr[S, InterpreterError](ldr)(e) } yield (res) } - def evalBV[S, E, T <: Effects[S, E]](f: T)(e: Expr): State[S, BitVecLiteral, E] = { - val ldr = StVarLoader[S, E, T](f) + def evalBV[S, T <: Effects[S, InterpreterError]](f: T)(e: Expr): State[S, BitVecLiteral, InterpreterError] = { for { - res <- ir.eval.statePartialEvalExpr[S, E](ldr)(e) - } yield (res match { - case l: BitVecLiteral => l - case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) + res <- evalExpr(f)(e) + r <- State.pureE(res match { + case l: BitVecLiteral => Right(l) + case _ => Left(InterpreterError(Errored(s"Eval BV residual $e"))) }) + } yield (r) } - def evalInt[S, E, T <: Effects[S, E]](f: T)(e: Expr): State[S, BigInt, E] = { - val ldr = StVarLoader[S, E, T](f) + def evalInt[S, T <: Effects[S, InterpreterError]](f: T)(e: Expr): State[S, BigInt, InterpreterError] = { for { - res <- ir.eval.statePartialEvalExpr[S, E](ldr)(e) - } yield (res match { - case l: IntLiteral => l.value - case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) + res <- evalExpr(f)(e) + r <- State.pureE(res match { + case l: IntLiteral => Right(l.value) + case _ => Left(InterpreterError(Errored(s"Eval Int residual $e"))) }) + } yield (r) } - def evalBool[S, E, T <: Effects[S, E]](f: T)(e: Expr): State[S, Boolean, E] = { - val ldr = StVarLoader[S, E, T](f) + def evalBool[S, T <: Effects[S, InterpreterError]](f: T)(e: Expr): State[S, Boolean, InterpreterError] = { for { - res <- ir.eval.statePartialEvalExpr[S, E](ldr)(e) - } yield (res match { - case l: BoolLit => l == TrueLiteral - case _ => throw InterpreterError(Errored(s"Eval BV residual $e")) + res <- evalExpr(f)(e) + r <- State.pureE(res match { + case l: BoolLit => Right(l == TrueLiteral) + case _ => Left(InterpreterError(Errored(s"Eval Bool residual $e"))) }) + } yield (r) } /*--------------------------------------------------------------------------------*/ /* Load functions */ /*--------------------------------------------------------------------------------*/ - def load[S, E, T <: Effects[S, E]]( + def load[S, T <: Effects[S, InterpreterError]]( f: T - )(vname: String, addr: Scalar, endian: Endian, count: Int): State[S, List[BasilValue], E] = { - if (count == 0) { - throw InterpreterError(Errored(s"Attempted fractional load")) - } - val keys = (0 until count).map(i => BasilValue.unsafeAdd(addr, i)) + )(vname: String, addr: Scalar, endian: Endian, count: Int): State[S, List[BasilValue], InterpreterError] = { for { + _ <- if (count == 0) then State.setError(InterpreterError(Errored(s"Attempted fractional load"))) else State.pure(()) + keys <- State.mapM(((i:Int) => State.pureE(BasilValue.unsafeAdd(addr, i))), (0 until count)) values <- f.loadMem(vname, keys.toList) vals = endian match { case Endian.LittleEndian => values.reverse @@ -115,29 +113,26 @@ case object Eval { } /** Load and concat bitvectors */ - def loadBV[S, E, T <: Effects[S, E]]( + def loadBV[S, T <: Effects[S, InterpreterError]]( f: T - )(vname: String, addr: Scalar, endian: Endian, size: Int): State[S, BitVecLiteral, E] = for { + )(vname: String, addr: Scalar, endian: Endian, size: Int): State[S, BitVecLiteral, InterpreterError] = for { mem <- f.loadVar(vname) - (valsize, mapv) = mem match { - case mapv @ MapValue(_, MapType(_, BitVecType(sz))) => (sz, mapv) - case _ => throw InterpreterError(Errored("Trued to load-concat non bv")) + x <- mem match { + case mapv @ MapValue(_, MapType(_, BitVecType(sz))) => State.pure((sz, mapv)) + case _ => State.setError(InterpreterError(Errored("Trued to load-concat non bv"))) } + (valsize, mapv) = x cells = size / valsize res <- load(f)(vname, addr, endian, cells) // actual load - bvs: List[BitVecLiteral] = { - val rr = res.map { - case Scalar(bv @ BitVecLiteral(v, sz)) if sz == valsize => bv - case c => - throw InterpreterError(TypeError(s"Loaded value of type ${c.irType} did not match expected type bv$valsize")) - } - rr - } + bvs: List[BitVecLiteral] <- (State.mapM ((c : BasilValue) => c match { + case Scalar(bv @ BitVecLiteral(v, sz)) if sz == valsize => State.pure(bv) + case c => State.setError(InterpreterError(TypeError(s"Loaded value of type ${c.irType} did not match expected type bv$valsize"))) + },res)) } yield (bvs.foldLeft(BitVecLiteral(0, 0))((acc, r) => eval.evalBVBinExpr(BVCONCAT, acc, r))) - def loadSingle[S, E, T <: Effects[S, E]](f: T)(vname: String, addr: Scalar): State[S, BasilValue, E] = { + def loadSingle[S, T <: Effects[S, InterpreterError]](f: T)(vname: String, addr: Scalar): State[S, BasilValue, InterpreterError] = { for { m <- load(f)(vname, addr, Endian.LittleEndian, 1) } yield (m.head) @@ -148,18 +143,19 @@ case object Eval { /*--------------------------------------------------------------------------------*/ /* Expand addr for number of values to store */ - def store[S, E, T <: Effects[S, E]](f: T)( + def store[S, T <: Effects[S, InterpreterError]](f: T)( vname: String, addr: BasilValue, values: List[BasilValue], endian: Endian - ): State[S, Unit, E] = for { + ): State[S, Unit, InterpreterError] = for { mem <- f.loadVar(vname) - (mapval, keytype, valtype) = mem match { - case m @ MapValue(_, MapType(kt, vt)) if kt == addr.irType && values.forall(v => v.irType == vt) => (m, kt, vt) - case v => throw InterpreterError(TypeError(s"Invalid map store operation to $vname : $v")) + x <- mem match { + case m @ MapValue(_, MapType(kt, vt)) if kt == addr.irType && values.forall(v => v.irType == vt) => State.pure((m, kt, vt)) + case v => State.setError(InterpreterError(TypeError(s"Invalid map store operation to $vname : $v"))) } - keys = (0 until values.size).map(i => BasilValue.unsafeAdd(addr, i)) + (mapval, keytype, valtype) = x + keys <- State.mapM((i: Int) => State.pureE(BasilValue.unsafeAdd(addr, i)), (0 until values.size)) vals = endian match { case Endian.LittleEndian => values.reverse case Endian.BigEndian => values @@ -168,12 +164,12 @@ case object Eval { } yield (x) /** Extract bitvec to bytes and store bytes */ - def storeBV[S, E, T <: Effects[S, E]](f: T)( + def storeBV[S, T <: Effects[S, InterpreterError]](f: T)( vname: String, addr: BasilValue, value: BitVecLiteral, endian: Endian - ): State[S, Unit, E] = for { + ): State[S, Unit, InterpreterError] = for { mem <- f.loadVar(vname) (mapval, vsize) = mem match { case m @ MapValue(_, MapType(kt, BitVecType(size))) if kt == addr.irType => (m, size) @@ -197,7 +193,7 @@ case object Eval { case Endian.BigEndian => extractVals.reverse.map(Scalar(_)) } - keys = (0 until cells).map(i => BasilValue.unsafeAdd(addr, i)) + keys <- State.mapM((i: Int) => State.pureE(BasilValue.unsafeAdd(addr, i)), (0 until cells)) s <- f.storeMem(vname, keys.zip(vs).toMap) } yield (s) @@ -228,7 +224,7 @@ case object InterpFuns { } yield (l) } - def initialiseProgram[S, E, T <: Effects[S, E]](f: T)(p: Program): State[S, Unit, E] = { + def initialiseProgram[S, T <: Effects[S, InterpreterError]](f: T)(p: Program): State[S, Unit, InterpreterError] = { def initMemory(mem: String, mems: Iterable[MemorySection]) = { for { m <- State.sequence( @@ -269,7 +265,7 @@ case object InterpFuns { } yield (r) } - def interpretJump[S, E, T <: Effects[S, E]](f: T)(j: Jump): State[S, Unit, E] = { + def interpretJump[S, T <: Effects[S, InterpreterError]](f: T)(j: Jump): State[S, Unit, InterpreterError] = { j match { case gt: GoTo if gt.targets.size == 1 => { f.setNext(Run(IRWalk.firstInBlock(gt.targets.head))) @@ -295,7 +291,7 @@ case object InterpFuns { } } - def interpretStatement[S, E, T <: Effects[S, E]](f: T)(s: Statement): State[S, Unit, E] = { + def interpretStatement[S, T <: Effects[S, InterpreterError]](f: T)(s: Statement): State[S, Unit, InterpreterError] = { s match { case assign: Assign => { for { @@ -369,7 +365,7 @@ case object InterpFuns { } } - def interpretProg[S, E, T <: Effects[S, E]](f: T)(p: Program, is: S): S = { + def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: Program, is: S): S = { val begin = State.execute(is, initialiseProgram(f)(p)) // State.execute[S,Unit](is, ) interpret(f, begin) diff --git a/src/main/scala/ir/eval/InterpretBreakpoints.scala b/src/main/scala/ir/eval/InterpretBreakpoints.scala index 29ad8d57b..4668b7075 100644 --- a/src/main/scala/ir/eval/InterpretBreakpoints.scala +++ b/src/main/scala/ir/eval/InterpretBreakpoints.scala @@ -22,10 +22,10 @@ case class BreakPointAction(saveState: Boolean = true, stop: Boolean = false, ev case class BreakPoint(name: String = "", location: BreakPointLoc, action: BreakPointAction) -case class RememberBreakpoints[T, E, I <: Effects[T, E]](val f: I, val breaks: List[BreakPoint]) extends NopEffects[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])]), E] { +case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, val breaks: List[BreakPoint]) extends NopEffects[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])]), InterpreterError] { - def findBreaks[R](c: Command) : State[(T,R), List[BreakPoint], E] = { + def findBreaks[R](c: Command) : State[(T,R), List[BreakPoint], InterpreterError] = { State.filterM(b => b.location match { case BreakPointLoc.CMD(bc) if (bc == c) => State.pure(true) case BreakPointLoc.CMDCond(bc, e) if bc == c => doLeft(Eval.evalBool(f)(e)) @@ -33,15 +33,15 @@ case class RememberBreakpoints[T, E, I <: Effects[T, E]](val f: I, val breaks: L }, breaks) } - override def interpretOne : State[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])]), Unit, E] = for { + override def interpretOne : State[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])]), Unit, InterpreterError] = for { v : ExecutionContinuation <- doLeft(f.getNext) n <- v match { case Run(s) => for { breaks : List[BreakPoint] <- findBreaks(s) - res <- State.sequence[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])]), Unit, E](State.pure(()), + res <- State.sequence[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])]), Unit, InterpreterError](State.pure(()), breaks.map((breakpoint: BreakPoint) => (breakpoint match { case breakpoint @ BreakPoint(name, stopcond, action) => (for { - saved <- doLeft(if action.saveState then State.getS[T, E].map(s => Some(s)) else State.pure(None)) + saved <- doLeft(if action.saveState then State.getS[T, InterpreterError].map(s => Some(s)) else State.pure(None)) evals <- (State.mapM((e:Expr) => for { ev <- doLeft(Eval.evalExpr(f)(e)) } yield (e, ev) @@ -73,7 +73,7 @@ case class RememberBreakpoints[T, E, I <: Effects[T, E]](val f: I, val breaks: L } -def interpretWithBreakPoints[I, E](p: Program, breakpoints: List[BreakPoint], innerInterpreter: Effects[I, E], innerInitialState: I) : (I, List[(BreakPoint, Option[I], List[(Expr, Expr)])]) = { +def interpretWithBreakPoints[I](p: Program, breakpoints: List[BreakPoint], innerInterpreter: Effects[I, InterpreterError], innerInitialState: I) : (I, List[(BreakPoint, Option[I], List[(Expr, Expr)])]) = { val interp = LayerInterpreter(innerInterpreter, RememberBreakpoints(innerInterpreter, breakpoints)) val res = InterpFuns.interpretProg(interp)(p, (innerInitialState, List())) res diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index da4fd87b0..82f0aff71 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -17,7 +17,6 @@ import scala.util.control.Breaks.{break, breakable} sealed trait ExecutionContinuation case class FailedAssertion(a: Assert) extends ExecutionContinuation - case class Stopped() extends ExecutionContinuation /* normal program stop */ case class Run(val next: Command) extends ExecutionContinuation /* continue by executing next command */ @@ -30,11 +29,11 @@ case class EvalError(val message: String = "") extends ExecutionContinuation /* failed to evaluate an expression to a concrete value */ case class MemoryError(val message: String = "") extends ExecutionContinuation /* An error to do with memory */ - // type InterpreterError = EscapedControlFlow | Errored | TypeError | EvalError | MemoryError /** TODO: errors should be encapsualted in error monad, rather than mapping exceptions back into state transitions at - * State.execute() */ + * State.execute() + */ case class InterpreterError(continue: ExecutionContinuation) extends Exception() /* Concrete value type of the interpreter. */ @@ -68,30 +67,28 @@ case object BasilValue { def size(v: BasilValue): Int = size(v.irType) - def unsafeAdd(l: BasilValue, vr: Int): BasilValue = { + def unsafeAdd[S, E](l: BasilValue, vr: Int): Either[InterpreterError, BasilValue] = { l match { - case _ if vr == 0 => l - case Scalar(IntLiteral(vl)) => Scalar(IntLiteral(vl + vr)) - case Scalar(b1: BitVecLiteral) => Scalar(eval.evalBVBinExpr(BVADD, b1, BitVecLiteral(vr, b1.size))) - case _ => throw InterpreterError(TypeError(s"Operation add $vr undefined on $l")) + case _ if vr == 0 => Right(l) + case Scalar(IntLiteral(vl)) => Right(Scalar(IntLiteral(vl + vr))) + case Scalar(b1: BitVecLiteral) => Right(Scalar(eval.evalBVBinExpr(BVADD, b1, BitVecLiteral(vr, b1.size)))) + case _ => Left(InterpreterError(TypeError(s"Operation add $vr undefined on $l"))) } } - def add(l: BasilValue, r: BasilValue): BasilValue = { - (l, r) match { - case (Scalar(IntLiteral(vl)), Scalar(IntLiteral(vr))) => Scalar(IntLiteral(vl + vr)) - case (Scalar(b1: BitVecLiteral), Scalar(b2: BitVecLiteral)) => Scalar(eval.evalBVBinExpr(BVADD, b1, b2)) - case (Scalar(b1: BoolLit), Scalar(b2: BoolLit)) => - Scalar(if (b2.value || b2.value) then TrueLiteral else FalseLiteral) - case _ => throw InterpreterError(TypeError(s"Operation add undefined on $l $r")) - } - } + // def add(l: BasilValue, r: BasilValue): BasilValue = { + // (l, r) match { + // case (Scalar(IntLiteral(vl)), Scalar(IntLiteral(vr))) => Scalar(IntLiteral(vl + vr)) + // case (Scalar(b1: BitVecLiteral), Scalar(b2: BitVecLiteral)) => Scalar(eval.evalBVBinExpr(BVADD, b1, b2)) + // case (Scalar(b1: BoolLit), Scalar(b2: BoolLit)) => + // Scalar(if (b2.value || b2.value) then TrueLiteral else FalseLiteral) + // case _ => throw InterpreterError(TypeError(s"Operation add undefined on $l $r")) + // } + // } } -/** - * Minimal language defining all state transitions in the interpreter, - * defined for the interpreter's concrete state T. - */ +/** Minimal language defining all state transitions in the interpreter, defined for the interpreter's concrete state T. + */ trait Effects[T, E] { // perform an execution step @@ -124,15 +121,13 @@ trait Effects[T, E] { def storeMem(vname: String, update: Map[BasilValue, BasilValue]): State[T, Unit, E] } - - trait NopEffects[T, E] extends Effects[T, E] { def interpretOne = State.pure(()) def loadVar(v: String) = State.pure(Scalar(FalseLiteral)) - def loadMem(v: String, addrs: List[BasilValue]) = State.pure(List()) + def loadMem(v: String, addrs: List[BasilValue]) = State.pure(List()) def evalAddrToProc(addr: Int) = State.pure(None) def getNext = State.pure(Stopped()) - def setNext(c: ExecutionContinuation) = State.pure(()) + def setNext(c: ExecutionContinuation) = State.pure(()) def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = State.pure(()) def doReturn() = State.pure(()) @@ -141,10 +136,9 @@ trait NopEffects[T, E] extends Effects[T, E] { def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = State.pure(()) } -/** -------------------------------------------------------------------------------- - * Definition of concrete state - * -------------------------------------------------------------------------------- */ - +/** -------------------------------------------------------------------------------- Definition of concrete state + * -------------------------------------------------------------------------------- + */ type StackFrameID = String val globalFrame: StackFrameID = "GLOBAL" @@ -189,14 +183,17 @@ case class MemoryState( MemoryState(frames, frameName :: activations, counts) } - def popStackFrame(): MemoryState = { - val (frame, remactivs) = activations match { - case Nil => throw InterpreterError(Errored("No stack frame to pop")) - case h :: Nil if h == globalFrame => throw InterpreterError(Errored("tried to pop global scope")) - case h :: tl => (h, tl) + def popStackFrame(): Either[InterpreterError, MemoryState] = { + val hv = activations match { + case Nil => Left(InterpreterError(Errored("No stack frame to pop"))) + case h :: Nil if h == globalFrame => Left(InterpreterError(Errored("tried to pop global scope"))) + case h :: tl => Right((h, tl)) } - val frames = stackFrames.removed(frame) - MemoryState(frames, remactivs, activationCount) + hv.map((hv) => { + val (frame, remactivs) = hv + val frames = stackFrames.removed(frame) + MemoryState(frames, remactivs, activationCount) + }) } /* Variable retrieval and setting */ @@ -235,83 +232,88 @@ case class MemoryState( ) } - def findVar(name: String): (StackFrameID, BasilValue) = { - findVarOpt(name: String).getOrElse(throw InterpreterError(Errored(s"Access to undefined variable $name"))) + def findVar(name: String): Either[InterpreterError, (StackFrameID, BasilValue)] = { + findVarOpt(name: String) + .map(Right(_)) + .getOrElse(Left(InterpreterError(Errored(s"Access to undefined variable $name")))) } def getVarOpt(name: String): Option[BasilValue] = findVarOpt(name).map(_._2) - def getVar(name: String): BasilValue = { - getVarOpt(name).getOrElse(throw InterpreterError(Errored(s"Access undefined variable $name"))) + def getVar(name: String): Either[InterpreterError, BasilValue] = { + getVarOpt(name).map(Right(_)).getOrElse(Left(InterpreterError(Errored(s"Access undefined variable $name")))) } - def getVar(v: Variable): BasilValue = { + def getVar(v: Variable): Either[InterpreterError, BasilValue] = { val value = getVar(v.name) value match { - case dv: BasilValue if v.getType != dv.irType => - throw InterpreterError( - Errored(s"Type mismatch on variable definition and load: defined ${dv.irType}, variable ${v.getType}") + case Right(dv: BasilValue) if v.getType != dv.irType => + Left( + InterpreterError( + Errored(s"Type mismatch on variable definition and load: defined ${dv.irType}, variable ${v.getType}") + ) ) - case o => o + case Right(o) => Right(o) + case o => o } } - def getVarLiteralOpt(v: Variable): Option[Literal] = { - getVar(v) match { - case Scalar(v) => Some(v) - case _ => None - } - } /* Map variable accessing ; load and store operations */ - def doLoadOpt(vname: String, addr: List[BasilValue]): Option[List[BasilValue]] = { - val (frame, mem) = findVar(vname) - val mapv: MapValue = mem match { - case m @ MapValue(innerMap, ty) => m - case m => throw InterpreterError(TypeError(s"Load from nonmap ${m.irType}")) - } - - val rs = addr.map(k => mapv.value.get(k)) - if (rs.forall(_.isDefined)) { - Some(rs.map(_.get)) - } else { - None - } - } - def doLoad(vname: String, addr: List[BasilValue]): List[BasilValue] = { - doLoadOpt(vname, addr) match { - case Some(vs) => vs - case None => { - throw InterpreterError(MemoryError(s"Read from uninitialised $vname[${addr.head} .. ${addr.last}]")) - } + def doLoad(vname: String, addr: List[BasilValue]): Either[InterpreterError, List[BasilValue]] = for { + v <- findVar(vname) + mapv: MapValue <- v._2 match { + case m @ MapValue(innerMap, ty) => Right(m) + case m => Left(InterpreterError(TypeError(s"Load from nonmap ${m.irType}"))) } - } + rs: List[Option[BasilValue]] = addr.map(k => mapv.value.get(k)) + xs <- + (if (rs.forall(_.isDefined)) { + Right(rs.map(_.get)) + } else { + Left(InterpreterError(MemoryError(s"Read from uninitialised $vname[${addr.head} .. ${addr.last}]"))) + }) + } yield (xs) + + // def doLoad[S](vname: String, addr: List[BasilValue]): State[S, List[BasilValue], InterpreterError] = for { + // v <- doLoadOpt(vname, addr) + // r <- v match { + // case Some(vs) => vs + // case None => { + // throw InterpreterError(MemoryError(s"Read from uninitialised ")) + // } + // } + // } /** typecheck and some fields of a map variable */ - def doStore(vname: String, values: Map[BasilValue, BasilValue]) = { - val (frame, mem) = findVar(vname) - - val (mapval, keytype, valtype) = mem match { - case m @ MapValue(_, MapType(kt, vt)) => (m, kt, vt) - case v => throw InterpreterError(TypeError(s"Invalid map store operation to $vname : ${v.irType}")) + def doStore(vname: String, values: Map[BasilValue, BasilValue]): Either[InterpreterError, MemoryState] = for { + // val (frame, mem) = findVar(vname) + v <- findVar(vname) + (frame, mem) = v + // val (mapval, keytype, valtype) = + mapi <- mem match { + case m @ MapValue(_, MapType(kt, vt)) => Right((m, kt, vt)) + case v => Left(InterpreterError(TypeError(s"Invalid map store operation to $vname : ${v.irType}"))) } + (mapval, keytype, valtype) = mapi - (values.find((k, v) => k.irType != keytype || v.irType != valtype)) match { + checkTypes <- (values.find((k, v) => k.irType != keytype || v.irType != valtype)) match { case Some(v) => - throw InterpreterError( - TypeError( - s"Invalid addr or value type (${v._1.irType}, ${v._2.irType}) does not match map type $vname : ($keytype, $valtype)" + Left( + InterpreterError( + TypeError( + s"Invalid addr or value type (${v._1.irType}, ${v._2.irType}) does not match map type $vname : ($keytype, $valtype)" + ) ) ) - case None => () + case None => Right(()) } - val nmap = MapValue(mapval.value ++ values, mapval.irType) - setVar(frame, vname, nmap) - } + nmap = MapValue(mapval.value ++ values, mapval.irType) + ms <- Right(setVar(frame, vname, nmap)) + } yield (ms) } - case class InterpreterState( val nextCmd: ExecutionContinuation = Stopped(), val callStack: List[ExecutionContinuation] = List.empty, @@ -322,9 +324,8 @@ case class InterpreterState( */ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { - def loadVar(v: String) = { - State.get((s: InterpreterState) => { + State.getE((s: InterpreterState) => { s.memoryState.getVar(v) }) } @@ -332,15 +333,17 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { def evalAddrToProc(addr: Int) = Logger.debug(s" eff : FIND PROC $addr") for { - res <- get((s: InterpreterState) => s.memoryState.doLoadOpt("funtable", List(Scalar(BitVecLiteral(addr, 64))))) + res: List[BasilValue] <- getE((s: InterpreterState) => + s.memoryState.doLoad("funtable", List(Scalar(BitVecLiteral(addr, 64)))) + ) } yield { res match { - case Some((f: FunPointer) :: Nil) => Some(f) - case _ => None + case ((f: FunPointer) :: Nil) => Some(f) + case _ => None } } - def formatStore(varname: String, update: Map[BasilValue, BasilValue]) : String = { + def formatStore(varname: String, update: Map[BasilValue, BasilValue]): String = { val ks = update.toList.sortWith((x, y) => { def conv(v: BasilValue): BigInt = v match { case (Scalar(b: BitVecLiteral)) => b.value @@ -378,7 +381,7 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { } def loadMem(v: String, addrs: List[BasilValue]) = { - State.get((s: InterpreterState) => { + State.getE((s: InterpreterState) => { val r = s.memoryState.doLoad(v, addrs) Logger.debug(s" eff : LOAD ${addrs.head} x ${addrs.size}") r @@ -405,10 +408,12 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { def doReturn() = { Logger.debug(s" eff : RETURN") - modify((s: InterpreterState) => { + modifyE((s: InterpreterState) => { s.callStack match { - case Nil => s.copy(nextCmd = Stopped()) - case h :: tl => s.copy(nextCmd = h, callStack = tl, memoryState = s.memoryState.popStackFrame()) + case Nil => Right(s.copy(nextCmd = Stopped())) + case h :: tl => for { + ms <- s.memoryState.popStackFrame() + } yield (s.copy(nextCmd = h, callStack = tl, memoryState = ms)) } }) } @@ -418,29 +423,30 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { State.modify((s: InterpreterState) => s.copy(memoryState = s.memoryState.defVar(v, scope, value))) } - def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = State.modify((s: InterpreterState) => { + def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = + State.modifyE((s: InterpreterState) => { Logger.debug(s" eff : STORE ${formatStore(vname, update)}") - s.copy(memoryState = s.memoryState.doStore(vname, update)) + for { + ms <- s.memoryState.doStore(vname, update) + } yield(s.copy(memoryState = ms)) }) def interpretOne: State[InterpreterState, Unit, InterpreterError] = for { next <- getNext - _ <- try { - next match { - case Run(c: Statement) => InterpFuns.interpretStatement(this)(c) - case Run(c: Jump) => InterpFuns.interpretJump(this)(c) - case Stopped() => State.pure (()) - case errorstop => State.pure (()) - } - } catch { - case InterpreterError(e) => setNext(e) - case e: java.lang.IllegalArgumentException => setNext(Errored(e.getStackTrace.take(5).mkString("\n"))) - } + _ <- (next match { + case Run(c: Statement) => InterpFuns.interpretStatement(this)(c) + case Run(c: Jump) => InterpFuns.interpretJump(this)(c) + case Stopped() => State.pure(()) + case errorstop => State.pure(()) + }).flatMapE((e: InterpreterError) => setNext(e.continue)) + // } catch { + // case InterpreterError(e) => setNext(e) + // case e: java.lang.IllegalArgumentException => setNext(Errored(e.getStackTrace.take(5).mkString("\n"))) + // } } yield () } - // def interpretTrace(IRProgram: Program): TracingInterpreter = { // val s: TracingInterpreter = InterpFuns.interpretProg(IRProgram, TracingInterpreter(InterpreterState(), List())) // s diff --git a/src/main/scala/util/functional.scala b/src/main/scala/util/functional.scala index d81b12e93..5f78e6d76 100644 --- a/src/main/scala/util/functional.scala +++ b/src/main/scala/util/functional.scala @@ -1,6 +1,9 @@ package util.functional -case class State[S, +A, E](f: S => (S, Either[E, A])) { +/* + * Flattened state monad with error. + */ +case class State[S, A, E](f: S => (S, Either[E, A])) { def unit[A](a: A): State[S, A, E] = State(s => (s, Right(a))) @@ -23,21 +26,39 @@ case class State[S, +A, E](f: S => (S, Either[E, A])) { } }) } + + def flatMapE(f: E => State[S, A, E]): State[S, A, E] = { + State(s => { + val (s2, a) = this.f(s) + a match { + case Left(l) => f(l).f(s2) + case Right(_) => (s2, a) + } + }) + } } object State { def get[S, A, E](f: S => A) : State[S, A, E] = State(s => (s, Right(f(s)))) + def getE[S, A, E](f: S => Either[E,A]) : State[S, A, E] = State(s => (s, f(s))) def getS[S,E] : State[S,S,E] = State((s:S) => (s,Right(s))) def putS[S,E](s: S) : State[S,Unit,E] = State((_) => (s,Right(()))) def modify[S, E](f: S => S) : State[S, Unit, E] = State(s => (f(s), Right(()))) + def modifyE[S, E](f: S => Either[E, S]) : State[S, Unit, E] = State(s => f(s) match { + case Right(ns) => (ns, Right(())) + case Left(e) => (s, Left(e)) + }) def execute[S, A, E](s: S, c: State[S,A, E]) : S = c.f(s)._1 def evaluate[S, A, E](s: S, c: State[S,A, E]) : A = c.f(s)._2 match { case Right(r) => r - case Left(l) => throw Exception(s"Wrong $l") + case Left(l) => throw Exception(s"Evaluation error $l") } + def setError[S,A,E](e: E) : State[S,A,E] = State(s => (s, Left(e))) + def pure[S, A, E](a: A) : State[S, A, E] = State((s:S) => (s, Right(a))) + def pureE[S, A, E](a: Either[E, A]) : State[S, A, E] = State((s:S) => (s, a)) def sequence[S, V, E](ident: State[S,V, E], xs: Iterable[State[S,V, E]]) : State[S, V, E] = { xs.foldRight(ident)((l,r) => for { From 1a139d24819bd43952d9cf497c947c594ef8aa1f Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Mon, 19 Aug 2024 11:55:23 +1000 Subject: [PATCH 20/62] refactor call to a statement & add unreachable and return jumps --- src/main/scala/analysis/Cfg.scala | 137 ++++++++++-------- src/main/scala/analysis/IDEAnalysis.scala | 4 +- .../analysis/InterLiveVarsAnalysis.scala | 6 +- .../analysis/IntraLiveVarsAnalysis.scala | 4 +- src/main/scala/analysis/VSA.scala | 10 +- .../scala/analysis/solvers/IDESolver.scala | 36 ++--- src/main/scala/ir/IRCursor.scala | 63 ++++---- src/main/scala/ir/Interpreter.scala | 99 +++++++------ src/main/scala/ir/Program.scala | 74 ++++------ src/main/scala/ir/Statement.scala | 42 +++--- src/main/scala/ir/Visitor.scala | 31 +--- src/main/scala/ir/cilvisitor/CILVisitor.scala | 9 +- src/main/scala/ir/dsl/DSL.scala | 29 ++-- .../scala/ir/transforms/ReplaceReturn.scala | 52 +++++++ src/main/scala/translating/BAPToIR.scala | 21 +-- src/main/scala/translating/GTIRBToIR.scala | 54 ++++--- src/main/scala/translating/ILtoIL.scala | 4 +- src/main/scala/translating/IRToBoogie.scala | 55 +++---- src/main/scala/util/RunUtils.scala | 98 +++++++------ 19 files changed, 441 insertions(+), 387 deletions(-) create mode 100644 src/main/scala/ir/transforms/ReplaceReturn.scala diff --git a/src/main/scala/analysis/Cfg.scala b/src/main/scala/analysis/Cfg.scala index d508353c1..59d512cd2 100644 --- a/src/main/scala/analysis/Cfg.scala +++ b/src/main/scala/analysis/Cfg.scala @@ -174,7 +174,7 @@ class CfgStatementNode( /** CFG's representation of a jump. This is used as a general jump node, for both indirect and direct calls. */ class CfgJumpNode( - val data: Jump, + val data: Jump | DirectCall | IndirectCall, val block: Block, val parent: CfgFunctionEntryNode ) extends CfgCommandNode: @@ -486,6 +486,72 @@ class ProgramCfgFactory: cfg.addEdge(prevNode, firstNode) visitedBlocks += (block -> firstNode) // This is guaranteed to be entrance to block if we are here + val statements = List.from(stmts).map(s => s match { + case d: DirectCall => CfgJumpNode(d, block, funcEntryNode) + case d: IndirectCall => CfgJumpNode(d, block, funcEntryNode) + case o => CfgStatementNode(o, block, funcEntryNode) + }) + val succs = if (statements.nonEmpty) then statements.zip(statements.tail ++ List(CfgJumpNode(statements.head.data.parent.jump, block, funcEntryNode))) else List() + + for ((s,nexts) <- succs) { + s.data match { + case dCall: DirectCall => + + var precNode = prevNode + + val targetProc: Procedure = dCall.target + funcEntryNode.callers.add(procToCfg(targetProc)._1) + + val callNode = CfgJumpNode(dCall, block, funcEntryNode) + + // Branch to this call + cfg.addEdge(precNode, callNode) + + procToCalls(proc) += callNode + procToCallers(targetProc) += callNode + callToNodes(funcEntryNode) += callNode + + // Record call association + + // Jump to return location + val returnTarget = nexts + // Add intermediary return node (split call into call and return) + val callRet = CfgCallReturnNode() + cfg.addEdge(callNode, callRet) + cfg.addEdge(callRet, returnTarget) + case iCall: IndirectCall => + Logger.debug(s"Indirect call found: $iCall in ${proc.name}") + var precNode = prevNode + + val jmpNode = CfgJumpNode(iCall, block, funcEntryNode) + // Branch to this call + cfg.addEdge(precNode, jmpNode) + + // Record call association + procToCalls(proc) += jmpNode + callToNodes(funcEntryNode) += jmpNode + + // R30 is the link register - this stores the address to return to. + // For now just add a node expressing that we are to return to the previous context. + if (iCall.target == Register("R30", 64)) { + val returnNode = CfgProcedureReturnNode() + cfg.addEdge(jmpNode, returnNode) + cfg.addEdge(returnNode, funcExitNode) + } + + val callRet = CfgCallReturnNode() + cfg.addEdge(jmpNode, callRet) + val returnTarget = nexts + cfg.addEdge(callRet, jmpNode) + case h: Halt => { + assert(false); + // not possible since s is only Statement. + } + case _ => () + } + } + + if (stmts.size == 1) { return firstNode } @@ -548,42 +614,10 @@ class ProgramCfgFactory: visitBlock(targetBlock, precNode) } } - case dCall: DirectCall => - val targetProc: Procedure = dCall.target - funcEntryNode.callers.add(procToCfg(targetProc)._1) - - val callNode = CfgJumpNode(dCall, block, funcEntryNode) - - // Branch to this call - cfg.addEdge(precNode, callNode) - - procToCalls(proc) += callNode - procToCallers(targetProc) += callNode - callToNodes(funcEntryNode) += callNode - - // Record call association - - // Jump to return location - dCall.returnTarget match { - case Some(retBlock) => - // Add intermediary return node (split call into call and return) - val callRet = CfgCallReturnNode() - - cfg.addEdge(callNode, callRet) - if (visitedBlocks.contains(retBlock)) { - val retBlockEntry: CfgCommandNode = visitedBlocks(retBlock) - cfg.addEdge(callRet, retBlockEntry) - } else { - visitBlock(retBlock, callRet) - } - case None => - val noReturn = CfgCallNoReturnNode() - cfg.addEdge(callNode, noReturn) - cfg.addEdge(noReturn, funcExitNode) - } - case iCall: IndirectCall => - Logger.debug(s"Indirect call found: $iCall in ${proc.name}") - + case h: Halt => { + cfg.addEdge(jmpNode, funcExitNode) + } + case r: Return => // Branch to this call cfg.addEdge(precNode, jmpNode) @@ -591,32 +625,9 @@ class ProgramCfgFactory: procToCalls(proc) += jmpNode callToNodes(funcEntryNode) += jmpNode - // R30 is the link register - this stores the address to return to. - // For now just add a node expressing that we are to return to the previous context. - if (iCall.target == Register("R30", 64)) { - val returnNode = CfgProcedureReturnNode() - cfg.addEdge(jmpNode, returnNode) - cfg.addEdge(returnNode, funcExitNode) - return - } - - // Jump to return location - iCall.returnTarget match { - case Some(retBlock) => // Add intermediary return node (split call into call and return) - val callRet = CfgCallReturnNode() - cfg.addEdge(jmpNode, callRet) - - if (visitedBlocks.contains(retBlock)) { - val retBlockEntry = visitedBlocks(retBlock) - cfg.addEdge(callRet, retBlockEntry) - } else { - visitBlock(retBlock, callRet) - } - case None => - val noReturn = CfgCallNoReturnNode() - cfg.addEdge(jmpNode, noReturn) - cfg.addEdge(noReturn, funcExitNode) - } + val returnNode = CfgProcedureReturnNode() + cfg.addEdge(jmpNode, returnNode) + cfg.addEdge(returnNode, funcExitNode) } // `jmps.head` match } // `visitJumps` function } // `visitBlocks` function diff --git a/src/main/scala/analysis/IDEAnalysis.scala b/src/main/scala/analysis/IDEAnalysis.scala index 8f2a7001c..e77627717 100644 --- a/src/main/scala/analysis/IDEAnalysis.scala +++ b/src/main/scala/analysis/IDEAnalysis.scala @@ -55,6 +55,6 @@ trait IDEAnalysis[E, EE, C, R, D, T, L <: Lattice[T]] { } // IndirectCall in these is because they are returns so that can be further tightened in future -trait ForwardIDEAnalysis[D, T, L <: Lattice[T]] extends IDEAnalysis[Procedure, IndirectCall, DirectCall, GoTo, D, T, L] +trait ForwardIDEAnalysis[D, T, L <: Lattice[T]] extends IDEAnalysis[Procedure, IndirectCall, DirectCall, Command, D, T, L] -trait BackwardIDEAnalysis[D, T, L <: Lattice[T]] extends IDEAnalysis[IndirectCall, Procedure, GoTo, DirectCall, D, T, L] +trait BackwardIDEAnalysis[D, T, L <: Lattice[T]] extends IDEAnalysis[IndirectCall, Procedure, Command, DirectCall, D, T, L] diff --git a/src/main/scala/analysis/InterLiveVarsAnalysis.scala b/src/main/scala/analysis/InterLiveVarsAnalysis.scala index 93a34d076..b20edf478 100644 --- a/src/main/scala/analysis/InterLiveVarsAnalysis.scala +++ b/src/main/scala/analysis/InterLiveVarsAnalysis.scala @@ -19,7 +19,7 @@ trait LiveVarsAnalysisFunctions extends BackwardIDEAnalysis[Variable, TwoElement val edgelattice: EdgeFunctionLattice[TwoElement, TwoElementLattice] = EdgeFunctionLattice(valuelattice) import edgelattice.{IdEdge, ConstEdge} - def edgesCallToEntry(call: GoTo, entry: IndirectCall)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { + def edgesCallToEntry(call: Command, entry: IndirectCall)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { Map(d -> IdEdge()) } @@ -27,7 +27,7 @@ trait LiveVarsAnalysisFunctions extends BackwardIDEAnalysis[Variable, TwoElement Map(d -> IdEdge()) } - def edgesCallToAfterCall(call: GoTo, aftercall: DirectCall)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { + def edgesCallToAfterCall(call: Command, aftercall: DirectCall)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { d match case Left(value) => Map() // maps all variables before the call to bottom case Right(_) => Map(d -> IdEdge()) @@ -70,7 +70,7 @@ trait LiveVarsAnalysisFunctions extends BackwardIDEAnalysis[Variable, TwoElement expr.variables.foldLeft(Map[DL, EdgeFunction[TwoElement]](d -> IdEdge())) { (mp, expVar) => mp + (Left(expVar) -> ConstEdge(TwoElementTop)) } - case IndirectCall(variable, _, _) => + case IndirectCall(variable, _) => d match case Left(value) => if value != variable then Map(d -> IdEdge()) else Map() case Right(_) => Map(d -> IdEdge(), Left(variable) -> ConstEdge(TwoElementTop)) diff --git a/src/main/scala/analysis/IntraLiveVarsAnalysis.scala b/src/main/scala/analysis/IntraLiveVarsAnalysis.scala index 1624dcdfb..f3f322321 100644 --- a/src/main/scala/analysis/IntraLiveVarsAnalysis.scala +++ b/src/main/scala/analysis/IntraLiveVarsAnalysis.scala @@ -15,7 +15,7 @@ abstract class LivenessAnalysis(program: Program) extends Analysis[Any]: case MemoryAssign(_, index, value, _, _, _) => s ++ index.variables ++ value.variables case Assume(expr, _, _, _) => s ++ expr.variables case Assert(expr, _, _) => s ++ expr.variables - case IndirectCall(variable, _, _) => s + variable + case IndirectCall(variable, _) => s + variable case c: DirectCall => s case g: GoTo => s case _ => ??? @@ -25,4 +25,4 @@ abstract class LivenessAnalysis(program: Program) extends Analysis[Any]: class IntraLiveVarsAnalysis(program: Program) extends LivenessAnalysis(program) with SimpleWorklistFixpointSolver[CFGPosition, Set[Variable], PowersetLattice[Variable]] - with IRIntraproceduralBackwardDependencies \ No newline at end of file + with IRIntraproceduralBackwardDependencies diff --git a/src/main/scala/analysis/VSA.scala b/src/main/scala/analysis/VSA.scala index 03ef8ff60..f7d9a55e7 100644 --- a/src/main/scala/analysis/VSA.scala +++ b/src/main/scala/analysis/VSA.scala @@ -121,7 +121,7 @@ trait ValueSetAnalysis(program: Program, m = m + (localAssign.lhs -> m(r)) m case None => - Logger.warn("could not find region for " + localAssign) + Logger.debug("could not find region for " + localAssign) m case e: Expr => evaluateExpression(e, constantProp(n)) match { @@ -129,7 +129,7 @@ trait ValueSetAnalysis(program: Program, m = m + (localAssign.lhs -> Set(getValueType(bv))) m case None => - Logger.warn("could not evaluate expression" + e) + Logger.debug("could not evaluate expression" + e) m } case memAssign: MemoryAssign => @@ -154,11 +154,11 @@ trait ValueSetAnalysis(program: Program, m = m + (r -> m(v)) m case _ => - Logger.warn(s"Too Complex: $storeValue") // do nothing + Logger.debug(s"Too Complex: $storeValue") // do nothing m } case None => - Logger.warn("could not find region for " + memAssign) + Logger.debug("could not find region for " + memAssign) m case _ => m @@ -207,4 +207,4 @@ class ValueSetAnalysisSolver( case _ => super.funsub(n, x) } } -} \ No newline at end of file +} diff --git a/src/main/scala/analysis/solvers/IDESolver.scala b/src/main/scala/analysis/solvers/IDESolver.scala index ef1d34b54..cef0f6eb2 100644 --- a/src/main/scala/analysis/solvers/IDESolver.scala +++ b/src/main/scala/analysis/solvers/IDESolver.scala @@ -1,7 +1,7 @@ package analysis.solvers import analysis.{BackwardIDEAnalysis, Dependencies, EdgeFunction, EdgeFunctionLattice, ForwardIDEAnalysis, IDEAnalysis, IRInterproceduralBackwardDependencies, IRInterproceduralForwardDependencies, Lambda, Lattice, MapLattice} -import ir.{CFGPosition, Command, DirectCall, GoTo, IRWalk, IndirectCall, InterProcIRCursor, Procedure, Program, end, isAfterCall} +import ir.{CFGPosition, Command, DirectCall, GoTo, IRWalk, IndirectCall, InterProcIRCursor, Procedure, Program, end, isAfterCall, Halt, Statement, Jump} import util.Logger import scala.collection.immutable.Map @@ -12,7 +12,7 @@ import scala.collection.mutable * Adapted from Tip * https://github.com/cs-au-dk/TIP/blob/master/src/tip/solvers/IDESolver.scala */ -abstract class IDESolver[E <: Procedure | Command, EE <: Procedure | Command, C <: DirectCall | GoTo, R <: DirectCall | GoTo, D, T, L <: Lattice[T]](val program: Program, val startNode: CFGPosition) +abstract class IDESolver[E <: Procedure | Command, EE <: Procedure | Command, C <: Command, R <: Command, D, T, L <: Lattice[T]](val program: Program, val startNode: CFGPosition) extends IDEAnalysis[E, EE, C, R, D, T, L], Dependencies[CFGPosition] { protected def entryToExit(entry: E): EE @@ -208,22 +208,25 @@ abstract class IDESolver[E <: Procedure | Command, EE <: Procedure | Command, C abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) - extends IDESolver[Procedure, IndirectCall, DirectCall, GoTo, D, T, L](program, program.mainProcedure), + extends IDESolver[Procedure, IndirectCall, DirectCall, Command, D, T, L](program, program.mainProcedure), ForwardIDEAnalysis[D, T, L], IRInterproceduralForwardDependencies { protected def entryToExit(entry: Procedure): IndirectCall = entry.end.asInstanceOf[IndirectCall] protected def exitToEntry(exit: IndirectCall): Procedure = IRWalk.procedure(exit) - protected def callToReturn(call: DirectCall): GoTo = call.parent.fallthrough.get + protected def callToReturn(call: DirectCall): Command = call.successor - protected def returnToCall(ret: GoTo): DirectCall = ret.parent.jump.asInstanceOf[DirectCall] + protected def returnToCall(ret: Command): DirectCall = ret match { + case ret: Statement => ret.parent.statements.getPrev(ret).asInstanceOf[DirectCall] + case r: Jump => ret.parent.statements.last.asInstanceOf[DirectCall] + } protected def getCallee(call: DirectCall): Procedure = call.target protected def isCall(call: CFGPosition): Boolean = call match - case directCall: DirectCall if directCall.returnTarget.isDefined && directCall.target.returnBlock.isDefined => true + case directCall: DirectCall if (!directCall.successor.isInstanceOf[Halt]) => true case _ => false protected def isExit(exit: CFGPosition): Boolean = @@ -232,33 +235,32 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) case command: Command => IRWalk.procedure(command).end == command case _ => false - protected def getAfterCalls(exit: IndirectCall): Set[GoTo] = - InterProcIRCursor.succ(exit).foreach(s => assert(s.isInstanceOf[GoTo])) - InterProcIRCursor.succ(exit).filter(_.isInstanceOf[GoTo]).map(_.asInstanceOf[GoTo]) + protected def getAfterCalls(exit: IndirectCall): Set[Command] = + InterProcIRCursor.succ(exit).filter(_.isInstanceOf[Command]).map(_.asInstanceOf[Command]) } abstract class BackwardIDESolver[D, T, L <: Lattice[T]](program: Program) - extends IDESolver[IndirectCall, Procedure, GoTo, DirectCall, D, T, L](program, program.mainProcedure.end), + extends IDESolver[IndirectCall, Procedure, Command, DirectCall, D, T, L](program, program.mainProcedure.end), BackwardIDEAnalysis[D, T, L], IRInterproceduralBackwardDependencies { protected def entryToExit(entry: IndirectCall): Procedure = IRWalk.procedure(entry) protected def exitToEntry(exit: Procedure): IndirectCall = exit.end.asInstanceOf[IndirectCall] - protected def callToReturn(call: GoTo): DirectCall = call.parent.jump.asInstanceOf[DirectCall] + protected def callToReturn(call: Command): DirectCall = call match { + case ret: Statement => ret.parent.statements.getPrev(ret).asInstanceOf[DirectCall] + case r: Jump => r.parent.statements.last.asInstanceOf[DirectCall] + } - protected def returnToCall(ret: DirectCall): GoTo = ret.parent.fallthrough.get + protected def returnToCall(ret: DirectCall): Command = ret.successor - protected def getCallee(call: GoTo): IndirectCall = callToReturn(call).target.end.asInstanceOf[IndirectCall] + protected def getCallee(call: Command): IndirectCall = callToReturn(call: Command).target.end.asInstanceOf[IndirectCall] protected def isCall(call: CFGPosition): Boolean = call match - case goto: GoTo if goto.isAfterCall => - goto.parent.jump match - case directCall: DirectCall => directCall.returnTarget.isDefined && directCall.target.returnBlock.isDefined - case _ => false + case directCall: DirectCall => (!directCall.successor.isInstanceOf[Halt]) case _ => false protected def isExit(exit: CFGPosition): Boolean = diff --git a/src/main/scala/ir/IRCursor.scala b/src/main/scala/ir/IRCursor.scala index 08ed26681..4e9cb3948 100644 --- a/src/main/scala/ir/IRCursor.scala +++ b/src/main/scala/ir/IRCursor.scala @@ -52,11 +52,11 @@ object IRWalk: } } -extension (p: Jump) +extension (p: Command) def isAfterCall : Boolean = { p match { - case g: GoTo => g.parent.fallthrough.contains(g) - case _ => false + case g: Jump => g.parent.statements.lastOption.map(_.isInstanceOf[Call]).getOrElse(false) + case g: Statement => g.parent.statements.prevOption(g).map(_.isInstanceOf[Call]).getOrElse(false) } } @@ -82,9 +82,10 @@ trait IntraProcIRCursor extends IRWalk[CFGPosition, CFGPosition] { pos match { case proc: Procedure => proc.entryBlock.toSet case b: Block => Set(b.statements.headOption.getOrElse(b.jump)) - case s: Statement => Set(s.succ().getOrElse(s.parent.jump)) + case s: Statement => Set(s.successor) case n: GoTo => n.targets.asInstanceOf[Set[CFGPosition]] - case c: Call => c.parent.fallthrough.toSet + case h: Halt => Set() + case h: Return => Set() } } @@ -143,43 +144,43 @@ trait InterProcIRCursor extends IRWalk[CFGPosition, CFGPosition] { IntraProcIRCursor.succ(pos) ++ (pos match case c: DirectCall if c.target.blocks.nonEmpty => Set(c.target) - case c: IndirectCall if c.parent.isProcReturn => c.parent.parent.incomingCalls().flatMap(_.parent.fallthrough.toSet).toSet + case c: IndirectCall if c.parent.isProcReturn => c.parent.parent.incomingCalls().map(_.successor).toSet case _ => Set.empty) } final def pred(pos: CFGPosition): Set[CFGPosition] = { IntraProcIRCursor.pred(pos) ++ (pos match + case d: DirectCall if d.target.blocks.nonEmpty => d.target.returnBlock.toSet case c: Procedure => c.incomingCalls().toSet.asInstanceOf[Set[CFGPosition]] - case b: GoTo if b.isAfterCall => b.parent.jump match { - case DirectCall(t,_, _) if t.blocks.nonEmpty => t.returnBlock.toSet - case _ => Set(b) - } case _ => Set.empty) } } -trait InterProcBlockIRCursor extends IRWalk[CFGPosition, Block] { - - final def succ(pos: CFGPosition): Set[Block] = { - IntraProcBlockIRCursor.succ(pos) ++ - (pos match { - case s: DirectCall if s.target.blocks.nonEmpty => s.target.entryBlock.toSet - case b: Block if b.isProcReturn => b.parent.incomingCalls().map(_.parent).toSet - case _ => Set.empty - }) - } +// less meaningful with call statements + +// trait InterProcBlockIRCursor extends IRWalk[CFGPosition, Block] { +// +// final def succ(pos: CFGPosition): Set[Block] = { +// IntraProcBlockIRCursor.succ(pos) ++ +// (pos match { +// case s: DirectCall if s.target.blocks.nonEmpty => s.target.entryBlock.toSet +// case b: Block if b.isProcReturn => b.parent.incomingCalls().map(_.parent).toSet +// case _ => Set.empty +// }) +// } +// +// final def pred(pos: CFGPosition): Set[Block] = { +// IntraProcBlockIRCursor.pred(pos) ++ +// (pos match { +// case b: Block if b.isAfterCall => b.incomingJumps.collect {_.parent.jump match +// case d: DirectCall => d.target }.flatMap(_.returnBlock).toSet +// case b: Block if b.isProcEntry => b.parent.incomingCalls().map(_.parent).toSet +// case _ => Set.empty +// }) +// } +// } - final def pred(pos: CFGPosition): Set[Block] = { - IntraProcBlockIRCursor.pred(pos) ++ - (pos match { - case b: Block if b.isAfterCall => b.incomingJumps.collect {_.parent.jump match - case d: DirectCall => d.target }.flatMap(_.returnBlock).toSet - case b: Block if b.isProcEntry => b.parent.incomingCalls().map(_.parent).toSet - case _ => Set.empty - }) - } -} object InterProcIRCursor extends InterProcIRCursor trait CallGraph extends IRWalk[Procedure, Procedure] { @@ -190,7 +191,7 @@ trait CallGraph extends IRWalk[Procedure, Procedure] { object CallGraph extends CallGraph -object InterProcBlockIRCursor extends InterProcBlockIRCursor +// object InterProcBlockIRCursor extends InterProcBlockIRCursor /** Computes the reachability transitive closure of the CFGPositions in initial under the successor relation defined by * walker. diff --git a/src/main/scala/ir/Interpreter.scala b/src/main/scala/ir/Interpreter.scala index f5437015b..de470d5d9 100644 --- a/src/main/scala/ir/Interpreter.scala +++ b/src/main/scala/ir/Interpreter.scala @@ -11,8 +11,8 @@ class Interpreter() { private val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) private val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) private val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) - private var nextBlock: Option[Block] = None - private val returnBlock: mutable.Stack[Block] = mutable.Stack() + private var nextCmd: Option[Command] = None + private val returnCmd: mutable.Stack[Command] = mutable.Stack() def eval(exp: Expr, env: mutable.Map[Variable, BitVecLiteral]): BitVecLiteral = { exp match { @@ -220,23 +220,15 @@ class Interpreter() { // Procedure.Block p.entryBlock match { - case Some(block) => nextBlock = Some(block) - case None => nextBlock = Some(returnBlock.pop()) + case Some(block) => nextCmd = Some(block.statements.headOption.getOrElse(block.jump)) + case None => nextCmd = Some(returnCmd.pop()) } } - private def interpretBlock(b: Block): Unit = { - Logger.debug(s"Block:${b.label} ${b.address}") - // Block.Statement - for ((statement, index) <- b.statements.zipWithIndex) { - Logger.debug(s"statement[$index]:") - interpretStatement(statement) - } - - // Block.Jump + private def interpretJump(j: Jump) : Unit = { + Logger.debug(s"jump:") breakable { - Logger.debug(s"jump:") - b.jump match { + j match { case gt: GoTo => Logger.debug(s"$gt") for (g <- gt.targets) { @@ -244,40 +236,28 @@ class Interpreter() { condition match { case Some(e) => evalBool(e, regs) match { case TrueLiteral => - nextBlock = Some(g) + nextCmd = Some(g.statements.headOption.getOrElse(g.jump)) break case _ => } case None => - nextBlock = Some(g) + nextCmd = Some(g.statements.headOption.getOrElse(g.jump)) break } } - case dc: DirectCall => - Logger.debug(s"$dc") - if (dc.returnTarget.isDefined) { - returnBlock.push(dc.returnTarget.get) - } - interpretProcedure(dc.target) - break - case ic: IndirectCall => - Logger.debug(s"$ic") - if (ic.target == Register("R30", 64) && ic.returnTarget.isEmpty) { - if (returnBlock.nonEmpty) { - nextBlock = Some(returnBlock.pop()) - } else { - //Exit Interpreter - nextBlock = None - } - break - } else { - ??? - } + case r: Return => { + nextCmd = Some(returnCmd.pop()) + } + case h: Halt => { + Logger.debug("Halt") + nextCmd = None + } } } } private def interpretStatement(s: Statement): Unit = { + Logger.debug(s"statement[$s]:") s match { case assign: Assign => Logger.debug(s"LocalAssign ${assign.lhs} = ${assign.rhs}") @@ -300,14 +280,42 @@ class Interpreter() { case BitVecLiteral(value, size) => Logger.debug(s"MemoryAssign ${assign.mem} := 0x${value.toString(16)}[u$size]\n") } - case _ : NOP => + case _ : NOP => () case assert: Assert => - Logger.debug(assert) // TODO - + Logger.debug(assert) + evalBool(assert.body, regs) match { + case TrueLiteral => () + case FalseLiteral => throw Exception(s"Assertion failed ${assert}") + } case assume: Assume => - Logger.debug(assume) // TODO, but already taken into effect if it is a branch condition + Logger.debug(assume) + evalBool(assume.body, regs) match { + case TrueLiteral => () + case FalseLiteral => { + nextCmd = None + Logger.debug(s"Assumption not satisfied: $assume") + } + } + case dc: DirectCall => + Logger.debug(s"$dc") + returnCmd.push(dc.successor) + interpretProcedure(dc.target) + break + case ic: IndirectCall => + Logger.debug(s"$ic") + if (ic.target == Register("R30", 64)) { + if (returnCmd.nonEmpty) { + nextCmd = Some(returnCmd.pop()) + } else { + //Exit Interpreter + nextCmd = None + } + break + } else { + ??? + } } } @@ -334,10 +342,13 @@ class Interpreter() { // Program.Procedure interpretProcedure(IRProgram.mainProcedure) - while (nextBlock.isDefined) { - interpretBlock(nextBlock.get) + while (nextCmd.isDefined) { + nextCmd.get match { + case c: Statement => interpretStatement(c) + case c: Jump => interpretJump(c) + } } regs } -} \ No newline at end of file +} diff --git a/src/main/scala/ir/Program.scala b/src/main/scala/ir/Program.scala index 9c87e3d6a..c2d479cca 100644 --- a/src/main/scala/ir/Program.scala +++ b/src/main/scala/ir/Program.scala @@ -141,7 +141,7 @@ class Program(var procedures: ArrayBuffer[Procedure], stack.pushAll(n match { case p: Procedure => p.blocks - case b: Block => Seq() ++ b.statements ++ Seq(b.jump) ++ b.fallthrough.toSet + case b: Block => Seq() ++ b.statements ++ Seq(b.jump) case s: Command => Seq() }) n @@ -291,30 +291,28 @@ class Procedure private ( block } - /** - * Remove blocks with the semantics of replacing them with a noop. The incoming jumps to this are replaced - * with a jump(s) to this blocks jump target(s). If this block ends in a call then only its statements are removed. - * @param blocks the block/blocks to remove - */ - def removeBlocksInline(blocks: Iterable[Block]): Unit = { - for (elem <- blocks) { - elem.jump match { - case g: GoTo => - // rewrite all the jumps to include our jump targets - elem.incomingJumps.foreach(_.removeTarget(elem)) - elem.incomingJumps.foreach(_.addAllTargets(g.targets)) - removeBlocks(elem) - case c: Call => - // just remove statements, keep call - elem.statements.clear() - } - } - } - - - def removeBlocksInline(blocks: Block*): Unit = { - removeBlocksInline(blocks.toSeq) - } +// unused +// /** +// * Remove blocks with the semantics of replacing them with a noop. The incoming jumps to this are replaced +// * with a jump(s) to this blocks jump target(s). If this block ends in a call then only its statements are removed. +// * @param blocks the block/blocks to remove +// */ +// def removeBlocksInline(blocks: Iterable[Block]): Unit = { +// for (elem <- blocks) { +// elem.jump match { +// case g: GoTo => +// // rewrite all the jumps to include our jump targets +// elem.incomingJumps.foreach(_.removeTarget(elem)) +// elem.incomingJumps.foreach(_.addAllTargets(g.targets)) +// removeBlocks(elem) +// } +// } +// } +// +// +// def removeBlocksInline(blocks: Block*): Unit = { +// removeBlocksInline(blocks.toSeq) +// } /** * Remove block(s) and all jumps that target it @@ -382,7 +380,6 @@ class Block private ( val statements: IntrusiveList[Statement], private var _jump: Jump, private val _incomingJumps: mutable.HashSet[GoTo], - var _fallthrough: Option[GoTo], ) extends HasParent[Procedure] { _jump.setParent(this) statements.foreach(_.setParent(this)) @@ -391,23 +388,11 @@ class Block private ( statements.onRemove = x => x.deParent() def this(label: String, address: Option[Int] = None, statements: IterableOnce[Statement] = Set.empty, jump: Jump = GoTo(Set.empty)) = { - this(label, address, IntrusiveList().addAll(statements), jump, mutable.HashSet.empty, None) + this(label, address, IntrusiveList().addAll(statements), jump, mutable.HashSet.empty) } def jump: Jump = _jump - def fallthrough: Option[GoTo] = _fallthrough - - def fallthrough_=(g: Option[GoTo]): Unit = { - /* - * Fallthrough is only set if Jump is a call, this is maintained maintained at the - * linkParent implementation on FallThrough of Call. - */ - _fallthrough.foreach(_.deParent()) - g.foreach(x => x.parent = this) - _fallthrough = g - } - private def jump_=(j: Jump): Unit = { require(!j.hasParent) if (j ne _jump) { @@ -436,7 +421,9 @@ class Block private ( assert(!incomingJumps.contains(g)) } - def calls: Set[Procedure] = _jump.calls + def calls: Set[Procedure] = statements.toSet.collect { + case d: DirectCall => d.target + } def modifies: Set[Global] = statements.flatMap(_.modifies).toSet //def locals: Set[Variable] = statements.flatMap(_.locals).toSet ++ jumps.flatMap(_.locals).toSet @@ -456,10 +443,7 @@ class Block private ( def nextBlocks: Iterable[Block] = { jump match { case c: GoTo => c.targets - case c: Call => fallthrough match { - case Some(x) => x.targets - case _ => Seq() - } + case _ => Seq() } } @@ -506,7 +490,7 @@ class Block private ( object Block { def procedureReturn(from: Procedure): Block = { - Block(from.name + "_basil_return", None, List(), IndirectCall(Register("R30", 64))) + Block(from.name + "_basil_return", None, List(), Return()) } } diff --git a/src/main/scala/ir/Statement.scala b/src/main/scala/ir/Statement.scala index 8d89a9726..74aaa18b5 100644 --- a/src/main/scala/ir/Statement.scala +++ b/src/main/scala/ir/Statement.scala @@ -23,6 +23,9 @@ sealed trait Statement extends Command, IntrusiveListElement[Statement] { def acceptVisit(visitor: Visitor): Statement = throw new Exception( "visitor " + visitor + " unimplemented for: " + this ) + + def successor: Command = parent.statements.nextOption(this).getOrElse(parent.jump) + } // invariant: rhs contains at most one MemoryLoad @@ -76,10 +79,18 @@ object Assume: sealed trait Jump extends Command { def modifies: Set[Global] = Set() //def locals: Set[Variable] = Set() - def calls: Set[Procedure] = Set() def acceptVisit(visitor: Visitor): Jump = throw new Exception("visitor " + visitor + " unimplemented for: " + this) } +class Halt(override val label: Option[String] = None) extends Jump { + override def acceptVisit(visitor: Visitor): Jump = this +} + +class Return(override val label: Option[String] = None) extends Jump { + override def acceptVisit(visitor: Visitor): Jump = this +} + + class GoTo private (private val _targets: mutable.LinkedHashSet[Block], override val label: Option[String]) extends Jump { def this(targets: Iterable[Block], label: Option[String] = None) = this(mutable.LinkedHashSet.from(targets), label) @@ -125,30 +136,18 @@ object GoTo: def unapply(g: GoTo): Option[(Set[Block], Option[String])] = Some(g.targets, g.label) -sealed trait Call extends Jump { - val returnTarget: Option[Block] - - // moving a call between blocks - override def linkParent(p: Block): Unit = { - returnTarget.foreach(t => parent.fallthrough = Some(GoTo(Set(t)))) - } - - override def unlinkParent(): Unit = { - parent.fallthrough = None - } -} +sealed trait Call extends Statement class DirectCall(val target: Procedure, - override val returnTarget: Option[Block] = None, override val label: Option[String] = None ) extends Call { /* override def locals: Set[Variable] = condition match { case Some(c) => c.locals case None => Set() } */ - override def calls: Set[Procedure] = Set(target) - override def toString: String = s"${labelStr}DirectCall(${target.name}, ${returnTarget.map(_.label)})" - override def acceptVisit(visitor: Visitor): Jump = visitor.visitDirectCall(this) + def calls: Set[Procedure] = Set(target) + override def toString: String = s"${labelStr}DirectCall(${target.name})" + override def acceptVisit(visitor: Visitor): Statement = visitor.visitDirectCall(this) override def linkParent(p: Block): Unit = { super.linkParent(p) @@ -163,19 +162,18 @@ class DirectCall(val target: Procedure, } object DirectCall: - def unapply(i: DirectCall): Option[(Procedure, Option[Block], Option[String])] = Some(i.target, i.returnTarget, i.label) + def unapply(i: DirectCall): Option[(Procedure, Option[String])] = Some(i.target, i.label) class IndirectCall(var target: Variable, - override val returnTarget: Option[Block] = None, override val label: Option[String] = None ) extends Call { /* override def locals: Set[Variable] = condition match { case Some(c) => c.locals + target case None => Set(target) } */ - override def toString: String = s"${labelStr}IndirectCall($target, ${returnTarget.map(_.label)})" - override def acceptVisit(visitor: Visitor): Jump = visitor.visitIndirectCall(this) + override def toString: String = s"${labelStr}IndirectCall($target)" + override def acceptVisit(visitor: Visitor): Statement = visitor.visitIndirectCall(this) } object IndirectCall: - def unapply(i: IndirectCall): Option[(Variable, Option[Block], Option[String])] = Some(i.target, i.returnTarget, i.label) \ No newline at end of file + def unapply(i: IndirectCall): Option[(Variable, Option[String])] = Some(i.target, i.label) diff --git a/src/main/scala/ir/Visitor.scala b/src/main/scala/ir/Visitor.scala index b9fd91b3e..1cc9c1b40 100644 --- a/src/main/scala/ir/Visitor.scala +++ b/src/main/scala/ir/Visitor.scala @@ -39,11 +39,11 @@ abstract class Visitor { node } - def visitDirectCall(node: DirectCall): Jump = { + def visitDirectCall(node: DirectCall): Statement = { node } - def visitIndirectCall(node: IndirectCall): Jump = { + def visitIndirectCall(node: IndirectCall): Statement = { node.target = visitVariable(node.target) node } @@ -199,11 +199,11 @@ abstract class ReadOnlyVisitor extends Visitor { node } - override def visitDirectCall(node: DirectCall): Jump = { + override def visitDirectCall(node: DirectCall): Statement = { node } - override def visitIndirectCall(node: IndirectCall): Jump = { + override def visitIndirectCall(node: IndirectCall): Statement = { visitVariable(node.target) node } @@ -281,14 +281,12 @@ abstract class IntraproceduralControlFlowVisitor extends Visitor { node } - override def visitDirectCall(node: DirectCall): Jump = { - node.returnTarget.foreach(visitBlock) + override def visitDirectCall(node: DirectCall): Statement = { node } - override def visitIndirectCall(node: IndirectCall): Jump = { + override def visitIndirectCall(node: IndirectCall): Statement = { node.target = visitVariable(node.target) - node.returnTarget.foreach(visitBlock) node } } @@ -431,20 +429,3 @@ class VariablesWithoutStoresLoads extends ReadOnlyVisitor { } } - -class ConvertToSingleProcedureReturn extends Visitor { - override def visitJump(node: Jump): Jump = { - node match - case c: IndirectCall => - val returnBlock = node.parent.parent.returnBlock match { - case Some(b) => b - case None => - val b = Block.procedureReturn(node.parent.parent) - node.parent.parent.returnBlock = b - b - } - // if we are return outside the return block then replace with a goto to the return block - if c.target.name == "R30" && c.returnTarget.isEmpty && !c.parent.isProcReturn then GoTo(Seq(returnBlock)) else node - case _ => node - } -} diff --git a/src/main/scala/ir/cilvisitor/CILVisitor.scala b/src/main/scala/ir/cilvisitor/CILVisitor.scala index a405d4fa1..5583b12da 100644 --- a/src/main/scala/ir/cilvisitor/CILVisitor.scala +++ b/src/main/scala/ir/cilvisitor/CILVisitor.scala @@ -95,6 +95,11 @@ class CILVisitorImpl(val v: CILVisitor) { def visit_stmt(s: Statement): List[Statement] = { def continue(n: Statement) = n match { + case d: DirectCall => d + case i: IndirectCall => { + i.target = visit_var(i.target) + i + } case m: MemoryAssign => { m.mem = visit_mem(m.mem) m.index = visit_expr(m.index) @@ -131,7 +136,6 @@ class CILVisitorImpl(val v: CILVisitor) { } }) b.replaceJump(visit_jump(b.jump)) - b.fallthrough = visit_fallthrough(b.fallthrough) b } @@ -153,7 +157,7 @@ class CILVisitorImpl(val v: CILVisitor) { doVisitList(v, v.vproc(p), p, continue) } - def visit_proc(p: Program): Program = { + def visit_prog(p: Program): Program = { def continue(p: Program) = { p.procedures = p.procedures.flatMap(visit_proc) p @@ -164,6 +168,7 @@ class CILVisitorImpl(val v: CILVisitor) { def visit_block(v: CILVisitor, b: Block): Block = CILVisitorImpl(v).visit_block(b) def visit_proc(v: CILVisitor, b: Procedure): List[Procedure] = CILVisitorImpl(v).visit_proc(b) +def visit_prog(v: CILVisitor, b: Program): Program = CILVisitorImpl(v).visit_prog(b) def visit_stmt(v: CILVisitor, e: Statement): List[Statement] = CILVisitorImpl(v).visit_stmt(e) def visit_jump(v: CILVisitor, e: Jump): Jump = CILVisitorImpl(v).visit_jump(e) def visit_expr(v: CILVisitor, e: Expr): Expr = CILVisitorImpl(v).visit_expr(e) diff --git a/src/main/scala/ir/dsl/DSL.scala b/src/main/scala/ir/dsl/DSL.scala index 7fdd8bdaa..3c55e2dfc 100644 --- a/src/main/scala/ir/dsl/DSL.scala +++ b/src/main/scala/ir/dsl/DSL.scala @@ -35,24 +35,31 @@ case class DelayNameResolve(ident: String) { } } +trait EventuallyStatement { + def resolve(p: Program): Statement +} + +case class ResolvableStatement(s: Statement) extends EventuallyStatement { + override def resolve(p: Program) = s +} + trait EventuallyJump { def resolve(p: Program): Jump } -case class EventuallyIndirectCall(target: Variable, fallthrough: Option[DelayNameResolve]) extends EventuallyJump { +case class EventuallyIndirectCall(target: Variable, fallthrough: Option[DelayNameResolve]) extends EventuallyStatement { override def resolve(p: Program): IndirectCall = { - IndirectCall(target, fallthrough.flatMap(_.resolveBlock(p))) + IndirectCall(target) } } -case class EventuallyCall(target: DelayNameResolve, fallthrough: Option[DelayNameResolve]) extends EventuallyJump { +case class EventuallyCall(target: DelayNameResolve, fallthrough: Option[DelayNameResolve]) extends EventuallyStatement { override def resolve(p: Program): DirectCall = { val t = target.resolveProc(p) match { case Some(x) => x case None => throw Exception("can't resolve proc " + p) } - val ft = fallthrough.flatMap(_.resolveBlock(p)) - DirectCall(t, ft) + DirectCall(t) } } @@ -79,18 +86,20 @@ def indirectCall(tgt: Variable, fallthrough: Option[String]): EventuallyIndirect // def directcall(tgt: String) = EventuallyCall(DelayNameResolve(tgt), None) -case class EventuallyBlock(label: String, sl: Seq[Statement], j: EventuallyJump) { - val tempBlock: Block = Block(label, None, sl, GoTo(List.empty)) +case class EventuallyBlock(label: String, sl: Seq[EventuallyStatement], j: EventuallyJump) { + val tempBlock: Block = Block(label, None, List(), GoTo(List.empty)) def resolve(prog: Program): Block = { + tempBlock.statements.addAll(sl.map(_.resolve(prog))) tempBlock.replaceJump(j.resolve(prog)) tempBlock } } -def block(label: String, sl: (Statement | EventuallyJump)*): EventuallyBlock = { - val statements = sl.collect { - case s: Statement => s +def block(label: String, sl: (Statement | EventuallyStatement | EventuallyJump)*): EventuallyBlock = { + val statements : Seq[EventuallyStatement] = sl.collect { + case s: Statement => ResolvableStatement(s) + case o: EventuallyStatement => o } val jump = sl.collectFirst { case j: EventuallyJump => j diff --git a/src/main/scala/ir/transforms/ReplaceReturn.scala b/src/main/scala/ir/transforms/ReplaceReturn.scala new file mode 100644 index 000000000..61b276833 --- /dev/null +++ b/src/main/scala/ir/transforms/ReplaceReturn.scala @@ -0,0 +1,52 @@ +package ir.transforms + +import util.Logger +import ir.cilvisitor._ +import ir._ + + +class ReplaceReturns extends CILVisitor { + /** + * Assumes IR with 1 call per block which appears as the last statement. + */ + override def vstmt(j: Statement): VisitAction[List[Statement]] = { + j match { + case IndirectCall(Register("R30", _), _) => { + assert(j.parent.statements.lastOption.contains(j)) + if (j.parent.jump.isInstanceOf[Halt | Return]) { + j.parent.replaceJump(Return()) + ChangeTo(List()) + } else { + SkipChildren() + } + } + case _ => SkipChildren() + } + } + + override def vjump(j: Jump) = SkipChildren() +} + + +def addReturnBlocks(p: Program) = { + p.procedures.foreach(p => { + val containsReturn = p.blocks.map(_.jump).find(_.isInstanceOf[Return]).isDefined + if (containsReturn) { + p.returnBlock = p.addBlocks(Block(label=p.name + "_return",jump=Return())) + } + }) +} + + +class ConvertSingleReturn extends CILVisitor { + /** + * Assumes procedures have defined return blocks if they contain a return statement. + */ + override def vjump(j: Jump) = j match { + case r: Return if !(j.parent.parent.returnBlock.contains(j.parent)) => ChangeTo(GoTo(Seq(j.parent.parent.returnBlock.get))) + case _ => SkipChildren() + } + + override def vstmt(s: Statement) = SkipChildren() +} + diff --git a/src/main/scala/translating/BAPToIR.scala b/src/main/scala/translating/BAPToIR.scala index 85ed82f21..90ba61335 100644 --- a/src/main/scala/translating/BAPToIR.scala +++ b/src/main/scala/translating/BAPToIR.scala @@ -48,8 +48,9 @@ class BAPToIR(var program: BAPProgram, mainAddress: Int) { for (st <- b.statements) { block.statements.append(translate(st)) } - val (jump, newBlocks) = translate(b.jumps, block) + val (call, jump, newBlocks) = translate(b.jumps, block) procedure.addBlocks(newBlocks) + call.foreach(c => block.statements.append(c)) block.replaceJump(jump) assert(jump.hasParent) } @@ -85,7 +86,7 @@ class BAPToIR(var program: BAPProgram, mainAddress: Int) { * Translates a list of jumps from BAP into a single Jump at the IR level by moving any conditions on jumps to * Assume statements in new blocks * */ - private def translate(jumps: List[BAPJump], block: Block): (Jump, ArrayBuffer[Block]) = { + private def translate(jumps: List[BAPJump], block: Block): (Option[Call], Jump, ArrayBuffer[Block]) = { if (jumps.size > 1) { val targets = ArrayBuffer[Block]() val conditions = ArrayBuffer[BAPExpr]() @@ -130,26 +131,28 @@ class BAPToIR(var program: BAPProgram, mainAddress: Int) { case _ => throw Exception("translation error, call where not expected: " + jumps.mkString(", ")) } } - (GoTo(targets, Some(line)), newBlocks) + (None, GoTo(targets, Some(line)), newBlocks) } else { jumps.head match { case b: BAPDirectCall => - val call = DirectCall(nameToProcedure(b.target), b.returnTarget.map(t => labelToBlock(t)), Some(b.line)) - (call, ArrayBuffer()) + val call = Some(DirectCall(nameToProcedure(b.target),Some(b.line))) + val ft = (b.returnTarget.map(t => labelToBlock(t))).map(x => GoTo(Set(x))).getOrElse(Halt()) + (call, ft, ArrayBuffer()) case b: BAPIndirectCall => - val call = IndirectCall(b.target.toIR, b.returnTarget.map(t => labelToBlock(t)), Some(b.line)) - (call, ArrayBuffer()) + val call = IndirectCall(b.target.toIR, Some(b.line)) + val ft = (b.returnTarget.map(t => labelToBlock(t))).map(x => GoTo(Set(x))).getOrElse(Halt()) + (Some(call), ft, ArrayBuffer()) case b: BAPGoTo => val target = labelToBlock(b.target) b.condition match { // condition is true case l: BAPLiteral if l.value > BigInt(0) => - (GoTo(ArrayBuffer(target), Some(b.line)), ArrayBuffer()) + (None, GoTo(ArrayBuffer(target), Some(b.line)), ArrayBuffer()) // non-true condition case _ => val condition = convertConditionBool(b.condition, false) val newBlock = newBlockCondition(block, target, condition) - (GoTo(ArrayBuffer(newBlock), Some(b.line)), ArrayBuffer(newBlock)) + (None, GoTo(ArrayBuffer(newBlock), Some(b.line)), ArrayBuffer(newBlock)) } } } diff --git a/src/main/scala/translating/GTIRBToIR.scala b/src/main/scala/translating/GTIRBToIR.scala index 37f6360ae..eb8281599 100644 --- a/src/main/scala/translating/GTIRBToIR.scala +++ b/src/main/scala/translating/GTIRBToIR.scala @@ -182,12 +182,13 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ throw Exception(s"block ${block.label} in subroutine ${procedure.name} has no outgoing edges") } - val jump = if (outgoingEdges.size == 1) { + val (calls, jump) = if (outgoingEdges.size == 1) { val edge = outgoingEdges.head handleSingleEdge(block, edge, procedure, procedures) } else { handleMultipleEdges(block, outgoingEdges, procedure) } + calls.foreach(c => block.statements.append(c)) block.replaceJump(jump) if (block.statements.nonEmpty) { @@ -363,8 +364,6 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ // need to copy jump as it can't have multiple parents val jumpCopy = currentBlock.jump match { case GoTo(targets, label) => GoTo(targets, label) - case IndirectCall(target, returnTarget, label) => IndirectCall(target, returnTarget, label) - case DirectCall(target, returnTarget, label) => DirectCall(target, returnTarget, label) case _ => throw Exception("this shouldn't be reachable") } trueBlock.replaceJump(currentBlock.jump) @@ -377,7 +376,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ } // Handles the case where a block has one outgoing edge using gtirb cfg labelling - private def handleSingleEdge(block: Block, edge: Edge, procedure: Procedure, procedures: ArrayBuffer[Procedure]): Jump = { + private def handleSingleEdge(block: Block, edge: Edge, procedure: Procedure, procedures: ArrayBuffer[Procedure]): (Option[Call], Jump) = { edge.getLabel match { case EdgeLabel(false, false, Type_Branch, _) => // indirect jump, possibly to external subroutine, possibly to another block in procedure @@ -391,7 +390,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ case _ => throw Exception(s"no assignment to program counter found before indirect call in block ${block.label}") } block.statements.remove(block.statements.last) // remove _PC assignment - IndirectCall(target, None) + (Some(IndirectCall(target)), Halt()) } else if (proxySymbols.size > 1) { // TODO requires further consideration once encountered throw Exception(s"multiple uuidToSymbol ${proxySymbols.map(_.name).mkString(", ")} associated with proxy block ${byteStringToString(edge.targetUuid)}, target of indirect call from block ${block.label}") @@ -407,14 +406,14 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ proc } removePCAssign(block) - DirectCall(target, None) + (Some(DirectCall(target)), Halt()) } } else if (uuidToBlock.contains(edge.targetUuid)) { // resolved indirect jump // TODO consider possibility this can go to another procedure? val target = uuidToBlock(edge.targetUuid) removePCAssign(block) - GoTo(mutable.Set(target)) + (None, GoTo(mutable.Set(target))) } else { throw Exception(s"edge from ${block.label} to ${byteStringToString(edge.targetUuid)} does not point to a known block or proxy block") } @@ -425,23 +424,23 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ // direct jump to start of own subroutine is treated as GoTo, not DirectCall // should probably investigate recursive cases to determine if this happens/is correct val jump = if (procedure == targetProc) { - GoTo(mutable.Set(uuidToBlock(edge.targetUuid))) + (None, GoTo(mutable.Set(uuidToBlock(edge.targetUuid)))) } else { - DirectCall(targetProc, None) + (Some(DirectCall(targetProc)), Halt()) } removePCAssign(block) jump } else if (uuidToBlock.contains(edge.targetUuid)) { val target = uuidToBlock(edge.targetUuid) removePCAssign(block) - GoTo(mutable.Set(target)) + (None, GoTo(mutable.Set(target))) } else { throw Exception(s"edge from ${block.label} to ${byteStringToString(edge.targetUuid)} does not point to a known block") } case EdgeLabel(false, _, Type_Return, _) => // return statement, value of 'direct' is just whether DDisasm has resolved the return target removePCAssign(block) - IndirectCall(Register("R30", 64), None) + (Some(IndirectCall(Register("R30", 64), None)), Halt()) case EdgeLabel(false, true, Type_Fallthrough, _) => // end of block that doesn't end in a control flow instruction and falls through to next if (entranceUUIDtoProcedure.contains(edge.targetUuid)) { @@ -449,10 +448,10 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ // probably doesn't actually happen in practice since it seems to be after brk instructions? val targetProc = entranceUUIDtoProcedure(edge.targetUuid) // assuming fallthrough won't fall through to start of own procedure - DirectCall(targetProc, None) + (Some(DirectCall(targetProc)), Halt()) } else if (uuidToBlock.contains(edge.targetUuid)) { val target = uuidToBlock(edge.targetUuid) - GoTo(mutable.Set(target)) + (None, GoTo(mutable.Set(target))) } else { throw Exception(s"edge from ${block.label} to ${byteStringToString(edge.targetUuid)} does not point to a known block") } @@ -462,7 +461,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ if (entranceUUIDtoProcedure.contains(edge.targetUuid)) { val target = entranceUUIDtoProcedure(edge.targetUuid) removePCAssign(block) - DirectCall(target, None) + (Some(DirectCall(target)), Halt()) } else { throw Exception(s"edge from ${block.label} to ${byteStringToString(edge.targetUuid)} does not point to a known procedure entrance") } @@ -473,14 +472,13 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ } } - def handleMultipleEdges(block: Block, outgoingEdges: mutable.Set[Edge], procedure: Procedure): Jump = { + def handleMultipleEdges(block: Block, outgoingEdges: mutable.Set[Edge], procedure: Procedure): (Option[Call], Jump) = { val edgeLabels = outgoingEdges.map(_.getLabel) if (edgeLabels.forall { (e: EdgeLabel) => !e.conditional && e.direct && e.`type` == Type_Return }) { // multiple resolved returns, translate as single return removePCAssign(block) - IndirectCall(Register("R30", 64), None) - + (None, Return()) } else if (edgeLabels.forall { (e: EdgeLabel) => !e.conditional && !e.direct && e.`type` == Type_Branch }) { // resolved indirect call with multiple blocks as targets val targets = mutable.Set[Block]() @@ -495,7 +493,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ } // TODO add assertion that target register is low removePCAssign(block) - GoTo(targets) + (None, GoTo(targets)) // TODO possibility not yet encountered: resolved indirect call that goes to multiple procedures? } else if (outgoingEdges.size == 2) { @@ -519,9 +517,9 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ handleIndirectCallWithReturn(edge1, edge0, block) // conditional branch case (EdgeLabel(true, true, Type_Fallthrough, _), EdgeLabel(true, true, Type_Branch, _)) => - handleConditionalBranch(edge0, edge1, block, procedure) + (None, handleConditionalBranch(edge0, edge1, block, procedure)) case (EdgeLabel(true, true, Type_Branch, _), EdgeLabel(true, true, Type_Fallthrough, _)) => - handleConditionalBranch(edge1, edge0, block, procedure) + (None, handleConditionalBranch(edge1, edge0, block, procedure)) case _ => throw Exception(s"cannot resolve outgoing edges from block ${block.label}") } @@ -542,7 +540,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ if (fallthroughs.size != 1 || indirectCallTargets.isEmpty) { throw Exception(s"cannot resolve outgoing edges from block ${block.label}") } - handleIndirectCallMultipleResolvedTargets(fallthroughs.head, indirectCallTargets, block, procedure) + (None, handleIndirectCallMultipleResolvedTargets(fallthroughs.head, indirectCallTargets, block, procedure)) } else { throw Exception(s"cannot resolve outgoing edges from block ${block.label}") } @@ -564,18 +562,18 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ } val target = entranceUUIDtoProcedure(call.targetUuid) - val resolvedCall = DirectCall(target, Some(returnTarget)) + val resolvedCall = DirectCall(target) val assume = Assume(BinaryExpr(BVEQ, targetRegister, BitVecLiteral(target.address.get, 64))) val label = block.label + "$" + target.name - newBlocks.append(Block(label, None, ArrayBuffer(assume), resolvedCall)) + newBlocks.append(Block(label, None, ArrayBuffer(assume, resolvedCall), GoTo(returnTarget))) } removePCAssign(block) procedure.addBlocks(newBlocks) GoTo(newBlocks) } - private def handleIndirectCallWithReturn(fallthrough: Edge, call: Edge, block: Block): Call = { + private def handleIndirectCallWithReturn(fallthrough: Edge, call: Edge, block: Block): (Option[Call], GoTo) = { if (!uuidToBlock.contains(fallthrough.targetUuid)) { throw Exception(s"block ${block.label} has fallthrough edge to ${byteStringToString(fallthrough.targetUuid)} that does not point to a known block") } @@ -586,16 +584,16 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ val target = getPCTarget(block) removePCAssign(block) - IndirectCall(target, Some(returnTarget)) + (Some(IndirectCall(target)), GoTo(Set(returnTarget))) } else { // resolved indirect call val target = entranceUUIDtoProcedure(call.targetUuid) removePCAssign(block) - DirectCall(target, Some(returnTarget)) + (Some(DirectCall(target)), GoTo(Set(returnTarget))) } } - private def handleDirectCallWithReturn(fallthrough: Edge, call: Edge, block: Block): DirectCall = { + private def handleDirectCallWithReturn(fallthrough: Edge, call: Edge, block: Block): (Option[Call], GoTo) = { if (!entranceUUIDtoProcedure.contains(call.targetUuid)) { throw Exception(s"block ${block.label} has direct call edge to ${byteStringToString(call.targetUuid)} that does not point to a known procedure") } @@ -607,7 +605,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ val target = entranceUUIDtoProcedure(call.targetUuid) val returnTarget = uuidToBlock(fallthrough.targetUuid) removePCAssign(block) - DirectCall(target, Some(returnTarget)) + (Some(DirectCall(target)), GoTo(Set(returnTarget))) } private def handleConditionalBranch(fallthrough: Edge, branch: Edge, block: Block, procedure: Procedure): GoTo = { diff --git a/src/main/scala/translating/ILtoIL.scala b/src/main/scala/translating/ILtoIL.scala index b34100704..99b64bc6e 100644 --- a/src/main/scala/translating/ILtoIL.scala +++ b/src/main/scala/translating/ILtoIL.scala @@ -74,7 +74,7 @@ private class ILSerialiser extends ReadOnlyVisitor { } - override def visitDirectCall(node: DirectCall): Jump = { + override def visitDirectCall(node: DirectCall): Statement = { program ++= "DirectCall(" program ++= procedureIdentifier(node.target) program ++= ", " @@ -82,7 +82,7 @@ private class ILSerialiser extends ReadOnlyVisitor { node } - override def visitIndirectCall(node: IndirectCall): Jump = { + override def visitIndirectCall(node: IndirectCall): Statement = { program ++= "IndirectCall(" visitVariable(node.target) program ++= ", " diff --git a/src/main/scala/translating/IRToBoogie.scala b/src/main/scala/translating/IRToBoogie.scala index 928517515..18c69a95f 100644 --- a/src/main/scala/translating/IRToBoogie.scala +++ b/src/main/scala/translating/IRToBoogie.scala @@ -632,39 +632,7 @@ class IRToBoogie(var program: Program, var spec: Specification, var thread: Opti } ) } - def translate(j: Jump): List[BCmd] = j match { - case d: DirectCall => - val call = BProcedureCall(d.target.name) - val returnTarget = d.returnTarget match { - case Some(r) => GoToCmd(Seq(r.label)) - case None => BAssume(FalseBLiteral, Some("no return target")) - } - - (config.procedureRely match { - case Some(ProcRelyVersion.Function) => - if (libRelies.contains(d.target.name) && libGuarantees.contains(d.target.name) && libRelies(d.target.name).nonEmpty && libGuarantees(d.target.name).nonEmpty) { - val invCall1 = BProcedureCall(d.target.name + "$inv", List(mem_inv1, Gamma_mem_inv1), List(mem, Gamma_mem)) - val invCall2 = BProcedureCall("rely$inv", List(mem_inv2, Gamma_mem_inv2), List(mem_inv1, Gamma_mem_inv1)) - val libRGAssert = libRelies(d.target.name).map(r => BAssert(r.resolveSpecInv)) - List(invCall1, invCall2) ++ libRGAssert - } else { - List() - } - case Some(ProcRelyVersion.IfCommandContradiction) => relyfun(d.target.name).toList - case None => List() - }) ++ List(call, returnTarget) - case i: IndirectCall => - // TODO put this elsewhere - if (i.target.name == "R30") { - List(ReturnCmd) - } else { - val unresolved: List[BCmd] = List(Comment(s"UNRESOLVED: call ${i.target.name}"), BAssert(FalseBLiteral)) - i.returnTarget match { - case Some(r) => unresolved :+ GoToCmd(Seq(r.label)) - case None => unresolved ++ List(Comment("no return target"), BAssume(FalseBLiteral)) - } - } case g: GoTo => // collects all targets of the goto with a branch condition that we need to check the security level for // and collects the variables for that @@ -681,9 +649,32 @@ class IRToBoogie(var program: Program, var spec: Specification, var thread: Opti } val jump = GoToCmd(g.targets.map(_.label).toSeq) conditionAssert :+ jump + case r: Return => List(ReturnCmd) + case r: Halt => List(BAssert(FalseBLiteral)) + } + + def translate(j: Call): List[BCmd] = j match { + case d: DirectCall => + val call = BProcedureCall(d.target.name) + + (config.procedureRely match { + case Some(ProcRelyVersion.Function) => + if (libRelies.contains(d.target.name) && libGuarantees.contains(d.target.name) && libRelies(d.target.name).nonEmpty && libGuarantees(d.target.name).nonEmpty) { + val invCall1 = BProcedureCall(d.target.name + "$inv", List(mem_inv1, Gamma_mem_inv1), List(mem, Gamma_mem)) + val invCall2 = BProcedureCall("rely$inv", List(mem_inv2, Gamma_mem_inv2), List(mem_inv1, Gamma_mem_inv1)) + val libRGAssert = libRelies(d.target.name).map(r => BAssert(r.resolveSpecInv)) + List(invCall1, invCall2) ++ libRGAssert + } else { + List() + } + case Some(ProcRelyVersion.IfCommandContradiction) => relyfun(d.target.name).toList + case None => List() + }) ++ List(call) + case i: IndirectCall => List(Comment(s"UNRESOLVED: call ${i.target.name}"), BAssert(FalseBLiteral)) } def translate(s: Statement): List[BCmd] = s match { + case d: Call => translate(d) case m: NOP => List.empty case m: MemoryAssign => val lhs = m.mem.toBoogie diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index d7dba1c42..cbc5cee37 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -29,6 +29,7 @@ import java.util.Base64 import spray.json.DefaultJsonProtocol.* import util.intrusive_list.IntrusiveList import analysis.CfgCommandNode +import cilvisitor._ import scala.annotation.tailrec import scala.collection.mutable @@ -198,11 +199,13 @@ object IRTransform { } val externalRemover = ExternalRemover(externalNamesLibRemoved.toSet) val renamer = Renamer(boogieReserved) - val returnUnifier = ConvertToSingleProcedureReturn() + + cilvisitor.visit_prog(transforms.ReplaceReturns(), ctx.program) + transforms.addReturnBlocks(ctx.program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), ctx.program) externalRemover.visitProgram(ctx.program) renamer.visitProgram(ctx.program) - returnUnifier.visitProgram(ctx.program) ctx } @@ -275,8 +278,8 @@ object IRTransform { modified = true // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) - val newCall = DirectCall(targets.head, indirectCall.returnTarget, indirectCall.label) - block.replaceJump(newCall) + val newCall = DirectCall(targets.head, indirectCall.label) + block.statements.replace(indirectCall, newCall) } else if (targets.size > 1) { modified = true val procedure = c.parent.data @@ -284,10 +287,14 @@ object IRTransform { for (t <- targets) { val assume = Assume(BinaryExpr(BVEQ, indirectCall.target, BitVecLiteral(t.address.get, 64))) val newLabel: String = block.label + t.name - val directCall = DirectCall(t, indirectCall.returnTarget) + val directCall = DirectCall(t) directCall.parent = indirectCall.parent - newBlocks.append(Block(newLabel, None, ArrayBuffer(assume), directCall)) + // assume indircall is the last statement in block + assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) + val fallthrough = indirectCall.parent.jump + + newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) } procedure.addBlocks(newBlocks) val newCall = GoTo(newBlocks, indirectCall.label) @@ -430,8 +437,8 @@ object IRTransform { modified = true // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) - val newCall = DirectCall(targets.head, indirectCall.returnTarget, indirectCall.label) - block.replaceJump(newCall) + val newCall = DirectCall(targets.head, indirectCall.label) + block.statements.replace(indirectCall, newCall) } else if (targets.size > 1) { modified = true val procedure = c.parent.data @@ -449,17 +456,20 @@ object IRTransform { addressExprs ::= addressExpr val assume = Assume(addressExpr) val newLabel: String = block.label + t.name - val directCall = DirectCall(t, indirectCall.returnTarget) + val directCall = DirectCall(t) directCall.parent = indirectCall.parent - newBlocks.append(Block(newLabel, None, ArrayBuffer(assume), directCall)) + + assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) + val fallthrough = indirectCall.parent.jump + newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) } procedure.addBlocks(newBlocks) val newCall = GoTo(newBlocks, indirectCall.label) val addressExprOr = addressExprs.tail.foldLeft(addressExprs.head) { (a: BinaryExpr, b: BinaryExpr) => BinaryExpr(BoolOR, a, b) } - val assert = Assert(addressExprOr, Some("check indirect call underapproximation")) - block.statements.append(assert) + val assertion = Assert(addressExprOr, Some("check indirect call underapproximation")) + block.statements.append(assertion) block.replaceJump(newCall) } case _ => @@ -501,42 +511,40 @@ object IRTransform { ): Unit = { // iterate over all commands - if call is to pthread_create, look up? - for (p <- program.procedures) { - for (b <- p.blocks) { - b.jump match { - case d: DirectCall if d.target.name == "pthread_create" => - - // R2 should hold the function pointer of the function that begins the thread - // look up R2 value using points to results - val R2 = Register("R2", 64) - val b = reachingDefs(d) - val R2Wrapper = RegisterVariableWrapper(R2, getDefinition(R2, d, reachingDefs)) - val threadTargets = pointsTo(R2Wrapper) - - if (threadTargets.size > 1) { - // currently can't handle case where the thread created is ambiguous - throw Exception("can't handle thread creation with more than one possible target") - } + program.foreach(c => + c match { + case d: DirectCall if d.target.name == "pthread_create" => + + // R2 should hold the function pointer of the function that begins the thread + // look up R2 value using points to results + val R2 = Register("R2", 64) + val b = reachingDefs(d) + val R2Wrapper = RegisterVariableWrapper(R2, getDefinition(R2, d, reachingDefs)) + val threadTargets = pointsTo(R2Wrapper) + + if (threadTargets.size > 1) { + // currently can't handle case where the thread created is ambiguous + throw Exception("can't handle thread creation with more than one possible target") + } - if (threadTargets.size == 1) { + if (threadTargets.size == 1) { - // not trying to untangle the very messy region resolution at present, just dealing with simplest case - threadTargets.head match { - case data: DataRegion => - val threadEntrance = program.procedures.find(_.name == data.regionIdentifier) match { - case Some(proc) => proc - case None => throw Exception("could not find procedure with name " + data.regionIdentifier) - } - val thread = ProgramThread(threadEntrance, mutable.LinkedHashSet(threadEntrance), Some(d)) - program.threads.addOne(thread) - case _ => - throw Exception("unexpected non-data region " + threadTargets.head + " as PointsTo result for R2 at " + d) - } + // not trying to untangle the very messy region resolution at present, just dealing with simplest case + threadTargets.head match { + case data: DataRegion => + val threadEntrance = program.procedures.find(_.name == data.regionIdentifier) match { + case Some(proc) => proc + case None => throw Exception("could not find procedure with name " + data.regionIdentifier) + } + val thread = ProgramThread(threadEntrance, mutable.LinkedHashSet(threadEntrance), Some(d)) + program.threads.addOne(thread) + case _ => + throw Exception("unexpected non-data region " + threadTargets.head + " as PointsTo result for R2 at " + d) } - case _ => - } - } - } + } + case _ => + }) + if (program.threads.nonEmpty) { val mainThread = ProgramThread(program.mainProcedure, mutable.LinkedHashSet(program.mainProcedure), None) From 7e9453693a339a9a928a61b1349a27feecdf741f Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Mon, 19 Aug 2024 11:56:49 +1000 Subject: [PATCH 21/62] move transforms out of RunUtils.scala --- src/main/scala/analysis/Cfg.scala | 4 +- src/main/scala/analysis/IDEAnalysis.scala | 6 +- .../analysis/InterLiveVarsAnalysis.scala | 10 +- .../analysis/IntraLiveVarsAnalysis.scala | 4 +- src/main/scala/analysis/VSA.scala | 2 +- .../scala/analysis/solvers/IDESolver.scala | 40 +- src/main/scala/ir/IRCursor.scala | 171 ++++----- src/main/scala/ir/Program.scala | 38 +- src/main/scala/ir/Statement.scala | 8 +- src/main/scala/ir/dsl/DSL.scala | 39 +- .../ir/invariant/EarlyCallStatement.scala | 15 + .../transforms/IndirectCallResolution.scala | 295 +++++++++++++++ .../scala/ir/transforms/ReplaceReturn.scala | 10 +- .../scala/ir/transforms/SplitThreads.scala | 84 +++++ src/main/scala/translating/GTIRBToIR.scala | 4 +- src/main/scala/translating/IRToBoogie.scala | 2 +- src/main/scala/util/RunUtils.scala | 356 +----------------- src/test/scala/IndirectCallsTests.scala | 87 +++-- src/test/scala/LiveVarsAnalysisTests.scala | 129 ++++--- src/test/scala/PointsToTest.scala | 24 +- src/test/scala/ir/IRTest.scala | 90 ++--- src/test/scala/ir/SingleCallInvariant.scala | 83 ++++ 22 files changed, 829 insertions(+), 672 deletions(-) create mode 100644 src/main/scala/ir/invariant/EarlyCallStatement.scala create mode 100644 src/main/scala/ir/transforms/IndirectCallResolution.scala create mode 100644 src/main/scala/ir/transforms/SplitThreads.scala create mode 100644 src/test/scala/ir/SingleCallInvariant.scala diff --git a/src/main/scala/analysis/Cfg.scala b/src/main/scala/analysis/Cfg.scala index 59d512cd2..8807f2d2a 100644 --- a/src/main/scala/analysis/Cfg.scala +++ b/src/main/scala/analysis/Cfg.scala @@ -502,7 +502,7 @@ class ProgramCfgFactory: val targetProc: Procedure = dCall.target funcEntryNode.callers.add(procToCfg(targetProc)._1) - val callNode = CfgJumpNode(dCall, block, funcEntryNode) + val callNode : CfgJumpNode = s.asInstanceOf[CfgJumpNode] // Branch to this call cfg.addEdge(precNode, callNode) @@ -523,7 +523,7 @@ class ProgramCfgFactory: Logger.debug(s"Indirect call found: $iCall in ${proc.name}") var precNode = prevNode - val jmpNode = CfgJumpNode(iCall, block, funcEntryNode) + val jmpNode = s.asInstanceOf[CfgJumpNode] // Branch to this call cfg.addEdge(precNode, jmpNode) diff --git a/src/main/scala/analysis/IDEAnalysis.scala b/src/main/scala/analysis/IDEAnalysis.scala index e77627717..c7ce74559 100644 --- a/src/main/scala/analysis/IDEAnalysis.scala +++ b/src/main/scala/analysis/IDEAnalysis.scala @@ -1,6 +1,6 @@ package analysis -import ir.{CFGPosition, Command, DirectCall, GoTo, IndirectCall, Procedure, Program} +import ir.{CFGPosition, Command, DirectCall, GoTo, Return, IndirectCall, Procedure, Program} final case class Lambda() @@ -55,6 +55,6 @@ trait IDEAnalysis[E, EE, C, R, D, T, L <: Lattice[T]] { } // IndirectCall in these is because they are returns so that can be further tightened in future -trait ForwardIDEAnalysis[D, T, L <: Lattice[T]] extends IDEAnalysis[Procedure, IndirectCall, DirectCall, Command, D, T, L] +trait ForwardIDEAnalysis[D, T, L <: Lattice[T]] extends IDEAnalysis[Procedure, Return, DirectCall, Command, D, T, L] -trait BackwardIDEAnalysis[D, T, L <: Lattice[T]] extends IDEAnalysis[IndirectCall, Procedure, Command, DirectCall, D, T, L] +trait BackwardIDEAnalysis[D, T, L <: Lattice[T]] extends IDEAnalysis[Return, Procedure, Command, DirectCall, D, T, L] diff --git a/src/main/scala/analysis/InterLiveVarsAnalysis.scala b/src/main/scala/analysis/InterLiveVarsAnalysis.scala index b20edf478..7a93266e8 100644 --- a/src/main/scala/analysis/InterLiveVarsAnalysis.scala +++ b/src/main/scala/analysis/InterLiveVarsAnalysis.scala @@ -1,7 +1,7 @@ package analysis import analysis.solvers.BackwardIDESolver -import ir.{Assert, Assume, GoTo, CFGPosition, Command, DirectCall, IndirectCall, Assign, MemoryAssign, Procedure, Program, Variable, toShortString} +import ir.{Assert, Assume, Block, GoTo, CFGPosition, Command, DirectCall, IndirectCall, Assign, MemoryAssign, Halt, Return, Procedure, Program, Variable, toShortString} /** * Micro-transfer-functions for LiveVar analysis @@ -19,7 +19,7 @@ trait LiveVarsAnalysisFunctions extends BackwardIDEAnalysis[Variable, TwoElement val edgelattice: EdgeFunctionLattice[TwoElement, TwoElementLattice] = EdgeFunctionLattice(valuelattice) import edgelattice.{IdEdge, ConstEdge} - def edgesCallToEntry(call: Command, entry: IndirectCall)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { + def edgesCallToEntry(call: Command, entry: Return)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { Map(d -> IdEdge()) } @@ -74,7 +74,11 @@ trait LiveVarsAnalysisFunctions extends BackwardIDEAnalysis[Variable, TwoElement d match case Left(value) => if value != variable then Map(d -> IdEdge()) else Map() case Right(_) => Map(d -> IdEdge(), Left(variable) -> ConstEdge(TwoElementTop)) - case _ => Map(d -> IdEdge()) + case r: Return => Map(d -> IdEdge()) + case h: Halt => Map(d -> IdEdge()) + case c: DirectCall => Map(d -> IdEdge()) + case c: Block => Map(d -> IdEdge()) + case c: GoTo => Map(d -> IdEdge()) } } diff --git a/src/main/scala/analysis/IntraLiveVarsAnalysis.scala b/src/main/scala/analysis/IntraLiveVarsAnalysis.scala index f3f322321..75fa1dbb0 100644 --- a/src/main/scala/analysis/IntraLiveVarsAnalysis.scala +++ b/src/main/scala/analysis/IntraLiveVarsAnalysis.scala @@ -1,7 +1,7 @@ package analysis import analysis.solvers.SimpleWorklistFixpointSolver -import ir.{Assert, Assume, Block, CFGPosition, Call, DirectCall, GoTo, IndirectCall, Jump, Assign, MemoryAssign, NOP, Procedure, Program, Statement, Variable} +import ir.{Assert, Assume, Block, CFGPosition, Call, DirectCall, GoTo, IndirectCall, Jump, Assign, MemoryAssign, NOP, Procedure, Program, Statement, Variable, Return, Halt} abstract class LivenessAnalysis(program: Program) extends Analysis[Any]: val lattice: MapLattice[CFGPosition, Set[Variable], PowersetLattice[Variable]] = MapLattice(PowersetLattice()) @@ -18,6 +18,8 @@ abstract class LivenessAnalysis(program: Program) extends Analysis[Any]: case IndirectCall(variable, _) => s + variable case c: DirectCall => s case g: GoTo => s + case r: Return => s + case r: Halt => s case _ => ??? } } diff --git a/src/main/scala/analysis/VSA.scala b/src/main/scala/analysis/VSA.scala index f7d9a55e7..734f9bf11 100644 --- a/src/main/scala/analysis/VSA.scala +++ b/src/main/scala/analysis/VSA.scala @@ -172,7 +172,7 @@ trait ValueSetAnalysis(program: Program, if (IRWalk.procedure(n) == n) { mmm.pushContext(n.asInstanceOf[Procedure].name) s - } else if (IRWalk.procedure(n).end == n) { + } else if (IRWalk.lastInProc(IRWalk.procedure(n)) == n) { mmm.popContext() s } else n match diff --git a/src/main/scala/analysis/solvers/IDESolver.scala b/src/main/scala/analysis/solvers/IDESolver.scala index cef0f6eb2..231c87521 100644 --- a/src/main/scala/analysis/solvers/IDESolver.scala +++ b/src/main/scala/analysis/solvers/IDESolver.scala @@ -1,7 +1,7 @@ package analysis.solvers import analysis.{BackwardIDEAnalysis, Dependencies, EdgeFunction, EdgeFunctionLattice, ForwardIDEAnalysis, IDEAnalysis, IRInterproceduralBackwardDependencies, IRInterproceduralForwardDependencies, Lambda, Lattice, MapLattice} -import ir.{CFGPosition, Command, DirectCall, GoTo, IRWalk, IndirectCall, InterProcIRCursor, Procedure, Program, end, isAfterCall, Halt, Statement, Jump} +import ir.{CFGPosition, Command, DirectCall, GoTo, IRWalk, IndirectCall, Return, InterProcIRCursor, Procedure, Program, isAfterCall, Halt, Statement, Jump} import util.Logger import scala.collection.immutable.Map @@ -208,10 +208,10 @@ abstract class IDESolver[E <: Procedure | Command, EE <: Procedure | Command, C abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) - extends IDESolver[Procedure, IndirectCall, DirectCall, Command, D, T, L](program, program.mainProcedure), + extends IDESolver[Procedure, Return, DirectCall, Command, D, T, L](program, program.mainProcedure), ForwardIDEAnalysis[D, T, L], IRInterproceduralForwardDependencies { - protected def entryToExit(entry: Procedure): IndirectCall = entry.end.asInstanceOf[IndirectCall] + protected def entryToExit(entry: Procedure): Return = IRWalk.lastInProc(entry).asInstanceOf[Return] protected def exitToEntry(exit: IndirectCall): Procedure = IRWalk.procedure(exit) @@ -222,7 +222,10 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) case r: Jump => ret.parent.statements.last.asInstanceOf[DirectCall] } - protected def getCallee(call: DirectCall): Procedure = call.target + protected def getCallee(call: DirectCall): Procedure = { + require(isCall(call)) + call.target + } protected def isCall(call: CFGPosition): Boolean = call match @@ -232,41 +235,46 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) protected def isExit(exit: CFGPosition): Boolean = exit match // only looking at functions with statements - case command: Command => IRWalk.procedure(command).end == command + case command: Command => IRWalk.lastInProc(IRWalk.procedure(command)) == command case _ => false protected def getAfterCalls(exit: IndirectCall): Set[Command] = InterProcIRCursor.succ(exit).filter(_.isInstanceOf[Command]).map(_.asInstanceOf[Command]) - } abstract class BackwardIDESolver[D, T, L <: Lattice[T]](program: Program) - extends IDESolver[IndirectCall, Procedure, Command, DirectCall, D, T, L](program, program.mainProcedure.end), + extends IDESolver[Return, Procedure, Command, DirectCall, D, T, L](program, IRWalk.lastInProc(program.mainProcedure)), BackwardIDEAnalysis[D, T, L], IRInterproceduralBackwardDependencies { - protected def entryToExit(entry: IndirectCall): Procedure = IRWalk.procedure(entry) + protected def entryToExit(entry: Return): Procedure = IRWalk.procedure(entry) - protected def exitToEntry(exit: Procedure): IndirectCall = exit.end.asInstanceOf[IndirectCall] + protected def exitToEntry(exit: Procedure): Return = exit.returnBlock.get.jump.asInstanceOf[Return] - protected def callToReturn(call: Command): DirectCall = call match { - case ret: Statement => ret.parent.statements.getPrev(ret).asInstanceOf[DirectCall] - case r: Jump => r.parent.statements.last.asInstanceOf[DirectCall] + protected def callToReturn(call: Command): DirectCall = { + IRWalk.prevCommandInBlock(call) match { + case Some(x : DirectCall) => x + case p => throw Exception(s"Not a return/aftercall node $call .... prev = $p") + } } protected def returnToCall(ret: DirectCall): Command = ret.successor - protected def getCallee(call: Command): IndirectCall = callToReturn(call: Command).target.end.asInstanceOf[IndirectCall] + protected def getCallee(call: Command): Return = { + require(isCall(call)) + val procCalled = callToReturn(call).target + procCalled.returnBlock.getOrElse(throw Exception(s"No return node for procedure ${procCalled}")).jump.asInstanceOf[Return] + } protected def isCall(call: CFGPosition): Boolean = call match - case directCall: DirectCall => (!directCall.successor.isInstanceOf[Halt]) + case c : Command => isAfterCall(c) && IRWalk.prevCommandInBlock(c).map(_.isInstanceOf[DirectCall]).getOrElse(false) case _ => false protected def isExit(exit: CFGPosition): Boolean = exit match - case procedure: Procedure => procedure.blocks.nonEmpty + case procedure: Procedure => true case _ => false - protected def getAfterCalls(exit: Procedure): Set[DirectCall] = InterProcIRCursor.pred(exit).filter(_.isInstanceOf[DirectCall]).map(_.asInstanceOf[DirectCall]) + protected def getAfterCalls(exit: Procedure): Set[DirectCall] = exit.incomingCalls().toSet } diff --git a/src/main/scala/ir/IRCursor.scala b/src/main/scala/ir/IRCursor.scala index 4e9cb3948..52c585aaf 100644 --- a/src/main/scala/ir/IRCursor.scala +++ b/src/main/scala/ir/IRCursor.scala @@ -14,12 +14,19 @@ import scala.annotation.tailrec */ type CFGPosition = Procedure | Block | Command +def isAfterCall(c: Command) = { + (IRWalk.prevCommandInBlock(c)) match { + case Some(c: Call) => true + case _ => false + } +} + extension (p: CFGPosition) def toShortString: String = p match case procedure: Procedure => procedure.toString - case block: Block => s"Block ${block.label}" - case command: Command => command.toString + case block: Block => s"Block ${block.label}" + case command: Command => command.toString // todo: we could just use the dependencies trait directly instead to avoid the instantiation issue trait IRWalk[IN <: CFGPosition, NT <: CFGPosition & IN] { @@ -28,75 +35,79 @@ trait IRWalk[IN <: CFGPosition, NT <: CFGPosition & IN] { } object IRWalk: - def procedure(pos: CFGPosition) : Procedure = { + + def prevCommandInBlock(c: Command): Option[Command] = c match { + case s: Statement => c.parent.statements.prevOption(s) + case j: Jump => c.parent.statements.lastOption + } + + def nextCommandInBlock(c: Command): Option[Command] = c match { + case s: Statement => Some(s.successor) + case j: Jump => None + } + + def procedure(pos: CFGPosition): Procedure = { pos match { case p: Procedure => p - case b: Block => b.parent - case c: Command => c.parent.parent + case b: Block => b.parent + case c: Command => c.parent.parent } } - def blockBegin(pos: CFGPosition) : Option[Block] = { + def blockBegin(pos: CFGPosition): Option[Block] = { pos match { case p: Procedure => p.entryBlock - case b: Block => Some(b) - case c: Command => Some(c.parent) + case b: Block => Some(b) + case c: Command => Some(c.parent) } } - def commandBegin(pos: CFGPosition) : Option[Command] = { + def commandBegin(pos: CFGPosition): Option[Command] = { pos match { case p: Procedure => p.entryBlock.map(b => b.statements.headOption.getOrElse(b.jump)) - case b: Block => Some(b.statements.headOption.getOrElse(b.jump)) - case c: Command => Some(c) - } - } - -extension (p: Command) - def isAfterCall : Boolean = { - p match { - case g: Jump => g.parent.statements.lastOption.map(_.isInstanceOf[Call]).getOrElse(false) - case g: Statement => g.parent.statements.prevOption(g).map(_.isInstanceOf[Call]).getOrElse(false) + case b: Block => Some(b.statements.headOption.getOrElse(b.jump)) + case c: Command => Some(c) } } -extension (p: Block) - def isProcEntry : Boolean = p.parent.entryBlock.contains(p) - def isProcReturn : Boolean = p.parent.returnBlock.contains(p) - // TODO: this method doesn't require aftercall blocks only have 1 incoming jump - def isAfterCall : Boolean = p.incomingJumps.nonEmpty && p.incomingJumps.forall(_.isAfterCall) + def lastInBlock(p: Block): Command = p.jump + def firstInBlock(p: Block): Command = p.statements.headOption.getOrElse(p.jump) - def begin: CFGPosition = p - def end: CFGPosition = p.jump + def firstInProc(p: Procedure): Command = firstInBlock(p.entryBlock.get) + def lastInProc(p: Procedure): Command = lastInBlock(p.returnBlock.get) -extension (p: Procedure) - def begin: CFGPosition = p - def end: CFGPosition = p.returnBlock.map(_.end).getOrElse(p) +// extension (p: Block) +// def isProcEntry: Boolean = p.parent.entryBlock.contains(p) +// def isProcReturn: Boolean = p.parent.returnBlock.contains(p) +// +// def begin: CFGPosition = p +// def end: CFGPosition = p.jump +// +// extension (p: Procedure) +// def begin: CFGPosition = p +// def end: CFGPosition = p.returnBlock.map(_.end).getOrElse(p) -/** - * Does not include edges between procedures. - */ +/** Does not include edges between procedures. + */ trait IntraProcIRCursor extends IRWalk[CFGPosition, CFGPosition] { def succ(pos: CFGPosition): Set[CFGPosition] = { pos match { case proc: Procedure => proc.entryBlock.toSet - case b: Block => Set(b.statements.headOption.getOrElse(b.jump)) - case s: Statement => Set(s.successor) + case b: Block => b.statements.headOption.orElse(Some(b.jump)).toSet case n: GoTo => n.targets.asInstanceOf[Set[CFGPosition]] case h: Halt => Set() case h: Return => Set() + case c: Statement => IRWalk.nextCommandInBlock(c).toSet } } def pred(pos: CFGPosition): Set[CFGPosition] = { pos match { - case s: Statement => Set(s.pred().getOrElse(s.parent)) - case j: GoTo if j.isAfterCall => Set(j.parent.jump) - case j: Jump => Set(j.parent.statements.lastOption.getOrElse(j.parent)) - case b: Block if b.isProcEntry => Set(b.parent) - case b: Block => b.incomingJumps.asInstanceOf[Set[CFGPosition]] - case proc: Procedure => Set() // intraproc + case c: Command => Set(IRWalk.prevCommandInBlock(c).getOrElse(c.parent)) + case b: Block if b.isEntry => Set(b.parent) + case b: Block => b.incomingJumps.asInstanceOf[Set[CFGPosition]] + case proc: Procedure => Set() // intraproc } } } @@ -117,70 +128,44 @@ trait IntraProcBlockIRCursor extends IRWalk[CFGPosition, Block] { @tailrec final def pred(pos: CFGPosition): Set[Block] = { pos match { - case b: Block if b.isProcEntry => Set.empty - case b: Block => b.incomingJumps.map(_.parent).toSet - case j: Command => pred(j.parent) - case s: Procedure => Set.empty + case b: Block if b.isEntry => Set.empty + case b: Block => b.incomingJumps.map(_.parent).toSet + case j: Command => pred(j.parent) + case s: Procedure => Set.empty } } } object IntraProcBlockIRCursor extends IntraProcBlockIRCursor -/** - * Includes all intraproc edges as well as edges between procedures. - * - * forwards: - * Direct call -> target - * return indirect Call -> the procedure return block for all possible direct-call sites - * - * backwards: - * Procedure -> all possible direct-call sites - * Call-return block -> return Call of the procedure called - * - */ +/** Includes all intraproc edges as well as edges between procedures. + * + * forwards: Direct call -> target return indirect Call -> the procedure return block for all possible direct-call + * sites + * + * backwards: Procedure -> all possible direct-call sites Call-return block -> return Call of the procedure called + */ trait InterProcIRCursor extends IRWalk[CFGPosition, CFGPosition] { final def succ(pos: CFGPosition): Set[CFGPosition] = { - IntraProcIRCursor.succ(pos) ++ - (pos match - case c: DirectCall if c.target.blocks.nonEmpty => Set(c.target) - case c: IndirectCall if c.parent.isProcReturn => c.parent.parent.incomingCalls().map(_.successor).toSet - case _ => Set.empty) + IntraProcIRCursor.succ(pos) ++ + (pos match + case c: DirectCall if c.target.blocks.nonEmpty => Set(c.target) + // case c: IndirectCall if c.parent.isProcReturn => c.parent.parent.incomingCalls().map(_.successor).toSet + case c: Return => c.parent.parent.incomingCalls().map(_.successor).toSet + case _ => Set.empty + ) } final def pred(pos: CFGPosition): Set[CFGPosition] = { IntraProcIRCursor.pred(pos) ++ - (pos match - case d: DirectCall if d.target.blocks.nonEmpty => d.target.returnBlock.toSet - case c: Procedure => c.incomingCalls().toSet.asInstanceOf[Set[CFGPosition]] - case _ => Set.empty) + (pos match + case d: DirectCall if d.target.blocks.nonEmpty => d.target.returnBlock.toSet + case c: Procedure => c.incomingCalls().toSet.asInstanceOf[Set[CFGPosition]] + case _ => Set.empty + ) } } -// less meaningful with call statements - -// trait InterProcBlockIRCursor extends IRWalk[CFGPosition, Block] { -// -// final def succ(pos: CFGPosition): Set[Block] = { -// IntraProcBlockIRCursor.succ(pos) ++ -// (pos match { -// case s: DirectCall if s.target.blocks.nonEmpty => s.target.entryBlock.toSet -// case b: Block if b.isProcReturn => b.parent.incomingCalls().map(_.parent).toSet -// case _ => Set.empty -// }) -// } -// -// final def pred(pos: CFGPosition): Set[Block] = { -// IntraProcBlockIRCursor.pred(pos) ++ -// (pos match { -// case b: Block if b.isAfterCall => b.incomingJumps.collect {_.parent.jump match -// case d: DirectCall => d.target }.flatMap(_.returnBlock).toSet -// case b: Block if b.isProcEntry => b.parent.incomingCalls().map(_.parent).toSet -// case _ => Set.empty -// }) -// } -// } - object InterProcIRCursor extends InterProcIRCursor trait CallGraph extends IRWalk[Procedure, Procedure] { @@ -316,17 +301,17 @@ def toDot[T <: CFGPosition]( def getArrow(s: CFGPosition, n: CFGPosition) = { if (IRWalk.procedure(n) eq IRWalk.procedure(s)) { - DotRegularArrow(dotNodes(s),dotNodes(n)) + DotRegularArrow(dotNodes(s), dotNodes(n)) } else { - DotInterArrow(dotNodes(s),dotNodes(n)) + DotInterArrow(dotNodes(s), dotNodes(n)) } } for (node <- domain) { node match { case s => - iterator.succ(s).foreach(n => dotArrows.addOne(getArrow(s,n))) - // iterator.pred(s).foreach(n => dotArrows.addOne(getArrow(s,n))) + iterator.succ(s).foreach(n => dotArrows.addOne(getArrow(s, n))) + // iterator.pred(s).foreach(n => dotArrows.addOne(getArrow(s,n))) } } diff --git a/src/main/scala/ir/Program.scala b/src/main/scala/ir/Program.scala index c2d479cca..5911be3a8 100644 --- a/src/main/scala/ir/Program.scala +++ b/src/main/scala/ir/Program.scala @@ -5,6 +5,7 @@ import scala.collection.{IterableOnceExtensionMethods, View, immutable, mutable} import boogie.* import analysis.BitVectorEval import util.intrusive_list.* +import translating.serialiseIL class Program(var procedures: ArrayBuffer[Procedure], var mainProcedure: Procedure, @@ -13,6 +14,10 @@ class Program(var procedures: ArrayBuffer[Procedure], val threads: ArrayBuffer[ProgramThread] = ArrayBuffer() + override def toString(): String = { + serialiseIL(this) + } + // This shouldn't be run before indirect calls are resolved def stripUnreachableFunctions(depth: Int = Int.MaxValue): Unit = { val procedureCalleeNames = procedures.map(f => f.name -> f.calls.map(_.name)).toMap @@ -141,7 +146,7 @@ class Program(var procedures: ArrayBuffer[Procedure], stack.pushAll(n match { case p: Procedure => p.blocks - case b: Block => Seq() ++ b.statements ++ Seq(b.jump) + case b: Block => Seq() ++ b.statements.toSeq ++ Seq(b.jump) case s: Command => Seq() }) n @@ -211,7 +216,7 @@ class Procedure private ( def returnBlock_=(value: Block): Unit = { if (!returnBlock.contains(value)) { - removeBlocks(_returnBlock) + _returnBlock.foreach(removeBlocks(_)) _returnBlock = Some(addBlocks(value)) } } @@ -220,7 +225,7 @@ class Procedure private ( def entryBlock_=(value: Block): Unit = { if (!entryBlock.contains(value)) { - removeBlocks(_entryBlock) + _entryBlock.foreach(removeBlocks(_)) _entryBlock = Some(addBlocks(value)) } } @@ -230,9 +235,6 @@ class Procedure private ( if (!_blocks.contains(block)) { block.parent = this _blocks.add(block) - if (entryBlock.isEmpty) { - entryBlock = block - } } block } @@ -291,28 +293,6 @@ class Procedure private ( block } -// unused -// /** -// * Remove blocks with the semantics of replacing them with a noop. The incoming jumps to this are replaced -// * with a jump(s) to this blocks jump target(s). If this block ends in a call then only its statements are removed. -// * @param blocks the block/blocks to remove -// */ -// def removeBlocksInline(blocks: Iterable[Block]): Unit = { -// for (elem <- blocks) { -// elem.jump match { -// case g: GoTo => -// // rewrite all the jumps to include our jump targets -// elem.incomingJumps.foreach(_.removeTarget(elem)) -// elem.incomingJumps.foreach(_.addAllTargets(g.targets)) -// removeBlocks(elem) -// } -// } -// } -// -// -// def removeBlocksInline(blocks: Block*): Unit = { -// removeBlocksInline(blocks.toSeq) -// } /** * Remove block(s) and all jumps that target it @@ -391,6 +371,8 @@ class Block private ( this(label, address, IntrusiveList().addAll(statements), jump, mutable.HashSet.empty) } + def isEntry: Boolean = parent.entryBlock.contains(this) + def jump: Jump = _jump private def jump_=(j: Jump): Unit = { diff --git a/src/main/scala/ir/Statement.scala b/src/main/scala/ir/Statement.scala index 74aaa18b5..2dea68f46 100644 --- a/src/main/scala/ir/Statement.scala +++ b/src/main/scala/ir/Statement.scala @@ -83,6 +83,7 @@ sealed trait Jump extends Command { } class Halt(override val label: Option[String] = None) extends Jump { + /* Terminate / No successors / assume false */ override def acceptVisit(visitor: Visitor): Jump = this } @@ -136,7 +137,12 @@ object GoTo: def unapply(g: GoTo): Option[(Set[Block], Option[String])] = Some(g.targets, g.label) -sealed trait Call extends Statement +sealed trait Call extends Statement { + def returnTarget: Option[Command] = successor match { + case h: Halt => None + case o => Some(o) + } +} class DirectCall(val target: Procedure, override val label: Option[String] = None diff --git a/src/main/scala/ir/dsl/DSL.scala b/src/main/scala/ir/dsl/DSL.scala index 3c55e2dfc..6a1b96742 100644 --- a/src/main/scala/ir/dsl/DSL.scala +++ b/src/main/scala/ir/dsl/DSL.scala @@ -14,7 +14,7 @@ val R7: Register = Register("R7", 64) val R29: Register = Register("R29", 64) val R30: Register = Register("R30", 64) val R31: Register = Register("R31", 64) -val ret: EventuallyIndirectCall = EventuallyIndirectCall(Register("R30", 64), None) + def bv32(i: Int): BitVecLiteral = BitVecLiteral(i, 32) @@ -40,21 +40,21 @@ trait EventuallyStatement { } case class ResolvableStatement(s: Statement) extends EventuallyStatement { - override def resolve(p: Program) = s + override def resolve(p: Program) : Statement = s } trait EventuallyJump { def resolve(p: Program): Jump } -case class EventuallyIndirectCall(target: Variable, fallthrough: Option[DelayNameResolve]) extends EventuallyStatement { - override def resolve(p: Program): IndirectCall = { +case class EventuallyIndirectCall(target: Variable) extends EventuallyStatement { + override def resolve(p: Program): Statement = { IndirectCall(target) } } -case class EventuallyCall(target: DelayNameResolve, fallthrough: Option[DelayNameResolve]) extends EventuallyStatement { - override def resolve(p: Program): DirectCall = { +case class EventuallyCall(target: DelayNameResolve) extends EventuallyStatement { + override def resolve(p: Program): Statement = { val t = target.resolveProc(p) match { case Some(x) => x case None => throw Exception("can't resolve proc " + p) @@ -63,12 +63,19 @@ case class EventuallyCall(target: DelayNameResolve, fallthrough: Option[DelayNam } } + case class EventuallyGoto(targets: List[DelayNameResolve]) extends EventuallyJump { override def resolve(p: Program): GoTo = { val tgs = targets.flatMap(tn => tn.resolveBlock(p)) GoTo(tgs) } } +case class EventuallyReturn() extends EventuallyJump { + override def resolve(p: Program) = Return() +} +case class EventuallyHalt() extends EventuallyJump { + override def resolve(p: Program) = Halt() +} def goto(): EventuallyGoto = EventuallyGoto(List.empty) @@ -76,13 +83,16 @@ def goto(targets: String*): EventuallyGoto = { EventuallyGoto(targets.map(p => DelayNameResolve(p)).toList) } +def ret: EventuallyReturn = EventuallyReturn() +def halt: EventuallyHalt= EventuallyHalt() + def goto(targets: List[String]): EventuallyGoto = { EventuallyGoto(targets.map(p => DelayNameResolve(p))) } -def directCall(tgt: String, fallthrough: Option[String]): EventuallyCall = EventuallyCall(DelayNameResolve(tgt), fallthrough.map(x => DelayNameResolve(x))) +def directCall(tgt: String): EventuallyCall = EventuallyCall(DelayNameResolve(tgt)) -def indirectCall(tgt: Variable, fallthrough: Option[String]): EventuallyIndirectCall = EventuallyIndirectCall(tgt, fallthrough.map(x => DelayNameResolve(x))) +def indirectCall(tgt: Variable): EventuallyIndirectCall = EventuallyIndirectCall(tgt) // def directcall(tgt: String) = EventuallyCall(DelayNameResolve(tgt), None) @@ -90,16 +100,20 @@ case class EventuallyBlock(label: String, sl: Seq[EventuallyStatement], j: Event val tempBlock: Block = Block(label, None, List(), GoTo(List.empty)) def resolve(prog: Program): Block = { - tempBlock.statements.addAll(sl.map(_.resolve(prog))) + val resolved = sl.map(_.resolve(prog)) + tempBlock.statements.addAll(resolved) tempBlock.replaceJump(j.resolve(prog)) tempBlock } } def block(label: String, sl: (Statement | EventuallyStatement | EventuallyJump)*): EventuallyBlock = { - val statements : Seq[EventuallyStatement] = sl.collect { - case s: Statement => ResolvableStatement(s) - case o: EventuallyStatement => o + val statements : Seq[EventuallyStatement] = sl.flatMap { + case s: Statement => Some(ResolvableStatement(s)) + case o: EventuallyStatement => Some(o) + case o: EventuallyCall => Some(o) + case o: EventuallyIndirectCall => Some(o) + case g: EventuallyJump => None } val jump = sl.collectFirst { case j: EventuallyJump => j @@ -113,6 +127,7 @@ case class EventuallyProcedure(label: String, blocks: Seq[EventuallyBlock]) { val jumps: Map[Block, EventuallyJump] = blocks.map(b => b.tempBlock -> b.j).toMap def resolve(prog: Program): Procedure = { + blocks.foreach(b => b.resolve(prog)) jumps.map((b, j) => b.replaceJump(j.resolve(prog))) tempProc } diff --git a/src/main/scala/ir/invariant/EarlyCallStatement.scala b/src/main/scala/ir/invariant/EarlyCallStatement.scala new file mode 100644 index 000000000..bb4504343 --- /dev/null +++ b/src/main/scala/ir/invariant/EarlyCallStatement.scala @@ -0,0 +1,15 @@ +package ir.invariant +import ir._ + + +def singleCallBlockEnd(p: Program) : Boolean = { + p.forall { + case b: Block => { + val calls = (b.statements.collect { + case c: Call => b.statements.lastOption.contains(c) + }) + (calls.size <= 1) && calls.headOption.getOrElse(true) + } + case _ => true + } +} diff --git a/src/main/scala/ir/transforms/IndirectCallResolution.scala b/src/main/scala/ir/transforms/IndirectCallResolution.scala new file mode 100644 index 000000000..e9c9fddf7 --- /dev/null +++ b/src/main/scala/ir/transforms/IndirectCallResolution.scala @@ -0,0 +1,295 @@ +package ir.transforms + + + +import scala.collection.mutable.ListBuffer +import scala.collection.mutable.ArrayBuffer +import analysis.solvers.* +import analysis.* +import bap.* +import ir.* +import translating.* +import util.Logger +import util.intrusive_list.IntrusiveList +import analysis.CfgCommandNode +import scala.collection.mutable +import cilvisitor._ + + +/** Resolve indirect calls to an address-conditional choice between direct calls using the Value Set Analysis results. + * Dead code, and currently broken by statement calls + * +def resolveIndirectCalls( + cfg: ProgramCfg, + valueSets: Map[CfgNode, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]], + IRProgram: Program +): Boolean = { + var modified: Boolean = false + val worklist = ListBuffer[CfgNode]() + cfg.startNode.succIntra.union(cfg.startNode.succInter).foreach(node => worklist.addOne(node)) + + val visited = mutable.Set[CfgNode]() + while (worklist.nonEmpty) { + val node = worklist.remove(0) + if (!visited.contains(node)) { + process(node) + node.succIntra.union(node.succInter).foreach(node => worklist.addOne(node)) + visited.add(node) + } + } + + def process(n: CfgNode): Unit = n match { + /* + case c: CfgStatementNode => + c.data match + + //We do not want to insert the VSA results into the IR like this + case localAssign: Assign => + localAssign.rhs match + case _: MemoryLoad => + if (valueSets(n).contains(localAssign.lhs) && valueSets(n).get(localAssign.lhs).head.size == 1) { + val extractedValue = extractExprFromValue(valueSets(n).get(localAssign.lhs).head.head) + localAssign.rhs = extractedValue + Logger.info(s"RESOLVED: Memory load ${localAssign.lhs} resolved to ${extractedValue}") + } else if (valueSets(n).contains(localAssign.lhs) && valueSets(n).get(localAssign.lhs).head.size > 1) { + Logger.info(s"RESOLVED: WARN Memory load ${localAssign.lhs} resolved to multiple values, cannot replace") + + /* + // must merge into a single memory variable to represent the possible values + // Make a binary OR of all the possible values takes two at a time (incorrect to do BVOR) + val values = valueSets(n).get(localAssign.lhs).head + val exprValues = values.map(extractExprFromValue) + val result = exprValues.reduce((a, b) => BinaryExpr(BVOR, a, b)) // need to express nondeterministic + // choice between these specific options + localAssign.rhs = result + */ + } + case _ => + */ + case c: CfgJumpNode => + val block = c.block + c.data match + case indirectCall: IndirectCall => + if (block.jump != indirectCall) { + // We only replace the calls with DirectCalls in the IR, and don't replace the CommandNode.data + // Hence if we have already processed this CFG node there will be no corresponding IndirectCall in the IR + // to replace. + // We want to replace all possible indirect calls based on this CFG, before regenerating it from the IR + return + } + valueSets(n) match { + case Lift(valueSet) => + val targetNames = resolveAddresses(valueSet(indirectCall.target)).map(_.name).toList.sorted + val targets = targetNames.map(name => IRProgram.procedures.filter(_.name.equals(name)).head) + + if (targets.size == 1) { + modified = true + + // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) + val newCall = DirectCall(targets.head, indirectCall.label) + block.statements.replace(indirectCall, newCall) + } else if (targets.size > 1) { + modified = true + val procedure = c.parent.data + val newBlocks = ArrayBuffer[Block]() + for (t <- targets) { + val assume = Assume(BinaryExpr(BVEQ, indirectCall.target, BitVecLiteral(t.address.get, 64))) + val newLabel: String = block.label + t.name + val directCall = DirectCall(t) + directCall.parent = indirectCall.parent + + // assume indircall is the last statement in block + assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) + val fallthrough = indirectCall.parent.jump + + newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) + } + procedure.addBlocks(newBlocks) + val newCall = GoTo(newBlocks, indirectCall.label) + block.replaceJump(newCall) + } + case LiftedBottom => + } + case _ => + case _ => + } + + def nameExists(name: String): Boolean = { + IRProgram.procedures.exists(_.name.equals(name)) + } + + def addFakeProcedure(name: String): Unit = { + IRProgram.procedures += Procedure(name) + } + + def resolveAddresses(valueSet: Set[Value]): Set[AddressValue] = { + var functionNames: Set[AddressValue] = Set() + valueSet.foreach { + case globalAddress: GlobalAddress => + if (nameExists(globalAddress.name)) { + functionNames += globalAddress + Logger.info(s"RESOLVED: Call to Global address ${globalAddress.name} rt statuesolved.") + } else { + addFakeProcedure(globalAddress.name) + functionNames += globalAddress + Logger.info(s"Global address ${globalAddress.name} does not exist in the program. Added a fake function.") + } + case localAddress: LocalAddress => + if (nameExists(localAddress.name)) { + functionNames += localAddress + Logger.info(s"RESOLVED: Call to Local address ${localAddress.name}") + } else { + addFakeProcedure(localAddress.name) + functionNames += localAddress + Logger.info(s"Local address ${localAddress.name} does not exist in the program. Added a fake function.") + } + case _ => + } + functionNames + } + + modified +} + + */ + +def resolveIndirectCallsUsingPointsTo( + cfg: ProgramCfg, + pointsTos: Map[RegisterVariableWrapper, Set[RegisterVariableWrapper | MemoryRegion]], + regionContents: Map[MemoryRegion, Set[BitVecLiteral | MemoryRegion]], + reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], + IRProgram: Program + ): Boolean = { + var modified: Boolean = false + val worklist = ListBuffer[CfgNode]() + cfg.startNode.succIntra.union(cfg.startNode.succInter).foreach(node => worklist.addOne(node)) + + val visited = mutable.Set[CfgNode]() + while (worklist.nonEmpty) { + val node = worklist.remove(0) + if (!visited.contains(node)) { + process(node) + node.succIntra.union(node.succInter).foreach(node => worklist.addOne(node)) + visited.add(node) + } + } + + def searchRegion(region: MemoryRegion): mutable.Set[String] = { + val result = mutable.Set[String]() + region match { + case stackRegion: StackRegion => + if (regionContents.contains(stackRegion)) { + for (c <- regionContents(stackRegion)) { + c match { + case bitVecLiteral: BitVecLiteral => Logger.debug("hi: " + bitVecLiteral)//??? + case memoryRegion: MemoryRegion => + result.addAll(searchRegion(memoryRegion)) + } + } + } + result + case dataRegion: DataRegion => + if (!regionContents.contains(dataRegion) || regionContents(dataRegion).isEmpty) { + result.add(dataRegion.regionIdentifier) + } else { + result.add(dataRegion.regionIdentifier) // TODO: may need to investigate if we should add the parent region + for (c <- regionContents(dataRegion)) { + c match { + case bitVecLiteral: BitVecLiteral => Logger.debug("hi: " + bitVecLiteral)//??? + case memoryRegion: MemoryRegion => + result.addAll(searchRegion(memoryRegion)) + } + } + } + result + } + } + + def addFakeProcedure(name: String): Procedure = { + val newProcedure = Procedure(name) + IRProgram.procedures += newProcedure + newProcedure + } + + def resolveAddresses(variable: Variable, i: IndirectCall): mutable.Set[String] = { + val names = mutable.Set[String]() + val variableWrapper = RegisterVariableWrapper(variable, getUse(variable, i, reachingDefs)) + pointsTos.get(variableWrapper) match { + case Some(value) => + value.map { + case v: RegisterVariableWrapper => names.addAll(resolveAddresses(v.variable, i)) + case m: MemoryRegion => names.addAll(searchRegion(m)) + } + names + case None => names + } + } + + def process(n: CfgNode): Unit = n match { + case c: CfgJumpNode => + val block = c.block + c.data match + // don't try to resolve returns + case indirectCall: IndirectCall if indirectCall.target != Register("R30", 64) => + if (!indirectCall.hasParent) { + // We only replace the calls with DirectCalls in the IR, and don't replace the CommandNode.data + // Hence if we have already processed this CFG node there will be no corresponding IndirectCall in the IR + // to replace. + // We want to replace all possible indirect calls based on this CFG, before regenerating it from the IR + return + } + assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) + + val targetNames = resolveAddresses(indirectCall.target, indirectCall) + Logger.debug(s"Points-To approximated call ${indirectCall.target} with $targetNames") + Logger.debug(IRProgram.procedures) + val targets: mutable.Set[Procedure] = targetNames.map(name => IRProgram.procedures.find(_.name == name).getOrElse(addFakeProcedure(name))) + + if (targets.size > 1) { + Logger.info(s"Resolved indirect call $indirectCall") + } + + + if (targets.size == 1) { + modified = true + + // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) + val newCall = DirectCall(targets.head, indirectCall.label) + block.statements.replace(indirectCall, newCall) + } else if (targets.size > 1) { + + val oft = indirectCall.parent.jump + + modified = true + val procedure = c.parent.data + val newBlocks = ArrayBuffer[Block]() + // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) + for (t <- targets) { + Logger.debug(targets) + val address = t.address.match { + case Some(a) => a + case None => throw Exception(s"resolved indirect call $indirectCall to procedure which does not have address: $t") + } + val assume = Assume(BinaryExpr(BVEQ, indirectCall.target, BitVecLiteral(address, 64))) + val newLabel: String = block.label + t.name + val directCall = DirectCall(t) + + /* copy the goto node resulting */ + val fallthrough = oft match { + case g: GoTo => GoTo(g.targets, g.label) + case h: Halt => Halt() + case r: Return => Return() + } + newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) + } + block.statements.remove(indirectCall) + procedure.addBlocks(newBlocks) + val newCall = GoTo(newBlocks, indirectCall.label) + block.replaceJump(newCall) + } + case _ => + case _ => + } + + modified +} diff --git a/src/main/scala/ir/transforms/ReplaceReturn.scala b/src/main/scala/ir/transforms/ReplaceReturn.scala index 61b276833..91681ab8e 100644 --- a/src/main/scala/ir/transforms/ReplaceReturn.scala +++ b/src/main/scala/ir/transforms/ReplaceReturn.scala @@ -28,11 +28,15 @@ class ReplaceReturns extends CILVisitor { } -def addReturnBlocks(p: Program) = { +def addReturnBlocks(p: Program, toAll: Boolean = false) = { p.procedures.foreach(p => { val containsReturn = p.blocks.map(_.jump).find(_.isInstanceOf[Return]).isDefined - if (containsReturn) { - p.returnBlock = p.addBlocks(Block(label=p.name + "_return",jump=Return())) + if (toAll && p.blocks.isEmpty && p.entryBlock.isEmpty && p.returnBlock.isEmpty) { + Logger.info(s"proc ${p.name} ${p.entryBlock}, ${p.returnBlock}") + p.returnBlock = (Block(label=p.name + "_basil_return",jump=Return())) + p.entryBlock = (Block(label=p.name + "_basil_entry",jump=GoTo(p.returnBlock.get))) + } else if (p.returnBlock.isEmpty && (toAll || containsReturn)) { + p.returnBlock = p.addBlocks(Block(label=p.name + "_basil_return",jump=Return())) } }) } diff --git a/src/main/scala/ir/transforms/SplitThreads.scala b/src/main/scala/ir/transforms/SplitThreads.scala new file mode 100644 index 000000000..7f36fdb3c --- /dev/null +++ b/src/main/scala/ir/transforms/SplitThreads.scala @@ -0,0 +1,84 @@ +package ir.transforms + +import scala.collection.mutable.ListBuffer +import scala.collection.mutable.ArrayBuffer +import analysis.solvers.* +import analysis.* +import bap.* +import ir.* +import translating.* +import util.Logger +import java.util.Base64 +import spray.json.DefaultJsonProtocol.* +import util.intrusive_list.IntrusiveList +import analysis.CfgCommandNode +import scala.collection.mutable +import cilvisitor._ + +// identify calls to pthread_create +// use analysis result to determine the third parameter's value (the function pointer) +// split off that procedure into new thread +// do reachability analysis +// also need a bit in the IR where it creates separate files +def splitThreads(program: Program, + pointsTo: Map[RegisterVariableWrapper, Set[RegisterVariableWrapper | MemoryRegion]], + regionContents: Map[MemoryRegion, Set[BitVecLiteral | MemoryRegion]], + reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])] + ): Unit = { + + // iterate over all commands - if call is to pthread_create, look up? + program.foreach(c => + c match { + case d: DirectCall if d.target.name == "pthread_create" => + + // R2 should hold the function pointer of the function that begins the thread + // look up R2 value using points to results + val R2 = Register("R2", 64) + val b = reachingDefs(d) + val R2Wrapper = RegisterVariableWrapper(R2, getDefinition(R2, d, reachingDefs)) + val threadTargets = pointsTo(R2Wrapper) + + if (threadTargets.size > 1) { + // currently can't handle case where the thread created is ambiguous + throw Exception("can't handle thread creation with more than one possible target") + } + + if (threadTargets.size == 1) { + + // not trying to untangle the very messy region resolution at present, just dealing with simplest case + threadTargets.head match { + case data: DataRegion => + val threadEntrance = program.procedures.find(_.name == data.regionIdentifier) match { + case Some(proc) => proc + case None => throw Exception("could not find procedure with name " + data.regionIdentifier) + } + val thread = ProgramThread(threadEntrance, mutable.LinkedHashSet(threadEntrance), Some(d)) + program.threads.addOne(thread) + case _ => + throw Exception("unexpected non-data region " + threadTargets.head + " as PointsTo result for R2 at " + d) + } + } + case _ => + }) + + + if (program.threads.nonEmpty) { + val mainThread = ProgramThread(program.mainProcedure, mutable.LinkedHashSet(program.mainProcedure), None) + program.threads.addOne(mainThread) + + val programProcs = program.procedures + + // do reachability for all threads + for (thread <- program.threads) { + val reachable = thread.entry.reachableFrom + + // add procedures to thread in way that maintains original ordering + for (p <- programProcs) { + if (reachable.contains(p)) { + thread.procedures.add(p) + } + } + + } + } +} diff --git a/src/main/scala/translating/GTIRBToIR.scala b/src/main/scala/translating/GTIRBToIR.scala index eb8281599..d47bd0b35 100644 --- a/src/main/scala/translating/GTIRBToIR.scala +++ b/src/main/scala/translating/GTIRBToIR.scala @@ -364,6 +364,8 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ // need to copy jump as it can't have multiple parents val jumpCopy = currentBlock.jump match { case GoTo(targets, label) => GoTo(targets, label) + case h: Halt => Halt() + case r: Return => Return() case _ => throw Exception("this shouldn't be reachable") } trueBlock.replaceJump(currentBlock.jump) @@ -440,7 +442,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ case EdgeLabel(false, _, Type_Return, _) => // return statement, value of 'direct' is just whether DDisasm has resolved the return target removePCAssign(block) - (Some(IndirectCall(Register("R30", 64), None)), Halt()) + (None, Return()) case EdgeLabel(false, true, Type_Fallthrough, _) => // end of block that doesn't end in a control flow instruction and falls through to next if (entranceUUIDtoProcedure.contains(edge.targetUuid)) { diff --git a/src/main/scala/translating/IRToBoogie.scala b/src/main/scala/translating/IRToBoogie.scala index 18c69a95f..3422c6daa 100644 --- a/src/main/scala/translating/IRToBoogie.scala +++ b/src/main/scala/translating/IRToBoogie.scala @@ -650,7 +650,7 @@ class IRToBoogie(var program: Program, var spec: Specification, var thread: Opti val jump = GoToCmd(g.targets.map(_.label).toSeq) conditionAssert :+ jump case r: Return => List(ReturnCmd) - case r: Halt => List(BAssert(FalseBLiteral)) + case r: Halt => List(BAssume(FalseBLiteral)) } def translate(j: Call): List[BCmd] = j match { diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index cbc5cee37..8c8e85e66 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -201,7 +201,7 @@ object IRTransform { val renamer = Renamer(boogieReserved) cilvisitor.visit_prog(transforms.ReplaceReturns(), ctx.program) - transforms.addReturnBlocks(ctx.program) + transforms.addReturnBlocks(ctx.program, true) // add return to all blocks because IDE solver expects it cilvisitor.visit_prog(transforms.ConvertSingleReturn(), ctx.program) externalRemover.visitProgram(ctx.program) @@ -209,276 +209,6 @@ object IRTransform { ctx } - /** Resolve indirect calls to an address-conditional choice between direct calls using the Value Set Analysis results. - */ - def resolveIndirectCalls( - cfg: ProgramCfg, - valueSets: Map[CfgNode, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]], - IRProgram: Program - ): Boolean = { - var modified: Boolean = false - val worklist = ListBuffer[CfgNode]() - cfg.startNode.succIntra.union(cfg.startNode.succInter).foreach(node => worklist.addOne(node)) - - val visited = mutable.Set[CfgNode]() - while (worklist.nonEmpty) { - val node = worklist.remove(0) - if (!visited.contains(node)) { - process(node) - node.succIntra.union(node.succInter).foreach(node => worklist.addOne(node)) - visited.add(node) - } - } - - def process(n: CfgNode): Unit = n match { - /* - case c: CfgStatementNode => - c.data match - - //We do not want to insert the VSA results into the IR like this - case localAssign: Assign => - localAssign.rhs match - case _: MemoryLoad => - if (valueSets(n).contains(localAssign.lhs) && valueSets(n).get(localAssign.lhs).head.size == 1) { - val extractedValue = extractExprFromValue(valueSets(n).get(localAssign.lhs).head.head) - localAssign.rhs = extractedValue - Logger.info(s"RESOLVED: Memory load ${localAssign.lhs} resolved to ${extractedValue}") - } else if (valueSets(n).contains(localAssign.lhs) && valueSets(n).get(localAssign.lhs).head.size > 1) { - Logger.info(s"RESOLVED: WARN Memory load ${localAssign.lhs} resolved to multiple values, cannot replace") - - /* - // must merge into a single memory variable to represent the possible values - // Make a binary OR of all the possible values takes two at a time (incorrect to do BVOR) - val values = valueSets(n).get(localAssign.lhs).head - val exprValues = values.map(extractExprFromValue) - val result = exprValues.reduce((a, b) => BinaryExpr(BVOR, a, b)) // need to express nondeterministic - // choice between these specific options - localAssign.rhs = result - */ - } - case _ => - */ - case c: CfgJumpNode => - val block = c.block - c.data match - case indirectCall: IndirectCall => - if (block.jump != indirectCall) { - // We only replace the calls with DirectCalls in the IR, and don't replace the CommandNode.data - // Hence if we have already processed this CFG node there will be no corresponding IndirectCall in the IR - // to replace. - // We want to replace all possible indirect calls based on this CFG, before regenerating it from the IR - return - } - valueSets(n) match { - case Lift(valueSet) => - val targetNames = resolveAddresses(valueSet(indirectCall.target)).map(_.name).toList.sorted - val targets = targetNames.map(name => IRProgram.procedures.filter(_.name.equals(name)).head) - - if (targets.size == 1) { - modified = true - - // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) - val newCall = DirectCall(targets.head, indirectCall.label) - block.statements.replace(indirectCall, newCall) - } else if (targets.size > 1) { - modified = true - val procedure = c.parent.data - val newBlocks = ArrayBuffer[Block]() - for (t <- targets) { - val assume = Assume(BinaryExpr(BVEQ, indirectCall.target, BitVecLiteral(t.address.get, 64))) - val newLabel: String = block.label + t.name - val directCall = DirectCall(t) - directCall.parent = indirectCall.parent - - // assume indircall is the last statement in block - assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) - val fallthrough = indirectCall.parent.jump - - newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) - } - procedure.addBlocks(newBlocks) - val newCall = GoTo(newBlocks, indirectCall.label) - block.replaceJump(newCall) - } - case LiftedBottom => - } - case _ => - case _ => - } - - def nameExists(name: String): Boolean = { - IRProgram.procedures.exists(_.name.equals(name)) - } - - def addFakeProcedure(name: String): Unit = { - IRProgram.procedures += Procedure(name) - } - - def resolveAddresses(valueSet: Set[Value]): Set[AddressValue] = { - var functionNames: Set[AddressValue] = Set() - valueSet.foreach { - case globalAddress: GlobalAddress => - if (nameExists(globalAddress.name)) { - functionNames += globalAddress - Logger.info(s"RESOLVED: Call to Global address ${globalAddress.name} rt statuesolved.") - } else { - addFakeProcedure(globalAddress.name) - functionNames += globalAddress - Logger.info(s"Global address ${globalAddress.name} does not exist in the program. Added a fake function.") - } - case localAddress: LocalAddress => - if (nameExists(localAddress.name)) { - functionNames += localAddress - Logger.info(s"RESOLVED: Call to Local address ${localAddress.name}") - } else { - addFakeProcedure(localAddress.name) - functionNames += localAddress - Logger.info(s"Local address ${localAddress.name} does not exist in the program. Added a fake function.") - } - case _ => - } - functionNames - } - - modified - } - - def resolveIndirectCallsUsingPointsTo( - cfg: ProgramCfg, - pointsTos: Map[RegisterVariableWrapper, Set[RegisterVariableWrapper | MemoryRegion]], - regionContents: Map[MemoryRegion, Set[BitVecLiteral | MemoryRegion]], - reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], - IRProgram: Program - ): Boolean = { - var modified: Boolean = false - val worklist = ListBuffer[CfgNode]() - cfg.startNode.succIntra.union(cfg.startNode.succInter).foreach(node => worklist.addOne(node)) - - val visited = mutable.Set[CfgNode]() - while (worklist.nonEmpty) { - val node = worklist.remove(0) - if (!visited.contains(node)) { - process(node) - node.succIntra.union(node.succInter).foreach(node => worklist.addOne(node)) - visited.add(node) - } - } - - def searchRegion(region: MemoryRegion): mutable.Set[String] = { - val result = mutable.Set[String]() - region match { - case stackRegion: StackRegion => - if (regionContents.contains(stackRegion)) { - for (c <- regionContents(stackRegion)) { - c match { - case bitVecLiteral: BitVecLiteral => Logger.debug("hi: " + bitVecLiteral)//??? - case memoryRegion: MemoryRegion => - result.addAll(searchRegion(memoryRegion)) - } - } - } - result - case dataRegion: DataRegion => - if (!regionContents.contains(dataRegion) || regionContents(dataRegion).isEmpty) { - result.add(dataRegion.regionIdentifier) - } else { - result.add(dataRegion.regionIdentifier) // TODO: may need to investigate if we should add the parent region - for (c <- regionContents(dataRegion)) { - c match { - case bitVecLiteral: BitVecLiteral => Logger.debug("hi: " + bitVecLiteral)//??? - case memoryRegion: MemoryRegion => - result.addAll(searchRegion(memoryRegion)) - } - } - } - result - } - } - - def addFakeProcedure(name: String): Procedure = { - val newProcedure = Procedure(name) - IRProgram.procedures += newProcedure - newProcedure - } - - def resolveAddresses(variable: Variable, i: IndirectCall): mutable.Set[String] = { - val names = mutable.Set[String]() - val variableWrapper = RegisterVariableWrapper(variable, getUse(variable, i, reachingDefs)) - pointsTos.get(variableWrapper) match { - case Some(value) => - value.map { - case v: RegisterVariableWrapper => names.addAll(resolveAddresses(v.variable, i)) - case m: MemoryRegion => names.addAll(searchRegion(m)) - } - names - case None => names - } - } - - def process(n: CfgNode): Unit = n match { - case c: CfgJumpNode => - val block = c.block - c.data match - // don't try to resolve returns - case indirectCall: IndirectCall if indirectCall.target != Register("R30", 64) => - if (block.jump != indirectCall) { - // We only replace the calls with DirectCalls in the IR, and don't replace the CommandNode.data - // Hence if we have already processed this CFG node there will be no corresponding IndirectCall in the IR - // to replace. - // We want to replace all possible indirect calls based on this CFG, before regenerating it from the IR - return - } - val targetNames = resolveAddresses(indirectCall.target, indirectCall) - Logger.debug(s"Points-To approximated call ${indirectCall.target} with $targetNames") - Logger.debug(IRProgram.procedures) - val targets: mutable.Set[Procedure] = targetNames.map(name => IRProgram.procedures.find(_.name == name).getOrElse(addFakeProcedure(name))) - - if (targets.size == 1) { - modified = true - - // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) - val newCall = DirectCall(targets.head, indirectCall.label) - block.statements.replace(indirectCall, newCall) - } else if (targets.size > 1) { - modified = true - val procedure = c.parent.data - val newBlocks = ArrayBuffer[Block]() - // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) - var addressExprs = List[BinaryExpr]() - for (t <- targets) { - Logger.debug(targets) - // TODO handle external procedures without a set address but this requires more information than the analysis gives at present - val address = t.address.match { - case Some(a) => a - case None => throw Exception(s"resolved indirect call $indirectCall to procedure which does not have address: $t") - } - val addressExpr = BinaryExpr(BVEQ, indirectCall.target, BitVecLiteral(address, 64)) - addressExprs ::= addressExpr - val assume = Assume(addressExpr) - val newLabel: String = block.label + t.name - val directCall = DirectCall(t) - directCall.parent = indirectCall.parent - - assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) - val fallthrough = indirectCall.parent.jump - newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) - } - procedure.addBlocks(newBlocks) - val newCall = GoTo(newBlocks, indirectCall.label) - val addressExprOr = addressExprs.tail.foldLeft(addressExprs.head) { - (a: BinaryExpr, b: BinaryExpr) => BinaryExpr(BoolOR, a, b) - } - val assertion = Assert(addressExprOr, Some("check indirect call underapproximation")) - block.statements.append(assertion) - block.replaceJump(newCall) - } - case _ => - case _ => - } - - modified - } - /** Cull unneccessary information that does not need to be included in the translation, and infer stack regions, and * add in modifies from the spec. */ @@ -497,75 +227,9 @@ object IRTransform { val specModifies = ctx.specification.subroutines.map(s => s.name -> s.modifies).toMap ctx.program.setModifies(specModifies) + assert(invariant.singleCallBlockEnd(ctx.program)) } - // identify calls to pthread_create - // use analysis result to determine the third parameter's value (the function pointer) - // split off that procedure into new thread - // do reachability analysis - // also need a bit in the IR where it creates separate files - def splitThreads(program: Program, - pointsTo: Map[RegisterVariableWrapper, Set[RegisterVariableWrapper | MemoryRegion]], - regionContents: Map[MemoryRegion, Set[BitVecLiteral | MemoryRegion]], - reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])] - ): Unit = { - - // iterate over all commands - if call is to pthread_create, look up? - program.foreach(c => - c match { - case d: DirectCall if d.target.name == "pthread_create" => - - // R2 should hold the function pointer of the function that begins the thread - // look up R2 value using points to results - val R2 = Register("R2", 64) - val b = reachingDefs(d) - val R2Wrapper = RegisterVariableWrapper(R2, getDefinition(R2, d, reachingDefs)) - val threadTargets = pointsTo(R2Wrapper) - - if (threadTargets.size > 1) { - // currently can't handle case where the thread created is ambiguous - throw Exception("can't handle thread creation with more than one possible target") - } - - if (threadTargets.size == 1) { - - // not trying to untangle the very messy region resolution at present, just dealing with simplest case - threadTargets.head match { - case data: DataRegion => - val threadEntrance = program.procedures.find(_.name == data.regionIdentifier) match { - case Some(proc) => proc - case None => throw Exception("could not find procedure with name " + data.regionIdentifier) - } - val thread = ProgramThread(threadEntrance, mutable.LinkedHashSet(threadEntrance), Some(d)) - program.threads.addOne(thread) - case _ => - throw Exception("unexpected non-data region " + threadTargets.head + " as PointsTo result for R2 at " + d) - } - } - case _ => - }) - - - if (program.threads.nonEmpty) { - val mainThread = ProgramThread(program.mainProcedure, mutable.LinkedHashSet(program.mainProcedure), None) - program.threads.addOne(mainThread) - - val programProcs = program.procedures - - // do reachability for all threads - for (thread <- program.threads) { - val reachable = thread.entry.reachableFrom - - // add procedures to thread in way that maintains original ordering - for (p <- programProcs) { - if (reachable.contains(p)) { - thread.procedures.add(p) - } - } - - } - } - } def generateProcedureSummaries( ctx: IRContext, @@ -736,18 +400,20 @@ object StaticAnalysis { val memoryRegionContents = steensgaardSolver.getMemoryRegionContents mmm.logRegions(memoryRegionContents) + // turn fake procedures into diamonds + transforms.addReturnBlocks(ctx.program, true) // add return to all blocks because IDE solver expects it Logger.info("[!] Running VSA") val vsaSolver = ValueSetAnalysisSolver(IRProgram, globalAddresses, externalAddresses, globalOffsets, subroutines, mmm, constPropResult) val vsaResult: Map[CFGPosition, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]] = vsaSolver.analyze() Logger.info("[!] Running Interprocedural Live Variables Analysis") - //val interLiveVarsResults = InterLiveVarsAnalysis(IRProgram).analyze() - val interLiveVarsResults = Map[CFGPosition, Map[Variable, TwoElement]]() + val interLiveVarsResults = InterLiveVarsAnalysis(IRProgram).analyze() + // val interLiveVarsResults = Map[CFGPosition, Map[Variable, TwoElement]]() Logger.info("[!] Running Parameter Analysis") - //val paramResults = ParamAnalysis(IRProgram).analyze() - val paramResults = Map[Procedure, Set[Variable]]() + val paramResults = ParamAnalysis(IRProgram).analyze() + // val paramResults = Map[Procedure, Set[Variable]]() StaticAnalysisContext( cfg = cfg, @@ -963,6 +629,7 @@ object RunUtils { val boogieTranslator = IRToBoogie(ctx.program, ctx.specification, None, q.outputPrefix) ArrayBuffer(boogieTranslator.translate(q.boogieTranslation)) } + assert(invariant.singleCallBlockEnd(ctx.program)) BASILResult(ctx, analysis, boogiePrograms) } @@ -978,7 +645,7 @@ object RunUtils { val result = StaticAnalysis.analyse(ctx, config, iteration) analysisResult.append(result) Logger.info("[!] Replacing Indirect Calls") - modified = IRTransform.resolveIndirectCallsUsingPointsTo(result.cfg, + modified = transforms.resolveIndirectCallsUsingPointsTo(result.cfg, result.steensgaardResults, result.memoryRegionContents, result.reachingDefs, @@ -997,7 +664,7 @@ object RunUtils { // should later move this to be inside while (modified) loop and have splitting threads cause further iterations if (config.threadSplit) { - IRTransform.splitThreads(ctx.program, analysisResult.last.steensgaardResults, analysisResult.last.memoryRegionContents, analysisResult.last.reachingDefs) + transforms.splitThreads(ctx.program, analysisResult.last.steensgaardResults, analysisResult.last.memoryRegionContents, analysisResult.last.reachingDefs) } config.analysisDotPath.foreach { s => @@ -1005,6 +672,7 @@ object RunUtils { writeToFile(newCFG.toDot(x => x.toString, Output.dotIder), s"${s}_resolvedCFG.dot") } + assert(invariant.singleCallBlockEnd(ctx.program)) Logger.info(s"[!] Finished indirect call resolution after $iteration iterations") analysisResult.last } diff --git a/src/test/scala/IndirectCallsTests.scala b/src/test/scala/IndirectCallsTests.scala index 60f7fdd67..15c635fc9 100644 --- a/src/test/scala/IndirectCallsTests.scala +++ b/src/test/scala/IndirectCallsTests.scala @@ -3,7 +3,7 @@ import ir.Endian.LittleEndian import org.scalatest.* import org.scalatest.funsuite.* import specification.* -import util.{BASILConfig, ILLoadingConfig, IRContext, RunUtils, StaticAnalysis, StaticAnalysisConfig, StaticAnalysisContext} +import util.{BASILConfig, ILLoadingConfig, IRContext, RunUtils, StaticAnalysis, StaticAnalysisConfig, StaticAnalysisContext, BASILResult} import java.io.IOException import java.nio.file.* @@ -76,14 +76,15 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) - case _ => + case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -113,14 +114,15 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -150,14 +152,15 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -193,14 +196,15 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -236,14 +240,15 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -278,14 +283,15 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -321,14 +327,15 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -341,7 +348,7 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before dumpIL = Some(tempPath + testName), ), outputPrefix = tempPath + testName, - staticAnalysis = Some(StaticAnalysisConfig(None, None, None)), + staticAnalysis = Some(StaticAnalysisConfig(Some("functionpointer"), None, None)), ) val result = loadAndTranslate(basilConfig) /* in this example we must find: @@ -356,17 +363,19 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before "l000004f3set_seven" -> ("set_seven", "R0") ) + println("prev " + result.ir.program) // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(block.label) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(block.label) => val callTransform = expectedCallTransform(block.label) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(block.label) case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -397,14 +406,15 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(block.label) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(block.label) => val callTransform = expectedCallTransform(block.label) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(block.label) case _ => } } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -431,18 +441,19 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before "l0000044dset_two" -> ("set_two", "R0"), "l0000044dset_seven" -> ("set_seven", "R0") ) + result.ir.program.mainProcedure.blocks.foreach { + block => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(block.label) => + val callTransform = expectedCallTransform(block.label) + assert(callTransform._1 == directCall.target.name) + expectedCallTransform.remove(block.label) + case _ => + } + } // Traverse the statements in the main function - result.ir.program.mainProcedure.blocks.foreach { - block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(block.label) => - val callTransform = expectedCallTransform(block.label) - assert(callTransform._1 == directCall.target.name) - expectedCallTransform.remove(block.label) - case _ => - } - } + println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -470,8 +481,8 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) @@ -505,8 +516,8 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) @@ -540,8 +551,8 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => - block.jump match { - case directCall: DirectCall if expectedCallTransform.contains(directCall.label.getOrElse("")) => + block.statements.lastOption match { + case Some(directCall: DirectCall) if expectedCallTransform.contains(directCall.label.getOrElse("")) => val callTransform = expectedCallTransform(directCall.label.getOrElse("")) assert(callTransform._1 == directCall.target.name) expectedCallTransform.remove(directCall.label.getOrElse("")) @@ -550,4 +561,4 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before } assert(expectedCallTransform.isEmpty) } -} \ No newline at end of file +} diff --git a/src/test/scala/LiveVarsAnalysisTests.scala b/src/test/scala/LiveVarsAnalysisTests.scala index 443dc011f..e1b142001 100644 --- a/src/test/scala/LiveVarsAnalysisTests.scala +++ b/src/test/scala/LiveVarsAnalysisTests.scala @@ -1,12 +1,14 @@ import analysis.{InterLiveVarsAnalysis, TwoElementTop} import ir.dsl.* -import ir.{BitVecLiteral, BitVecType, ConvertToSingleProcedureReturn, dsl, Assign, LocalVar, Program, Register, Statement, Variable} +import ir.{BitVecLiteral, BitVecType, dsl, Assign, LocalVar, Program, Register, Statement, Variable, transforms, cilvisitor, Procedure} +import util.{Logger, LogLevel} import org.scalatest.funsuite.AnyFunSuite import test_util.TestUtil import util.BASILResult class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { + Logger.setLevel(LogLevel.ERROR) def createSimpleProc(name: String, statements: Seq[Statement | EventuallyJump]): EventuallyProcedure = { proc(name, @@ -31,10 +33,12 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { block("first_call", r0ConstantAssign, r1ConstantAssign, - directCall("callee1", Some("second_call")) + directCall("callee1"), + goto("second_call") ), block("second_call", - directCall("callee2", Some("returnBlock")) + directCall("callee2"), + goto("returnBlock") ), block("returnBlock", ret @@ -44,15 +48,21 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { createSimpleProc("callee2", Seq(r2r1Assign)) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val liveVarAnalysisResults = InterLiveVarsAnalysis(program).analyze() + // fix for DSA pairs of results? val procs = program.procs - assert(liveVarAnalysisResults(procs("main")) == Map(R30 -> TwoElementTop)) - assert(liveVarAnalysisResults(procs("callee1")) == Map(R0 -> TwoElementTop, R1 -> TwoElementTop, R30 -> TwoElementTop)) - assert(liveVarAnalysisResults(procs("callee2")) == Map(R1 -> TwoElementTop, R30 -> TwoElementTop)) + println(liveVarAnalysisResults.filter((k,n) => k match { + case p => true + case _ => false + })) + // assert(liveVarAnalysisResults(procs("main")) == Map(R30 -> TwoElementTop)) + assert(liveVarAnalysisResults(procs("callee1")) == Map(R0 -> TwoElementTop, R1 -> TwoElementTop)) + assert(liveVarAnalysisResults(procs("callee2")) == Map(R1 -> TwoElementTop)) } @@ -69,10 +79,10 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { block("first_call", r0ConstantAssign, r1ConstantAssign, - directCall("callee1", Some("second_call")) + directCall("callee1"), goto("second_call") ), block("second_call", - directCall("callee2", Some("returnBlock")) + directCall("callee2"), goto("returnBlock") ), block("returnBlock", ret @@ -82,15 +92,16 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { createSimpleProc("callee2", Seq(r2r1Assign)) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val liveVarAnalysisResults = InterLiveVarsAnalysis(program).analyze() val procs = program.procs - assert(liveVarAnalysisResults(procs("main")) == Map(R30 -> TwoElementTop)) - assert(liveVarAnalysisResults(procs("callee1")) == Map(R0 -> TwoElementTop, R30 -> TwoElementTop)) - assert(liveVarAnalysisResults(procs("callee2")) == Map(R1 -> TwoElementTop, R30 -> TwoElementTop)) + // assert(liveVarAnalysisResults(procs("main")) == Map()) + assert(liveVarAnalysisResults(procs("callee1")) == Map(R0 -> TwoElementTop)) + assert(liveVarAnalysisResults(procs("callee2")) == Map(R1 -> TwoElementTop)) } def twoCallers(): Unit = { @@ -104,10 +115,10 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { var program = prog( proc("main", block("main_first_call", - directCall("wrapper1", Some("main_second_call")) + directCall("wrapper1"), goto("main_second_call") ), block("main_second_call", - directCall("wrapper2", Some("main_return")) + directCall("wrapper2"), goto("main_return") ), block("main_return", ret) ), @@ -117,30 +128,31 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { proc("wrapper1", block("wrapper1_first_call", Assign(R1, constant1), - directCall("callee", Some("wrapper1_second_call")) + directCall("callee"), goto("wrapper1_second_call") ), block("wrapper1_second_call", - directCall("callee2", Some("wrapper1_return"))), + directCall("callee2"), goto("wrapper1_return")), block("wrapper1_return", ret) ), proc("wrapper2", block("wrapper2_first_call", Assign(R2, constant1), - directCall("callee", Some("wrapper2_second_call")) + directCall("callee"), goto("wrapper2_second_call") ), block("wrapper2_second_call", - directCall("callee3", Some("wrapper2_return"))), + directCall("callee3"), goto("wrapper2_return")), block("wrapper2_return", ret) ) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val liveVarAnalysisResults = InterLiveVarsAnalysis(program).analyze() val blocks = program.blocks - assert(liveVarAnalysisResults(blocks("wrapper1_first_call").jump) == Map(R1 -> TwoElementTop, R30 -> TwoElementTop)) - assert(liveVarAnalysisResults(blocks("wrapper2_first_call").jump) == Map(R2 -> TwoElementTop, R30 -> TwoElementTop)) + assert(liveVarAnalysisResults(blocks("wrapper1_first_call").jump) == Map(R1 -> TwoElementTop)) + assert(liveVarAnalysisResults(blocks("wrapper2_first_call").jump) == Map(R2 -> TwoElementTop)) } @@ -148,7 +160,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { var program = prog( proc("main", block("lmain", - directCall("killer", Some("aftercall")) + directCall("killer"), goto("aftercall") ), block("aftercall", Assign(R0, R1), @@ -158,14 +170,15 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { createSimpleProc("killer", Seq(Assign(R1, bv64(1)))) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val liveVarAnalysisResults = InterLiveVarsAnalysis(program).analyze() val blocks = program.blocks - assert(liveVarAnalysisResults(blocks("aftercall")) == Map(R1 -> TwoElementTop, R30 -> TwoElementTop)) - assert(liveVarAnalysisResults(blocks("lmain")) == Map(R30 -> TwoElementTop)) + assert(liveVarAnalysisResults(blocks("aftercall")) == Map(R1 -> TwoElementTop)) + // assert(liveVarAnalysisResults(blocks("lmain")) == Map()) } def simpleBranch(): Unit = { @@ -193,15 +206,16 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { ) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val blocks = program.blocks val liveVarAnalysisResults = InterLiveVarsAnalysis(program).analyze() - assert(liveVarAnalysisResults(blocks("branch1")) == Map(R1 -> TwoElementTop, R30 -> TwoElementTop)) - assert(liveVarAnalysisResults(blocks("branch2")) == Map(R2 -> TwoElementTop, R30 -> TwoElementTop)) - assert(liveVarAnalysisResults(blocks("lmain")) == Map(R1 -> TwoElementTop, R2 -> TwoElementTop, R30 -> TwoElementTop)) + assert(liveVarAnalysisResults(blocks("branch1")) == Map(R1 -> TwoElementTop)) + assert(liveVarAnalysisResults(blocks("branch2")) == Map(R2 -> TwoElementTop)) + assert(liveVarAnalysisResults(blocks("lmain")) == Map(R1 -> TwoElementTop, R2 -> TwoElementTop)) } @@ -212,7 +226,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { block( "lmain", Assign(R0, R1), - directCall("main", Some("return")) + directCall("main"), goto("return") ), block("return", Assign(R0, R2), @@ -221,13 +235,14 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { ) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val liveVarAnalysisResults = InterLiveVarsAnalysis(program).analyze() val blocks = program.blocks - assert(liveVarAnalysisResults(program.mainProcedure) == Map(R1 -> TwoElementTop, R2 -> TwoElementTop, R30 -> TwoElementTop)) + assert(liveVarAnalysisResults(program.mainProcedure) == Map(R1 -> TwoElementTop, R2 -> TwoElementTop)) } def recursionBaseCase(): Unit = { @@ -240,7 +255,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { ), block( "recursion", - directCall("main", Some("assign")) + directCall("main"), goto("assign") ), block("assign", Assign(R0, R2), @@ -256,13 +271,14 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { ) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val liveVarAnalysisResults = InterLiveVarsAnalysis(program).analyze() val blocks = program.blocks - assert(liveVarAnalysisResults(program.mainProcedure) == Map(R1 -> TwoElementTop, R2 -> TwoElementTop, R30 -> TwoElementTop)) + assert(liveVarAnalysisResults(program.mainProcedure) == Map(R1 -> TwoElementTop, R2 -> TwoElementTop)) } test("differentCalleesBothAlive") { @@ -299,7 +315,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { val blocks = result.ir.program.blocks // main has a parameter, R0 should be alive - assert(analysisResults(blocks("lmain")) == Map(R0 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("lmain")) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) } test("function") { @@ -309,9 +325,8 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { // checks function call blocks assert(analysisResults(blocks("lmain")) == Map(R29 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) - assert(analysisResults(blocks("lget_two")) == Map(R30 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("lget_two")) == Map(R31 -> TwoElementTop)) assert(analysisResults(blocks("l00000946")) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) // aftercall block - assert(analysisResults(blocks("main_basil_return")) == Map(R30 -> TwoElementTop)) } @@ -323,9 +338,9 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { // main has parameter, callee (zero) has return and no parameter assert(analysisResults(blocks("lmain")) == Map(R0 -> TwoElementTop, R29 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) - assert(analysisResults(blocks("lzero")) == Map(R30 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("lzero")) == Map(R31 -> TwoElementTop)) assert(analysisResults(blocks("l00000323")) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) // aftercall block - assert(analysisResults(blocks("zero_basil_return")) == Map(R0 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("zero_basil_return")) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) } test("function1") { @@ -334,12 +349,12 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { val blocks = result.ir.program.blocks // main has no parameters, get_two has three and a return - assert(analysisResults(blocks("lmain")) == Map(R29 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) - assert(analysisResults(blocks("l000003ec")) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) // get_two aftercall - assert(analysisResults(blocks("l00000430")) == Map(R31 -> TwoElementTop)) // printf aftercall - assert(analysisResults(blocks("main_basil_return")) == Map(R30 -> TwoElementTop)) - assert(analysisResults(blocks("lget_two")) == Map(R0 -> TwoElementTop, R1 -> TwoElementTop, R2 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) - assert(analysisResults(blocks("get_two_basil_return")) == Map(R0 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("lmain").jump) == Map(R29 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("l000003ec").jump) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) // get_two aftercall + assert(analysisResults(blocks("l00000430").jump) == Map(R31 -> TwoElementTop)) // printf aftercall + assert(analysisResults(blocks("main_basil_return").jump) == Map(R30 -> TwoElementTop)) + assert(analysisResults(blocks("lget_two").jump) == Map(R0 -> TwoElementTop, R1 -> TwoElementTop, R2 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("get_two_basil_return").jump) == Map(R0 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) } test("ifbranches") { @@ -348,11 +363,11 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { val blocks = result.ir.program.blocks // block after branch - assert(analysisResults(blocks("l00000342")) == Map(R30 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("l00000342")) == Map(R31 -> TwoElementTop)) // branch blocks assert(analysisResults(blocks("lmain_goto_l00000330")) == Map(Register("ZF", 1) -> TwoElementTop, - R30 -> TwoElementTop, R31 -> TwoElementTop)) + R31 -> TwoElementTop)) assert(analysisResults(blocks("lmain_goto_l00000369")) == Map(Register("ZF", 1) -> TwoElementTop, - R30 -> TwoElementTop, R31 -> TwoElementTop)) + R31 -> TwoElementTop)) } } diff --git a/src/test/scala/PointsToTest.scala b/src/test/scala/PointsToTest.scala index 32131ed46..9534053a3 100644 --- a/src/test/scala/PointsToTest.scala +++ b/src/test/scala/PointsToTest.scala @@ -70,8 +70,8 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft ) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val results = runAnalyses(program) results.mmmResults.pushContext("main") @@ -99,8 +99,8 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft ) ) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val results = runAnalyses(program) results.mmmResults.pushContext("main") @@ -168,7 +168,7 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft goto("0x1") ), block("0x1", - directCall("p2", Some("returntarget")) + directCall("p2"), goto("returntarget") ), block("returntarget", ret @@ -186,8 +186,8 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft ) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val results = runAnalyses(program) results.mmmResults.pushContext("main") @@ -217,7 +217,7 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft goto("0x1") ), block("0x1", - directCall("p2", Some("returntarget")) + directCall("p2"), goto("returntarget") ), block("returntarget", ret @@ -227,7 +227,7 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft block("l_foo", Assign(getRegister("R0"), MemoryLoad(mem, BinaryExpr(BVADD, getRegister("R31"), bv64(6)), LittleEndian, 64)), Assign(getRegister("R1"), BinaryExpr(BVADD, getRegister("R31"), bv64(10))), - directCall("p2", Some("l_foo_1")) + directCall("p2"), goto("l_foo_1") ), block("l_foo_1", ret, @@ -245,8 +245,8 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft ) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + transforms.addReturnBlocks(program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val results = runAnalyses(program) results.mmmResults.pushContext("main") @@ -303,4 +303,4 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft // // runSteensgaardAnalysis(program, globals = globals, globalOffsets = globalOffsets) // } -} \ No newline at end of file +} diff --git a/src/test/scala/ir/IRTest.scala b/src/test/scala/ir/IRTest.scala index 7c06315d3..7421c2a28 100644 --- a/src/test/scala/ir/IRTest.scala +++ b/src/test/scala/ir/IRTest.scala @@ -4,7 +4,9 @@ import scala.collection.mutable import scala.collection.immutable.* import org.scalatest.funsuite.AnyFunSuite import util.intrusive_list.* +import translating.serialiseIL import ir.dsl.* +import ir._ class IRTest extends AnyFunSuite { @@ -57,31 +59,6 @@ class IRTest extends AnyFunSuite { } - test("removeblockinline") { - - val p = prog( - proc("main", - block("lmain", - goto("lmain1") - ), - block("lmain1", - goto("lmain2")), - block("lmain2", - ret) - ) - ) - - val blocks = p.collect { - case b: Block => b.label -> b - }.toMap - - p.procedures.head.removeBlocksInline(blocks("lmain1")) - - blocks("lmain").singleSuccessor.contains(blocks("lmain2")) - blocks("lmain2").singlePredecessor.contains(blocks("lmain")) - - } - test("simple replace jump") { val p = prog( @@ -142,7 +119,8 @@ class IRTest extends AnyFunSuite { ), block("l_main_1", Assign(R0, bv64(22)), - directCall("p2", Some("returntarget")) + directCall("p2"), + goto("returntarget") ), block("returntarget", ret @@ -154,34 +132,32 @@ class IRTest extends AnyFunSuite { ) ) + val blocks = p.collect { case b: Block => b.label -> b }.toMap - val directcalls = p.collect { case c: DirectCall => c } - assert(blocks("l_main_1").fallthrough.nonEmpty) - assert(p.toSet.contains(blocks("l_main_1").fallthrough.get)) - assert(directcalls.forall(c => IntraProcIRCursor.succ(c).count(_.asInstanceOf[GoTo].isAfterCall) == 1)) - assert(directcalls.forall(c => IntraProcBlockIRCursor.succ(c).count(_.isAfterCall) == 1)) + assert(p.toSet.contains(blocks("l_main_1").jump)) + assert(directcalls.forall(c => IntraProcIRCursor.succ(c).count(c => isAfterCall(c.asInstanceOf[Command])) == 1)) val afterCalls = p.collect { - case b: Block if b.isAfterCall => b + case b: Command if isAfterCall(b) => b }.toSet - assert(afterCalls.toSet == Set(blocks("returntarget"))) + assert(afterCalls.toSet == Set(blocks("l_main_1").jump)) val aftercallGotos = p.collect { - case c: Jump if c.isAfterCall => c + case c: Command if isAfterCall(c) => c }.toSet - assert(aftercallGotos == Set(blocks("l_main_1").fallthrough.get)) + // assert(aftercallGotos == Set(blocks("l_main_1").fallthrough.get)) assert(1 == aftercallGotos.count(b => IntraProcIRCursor.pred(b).contains(blocks("l_main_1").jump))) - assert(1 == aftercallGotos.count(b => IntraProcIRCursor.succ(b).contains(blocks("l_main_1").fallthrough.map(_.targets.head).head))) - - assert(afterCalls.forall(b => IntraProcBlockIRCursor.pred(b).contains(blocks("l_main_1")))) + assert(1 == aftercallGotos.count(b => IntraProcIRCursor.succ(b).contains(blocks("l_main_1").jump match { + case GoTo(targets, _) => targets.head + }))) } @@ -246,7 +222,8 @@ class IRTest extends AnyFunSuite { Assign(R0, bv64(22)), Assign(R0, bv64(22)), Assign(R0, bv64(22)), - directCall("main", None) + directCall("main"), + halt ).resolve(p) val b2 = block("newblock1", Assign(R0, bv64(22)), @@ -271,7 +248,8 @@ class IRTest extends AnyFunSuite { assert(called.incomingCalls().isEmpty) val b3 = block("newblock3", Assign(R0, bv64(22)), - directCall("called", None) + directCall("called"), + halt ).resolve(p) assert(b3.calls.toSet == Set(p.procs("called"))) @@ -283,11 +261,11 @@ class IRTest extends AnyFunSuite { assert(!oldb.hasParent) assert(oldb.incomingJumps.isEmpty) assert(!blocks("lmain").jump.asInstanceOf[GoTo].targets.contains(oldb)) - assert(called.incomingCalls().toSet == Set(b3.jump)) + assert(called.incomingCalls().toSet == Set(b3.statements.last)) assert(called.incomingCalls().map(_.parent.parent).toSet == called.callers().toSet) val olds = blocks.size p.mainProcedure.replaceBlock(b3, b3) - assert(called.incomingCalls().toSet == Set(b3.jump)) + assert(called.incomingCalls().toSet == Set(b3.statements.last)) assert(olds == blocks.size) p.mainProcedure.addBlocks(block("test", ret).resolve(p)) assert(olds != blocks.size) @@ -333,35 +311,35 @@ class IRTest extends AnyFunSuite { proc("main", block("l_main", Assign(R0, bv64(10)), - directCall("p1", Some("returntarget")) + directCall("p1"), goto("returntarget") ), block("returntarget", ret ) ), ) - val returnUnifier = ConvertToSingleProcedureReturn() - returnUnifier.visitProgram(p) + + cilvisitor.visit_prog(transforms.ReplaceReturns(), p) + transforms.addReturnBlocks(p) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), p) val next = InterProcIRCursor.succ(p.blocks("l_main").jump) val prev = InterProcIRCursor.pred(p.blocks("returntarget")) assert(prev.size == 1 && prev.collect { - case c : GoTo => (c.parent == p.blocks("l_main")) && c.isAfterCall + case c : GoTo => (c.parent == p.blocks("l_main")) }.contains(true)) - assert(next == Set(p.procs("p1"), p.blocks("l_main").fallthrough.get)) + // assert(next == Set(p.procs("p1"), p.blocks("l_main").fallthrough.get)) - val prevB: Block = (p.blocks("l_main").jump match - case c: IndirectCall => c.returnTarget - case c: DirectCall => c.returnTarget - case _ => None + val prevB: Command = (p.blocks("l_main").statements.lastOption match + case Some(c: IndirectCall) => c.returnTarget + case Some(c: DirectCall) => c.returnTarget + case o => None ).get - assert(prevB.isAfterCall) + assert(isAfterCall(prevB)) assert(InterProcIRCursor.pred(prevB).size == 1) - assert(InterProcIRCursor.pred(prevB).head == p.blocks("l_main").fallthrough.get) - assert(InterProcBlockIRCursor.pred(prevB).head == p.blocks("l_main"), p.procs("p1").returnBlock.get) } @@ -374,10 +352,10 @@ class IRTest extends AnyFunSuite { ), proc("main", block("l_main", - indirectCall(R1, Some("returntarget")) + indirectCall(R1), goto("returntarget") ), block("block2", - directCall("p1", Some("returntarget")) + directCall("p1"), goto("returntarget") ), block("returntarget", ret diff --git a/src/test/scala/ir/SingleCallInvariant.scala b/src/test/scala/ir/SingleCallInvariant.scala new file mode 100644 index 000000000..d8efb6fc2 --- /dev/null +++ b/src/test/scala/ir/SingleCallInvariant.scala @@ -0,0 +1,83 @@ +package ir + + +import ir.dsl._ + +import org.scalatest.funsuite.AnyFunSuite +class InvariantTest extends AnyFunSuite { + + test("sat singleCallBlockEnd case") { + var program: Program = prog( + proc("main", + block("first_call", + Assign(R0, bv64(10)), + Assign(R1, bv64(10)), + directCall("callee1"), + ret + ), + block("second_call", + Assign(R0, bv64(10)), + directCall("callee2"), + ret + ), + block("returnBlock", + ret + ) + ), + proc("callee1", block("bye1", ret)), + proc("callee2", block("bye2", ret)), + ) + + assert(invariant.singleCallBlockEnd(program)) + } + + test("unsat singleCallBlockEnd 1 (two calls)") { + var program: Program = prog( + proc("main", + block("first_call", + Assign(R0, bv64(10)), + directCall("callee2"), + Assign(R1, bv64(10)), + directCall("callee1"), + ret + ), + block("second_call", + Assign(R0, bv64(10)), + ret + ), + block("returnBlock", + ret + ) + ), + proc("callee1", block("bye1", ret)), + proc("callee2", block("bye2", ret)), + ) + + assert(!invariant.singleCallBlockEnd(program)) + } + + test("unsat singleCallBlockEnd 2 (not at end)") { + var program: Program = prog( + proc("main", + block("first_call", + Assign(R0, bv64(10)), + Assign(R1, bv64(10)), + ret + ), + block("second_call", + directCall("callee2"), + Assign(R0, bv64(10)), + ret + ), + block("returnBlock", + ret + ) + ), + proc("callee1", block("bye1", ret)), + proc("callee2", block("bye2", ret)), + ) + + assert(!invariant.singleCallBlockEnd(program)) + } + +} From aae006495876eb0a3f547076f4de2868f841e040 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Fri, 9 Aug 2024 16:43:58 +1000 Subject: [PATCH 22/62] remove old cfg --- src/main/scala/analysis/Cfg.scala | 829 ------------------ src/main/scala/analysis/Dependencies.scala | 20 - .../analysis/InterLiveVarsAnalysis.scala | 8 +- .../InterprocSteensgaardAnalysis.scala | 4 +- .../analysis/IntraLiveVarsAnalysis.scala | 4 +- .../scala/analysis/MemoryRegionAnalysis.scala | 6 +- .../scala/analysis/RegToMemAnalysis.scala | 38 +- .../scala/analysis/solvers/IDESolver.scala | 19 +- src/main/scala/cfg_visualiser/Output.scala | 37 - src/main/scala/ir/IRCursor.scala | 11 +- src/main/scala/ir/Interpreter.scala | 4 +- src/main/scala/ir/Program.scala | 39 +- src/main/scala/ir/Statement.scala | 4 +- src/main/scala/ir/dsl/DSL.scala | 8 +- .../transforms/IndirectCallResolution.scala | 280 ++---- .../scala/ir/transforms/ReplaceReturn.scala | 3 +- .../scala/ir/transforms/SplitThreads.scala | 1 - .../StripUnreachableFunctions.scala | 38 + src/main/scala/translating/BAPToIR.scala | 4 +- src/main/scala/translating/GTIRBToIR.scala | 12 +- src/main/scala/translating/ILtoIL.scala | 16 +- src/main/scala/translating/IRToBoogie.scala | 2 +- src/main/scala/util/RunUtils.scala | 142 +-- src/test/scala/IndirectCallsTests.scala | 11 - src/test/scala/LiveVarsAnalysisTests.scala | 23 +- .../scala/MemoryRegionAnalysisMiscTest.scala | 2 +- src/test/scala/ir/IRTest.scala | 4 +- src/test/scala/ir/InterpreterTests.scala | 2 +- 28 files changed, 215 insertions(+), 1356 deletions(-) delete mode 100644 src/main/scala/analysis/Cfg.scala delete mode 100644 src/main/scala/cfg_visualiser/Output.scala create mode 100644 src/main/scala/ir/transforms/StripUnreachableFunctions.scala diff --git a/src/main/scala/analysis/Cfg.scala b/src/main/scala/analysis/Cfg.scala deleted file mode 100644 index 8807f2d2a..000000000 --- a/src/main/scala/analysis/Cfg.scala +++ /dev/null @@ -1,829 +0,0 @@ -package analysis - -import scala.collection.mutable -import ir.* -import cfg_visualiser.{DotArrow, DotGraph, DotInlineArrow, DotInterArrow, DotIntraArrow, DotNode, DotRegularArrow} - -import scala.collection.mutable.{ArrayBuffer, ListBuffer} -import scala.util.control.Breaks.break -import util.Logger - -import scala.annotation.tailrec - -/** Node in the control-flow graph. - */ -object CfgNode: - - var id: Int = 0 - - def nextId(): Int = - id += 1 - id - -/** Node in the control-flow graph. Each node has a (simple incremental) unique identifier used to distinguish it from - * other nodes in the cfg - this is mainly used for copying procedure cfgs when inlining them. - * - * Each node will store four separate sets: ingoing/outgoing of both inter-/intra-procedural CFG edges. Both intra and - * inter will also store regular edges in the cfg. This is duplication of storage, however is done so knowingly. - * - * By separating sets into inter/intra we are able to return these directly without doing any processing. Alternative - * means of achieving this same behaviour would involve some form of set operations, or a filter operation, both of - * which can be expensive, especially as the successors/predecessors will be accessed frequently by analyses. - * Additionally, inspecting the space complexity of these sets, we note that their sizes should be relatively limited: - * a. #(outgoing edges) <= 2 b. #(incoming edges) ~ N Thus in `succIntra` + `succInter` we have at most 4 elements. - * For `predIntra` + `predInter` we have a maximum of 2N, resulting from the case that this node is a block entry - * that is jumped to by N other nodes. It should be noted that in the majority of cases (statements which are - * neither the start of blocks nor jumps), both sets will be of size 1, making the storage complexity negligible. - * This is a point which can be optimised upon however. - * - * A node can have three main types of connected edges: - * a. A regular edge A regular edge connects two statements which only relate to the current procedure's context. b. - * Intra-procedural edge Intra-procedural edges connect a call node with the subsequent cfg node in a way that - * bypasses dealing with the semantics of the callee - it is up to analyses to determine how to treat such a call. - * c. Inter-procedural edge These are split into two cases: - * i. Inline edge These connect call nodes with an inlined copy of the target's procedure body. The exit of the - * target procedure's clone is also linked back to the caller via an inline edge. For an inline limit of `n`, - * these are the inter-procedural edges for depth 0 <= i < n ii. Call edge These connect leaf call nodes (calls - * which are not inlined) to the start of the independent cfg of the call's target. For an inline limit of `n`, - * these are the inter-procedural edges at depth i == n. - */ -trait CfgNode: - - /** Edges to this node from regular statements or ignored procedure calls. - * - */ - val predIntra: mutable.Set[CfgNode] = mutable.Set() - - /** Edges to this node from procedure calls. Likely empty unless this node is a [[CfgFunctionEntryNode]] - * - */ - val predInter: mutable.Set[CfgNode] = mutable.Set() - - /** Edges to successor nodes, either regular or ignored procedure calls - * - */ - val succIntra: mutable.Set[CfgNode] = mutable.Set() - - /** Edges to successor procedure calls. Used when walking inter-proc cfg. - * - */ - val succInter: mutable.Set[CfgNode] = mutable.Set() - - /** Unique identifier. */ - val id: Int = CfgNode.nextId() - def copyNode(): CfgNode - - override def equals(obj: scala.Any): Boolean = - obj match - case o: CfgNode => o.id == this.id - case _ => false - - override def hashCode(): Int = id.hashCode() - -/** Control-flow graph node that additionally stores an AST node. - */ -trait CfgNodeWithData[T] extends CfgNode { - val data: T -} - -/** Control-flow graph node for the entry of a function. - */ -class CfgFunctionEntryNode(val data: Procedure) extends CfgNodeWithData[Procedure]: - val callers: mutable.Set[CfgFunctionEntryNode] = mutable.Set[CfgFunctionEntryNode]() - override def toString: String = s"[FunctionEntry] $data" - - /** Copy this node, but give unique ID and reset edges */ - override def copyNode(): CfgFunctionEntryNode = CfgFunctionEntryNode(data) - -/** Control-flow graph node for the exit of a function. - */ -class CfgFunctionExitNode(val data: Procedure) extends CfgNodeWithData[Procedure]: - override def toString: String = s"[FunctionExit] $data" - - /** Copy this node, but give unique ID and reset edges */ - override def copyNode(): CfgFunctionExitNode = CfgFunctionExitNode(data) - -/** CFG node immediately proceeding a indirect call. This signifies that the call is a return from the current context - * (i.e., likely an indirect call to R30). Its purpose is to provide a way for analyses to identify whether they should - * return to the previous function context, if it is a context dependent analyses, and otherwise can be ignored. - * - * In the cfg we treat this as a stepping stone to `CfgFunctionExitNode`, as a way to emphasise that the current - * procedure has no functionality past this point. - */ -class CfgProcedureReturnNode() extends CfgNode: - override def toString: String = s"[ProcedureReturn]" - - /** Copy this node, but give unique ID and reset edges */ - override def copyNode(): CfgProcedureReturnNode = CfgProcedureReturnNode() - -/** CFG node immediately proceeding a direct/indirect call, if that call has no specified return block. There are a few - * reasons this can occur: - * a. It is not expected that the program will return from the callee b. The lifter has erroneously labelled a call - * as a jump / mislabelled a function name, which was then not associated with a block in some later stage. This - * happened with a call to `__gmon_start_`, which was optimised from a call to a jump which was incorrectly - * interpreted by the lifter. c. The indirect call is some other form of return-to-caller (which does not use - * R30). These are currently unhandled, and could potentially be integrated - e.g., sometimes R17 and R16 can be - * used in a similar way to R30. - * https://blog.tomzhao.me/wp-content/uploads/2021/08/Procedure_Call_Standard_in_Armv8_54f88cbfe905409aaff956ac2d1ad059.pdf - * - * In the cfg this is similarly used as a stepping stone to `CfgFunctionExitNode`. - */ -class CfgCallNoReturnNode() extends CfgNode: - override def toString: String = s"[Call NoReturn]" - - /** Copy this node, but give unique ID and reset edges */ - override def copyNode(): CfgCallNoReturnNode = CfgCallNoReturnNode() - -/** CFG node immediately proceeding a direct/indirect call, if that call has a return location specified. This serves as - * a point for analysis to stop and process update their states after handling a procedure call before continuing - * within the current contex. For example, a context sensitive analysis will return to this node after reaching a - * procedure return within the caller. It will then restore and update its context from before the call, before - * continuing on within the original procedure. - * - * Effectively, this just splits a procedure call from a single `Jmp` node into two - the call, and the return point. - * Incoming edges to the `Jmp` are then incoming edges to the respective `CfgJumpnode`, and outgoing edges from the - * `Jmp` are then outgoing edges of the `CfgCallReturnNode`. It is functionally in the same spirit as - * `CfgCallNoReturnNode`, though handles the case that this procedure still has functionality to be explored. - */ -class CfgCallReturnNode() extends CfgNode: - override def toString: String = s"[Call Return]" - - /** Copy this node, but give unique ID and reset edges */ - override def copyNode(): CfgCallReturnNode = CfgCallReturnNode() - -/** Control-flow graph node for a command (statement or jump). - */ -trait CfgCommandNode extends CfgNodeWithData[Command] { - override def copyNode(): CfgCommandNode - val block: Block - val parent: CfgFunctionEntryNode -} - -/** CFG's representation of a single statement. - */ -class CfgStatementNode( - val data: Statement, - val block: Block, - val parent: CfgFunctionEntryNode -) extends CfgCommandNode: - override def toString: String = s"[Stmt] $data" - - /** Copy this node, but give unique ID and reset edges */ - override def copyNode(): CfgStatementNode = CfgStatementNode(data, block, parent) - -/** CFG's representation of a jump. This is used as a general jump node, for both indirect and direct calls. - */ -class CfgJumpNode( - val data: Jump | DirectCall | IndirectCall, - val block: Block, - val parent: CfgFunctionEntryNode -) extends CfgCommandNode: - override def toString: String = s"[Jmp] $data" - - /** Copy this node, but give unique ID and reset edges */ - override def copyNode(): CfgJumpNode = CfgJumpNode(data, block, parent) - -/** A general purpose node which in terms of the IR has no functionality, but can have purpose in the CFG. As example, - * this is used as a "block" start node for the case that a block contains no statements, but has a `GoTo` as its jump. - * In this case we introduce a ghost node as the start of the block for the case that some part of the program jumps - * back to this conditional jump (e.g. in the case of loops). - */ -class CfgGhostNode( - val block: Block, - val parent: CfgFunctionEntryNode, - val data: NOP -) extends CfgCommandNode: - override def toString: String = s"[NOP] $data" - - /** Copy this node, but give unique ID and reset edges */ - override def copyNode(): CfgGhostNode = CfgGhostNode(block, parent, data) - -/** A control-flow graph. Nodes provide the ability to walk it as both an intra and inter procedural CFG. - */ -class ProgramCfg: - - var startNode: CfgFunctionEntryNode = _ - var nodes: mutable.Set[CfgNode] = mutable.Set() - var funEntries: mutable.Set[CfgFunctionEntryNode] = mutable.Set() - - /** Inline edges are for connecting an intraprocedural cfg with a copy of another procedure's intraprocedural cfg - * which is placed inside this one. They are considered interprocedural edges, and will not be followed if the caller - * requests an intraprocedural cfg. - */ - def addInlineEdge(from: CfgNode, to: CfgNode): Unit = { - from.succInter += to - to.predInter += from - } - - /** Interprocedural call edges connect an intraprocedural cfg with another procedure's intraprocedural cfg that it is - * calling. - */ - def addInterprocCallEdge(from: CfgNode, to: CfgNode): Unit = { - from.succInter += to - to.predInter += from - } - - /** Intraprocedural edges are for connecting call nodes to the call's return node, without following the call itself - * (stepping over the call). - */ - def addIntraprocEdge(from: CfgNode, to: CfgNode): Unit = { - from.succIntra += to - to.predIntra += from - } - - /** Regular edges are normal control flow - used in both inter-/intra-procedural cfgs. - */ - def addRegularEdge(from: CfgNode, to: CfgNode): Unit = { - from.succInter += to - from.succIntra += to - to.predInter += from - to.predIntra += from - } - - /** Add an outgoing edge from the current node, taking into account any conditionals on this jump. Note that we have - * some duplication of storage here - this is a performance consideration. We don't expect too many edges for any - * given node, and so the increased storage is relatively minimal. This saves having to filter / union sets when - * trying to retrieve only an intra/inter cfg, hopefully improving computation time. - * - * NOTE: this function attempts to "smartly" identify how to connect two edges. Perhaps as the CFG changes however - * different requirements will be made of nodes, and so the conditions on edges below may change. In that case, - * either update the below, or explicitly specify the edge to be added between two nodes. - * - * @param from - * The originating node - * @param to - * The destination node - */ - def addEdge(from: CfgNode, to: CfgNode): Unit = { - - (from, to) match { - // Ignored procedure (e.g. library calls such as @printf) - case (from: CfgFunctionEntryNode, to: CfgFunctionExitNode) => addRegularEdge(from, to) - // Calling procedure (follow as inline) - // This to be used if inlining skips the call node and links the most recent statement to the first statement of the target - case (from: CfgCommandNode, to: CfgFunctionEntryNode) => addInlineEdge(from, to) - // Returning from procedure (follow as inline - see above) - case (from: CfgFunctionExitNode, to: CfgNode) => addInlineEdge(from, to) - // First instruction of procedure - case (from: CfgFunctionEntryNode, to: CfgNode) => addRegularEdge(from, to) - // Function call which returns to the previous context - case (from: CfgJumpNode, to: CfgProcedureReturnNode) => addRegularEdge(from, to) - // Edge to intermediary return node (no semantic meaning, a cfg convenience edge) - case (from: CfgJumpNode, to: (CfgCallReturnNode | CfgCallNoReturnNode)) => addIntraprocEdge(from, to) - // Pre-exit nodes - case (from: (CfgProcedureReturnNode | CfgCallNoReturnNode | CfgCallReturnNode), to: CfgFunctionExitNode) => - addRegularEdge(from, to) - // Regular continuation of execution - case (from: CfgCallReturnNode, to: CfgCommandNode) => addRegularEdge(from, to) - // Regular flow of instructions - case (from: CfgCommandNode, to: (CfgCommandNode | CfgFunctionExitNode)) => addRegularEdge(from, to) - case _ => throw new Exception(s"[!] Unexpected edge combination when adding cfg edge between $from -> $to.") - } - - nodes += from - nodes += to - } - - /** Returns a Graphviz dot representation of the CFG. Each node is labeled using the given function labeler. - */ - def toDot(labeler: CfgNode => String, idGen: (CfgNode, Int) => String): String = { - val dotNodes = mutable.Map[CfgNode, DotNode]() - var dotArrows = mutable.ListBuffer[DotArrow]() - var uniqueId = 0 - nodes.foreach { n => - dotNodes += (n -> DotNode(s"${idGen(n, uniqueId)}", labeler(n))) - uniqueId += 1 - } - nodes.foreach { n => - - val successors = n.succIntra.toSet.union(n.succInter) - - successors.foreach { s => - (n, s) match { - case (from: CfgFunctionEntryNode, to: CfgNode) => - dotArrows += DotRegularArrow(dotNodes(n), dotNodes(to)) - case (from: CfgJumpNode, to: CfgProcedureReturnNode) => - dotArrows += DotRegularArrow(dotNodes(n), dotNodes(to)) - case (from: (CfgProcedureReturnNode | CfgCallNoReturnNode | CfgCallReturnNode), to: CfgFunctionExitNode) => - dotArrows += DotRegularArrow(dotNodes(n), dotNodes(to)) - case (from: CfgCallReturnNode, to: CfgCommandNode) => - dotArrows += DotRegularArrow(dotNodes(n), dotNodes(to)) - case (from: CfgCommandNode, to: (CfgCommandNode | CfgFunctionExitNode)) => - dotArrows += DotRegularArrow(dotNodes(n), dotNodes(to)) - case (from: CfgCommandNode, to: CfgFunctionEntryNode) => - DotInlineArrow(dotNodes(n), dotNodes(to)) - case (from: CfgFunctionExitNode, to: CfgNode) => - DotInlineArrow(dotNodes(n), dotNodes(to)) - case (from: CfgJumpNode, to: (CfgCallReturnNode | CfgCallNoReturnNode)) => - dotArrows += DotIntraArrow(dotNodes(n), dotNodes(to)) - /* - Displaying the below in the CFG is mostly for debugging purposes. With it included the CFG becomes a little unreadable, but - will emphasise that the leaf-call nodes are linked to the start of the procedures they're calling (as green inter-procedural edges). - To verify this is still happening, simply uncomment the below and it will add these edges. - case (from: CfgCommandNode, to: CfgFunctionEntry) => - dotArrows += DotInterArrow(dotNodes(n), dotNodes(to)) - */ - - case _ => - } - } - } - dotArrows = dotArrows.sortBy(arr => arr.fromNode.id + "-" + arr.toNode.id) - val allNodes = dotNodes.values.toList.sortBy(n => n.id) - DotGraph("CFG", allNodes, dotArrows).toDotString - } - - override def toString: String = { - val sb = StringBuilder() - sb.append("CFG {") - sb.append(" nodes: ") - sb.append(nodes) - sb.append("}") - sb.toString() - } - -/** Control-flow graph for an entire program. We have a more granular approach, storing commands as nodes instead of - * basic blocks. - */ -class ProgramCfgFactory: - val cfg: ProgramCfg = ProgramCfg() - - // Mapping from procedures to the start of their individual (intra) cfgs - val procToCfg: mutable.Map[Procedure, (CfgFunctionEntryNode, CfgFunctionExitNode)] = mutable.Map() - // Mapping from procedures to procedure call nodes (all the calls made within this procedure, including inlined functions) - val procToCalls: mutable.Map[Procedure, mutable.Set[CfgJumpNode]] = mutable.Map() - // Mapping from procedure entry instances to procedure call nodes within that procedure's instance (`CfgCommandNode.data <: DirectCall`) - // Updated on first creation of a new procedure (e.g. in initial creation, or in cloning of a procedure's cfg) - val callToNodes: mutable.Map[CfgFunctionEntryNode, mutable.Set[CfgJumpNode]] = mutable.Map() - // Mapping from procedures to nodes in any node in the cfg which has a call to that procedure - val procToCallers: mutable.Map[Procedure, mutable.Set[CfgJumpNode]] = mutable.Map() - - /** Generate the cfg for each function of the program. NOTE: is this functionally different to a constructor? Do we - * ever expect to generate a CFG from any other data structure? If not then the `class` could probably be absorbed - * into this object. - * - * @param program - * Basil IR of the program - * @param inlineLimit - * How many levels deep to inline function calls. Default is 3 - */ - def fromIR(program: Program, unify: Boolean = true, inlineLimit: Int = 0): ProgramCfg = { - CfgNode.id = 0 - require(inlineLimit >= 0, "Can't inline procedures to negative depth...") - Logger.info("[+] Generating CFG...") - - // Have to initialise the map entries manually. Scala maps have a `.withDefaulValue`, but this is buggy and doesn't - // behave as you would expect: https://github.com/scala/bug/issues/8099 - thus the manual approach. - // We don't initialise `procToCfg` here, because it will never be accessed before `cfgForProcedure`, - // and because it relies on the entry/exit nodes be initialised. It is initialised in `cfgForProcedure`. - program.procedures.foreach(proc => - procToCalls += (proc -> mutable.Set()) - procToCallers += (proc -> mutable.Set()) - ) - - // Create CFG for individual procedures - program.procedures.foreach( - proc => { - val funcEntryNode: CfgFunctionEntryNode = CfgFunctionEntryNode(proc) - val funcExitNode: CfgFunctionExitNode = CfgFunctionExitNode(proc) - cfg.nodes += funcEntryNode - cfg.nodes += funcExitNode - cfg.funEntries += funcEntryNode - - procToCfg += (proc -> (funcEntryNode, funcExitNode)) - callToNodes += (funcEntryNode -> mutable.Set()) - } - ) - program.procedures.foreach(proc => cfgForProcedure(proc)) - - // Inline functions up to `inlineLimit` level - // EXTENSION; one way to improve this would be to specify inline depths for specific functions / situations. - // i.e. we may not want to inline self-recursive functions too much. - // Of note is whether we want this at all or note. If not, then we can simply remove the below and pass `procCallNodes` to - // `addInterprocEdges`. - val procCallNodes: Set[CfgJumpNode] = procToCalls.values.flatten.toSet - val leafCallNodes: Set[CfgJumpNode] = - if !unify then inlineProcedureCalls(procCallNodes, inlineLimit) else procCallNodes - - // Add inter-proc edges to leaf call nodes - if (leafCallNodes.nonEmpty) { - addInterprocEdges(leafCallNodes) - } - - cfg.startNode = procToCfg(program.mainProcedure)._1 - - cfg - } - - /** Create an intraprocedural CFG for the given IR procedure. The start of the CFG for a procedure is identified by - * its `CfgFunctionEntryNode`, and its closure is identified by the `CfgFunctionExitNode`. - * - * @param proc - * Procedure for which to generate the intraprocedural cfg - */ - private def cfgForProcedure(proc: Procedure): Unit = { - val funcEntryNode: CfgFunctionEntryNode = procToCfg(proc)._1 - val funcExitNode: CfgFunctionExitNode = procToCfg(proc)._2 - - // Track blocks we've already processed so we don't double up - val visitedBlocks: mutable.Map[Block, CfgCommandNode] = mutable.Map() - - // Procedure has no content (in our case this probably means it's an ignored procedure, e.g., an external function such as @printf) - if (proc.blocks.isEmpty) { - cfg.addEdge(funcEntryNode, funcExitNode) - } else { - // Recurse through blocks - visitBlock(proc.entryBlock.get, funcEntryNode) - } - - /** Add a block to the CFG. A block in this case is a basic block, so it contains a list of consecutive statements - * followed by a jump at the end to another block. We process statements in this block (if they exist), and then - * follow the jump to recurse through all other blocks. - * - * This recursive approach is effectively a "reaches" approach, and will miss cases that we encounter a jump we - * can't resolve, or cases where the lifter has not identified a section of code. In each case: - * a. The only jumps we can't resolve are indirect calls. It's the intent of the tool to attempt to resolve these - * through analysis however. The CFG can then be updated as these are resolved to incorporate their jumps. In - * construction we do a simple check for register R30 to identify if an indirect call is a return, but - * otherwise consider it as unresolved. b. If the lifter has failed to identify a region of code, then the - * problem exists at the lifter level. In that case we need a way to coerce the lifter into identifying it, or - * to use a new lifter. - * - * These visitations will also only produce the intra-procedural CFG - the burden of "creating" the - * inter-procedural CFG is left to processes later during CFG construction. The benefit of doing this is that we - * can completely resolve a procedure's CFG without jumping to other procedures mid-way through processing, which - * assures we don't have any issues with referencing nodes before they exist. Essentially this is a depth-first - * approach to CFG construction, as opposed to a breadth-first. - * - * @param block - * The block being added to the CFG. - * @param prevBlockEnd - * Preceding block's end node (jump) - */ - def visitBlock(block: Block, prevBlockEnd: CfgNode): Unit = { - - if (block.statements.nonEmpty) { - val endStmt = visitStmts(block.statements, prevBlockEnd) - visitJump(block.jump, endStmt, false) - } else { - // Only jumps in this block - visitJump(block.jump, prevBlockEnd, true) - } - - /** If a block has statements, we add them to the CFG. Blocks in this case are basic blocks, so we know - * consecutive statements will be linked by an unconditional, regular edge. - * - * @param stmts - * Statements in this block - * @param prevNode - * Preceding block's end node (jump) - * @return - * The last statement's CFG node - */ - def visitStmts(stmts: Iterable[Statement], prevNode: CfgNode): CfgCommandNode = { - - val firstNode = CfgStatementNode(stmts.head, block, funcEntryNode) - cfg.addEdge(prevNode, firstNode) - visitedBlocks += (block -> firstNode) // This is guaranteed to be entrance to block if we are here - - val statements = List.from(stmts).map(s => s match { - case d: DirectCall => CfgJumpNode(d, block, funcEntryNode) - case d: IndirectCall => CfgJumpNode(d, block, funcEntryNode) - case o => CfgStatementNode(o, block, funcEntryNode) - }) - val succs = if (statements.nonEmpty) then statements.zip(statements.tail ++ List(CfgJumpNode(statements.head.data.parent.jump, block, funcEntryNode))) else List() - - for ((s,nexts) <- succs) { - s.data match { - case dCall: DirectCall => - - var precNode = prevNode - - val targetProc: Procedure = dCall.target - funcEntryNode.callers.add(procToCfg(targetProc)._1) - - val callNode : CfgJumpNode = s.asInstanceOf[CfgJumpNode] - - // Branch to this call - cfg.addEdge(precNode, callNode) - - procToCalls(proc) += callNode - procToCallers(targetProc) += callNode - callToNodes(funcEntryNode) += callNode - - // Record call association - - // Jump to return location - val returnTarget = nexts - // Add intermediary return node (split call into call and return) - val callRet = CfgCallReturnNode() - cfg.addEdge(callNode, callRet) - cfg.addEdge(callRet, returnTarget) - case iCall: IndirectCall => - Logger.debug(s"Indirect call found: $iCall in ${proc.name}") - var precNode = prevNode - - val jmpNode = s.asInstanceOf[CfgJumpNode] - // Branch to this call - cfg.addEdge(precNode, jmpNode) - - // Record call association - procToCalls(proc) += jmpNode - callToNodes(funcEntryNode) += jmpNode - - // R30 is the link register - this stores the address to return to. - // For now just add a node expressing that we are to return to the previous context. - if (iCall.target == Register("R30", 64)) { - val returnNode = CfgProcedureReturnNode() - cfg.addEdge(jmpNode, returnNode) - cfg.addEdge(returnNode, funcExitNode) - } - - val callRet = CfgCallReturnNode() - cfg.addEdge(jmpNode, callRet) - val returnTarget = nexts - cfg.addEdge(callRet, jmpNode) - case h: Halt => { - assert(false); - // not possible since s is only Statement. - } - case _ => () - } - } - - - if (stmts.size == 1) { - return firstNode - } - - var prevStmtNode: CfgStatementNode = firstNode - - stmts.tail.foreach(stmt => - val stmtNode = CfgStatementNode(stmt, block, funcEntryNode) - cfg.addEdge(prevStmtNode, stmtNode) - prevStmtNode = stmtNode - ) - - prevStmtNode - } - - /** All blocks end with jump(s), whereas some also start with a jump (in the case of no statements). Add these to - * the CFG and visit their target blocks for processing. - * - * @param jmps - * Jumps in the current block being processed - * @param prevNode - * Either the previous statement in the block, or the previous block's end node (in the case that this block - * contains no statements) - * @param solitary - * `True` if this block contains no statements, `False` otherwise - */ - def visitJump(jmp: Jump, prevNode: CfgNode, solitary: Boolean): Unit = { - val jmpNode = CfgJumpNode(jmp, block, funcEntryNode) - var precNode = prevNode - - if (solitary) { - /* If the block contains only jumps (no statements), then the "start" of the block is a jump. - If this is a direct call, then we simply use that call node as the start of the block. - However, GoTos in the CFG are resolved as edges, and so there doesn't exist a node to use as - the start. Thus we introduce a "ghost" node to act as that jump point - it has no functionality - and will simply be skipped by analyses. - - Currently we display these nodes in the DOT view of the CFG, however these could be hidden if desired. - */ - jmp match { - case jmp: GoTo => - // `GoTo`s are just edges, so introduce a fake `start of block` that can be jmp'd to - val ghostNode = CfgGhostNode(block, funcEntryNode, NOP(jmp.label)) - cfg.addEdge(prevNode, ghostNode) - precNode = ghostNode - visitedBlocks += (block -> ghostNode) - case _ => - // (In)direct call - use this as entrance to block - visitedBlocks += (block -> jmpNode) - } - } - - jmp match { - case n: GoTo => - for (targetBlock <- n.targets) { - if (visitedBlocks.contains(targetBlock)) { - val targetBlockEntry: CfgCommandNode = visitedBlocks(targetBlock) - cfg.addEdge(precNode, targetBlockEntry) - } else { - visitBlock(targetBlock, precNode) - } - } - case h: Halt => { - cfg.addEdge(jmpNode, funcExitNode) - } - case r: Return => - // Branch to this call - cfg.addEdge(precNode, jmpNode) - - // Record call association - procToCalls(proc) += jmpNode - callToNodes(funcEntryNode) += jmpNode - - val returnNode = CfgProcedureReturnNode() - cfg.addEdge(jmpNode, returnNode) - cfg.addEdge(returnNode, funcExitNode) - } // `jmps.head` match - } // `visitJumps` function - } // `visitBlocks` function - } // `cfgForProcedure` function - - /** This takes an expression used in a conditional (jump) and tries to negate it in a (hopefully) nice way. Most - * conditional jumps are just bitvector comparisons. - * - * @param expr - * The expression to negate - * @return - * The negated expression - */ - private def negateConditional(expr: Expr): Expr = expr match { - case binop: BinaryExpr => - binop.op match { - case BVNEQ => - BinaryExpr( - BVEQ, - binop.arg1, - binop.arg2 - ) - case BVEQ => - BinaryExpr( - BVNEQ, - binop.arg1, - binop.arg2 - ) - case _ => - // Worst case scenario we just take the logical not of everything - UnaryExpr( - BoolNOT, - binop - ) - } - case unop: UnaryExpr => - unop.op match { - case BVNOT | BoolNOT => - unop.arg - case _ => - UnaryExpr( - BoolNOT, - unop - ) - } - case _ => - UnaryExpr( - BoolNOT, - expr - ) - } - - /** Recursively inline procedures. This has a dumb/flat approach - we simply continue inlining each all direct calls - * until we either run out of direct calls, or we are at our max inline depth. - * - * For each direct call to be inlined we make a copy of the target's intraprocedural cfg, which is then linked to the - * calling procedure's cfg. We keep track of newly found direct calls that come from inlined functions, which is what - * we pass to the next recursive call. At the end of recursion this set stores the leaf nodes of the cfg - this is - * then used later to link interprocedural calls. - * - * @param procNodes - * The call nodes to inline - * @param inlineAmount - * Maximum amount of inlining from this depth allowed - * @return - * Tthe next leaf call nodes - */ - @tailrec - private def inlineProcedureCalls(procNodes: Set[CfgJumpNode], inlineAmount: Int): Set[CfgJumpNode] = { - assert(inlineAmount >= 0) - Logger.info(s"[+] Inlining ${procNodes.size} leaf call nodes with $inlineAmount level(s) left") - - if (inlineAmount == 0 || procNodes.isEmpty) { - return procNodes - } - - // Set of procedure calls to be discovered by inlining the ones in `procNodes` - val nextProcNodes: mutable.Set[CfgJumpNode] = mutable.Set() - - procNodes.foreach { procNode => - procNode.data match { - case targetCall: DirectCall => - // Retrieve information about the call to the target procedure - val targetProc = targetCall.target - val (procEntry, procExit) = cloneProcedureCFG(targetProc) - - // Add link between call node and the procedure's `Entry`. - cfg.addInlineEdge(procNode, procEntry) - - // Link the procedure's `Exit` to the return point. There should only be one. - assert( - procNode.succIntra.size == 1, - s"More than 1 return node... $procNode has ${procNode.succIntra}" - ) - val returnNode = procNode.succIntra.head - cfg.addInlineEdge(procExit, returnNode) - - // Add new (un-inlined) function calls to be inlined - nextProcNodes ++= callToNodes(procEntry) - case _ => - } - } - - inlineProcedureCalls(nextProcNodes.toSet, inlineAmount - 1) - } - - /** Clones the intraproc-cfg of the given procedure, with unique CfgNode ids. Adds the new nodes to the cfg, and - * returns the start/end nodes of the new procedure cfg. - * - * @param proc - * The procedure to clone (used to index the pre-computed cfgs) - * @return - * (CfgFunctionEntryNode, CfgFunctionExitNode) of the cloned cfg - */ - private def cloneProcedureCFG(proc: Procedure): (CfgFunctionEntryNode, CfgFunctionExitNode) = { - - val (entryNode: CfgFunctionEntryNode, exitNode: CfgFunctionExitNode) = procToCfg(proc) - val (newEntry: CfgFunctionEntryNode, newExit: CfgFunctionExitNode) = (entryNode.copyNode(), exitNode.copyNode()) - - callToNodes += (newEntry -> mutable.Set()) - - // Entry is guaranteed to only have one successor (by our cfg design) - val currNode: CfgNode = entryNode.succIntra.head - visitNode(currNode, newEntry) - - /** Walk this proc's cfg until we reach the exit node on each branch. We do this recursively, tracking the previous - * node, to account for branches and loops. - * - * We can't represent the parameters as an edge as one node comes from the old cfg, and the other from the new cfg. - * - * @param node - * Node in the original procedure's cfg we're up to cloning - * @param prevNewNode - * The originating node in the new clone's cfg - */ - def visitNode(node: CfgNode, prevNewNode: CfgNode): Unit = { - - if (node == exitNode) { - cfg.addEdge(prevNewNode, newExit) - return - } - - node match { - case n: CfgJumpNode => - val newNode = n.copyNode() - - // Link this node with predecessor in the new cfg - cfg.addEdge(prevNewNode, newNode) - - n.data match { - case d: DirectCall => - procToCalls(proc) += newNode - callToNodes(newEntry) += newNode - procToCallers(d.target) += newNode - case i: IndirectCall => - procToCalls(proc) += newNode - callToNodes(newEntry) += newNode - case _ => - } - - // Get intra-cfg successors - val outNodes = node.succIntra - outNodes.foreach(node => visitNode(node, newNode)) - - // For other node types, link with predecessor and continue traversal - case _ => - val newNode = node.copyNode() - cfg.addEdge(prevNewNode, newNode) - - val outNodes = node.succIntra - outNodes.foreach(node => visitNode(node, newNode)) - } - } - - (newEntry, newExit) - } - - /** After inlining has been done, we link all residual direct calls (leaf nodes) to the start of the intraprocedural - * that are the target of the call. - * - * @param leaves - * The call nodes at edge of intraprocedural cfgs to be linked to their targets - */ - private def addInterprocEdges(leaves: Set[CfgJumpNode]): Unit = { - - leaves.foreach { callNode => - callNode.data match { - case targetCall: DirectCall => - val targetProc: Procedure = targetCall.target - - // this does not add returns for any of the calls, so the interprocedural analysis will not work if any - // calls are not in-lined - val (targetEntry: CfgFunctionEntryNode, _) = procToCfg(targetProc) - - cfg.addInterprocCallEdge(callNode, targetEntry) - case _ => - } - } - } diff --git a/src/main/scala/analysis/Dependencies.scala b/src/main/scala/analysis/Dependencies.scala index 040861803..4b5e15106 100644 --- a/src/main/scala/analysis/Dependencies.scala +++ b/src/main/scala/analysis/Dependencies.scala @@ -22,26 +22,6 @@ trait Dependencies[N]: */ def indep(n: N): Set[N] -trait InterproceduralForwardDependencies extends Dependencies[CfgNode] { - override def outdep(n: CfgNode): Set[CfgNode] = n.succInter.toSet - override def indep(n: CfgNode): Set[CfgNode] = n.predInter.toSet -} - -trait IntraproceduralForwardDependencies extends Dependencies[CfgNode] { - override def outdep(n: CfgNode): Set[CfgNode] = n.succIntra.toSet - override def indep(n: CfgNode): Set[CfgNode] = n.predIntra.toSet -} - -trait InterproceduralBackwardDependencies extends Dependencies[CfgNode] { - override def outdep(n: CfgNode): Set[CfgNode] = n.predInter.toSet - override def indep(n: CfgNode): Set[CfgNode] = n.succInter.toSet -} - -trait IntraproceduralBackwardDependencies extends Dependencies[CfgNode] { - override def outdep(n: CfgNode): Set[CfgNode] = n.predIntra.toSet - override def indep(n: CfgNode): Set[CfgNode] = n.succIntra.toSet -} - trait IRInterproceduralForwardDependencies extends Dependencies[CFGPosition] { override def outdep(n: CFGPosition): Set[CFGPosition] = InterProcIRCursor.succ(n) override def indep(n: CFGPosition): Set[CFGPosition] = InterProcIRCursor.pred(n) diff --git a/src/main/scala/analysis/InterLiveVarsAnalysis.scala b/src/main/scala/analysis/InterLiveVarsAnalysis.scala index 7a93266e8..cbe3076c6 100644 --- a/src/main/scala/analysis/InterLiveVarsAnalysis.scala +++ b/src/main/scala/analysis/InterLiveVarsAnalysis.scala @@ -1,7 +1,7 @@ package analysis import analysis.solvers.BackwardIDESolver -import ir.{Assert, Assume, Block, GoTo, CFGPosition, Command, DirectCall, IndirectCall, Assign, MemoryAssign, Halt, Return, Procedure, Program, Variable, toShortString} +import ir.{Assert, Assume, Block, GoTo, CFGPosition, Command, DirectCall, IndirectCall, Assign, MemoryAssign, Unreachable, Return, Procedure, Program, Variable, toShortString} /** * Micro-transfer-functions for LiveVar analysis @@ -74,11 +74,7 @@ trait LiveVarsAnalysisFunctions extends BackwardIDEAnalysis[Variable, TwoElement d match case Left(value) => if value != variable then Map(d -> IdEdge()) else Map() case Right(_) => Map(d -> IdEdge(), Left(variable) -> ConstEdge(TwoElementTop)) - case r: Return => Map(d -> IdEdge()) - case h: Halt => Map(d -> IdEdge()) - case c: DirectCall => Map(d -> IdEdge()) - case c: Block => Map(d -> IdEdge()) - case c: GoTo => Map(d -> IdEdge()) + case _ => Map(d -> IdEdge()) } } diff --git a/src/main/scala/analysis/InterprocSteensgaardAnalysis.scala b/src/main/scala/analysis/InterprocSteensgaardAnalysis.scala index 38153e277..4e7ba81f1 100644 --- a/src/main/scala/analysis/InterprocSteensgaardAnalysis.scala +++ b/src/main/scala/analysis/InterprocSteensgaardAnalysis.scala @@ -39,7 +39,7 @@ case class RegisterWrapperEqualSets(variable: Variable, assigns: Set[Assign]) { class InterprocSteensgaardAnalysis( program: Program, constantProp: Map[CFGPosition, Map[RegisterWrapperEqualSets, Set[BitVecLiteral]]], - regionAccesses: Map[CfgNode, Map[RegisterVariableWrapper, FlatElement[Expr]]], + regionAccesses: Map[CFGPosition, Map[RegisterVariableWrapper, FlatElement[Expr]]], mmm: MemoryModelMap, reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], globalOffsets: Map[BigInt, BigInt]) extends Analysis[Any] { @@ -431,4 +431,4 @@ object Fresh { n += 1 n } -} \ No newline at end of file +} diff --git a/src/main/scala/analysis/IntraLiveVarsAnalysis.scala b/src/main/scala/analysis/IntraLiveVarsAnalysis.scala index 75fa1dbb0..a576b27fb 100644 --- a/src/main/scala/analysis/IntraLiveVarsAnalysis.scala +++ b/src/main/scala/analysis/IntraLiveVarsAnalysis.scala @@ -1,7 +1,7 @@ package analysis import analysis.solvers.SimpleWorklistFixpointSolver -import ir.{Assert, Assume, Block, CFGPosition, Call, DirectCall, GoTo, IndirectCall, Jump, Assign, MemoryAssign, NOP, Procedure, Program, Statement, Variable, Return, Halt} +import ir.{Assert, Assume, Block, CFGPosition, Call, DirectCall, GoTo, IndirectCall, Jump, Assign, MemoryAssign, NOP, Procedure, Program, Statement, Variable, Return, Unreachable} abstract class LivenessAnalysis(program: Program) extends Analysis[Any]: val lattice: MapLattice[CFGPosition, Set[Variable], PowersetLattice[Variable]] = MapLattice(PowersetLattice()) @@ -19,7 +19,7 @@ abstract class LivenessAnalysis(program: Program) extends Analysis[Any]: case c: DirectCall => s case g: GoTo => s case r: Return => s - case r: Halt => s + case r: Unreachable => s case _ => ??? } } diff --git a/src/main/scala/analysis/MemoryRegionAnalysis.scala b/src/main/scala/analysis/MemoryRegionAnalysis.scala index 65aa1cc20..c7d888a43 100644 --- a/src/main/scala/analysis/MemoryRegionAnalysis.scala +++ b/src/main/scala/analysis/MemoryRegionAnalysis.scala @@ -14,7 +14,7 @@ trait MemoryRegionAnalysis(val program: Program, val constantProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], val ANRResult: Map[CFGPosition, Set[Variable]], val RNAResult: Map[CFGPosition, Set[Variable]], - val regionAccesses: Map[CfgNode, Map[RegisterVariableWrapper, FlatElement[Expr]]], + val regionAccesses: Map[CFGPosition, Map[RegisterVariableWrapper, FlatElement[Expr]]], reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])]) { var mallocCount: Int = 0 @@ -234,7 +234,7 @@ class MemoryRegionAnalysisSolver( constantProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], ANRResult: Map[CFGPosition, Set[Variable]], RNAResult: Map[CFGPosition, Set[Variable]], - regionAccesses: Map[CfgNode, Map[RegisterVariableWrapper, FlatElement[Expr]]], + regionAccesses: Map[CFGPosition, Map[RegisterVariableWrapper, FlatElement[Expr]]], reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])] ) extends MemoryRegionAnalysis(program, globals, globalOffsets, subroutines, constantProp, ANRResult, RNAResult, regionAccesses, reachingDefs) with IRIntraproceduralForwardDependencies @@ -249,4 +249,4 @@ class MemoryRegionAnalysisSolver( case _ => super.funsub(n, x) } } -} \ No newline at end of file +} diff --git a/src/main/scala/analysis/RegToMemAnalysis.scala b/src/main/scala/analysis/RegToMemAnalysis.scala index df7217f75..1afde6f4d 100644 --- a/src/main/scala/analysis/RegToMemAnalysis.scala +++ b/src/main/scala/analysis/RegToMemAnalysis.scala @@ -15,29 +15,29 @@ import scala.collection.immutable * * Both in which constant propagation mark as TOP which is not useful. */ -trait RegionAccessesAnalysis(cfg: ProgramCfg, constantProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])]) { +trait RegionAccessesAnalysis(program: Program, constantProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])]) { val mapLattice: MapLattice[RegisterVariableWrapper, FlatElement[Expr], FlatLattice[Expr]] = MapLattice(FlatLattice[_root_.ir.Expr]()) - val lattice: MapLattice[CfgNode, Map[RegisterVariableWrapper, FlatElement[Expr]], MapLattice[RegisterVariableWrapper, FlatElement[Expr], FlatLattice[Expr]]] = MapLattice(mapLattice) + val lattice: MapLattice[CFGPosition, Map[RegisterVariableWrapper, FlatElement[Expr]], MapLattice[RegisterVariableWrapper, FlatElement[Expr], FlatLattice[Expr]]] = MapLattice(mapLattice) - val domain: Set[CfgNode] = cfg.nodes.toSet + val domain: Set[CFGPosition] = program.toSet - val first: Set[CfgNode] = Set(cfg.startNode) + val first: Set[CFGPosition] = program.procedures.toSet /** Default implementation of eval. */ - def eval(cmd: CfgCommandNode, constants: Map[Variable, FlatElement[BitVecLiteral]], s: Map[RegisterVariableWrapper, FlatElement[Expr]]): Map[RegisterVariableWrapper, FlatElement[Expr]] = { - cmd.data match { + def eval(cmd: Statement, constants: Map[Variable, FlatElement[BitVecLiteral]], s: Map[RegisterVariableWrapper, FlatElement[Expr]]): Map[RegisterVariableWrapper, FlatElement[Expr]] = { + cmd match { case assign: Assign => assign.rhs match { case memoryLoad: MemoryLoad => - s + (RegisterVariableWrapper(assign.lhs, getDefinition(assign.lhs, cmd.data, reachingDefs)) -> FlatEl(memoryLoad)) + s + (RegisterVariableWrapper(assign.lhs, getDefinition(assign.lhs, cmd, reachingDefs)) -> FlatEl(memoryLoad)) case binaryExpr: BinaryExpr => if (evaluateExpression(binaryExpr.arg1, constants).isEmpty) { // approximates Base + Offset Logger.debug(s"Approximating $assign in $binaryExpr") - Logger.debug(s"Reaching defs: ${reachingDefs(cmd.data)}") - s + (RegisterVariableWrapper(assign.lhs, getDefinition(assign.lhs, cmd.data, reachingDefs)) -> FlatEl(binaryExpr)) + Logger.debug(s"Reaching defs: ${reachingDefs(cmd)}") + s + (RegisterVariableWrapper(assign.lhs, getDefinition(assign.lhs, cmd, reachingDefs)) -> FlatEl(binaryExpr)) } else { s } @@ -50,23 +50,23 @@ trait RegionAccessesAnalysis(cfg: ProgramCfg, constantProp: Map[CFGPosition, Map /** Transfer function for state lattice elements. */ - def localTransfer(n: CfgNode, s: Map[RegisterVariableWrapper, FlatElement[Expr]]): Map[RegisterVariableWrapper, FlatElement[Expr]] = n match { - case cmd: CfgCommandNode => - eval(cmd, constantProp(cmd.data), s) + def localTransfer(n: CFGPosition, s: Map[RegisterVariableWrapper, FlatElement[Expr]]): Map[RegisterVariableWrapper, FlatElement[Expr]] = n match { + case cmd: Statement => + eval(cmd, constantProp(cmd), s) case _ => s // ignore other kinds of nodes } /** Transfer function for state lattice elements. */ - def transfer(n: CfgNode, s: Map[RegisterVariableWrapper, FlatElement[Expr]]): Map[RegisterVariableWrapper, FlatElement[Expr]] = localTransfer(n, s) + def transfer(n: CFGPosition, s: Map[RegisterVariableWrapper, FlatElement[Expr]]): Map[RegisterVariableWrapper, FlatElement[Expr]] = localTransfer(n, s) } class RegionAccessesAnalysisSolver( - cfg: ProgramCfg, + program: Program, constantProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], - ) extends RegionAccessesAnalysis(cfg, constantProp, reachingDefs) - with InterproceduralForwardDependencies - with Analysis[Map[CfgNode, Map[RegisterVariableWrapper, FlatElement[Expr]]]] - with SimpleWorklistFixpointSolver[CfgNode, Map[RegisterVariableWrapper, FlatElement[Expr]], MapLattice[RegisterVariableWrapper, FlatElement[Expr], FlatLattice[Expr]]] { -} \ No newline at end of file + ) extends RegionAccessesAnalysis(program, constantProp, reachingDefs) + with IRInterproceduralForwardDependencies + with Analysis[Map[CFGPosition, Map[RegisterVariableWrapper, FlatElement[Expr]]]] + with SimpleWorklistFixpointSolver[CFGPosition, Map[RegisterVariableWrapper, FlatElement[Expr]], MapLattice[RegisterVariableWrapper, FlatElement[Expr], FlatLattice[Expr]]] { +} diff --git a/src/main/scala/analysis/solvers/IDESolver.scala b/src/main/scala/analysis/solvers/IDESolver.scala index 231c87521..eaa9a2369 100644 --- a/src/main/scala/analysis/solvers/IDESolver.scala +++ b/src/main/scala/analysis/solvers/IDESolver.scala @@ -1,7 +1,7 @@ package analysis.solvers import analysis.{BackwardIDEAnalysis, Dependencies, EdgeFunction, EdgeFunctionLattice, ForwardIDEAnalysis, IDEAnalysis, IRInterproceduralBackwardDependencies, IRInterproceduralForwardDependencies, Lambda, Lattice, MapLattice} -import ir.{CFGPosition, Command, DirectCall, GoTo, IRWalk, IndirectCall, Return, InterProcIRCursor, Procedure, Program, isAfterCall, Halt, Statement, Jump} +import ir.{CFGPosition, Command, DirectCall, GoTo, IRWalk, IndirectCall, Return, InterProcIRCursor, Procedure, Program, isAfterCall, Unreachable, Statement, Jump} import util.Logger import scala.collection.immutable.Map @@ -213,7 +213,7 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) protected def entryToExit(entry: Procedure): Return = IRWalk.lastInProc(entry).asInstanceOf[Return] - protected def exitToEntry(exit: IndirectCall): Procedure = IRWalk.procedure(exit) + protected def exitToEntry(exit: Return): Procedure = IRWalk.procedure(exit) protected def callToReturn(call: DirectCall): Command = call.successor @@ -229,13 +229,13 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) protected def isCall(call: CFGPosition): Boolean = call match - case directCall: DirectCall if (!directCall.successor.isInstanceOf[Halt]) => true + case directCall: DirectCall if (!directCall.successor.isInstanceOf[Unreachable]) => true case _ => false protected def isExit(exit: CFGPosition): Boolean = exit match // only looking at functions with statements - case command: Command => IRWalk.lastInProc(IRWalk.procedure(command)) == command + case command: Return => true case _ => false protected def getAfterCalls(exit: IndirectCall): Set[Command] = @@ -268,12 +268,19 @@ abstract class BackwardIDESolver[D, T, L <: Lattice[T]](program: Program) protected def isCall(call: CFGPosition): Boolean = call match - case c : Command => isAfterCall(c) && IRWalk.prevCommandInBlock(c).map(_.isInstanceOf[DirectCall]).getOrElse(false) + case c: Unreachable => false /* don't process non-returning calls */ + case c : Command => { + val call = IRWalk.prevCommandInBlock(c) + call match { + case Some(d: DirectCall) if d.target.returnBlock.isDefined => true + case _ => false + } + } case _ => false protected def isExit(exit: CFGPosition): Boolean = exit match - case procedure: Procedure => true + case procedure: Procedure => procedure.blocks.nonEmpty case _ => false protected def getAfterCalls(exit: Procedure): Set[DirectCall] = exit.incomingCalls().toSet diff --git a/src/main/scala/cfg_visualiser/Output.scala b/src/main/scala/cfg_visualiser/Output.scala deleted file mode 100644 index cc730d62e..000000000 --- a/src/main/scala/cfg_visualiser/Output.scala +++ /dev/null @@ -1,37 +0,0 @@ -package cfg_visualiser - -import java.io.{File, PrintWriter} -import analysis._ - -/** Basic outputting functionality. - */ -object Output { - - /** Helper function for producing string output for a control-flow graph node after an analysis. - * @param res - * map from control-flow graph nodes to strings, as produced by the analysis - */ - def labeler(res: Map[CfgNode, _], stateAfterNode: Boolean)(n: CfgNode): String = { - val r = res.getOrElse(n, "-") - val desc = n match { - case entry: CfgFunctionEntryNode => s"Function ${entry.data.name} entry" - case exit: CfgFunctionExitNode => s"Function ${exit.data.name} exit" - case _ => n.toString + s" ${n.id}" - } - if (stateAfterNode) s"$desc\n$r" - else s"$r\n$desc" - } - - /** Generate an unique ID string for the given AST node. - */ - def dotIder(n: CfgNode, uniqueId: Int): String = - n match { - case real: CfgCommandNode => s"real${real.data}_$uniqueId" - case entry: CfgFunctionEntryNode => s"entry${entry.data}_$uniqueId" - case exit: CfgFunctionExitNode => s"exit${exit.data}_$uniqueId" - case ret: CfgProcedureReturnNode => s"return_$uniqueId" - case noCallRet: CfgCallNoReturnNode => s"callnoreturn_$uniqueId" - case callRet: CfgCallReturnNode => s"callreturn_$uniqueId" - case _ => ??? - } -} \ No newline at end of file diff --git a/src/main/scala/ir/IRCursor.scala b/src/main/scala/ir/IRCursor.scala index 52c585aaf..0e79fa422 100644 --- a/src/main/scala/ir/IRCursor.scala +++ b/src/main/scala/ir/IRCursor.scala @@ -96,7 +96,7 @@ trait IntraProcIRCursor extends IRWalk[CFGPosition, CFGPosition] { case proc: Procedure => proc.entryBlock.toSet case b: Block => b.statements.headOption.orElse(Some(b.jump)).toSet case n: GoTo => n.targets.asInstanceOf[Set[CFGPosition]] - case h: Halt => Set() + case h: Unreachable => Set() case h: Return => Set() case c: Statement => IRWalk.nextCommandInBlock(c).toSet } @@ -150,7 +150,6 @@ trait InterProcIRCursor extends IRWalk[CFGPosition, CFGPosition] { IntraProcIRCursor.succ(pos) ++ (pos match case c: DirectCall if c.target.blocks.nonEmpty => Set(c.target) - // case c: IndirectCall if c.parent.isProcReturn => c.parent.parent.incomingCalls().map(_.successor).toSet case c: Return => c.parent.parent.incomingCalls().map(_.successor).toSet case _ => Set.empty ) @@ -159,7 +158,13 @@ trait InterProcIRCursor extends IRWalk[CFGPosition, CFGPosition] { final def pred(pos: CFGPosition): Set[CFGPosition] = { IntraProcIRCursor.pred(pos) ++ (pos match - case d: DirectCall if d.target.blocks.nonEmpty => d.target.returnBlock.toSet + case c: Command => { + IRWalk.prevCommandInBlock(c) match { + case Some(d: DirectCall) if d.target.blocks.nonEmpty => d.target.returnBlock.toSet + case o => o.toSet + } + + } case c: Procedure => c.incomingCalls().toSet.asInstanceOf[Set[CFGPosition]] case _ => Set.empty ) diff --git a/src/main/scala/ir/Interpreter.scala b/src/main/scala/ir/Interpreter.scala index de470d5d9..0430ef66a 100644 --- a/src/main/scala/ir/Interpreter.scala +++ b/src/main/scala/ir/Interpreter.scala @@ -248,8 +248,8 @@ class Interpreter() { case r: Return => { nextCmd = Some(returnCmd.pop()) } - case h: Halt => { - Logger.debug("Halt") + case h: Unreachable => { + Logger.debug("Unreachable") nextCmd = None } } diff --git a/src/main/scala/ir/Program.scala b/src/main/scala/ir/Program.scala index 5911be3a8..ba8c67239 100644 --- a/src/main/scala/ir/Program.scala +++ b/src/main/scala/ir/Program.scala @@ -18,40 +18,6 @@ class Program(var procedures: ArrayBuffer[Procedure], serialiseIL(this) } - // This shouldn't be run before indirect calls are resolved - def stripUnreachableFunctions(depth: Int = Int.MaxValue): Unit = { - val procedureCalleeNames = procedures.map(f => f.name -> f.calls.map(_.name)).toMap - - val toVisit: mutable.LinkedHashSet[(Int, String)] = mutable.LinkedHashSet((0, mainProcedure.name)) - var reachableFound = true - val reachableNames = mutable.HashMap[String, Int]() - while (toVisit.nonEmpty) { - val next = toVisit.head - toVisit.remove(next) - - if (next._1 <= depth) { - - def addName(depth: Int, name: String): Unit = { - val oldDepth = reachableNames.getOrElse(name, Integer.MAX_VALUE) - reachableNames.put(next._2, if depth < oldDepth then depth else oldDepth) - } - addName(next._1, next._2) - - val callees = procedureCalleeNames(next._2) - - toVisit.addAll(callees.diff(reachableNames.keySet).map(c => (next._1 + 1, c))) - callees.foreach(c => addName(next._1 + 1, c)) - } - } - procedures = procedures.filter(f => reachableNames.keySet.contains(f.name)) - - for (elem <- procedures.filter(c => c.calls.exists(s => !procedures.contains(s)))) { - // last layer is analysed only as specifications so we remove the body for anything that calls - // a function we have removed - - elem.clearBlocks() - } - } def setModifies(specModifies: Map[String, List[String]]): Unit = { val procToCalls: mutable.Map[Procedure, Set[Procedure]] = mutable.Map() @@ -231,7 +197,6 @@ class Procedure private ( } def addBlocks(block: Block): Block = { - block.parent = this if (!_blocks.contains(block)) { block.parent = this _blocks.add(block) @@ -320,7 +285,8 @@ class Procedure private ( def clearBlocks(): Unit = { // O(n) because we are careful to unlink the parents etc. - removeBlocks(_blocks) + // .toList to avoid modifying our own iterator + removeBlocksDisconnect(_blocks.toList) } def callers(): Iterable[Procedure] = _callers.map(_.parent.parent).toSet[Procedure] @@ -371,6 +337,7 @@ class Block private ( this(label, address, IntrusiveList().addAll(statements), jump, mutable.HashSet.empty) } + def isReturn: Boolean = parent.returnBlock.contains(this) def isEntry: Boolean = parent.entryBlock.contains(this) def jump: Jump = _jump diff --git a/src/main/scala/ir/Statement.scala b/src/main/scala/ir/Statement.scala index 2dea68f46..ce49bc82e 100644 --- a/src/main/scala/ir/Statement.scala +++ b/src/main/scala/ir/Statement.scala @@ -82,7 +82,7 @@ sealed trait Jump extends Command { def acceptVisit(visitor: Visitor): Jump = throw new Exception("visitor " + visitor + " unimplemented for: " + this) } -class Halt(override val label: Option[String] = None) extends Jump { +class Unreachable(override val label: Option[String] = None) extends Jump { /* Terminate / No successors / assume false */ override def acceptVisit(visitor: Visitor): Jump = this } @@ -139,7 +139,7 @@ object GoTo: sealed trait Call extends Statement { def returnTarget: Option[Command] = successor match { - case h: Halt => None + case h: Unreachable => None case o => Some(o) } } diff --git a/src/main/scala/ir/dsl/DSL.scala b/src/main/scala/ir/dsl/DSL.scala index 6a1b96742..3ebeefbc4 100644 --- a/src/main/scala/ir/dsl/DSL.scala +++ b/src/main/scala/ir/dsl/DSL.scala @@ -73,8 +73,8 @@ case class EventuallyGoto(targets: List[DelayNameResolve]) extends EventuallyJum case class EventuallyReturn() extends EventuallyJump { override def resolve(p: Program) = Return() } -case class EventuallyHalt() extends EventuallyJump { - override def resolve(p: Program) = Halt() +case class EventuallyUnreachable() extends EventuallyJump { + override def resolve(p: Program) = Unreachable() } def goto(): EventuallyGoto = EventuallyGoto(List.empty) @@ -84,7 +84,7 @@ def goto(targets: String*): EventuallyGoto = { } def ret: EventuallyReturn = EventuallyReturn() -def halt: EventuallyHalt= EventuallyHalt() +def unreachable: EventuallyUnreachable= EventuallyUnreachable() def goto(targets: List[String]): EventuallyGoto = { EventuallyGoto(targets.map(p => DelayNameResolve(p))) @@ -111,8 +111,6 @@ def block(label: String, sl: (Statement | EventuallyStatement | EventuallyJump)* val statements : Seq[EventuallyStatement] = sl.flatMap { case s: Statement => Some(ResolvableStatement(s)) case o: EventuallyStatement => Some(o) - case o: EventuallyCall => Some(o) - case o: EventuallyIndirectCall => Some(o) case g: EventuallyJump => None } val jump = sl.collectFirst { diff --git a/src/main/scala/ir/transforms/IndirectCallResolution.scala b/src/main/scala/ir/transforms/IndirectCallResolution.scala index e9c9fddf7..4345dcaa1 100644 --- a/src/main/scala/ir/transforms/IndirectCallResolution.scala +++ b/src/main/scala/ir/transforms/IndirectCallResolution.scala @@ -1,7 +1,5 @@ package ir.transforms - - import scala.collection.mutable.ListBuffer import scala.collection.mutable.ArrayBuffer import analysis.solvers.* @@ -11,165 +9,27 @@ import ir.* import translating.* import util.Logger import util.intrusive_list.IntrusiveList -import analysis.CfgCommandNode import scala.collection.mutable import cilvisitor._ - -/** Resolve indirect calls to an address-conditional choice between direct calls using the Value Set Analysis results. - * Dead code, and currently broken by statement calls - * -def resolveIndirectCalls( - cfg: ProgramCfg, - valueSets: Map[CfgNode, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]], +def resolveIndirectCallsUsingPointsTo( + pointsTos: Map[RegisterVariableWrapper, Set[RegisterVariableWrapper | MemoryRegion]], + regionContents: Map[MemoryRegion, Set[BitVecLiteral | MemoryRegion]], + reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], IRProgram: Program ): Boolean = { var modified: Boolean = false - val worklist = ListBuffer[CfgNode]() - cfg.startNode.succIntra.union(cfg.startNode.succInter).foreach(node => worklist.addOne(node)) - - val visited = mutable.Set[CfgNode]() - while (worklist.nonEmpty) { - val node = worklist.remove(0) - if (!visited.contains(node)) { - process(node) - node.succIntra.union(node.succInter).foreach(node => worklist.addOne(node)) - visited.add(node) - } - } - - def process(n: CfgNode): Unit = n match { - /* - case c: CfgStatementNode => - c.data match + val worklist = ListBuffer[CFGPosition]() - //We do not want to insert the VSA results into the IR like this - case localAssign: Assign => - localAssign.rhs match - case _: MemoryLoad => - if (valueSets(n).contains(localAssign.lhs) && valueSets(n).get(localAssign.lhs).head.size == 1) { - val extractedValue = extractExprFromValue(valueSets(n).get(localAssign.lhs).head.head) - localAssign.rhs = extractedValue - Logger.info(s"RESOLVED: Memory load ${localAssign.lhs} resolved to ${extractedValue}") - } else if (valueSets(n).contains(localAssign.lhs) && valueSets(n).get(localAssign.lhs).head.size > 1) { - Logger.info(s"RESOLVED: WARN Memory load ${localAssign.lhs} resolved to multiple values, cannot replace") - - /* - // must merge into a single memory variable to represent the possible values - // Make a binary OR of all the possible values takes two at a time (incorrect to do BVOR) - val values = valueSets(n).get(localAssign.lhs).head - val exprValues = values.map(extractExprFromValue) - val result = exprValues.reduce((a, b) => BinaryExpr(BVOR, a, b)) // need to express nondeterministic - // choice between these specific options - localAssign.rhs = result - */ - } - case _ => - */ - case c: CfgJumpNode => - val block = c.block - c.data match - case indirectCall: IndirectCall => - if (block.jump != indirectCall) { - // We only replace the calls with DirectCalls in the IR, and don't replace the CommandNode.data - // Hence if we have already processed this CFG node there will be no corresponding IndirectCall in the IR - // to replace. - // We want to replace all possible indirect calls based on this CFG, before regenerating it from the IR - return - } - valueSets(n) match { - case Lift(valueSet) => - val targetNames = resolveAddresses(valueSet(indirectCall.target)).map(_.name).toList.sorted - val targets = targetNames.map(name => IRProgram.procedures.filter(_.name.equals(name)).head) - - if (targets.size == 1) { - modified = true - - // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) - val newCall = DirectCall(targets.head, indirectCall.label) - block.statements.replace(indirectCall, newCall) - } else if (targets.size > 1) { - modified = true - val procedure = c.parent.data - val newBlocks = ArrayBuffer[Block]() - for (t <- targets) { - val assume = Assume(BinaryExpr(BVEQ, indirectCall.target, BitVecLiteral(t.address.get, 64))) - val newLabel: String = block.label + t.name - val directCall = DirectCall(t) - directCall.parent = indirectCall.parent - - // assume indircall is the last statement in block - assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) - val fallthrough = indirectCall.parent.jump - - newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) - } - procedure.addBlocks(newBlocks) - val newCall = GoTo(newBlocks, indirectCall.label) - block.replaceJump(newCall) - } - case LiftedBottom => - } - case _ => - case _ => - } - - def nameExists(name: String): Boolean = { - IRProgram.procedures.exists(_.name.equals(name)) - } - - def addFakeProcedure(name: String): Unit = { - IRProgram.procedures += Procedure(name) - } + worklist.addAll(IRProgram) - def resolveAddresses(valueSet: Set[Value]): Set[AddressValue] = { - var functionNames: Set[AddressValue] = Set() - valueSet.foreach { - case globalAddress: GlobalAddress => - if (nameExists(globalAddress.name)) { - functionNames += globalAddress - Logger.info(s"RESOLVED: Call to Global address ${globalAddress.name} rt statuesolved.") - } else { - addFakeProcedure(globalAddress.name) - functionNames += globalAddress - Logger.info(s"Global address ${globalAddress.name} does not exist in the program. Added a fake function.") - } - case localAddress: LocalAddress => - if (nameExists(localAddress.name)) { - functionNames += localAddress - Logger.info(s"RESOLVED: Call to Local address ${localAddress.name}") - } else { - addFakeProcedure(localAddress.name) - functionNames += localAddress - Logger.info(s"Local address ${localAddress.name} does not exist in the program. Added a fake function.") - } - case _ => - } - functionNames - } - - modified -} - - */ - -def resolveIndirectCallsUsingPointsTo( - cfg: ProgramCfg, - pointsTos: Map[RegisterVariableWrapper, Set[RegisterVariableWrapper | MemoryRegion]], - regionContents: Map[MemoryRegion, Set[BitVecLiteral | MemoryRegion]], - reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], - IRProgram: Program - ): Boolean = { - var modified: Boolean = false - val worklist = ListBuffer[CfgNode]() - cfg.startNode.succIntra.union(cfg.startNode.succInter).foreach(node => worklist.addOne(node)) - - val visited = mutable.Set[CfgNode]() + val visited = mutable.Set[CFGPosition]() while (worklist.nonEmpty) { val node = worklist.remove(0) if (!visited.contains(node)) { + // add to worklist before we delete the node and can no longer find its successors + InterProcIRCursor.succ(node).foreach(node => worklist.addOne(node)) process(node) - node.succIntra.union(node.succInter).foreach(node => worklist.addOne(node)) visited.add(node) } } @@ -181,7 +41,7 @@ def resolveIndirectCallsUsingPointsTo( if (regionContents.contains(stackRegion)) { for (c <- regionContents(stackRegion)) { c match { - case bitVecLiteral: BitVecLiteral => Logger.debug("hi: " + bitVecLiteral)//??? + case bitVecLiteral: BitVecLiteral => Logger.debug("hi: " + bitVecLiteral) //??? case memoryRegion: MemoryRegion => result.addAll(searchRegion(memoryRegion)) } @@ -195,7 +55,7 @@ def resolveIndirectCallsUsingPointsTo( result.add(dataRegion.regionIdentifier) // TODO: may need to investigate if we should add the parent region for (c <- regionContents(dataRegion)) { c match { - case bitVecLiteral: BitVecLiteral => Logger.debug("hi: " + bitVecLiteral)//??? + case bitVecLiteral: BitVecLiteral => Logger.debug("hi: " + bitVecLiteral) //??? case memoryRegion: MemoryRegion => result.addAll(searchRegion(memoryRegion)) } @@ -218,76 +78,70 @@ def resolveIndirectCallsUsingPointsTo( case Some(value) => value.map { case v: RegisterVariableWrapper => names.addAll(resolveAddresses(v.variable, i)) - case m: MemoryRegion => names.addAll(searchRegion(m)) + case m: MemoryRegion => names.addAll(searchRegion(m)) } names case None => names } } - def process(n: CfgNode): Unit = n match { - case c: CfgJumpNode => - val block = c.block - c.data match - // don't try to resolve returns - case indirectCall: IndirectCall if indirectCall.target != Register("R30", 64) => - if (!indirectCall.hasParent) { - // We only replace the calls with DirectCalls in the IR, and don't replace the CommandNode.data - // Hence if we have already processed this CFG node there will be no corresponding IndirectCall in the IR - // to replace. - // We want to replace all possible indirect calls based on this CFG, before regenerating it from the IR - return - } - assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) - - val targetNames = resolveAddresses(indirectCall.target, indirectCall) - Logger.debug(s"Points-To approximated call ${indirectCall.target} with $targetNames") - Logger.debug(IRProgram.procedures) - val targets: mutable.Set[Procedure] = targetNames.map(name => IRProgram.procedures.find(_.name == name).getOrElse(addFakeProcedure(name))) - - if (targets.size > 1) { - Logger.info(s"Resolved indirect call $indirectCall") + def process(n: CFGPosition): Unit = n match { + case indirectCall: IndirectCall if indirectCall.target != Register("R30", 64) => + if (!indirectCall.hasParent) { + // skip if we have already processesd this call + return + } + // we need the single-call-at-end-of-block invariant + assert(indirectCall.parent.statements.lastOption.contains(indirectCall)) + + val block = indirectCall.parent + val procedure = block.parent + + val targetNames = resolveAddresses(indirectCall.target, indirectCall) + Logger.debug(s"Points-To approximated call ${indirectCall.target} with $targetNames") + Logger.debug(IRProgram.procedures) + val targets: mutable.Set[Procedure] = + targetNames.map(name => IRProgram.procedures.find(_.name == name).getOrElse(addFakeProcedure(name))) + + if (targets.size > 1) { + Logger.info(s"Resolved indirect call $indirectCall") + } + + if (targets.size == 1) { + modified = true + + val newCall = DirectCall(targets.head, indirectCall.label) + block.statements.replace(indirectCall, newCall) + } else if (targets.size > 1) { + + val oft = indirectCall.parent.jump + + modified = true + val newBlocks = ArrayBuffer[Block]() + for (t <- targets) { + Logger.debug(targets) + val address = t.address.match { + case Some(a) => a + case None => + throw Exception(s"resolved indirect call $indirectCall to procedure which does not have address: $t") } - - - if (targets.size == 1) { - modified = true - - // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) - val newCall = DirectCall(targets.head, indirectCall.label) - block.statements.replace(indirectCall, newCall) - } else if (targets.size > 1) { - - val oft = indirectCall.parent.jump - - modified = true - val procedure = c.parent.data - val newBlocks = ArrayBuffer[Block]() - // indirectCall.parent.parent.removeBlocks(indirectCall.returnTarget) - for (t <- targets) { - Logger.debug(targets) - val address = t.address.match { - case Some(a) => a - case None => throw Exception(s"resolved indirect call $indirectCall to procedure which does not have address: $t") - } - val assume = Assume(BinaryExpr(BVEQ, indirectCall.target, BitVecLiteral(address, 64))) - val newLabel: String = block.label + t.name - val directCall = DirectCall(t) - - /* copy the goto node resulting */ - val fallthrough = oft match { - case g: GoTo => GoTo(g.targets, g.label) - case h: Halt => Halt() - case r: Return => Return() - } - newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) - } - block.statements.remove(indirectCall) - procedure.addBlocks(newBlocks) - val newCall = GoTo(newBlocks, indirectCall.label) - block.replaceJump(newCall) + val assume = Assume(BinaryExpr(BVEQ, indirectCall.target, BitVecLiteral(address, 64))) + val newLabel: String = block.label + t.name + val directCall = DirectCall(t) + + /* copy the goto node resulting */ + val fallthrough = oft match { + case g: GoTo => GoTo(g.targets, g.label) + case h: Unreachable => Unreachable() + case r: Return => Return() } - case _ => + newBlocks.append(Block(newLabel, None, ArrayBuffer(assume, directCall), fallthrough)) + } + block.statements.remove(indirectCall) + procedure.addBlocks(newBlocks) + val newCall = GoTo(newBlocks, indirectCall.label) + block.replaceJump(newCall) + } case _ => } diff --git a/src/main/scala/ir/transforms/ReplaceReturn.scala b/src/main/scala/ir/transforms/ReplaceReturn.scala index 91681ab8e..f41c25e3d 100644 --- a/src/main/scala/ir/transforms/ReplaceReturn.scala +++ b/src/main/scala/ir/transforms/ReplaceReturn.scala @@ -13,7 +13,7 @@ class ReplaceReturns extends CILVisitor { j match { case IndirectCall(Register("R30", _), _) => { assert(j.parent.statements.lastOption.contains(j)) - if (j.parent.jump.isInstanceOf[Halt | Return]) { + if (j.parent.jump.isInstanceOf[Unreachable | Return]) { j.parent.replaceJump(Return()) ChangeTo(List()) } else { @@ -32,7 +32,6 @@ def addReturnBlocks(p: Program, toAll: Boolean = false) = { p.procedures.foreach(p => { val containsReturn = p.blocks.map(_.jump).find(_.isInstanceOf[Return]).isDefined if (toAll && p.blocks.isEmpty && p.entryBlock.isEmpty && p.returnBlock.isEmpty) { - Logger.info(s"proc ${p.name} ${p.entryBlock}, ${p.returnBlock}") p.returnBlock = (Block(label=p.name + "_basil_return",jump=Return())) p.entryBlock = (Block(label=p.name + "_basil_entry",jump=GoTo(p.returnBlock.get))) } else if (p.returnBlock.isEmpty && (toAll || containsReturn)) { diff --git a/src/main/scala/ir/transforms/SplitThreads.scala b/src/main/scala/ir/transforms/SplitThreads.scala index 7f36fdb3c..8678720e7 100644 --- a/src/main/scala/ir/transforms/SplitThreads.scala +++ b/src/main/scala/ir/transforms/SplitThreads.scala @@ -11,7 +11,6 @@ import util.Logger import java.util.Base64 import spray.json.DefaultJsonProtocol.* import util.intrusive_list.IntrusiveList -import analysis.CfgCommandNode import scala.collection.mutable import cilvisitor._ diff --git a/src/main/scala/ir/transforms/StripUnreachableFunctions.scala b/src/main/scala/ir/transforms/StripUnreachableFunctions.scala new file mode 100644 index 000000000..96784cc28 --- /dev/null +++ b/src/main/scala/ir/transforms/StripUnreachableFunctions.scala @@ -0,0 +1,38 @@ +package ir.transforms +import ir._ +import collection.mutable + +// This shouldn't be run before indirect calls are resolved +def stripUnreachableFunctions(p: Program, depth: Int = Int.MaxValue): Unit = { + val procedureCalleeNames = p.procedures.map(f => f.name -> f.calls.map(_.name)).toMap + + val toVisit: mutable.LinkedHashSet[(Int, String)] = mutable.LinkedHashSet((0, p.mainProcedure.name)) + var reachableFound = true + val reachableNames = mutable.HashMap[String, Int]() + while (toVisit.nonEmpty) { + val next = toVisit.head + toVisit.remove(next) + + if (next._1 <= depth) { + + def addName(depth: Int, name: String): Unit = { + val oldDepth = reachableNames.getOrElse(name, Integer.MAX_VALUE) + reachableNames.put(next._2, if depth < oldDepth then depth else oldDepth) + } + addName(next._1, next._2) + + val callees = procedureCalleeNames(next._2) + + toVisit.addAll(callees.diff(reachableNames.keySet).map(c => (next._1 + 1, c))) + callees.foreach(c => addName(next._1 + 1, c)) + } + } + p.procedures = p.procedures.filter(f => reachableNames.keySet.contains(f.name)) + + for (elem <- p.procedures.filter(c => c.calls.exists(s => !p.procedures.contains(s)))) { + // last layer is analysed only as specifications so we remove the body for anything that calls + // a function we have removed + + elem.clearBlocks() + } +} diff --git a/src/main/scala/translating/BAPToIR.scala b/src/main/scala/translating/BAPToIR.scala index 90ba61335..93a5010cf 100644 --- a/src/main/scala/translating/BAPToIR.scala +++ b/src/main/scala/translating/BAPToIR.scala @@ -136,11 +136,11 @@ class BAPToIR(var program: BAPProgram, mainAddress: Int) { jumps.head match { case b: BAPDirectCall => val call = Some(DirectCall(nameToProcedure(b.target),Some(b.line))) - val ft = (b.returnTarget.map(t => labelToBlock(t))).map(x => GoTo(Set(x))).getOrElse(Halt()) + val ft = (b.returnTarget.map(t => labelToBlock(t))).map(x => GoTo(Set(x))).getOrElse(Unreachable()) (call, ft, ArrayBuffer()) case b: BAPIndirectCall => val call = IndirectCall(b.target.toIR, Some(b.line)) - val ft = (b.returnTarget.map(t => labelToBlock(t))).map(x => GoTo(Set(x))).getOrElse(Halt()) + val ft = (b.returnTarget.map(t => labelToBlock(t))).map(x => GoTo(Set(x))).getOrElse(Unreachable()) (Some(call), ft, ArrayBuffer()) case b: BAPGoTo => val target = labelToBlock(b.target) diff --git a/src/main/scala/translating/GTIRBToIR.scala b/src/main/scala/translating/GTIRBToIR.scala index d47bd0b35..3c8b1a13c 100644 --- a/src/main/scala/translating/GTIRBToIR.scala +++ b/src/main/scala/translating/GTIRBToIR.scala @@ -364,7 +364,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ // need to copy jump as it can't have multiple parents val jumpCopy = currentBlock.jump match { case GoTo(targets, label) => GoTo(targets, label) - case h: Halt => Halt() + case h: Unreachable => Unreachable() case r: Return => Return() case _ => throw Exception("this shouldn't be reachable") } @@ -392,7 +392,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ case _ => throw Exception(s"no assignment to program counter found before indirect call in block ${block.label}") } block.statements.remove(block.statements.last) // remove _PC assignment - (Some(IndirectCall(target)), Halt()) + (Some(IndirectCall(target)), Unreachable()) } else if (proxySymbols.size > 1) { // TODO requires further consideration once encountered throw Exception(s"multiple uuidToSymbol ${proxySymbols.map(_.name).mkString(", ")} associated with proxy block ${byteStringToString(edge.targetUuid)}, target of indirect call from block ${block.label}") @@ -408,7 +408,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ proc } removePCAssign(block) - (Some(DirectCall(target)), Halt()) + (Some(DirectCall(target)), Unreachable()) } } else if (uuidToBlock.contains(edge.targetUuid)) { // resolved indirect jump @@ -428,7 +428,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ val jump = if (procedure == targetProc) { (None, GoTo(mutable.Set(uuidToBlock(edge.targetUuid)))) } else { - (Some(DirectCall(targetProc)), Halt()) + (Some(DirectCall(targetProc)), Unreachable()) } removePCAssign(block) jump @@ -450,7 +450,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ // probably doesn't actually happen in practice since it seems to be after brk instructions? val targetProc = entranceUUIDtoProcedure(edge.targetUuid) // assuming fallthrough won't fall through to start of own procedure - (Some(DirectCall(targetProc)), Halt()) + (Some(DirectCall(targetProc)), Unreachable()) } else if (uuidToBlock.contains(edge.targetUuid)) { val target = uuidToBlock(edge.targetUuid) (None, GoTo(mutable.Set(target))) @@ -463,7 +463,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ if (entranceUUIDtoProcedure.contains(edge.targetUuid)) { val target = entranceUUIDtoProcedure(edge.targetUuid) removePCAssign(block) - (Some(DirectCall(target)), Halt()) + (Some(DirectCall(target)), Unreachable()) } else { throw Exception(s"edge from ${block.label} to ${byteStringToString(edge.targetUuid)} does not point to a known procedure entrance") } diff --git a/src/main/scala/translating/ILtoIL.scala b/src/main/scala/translating/ILtoIL.scala index 99b64bc6e..856b18934 100644 --- a/src/main/scala/translating/ILtoIL.scala +++ b/src/main/scala/translating/ILtoIL.scala @@ -61,7 +61,12 @@ private class ILSerialiser extends ReadOnlyVisitor { } override def visitJump(node: Jump): Jump = { - node.acceptVisit(this) + node match { + case j: GoTo => program ++= s"goTo(${j.targets.map(_.label).mkString(", ")})" + case h: Unreachable => program ++= "halt" + case h: Return => program ++= "return" + } + node } @@ -77,7 +82,6 @@ private class ILSerialiser extends ReadOnlyVisitor { override def visitDirectCall(node: DirectCall): Statement = { program ++= "DirectCall(" program ++= procedureIdentifier(node.target) - program ++= ", " program ++= ")" // DirectCall node } @@ -95,7 +99,10 @@ private class ILSerialiser extends ReadOnlyVisitor { program ++= "Block(" + blockIdentifier(node) + ",\n" indentLevel += 1 program ++= getIndent() - program ++= "statements(\n" + program ++= "statements(" + if (node.statements.size > 0) { + program ++= "\n" + } indentLevel += 1 for (s <- node.statements) { @@ -105,8 +112,7 @@ private class ILSerialiser extends ReadOnlyVisitor { } indentLevel -= 1 program ++= getIndent() + "),\n" - program ++= getIndent() + "jumps(\n" - program ++= getIndent() + program ++= getIndent() + "jump(" visitJump(node.jump) program ++= ")\n" indentLevel -= 1 diff --git a/src/main/scala/translating/IRToBoogie.scala b/src/main/scala/translating/IRToBoogie.scala index 3422c6daa..6b4cf740f 100644 --- a/src/main/scala/translating/IRToBoogie.scala +++ b/src/main/scala/translating/IRToBoogie.scala @@ -650,7 +650,7 @@ class IRToBoogie(var program: Program, var spec: Specification, var thread: Opti val jump = GoToCmd(g.targets.map(_.label).toSeq) conditionAssert :+ jump case r: Return => List(ReturnCmd) - case r: Halt => List(BAssume(FalseBLiteral)) + case r: Unreachable => List(BAssume(FalseBLiteral)) } def translate(j: Call): List[BCmd] = j match { diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index 8c8e85e66..f67087070 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -13,7 +13,6 @@ import java.io.{BufferedWriter, FileWriter, IOException} import scala.jdk.CollectionConverters.* import analysis.solvers.* import analysis.* -import cfg_visualiser.Output import bap.* import ir.* import boogie.* @@ -28,7 +27,6 @@ import util.Logger import java.util.Base64 import spray.json.DefaultJsonProtocol.* import util.intrusive_list.IntrusiveList -import analysis.CfgCommandNode import cilvisitor._ import scala.annotation.tailrec @@ -51,7 +49,6 @@ case class IRContext( /** Stores the results of the static analyses. */ case class StaticAnalysisContext( - cfg: ProgramCfg, constPropResult: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], IRconstPropResult: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], memoryRegionResult: Map[CFGPosition, LiftedElement[Set[MemoryRegion]]], @@ -217,7 +214,7 @@ object IRTransform { Logger.info("[!] Stripping unreachable") val before = ctx.program.procedures.size - ctx.program.stripUnreachableFunctions(config.procedureTrimDepth) + transforms.stripUnreachableFunctions(ctx.program, config.procedureTrimDepth) Logger.info( s"[!] Removed ${before - ctx.program.procedures.size} functions (${ctx.program.procedures.size} remaining)" ) @@ -304,15 +301,12 @@ object StaticAnalysis { newLoops.foreach(l => Logger.info(s"Loop found: ${l.name}")) config.analysisDotPath.foreach { s => - val newCFG = ProgramCfgFactory().fromIR(IRProgram) - writeToFile(newCFG.toDot(x => x.toString, Output.dotIder), s"${s}_resolvedCFG-reducible.dot") + writeToFile(dotBlockGraph(IRProgram, IRProgram.map(b => b -> b.toString).toMap), s"${s}_graph-after-reduce-$iteration.dot") writeToFile(dotBlockGraph(IRProgram, IRProgram.filter(_.isInstanceOf[Block]).map(b => b -> b.toString).toMap), s"${s}_blockgraph-after-reduce-$iteration.dot") } val mergedSubroutines = subroutines ++ externalAddresses - val cfg = ProgramCfgFactory().fromIR(IRProgram) - val domain = computeDomain(IntraProcIRCursor, IRProgram.procedures) Logger.info("[!] Running ANR") @@ -356,11 +350,17 @@ object StaticAnalysis { Logger.info("[!] Running RegToMemAnalysisSolver") - val regionAccessesAnalysisSolver = RegionAccessesAnalysisSolver(cfg, constPropResult, reachingDefinitionsAnalysisResults) + val regionAccessesAnalysisSolver = RegionAccessesAnalysisSolver(IRProgram, constPropResult, reachingDefinitionsAnalysisResults) val regionAccessesAnalysisResults = regionAccessesAnalysisSolver.analyze() - config.analysisDotPath.foreach(s => writeToFile(cfg.toDot(Output.labeler(regionAccessesAnalysisResults, true), Output.dotIder), s"${s}_RegTo$iteration.dot")) - config.analysisResultsPath.foreach(s => writeToFile(printAnalysisResults(cfg, regionAccessesAnalysisResults, iteration), s"${s}_RegTo$iteration.txt")) +// config.analysisDotPath.foreach(s => writeToFile(cfg.toDot(Output.labeler(regionAccessesAnalysisResults, true), Output.dotIder), s"${s}_RegTo$iteration.dot")) + config.analysisResultsPath.foreach(s => writeToFile(printAnalysisResults(IRProgram, regionAccessesAnalysisResults), s"${s}_RegTo$iteration.txt")) + config.analysisDotPath.foreach(s => { + writeToFile( + toDot(IRProgram, IRProgram.filter(_.isInstanceOf[Command]).map(b => b -> regionAccessesAnalysisResults(b).toString).toMap), + s"${s}_RegTo$iteration.dot" + ) + }) Logger.info("[!] Running Constant Propagation with SSA") val constPropSolverWithSSA = ConstantPropagationSolverWithSSA(IRProgram, reachingDefinitionsAnalysisResults) @@ -401,10 +401,8 @@ object StaticAnalysis { mmm.logRegions(memoryRegionContents) // turn fake procedures into diamonds - transforms.addReturnBlocks(ctx.program, true) // add return to all blocks because IDE solver expects it Logger.info("[!] Running VSA") - val vsaSolver = - ValueSetAnalysisSolver(IRProgram, globalAddresses, externalAddresses, globalOffsets, subroutines, mmm, constPropResult) + val vsaSolver = ValueSetAnalysisSolver(IRProgram, globalAddresses, externalAddresses, globalOffsets, subroutines, mmm, constPropResult) val vsaResult: Map[CFGPosition, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]] = vsaSolver.analyze() Logger.info("[!] Running Interprocedural Live Variables Analysis") @@ -416,7 +414,6 @@ object StaticAnalysis { // val paramResults = Map[Procedure, Set[Variable]]() StaticAnalysisContext( - cfg = cfg, constPropResult = constPropResult, IRconstPropResult = newCPResult, memoryRegionResult = mraResult, @@ -431,34 +428,6 @@ object StaticAnalysis { ) } - /** Converts MapLattice of CfgNodes to a MapLattice from IRPosition. - * @param cfg - * The CFG - * @param result - * The analysis result MapLattice - * @tparam T - * The analysis result type. - * @return - * The new map analysis result. - */ - def convertAnalysisResults[T](cfg: ProgramCfg, result: Map[CfgNode, T]): Map[CFGPosition, T] = { - val results = mutable.HashMap[CFGPosition, T]() - result.foreach((node, res) => - node match { - case s: CfgStatementNode => results.addOne(s.data -> res) - case s: CfgFunctionEntryNode => results.addOne(s.data -> res) - case s: CfgJumpNode => results.addOne(s.data -> res) - case s: CfgCommandNode => results.addOne(s.data -> res) - case _ => () - } - ) - - results.toMap - } - - def printAnalysisResults[T](program: Program, cfg: ProgramCfg, result: Map[CfgNode, T]): String = { - printAnalysisResults(program, convertAnalysisResults(cfg, result)) - } def printAnalysisResults(prog: Program, result: Map[CFGPosition, _]): String = { val results = mutable.ArrayBuffer[String]() @@ -502,86 +471,6 @@ object StaticAnalysis { results.mkString(System.lineSeparator()) } - def printAnalysisResults(cfg: ProgramCfg, result: Map[CfgNode, _], iteration: Int): String = { - val functionEntries = cfg.nodes.collect { case n: CfgFunctionEntryNode => n }.toSeq.sortBy(_.data.name) - val s = StringBuilder() - s.append(System.lineSeparator()) - for (f <- functionEntries) { - val stack: mutable.Stack[CfgNode] = mutable.Stack() - val visited: mutable.Set[CfgNode] = mutable.Set() - stack.push(f) - var previousBlock: String = "" - var isEntryNode = false - while (stack.nonEmpty) { - val next = stack.pop() - if (!visited.contains(next)) { - visited.add(next) - next.match { - case c: CfgCommandNode => - if (c.block.label != previousBlock) { - printBlock(c) - } - c match { - case _: CfgStatementNode => s.append(" ") - case _ => () - } - printNode(c) - previousBlock = c.block.label - isEntryNode = false - case c: CfgFunctionEntryNode => - printNode(c) - isEntryNode = true - case c: CfgCallNoReturnNode => - s.append(System.lineSeparator()) - isEntryNode = false - case _ => isEntryNode = false - } - val successors = next.succIntra - if (successors.size > 1) { - val successorsCmd = successors.collect { case c: CfgCommandNode => c }.toSeq.sortBy(_.data.toString) - printGoTo(successorsCmd) - for (s <- successorsCmd) { - if (!visited.contains(s)) { - stack.push(s) - } - } - } else if (successors.size == 1) { - val successor = successors.head - if (!visited.contains(successor)) { - stack.push(successor) - } - successor.match { - case c: CfgCommandNode if (c.block.label != previousBlock) && (!isEntryNode) => printGoTo(Seq(c)) - case _ => - } - } - } - } - s.append(System.lineSeparator()) - } - - def printNode(node: CfgNode): Unit = { - s.append(node) - s.append(" :: ") - s.append(result(node)) - s.append(System.lineSeparator()) - } - - def printGoTo(nodes: Seq[CfgCommandNode]): Unit = { - s.append("[GoTo] ") - s.append(nodes.map(_.block.label).mkString(", ")) - s.append(System.lineSeparator()) - s.append(System.lineSeparator()) - } - - def printBlock(node: CfgCommandNode): Unit = { - s.append("[Block] ") - s.append(node.block.label) - s.append(System.lineSeparator()) - } - - s.toString - } } @@ -645,7 +534,7 @@ object RunUtils { val result = StaticAnalysis.analyse(ctx, config, iteration) analysisResult.append(result) Logger.info("[!] Replacing Indirect Calls") - modified = transforms.resolveIndirectCallsUsingPointsTo(result.cfg, + modified = transforms.resolveIndirectCallsUsingPointsTo( result.steensgaardResults, result.memoryRegionContents, result.reachingDefs, @@ -667,11 +556,6 @@ object RunUtils { transforms.splitThreads(ctx.program, analysisResult.last.steensgaardResults, analysisResult.last.memoryRegionContents, analysisResult.last.reachingDefs) } - config.analysisDotPath.foreach { s => - val newCFG = analysisResult.last.cfg - writeToFile(newCFG.toDot(x => x.toString, Output.dotIder), s"${s}_resolvedCFG.dot") - } - assert(invariant.singleCallBlockEnd(ctx.program)) Logger.info(s"[!] Finished indirect call resolution after $iteration iterations") analysisResult.last diff --git a/src/test/scala/IndirectCallsTests.scala b/src/test/scala/IndirectCallsTests.scala index 15c635fc9..fbc5a4362 100644 --- a/src/test/scala/IndirectCallsTests.scala +++ b/src/test/scala/IndirectCallsTests.scala @@ -84,7 +84,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -122,7 +121,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -160,7 +158,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -204,7 +201,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -248,7 +244,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -291,7 +286,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -335,7 +329,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -363,7 +356,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before "l000004f3set_seven" -> ("set_seven", "R0") ) - println("prev " + result.ir.program) // Traverse the statements in the main function result.ir.program.mainProcedure.blocks.foreach { block => @@ -375,7 +367,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -414,7 +405,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before case _ => } } - println(result.ir.program) assert(expectedCallTransform.isEmpty) } @@ -453,7 +443,6 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before } // Traverse the statements in the main function - println(result.ir.program) assert(expectedCallTransform.isEmpty) } diff --git a/src/test/scala/LiveVarsAnalysisTests.scala b/src/test/scala/LiveVarsAnalysisTests.scala index e1b142001..881ad61fc 100644 --- a/src/test/scala/LiveVarsAnalysisTests.scala +++ b/src/test/scala/LiveVarsAnalysisTests.scala @@ -115,10 +115,12 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { var program = prog( proc("main", block("main_first_call", - directCall("wrapper1"), goto("main_second_call") + directCall("wrapper1"), + goto("main_second_call") ), block("main_second_call", - directCall("wrapper2"), goto("main_return") + directCall("wrapper2"), + goto("main_return") ), block("main_return", ret) ), @@ -128,10 +130,12 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { proc("wrapper1", block("wrapper1_first_call", Assign(R1, constant1), - directCall("callee"), goto("wrapper1_second_call") + directCall("callee"), + goto("wrapper1_second_call") ), block("wrapper1_second_call", - directCall("callee2"), goto("wrapper1_return")), + directCall("callee2"), + goto("wrapper1_return")), block("wrapper1_return", ret) ), proc("wrapper2", @@ -349,12 +353,11 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil { val blocks = result.ir.program.blocks // main has no parameters, get_two has three and a return - assert(analysisResults(blocks("lmain").jump) == Map(R29 -> TwoElementTop, R31 -> TwoElementTop)) - assert(analysisResults(blocks("l000003ec").jump) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) // get_two aftercall - assert(analysisResults(blocks("l00000430").jump) == Map(R31 -> TwoElementTop)) // printf aftercall - assert(analysisResults(blocks("main_basil_return").jump) == Map(R30 -> TwoElementTop)) - assert(analysisResults(blocks("lget_two").jump) == Map(R0 -> TwoElementTop, R1 -> TwoElementTop, R2 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) - assert(analysisResults(blocks("get_two_basil_return").jump) == Map(R0 -> TwoElementTop, R30 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("lmain")) == Map(R29 -> TwoElementTop, R31 -> TwoElementTop, R30 -> TwoElementTop)) + assert(analysisResults(blocks("l000003ec")) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) // get_two aftercall + assert(analysisResults(blocks("l00000430")) == Map(R31 -> TwoElementTop)) // printf aftercall + assert(analysisResults(blocks("lget_two")) == Map(R0 -> TwoElementTop, R1 -> TwoElementTop, R2 -> TwoElementTop, R31 -> TwoElementTop)) + assert(analysisResults(blocks("get_two_basil_return")) == Map(R0 -> TwoElementTop, R31 -> TwoElementTop)) } test("ifbranches") { diff --git a/src/test/scala/MemoryRegionAnalysisMiscTest.scala b/src/test/scala/MemoryRegionAnalysisMiscTest.scala index 2844548d1..a7d7e4e73 100644 --- a/src/test/scala/MemoryRegionAnalysisMiscTest.scala +++ b/src/test/scala/MemoryRegionAnalysisMiscTest.scala @@ -1,4 +1,4 @@ -import analysis.{CfgNode, LiftedElement, MemoryRegion} +import analysis.{LiftedElement, MemoryRegion} import org.scalatest.Inside.inside import org.scalatest.* import org.scalatest.funsuite.* diff --git a/src/test/scala/ir/IRTest.scala b/src/test/scala/ir/IRTest.scala index 7421c2a28..ac6df38cf 100644 --- a/src/test/scala/ir/IRTest.scala +++ b/src/test/scala/ir/IRTest.scala @@ -223,7 +223,7 @@ class IRTest extends AnyFunSuite { Assign(R0, bv64(22)), Assign(R0, bv64(22)), directCall("main"), - halt + unreachable ).resolve(p) val b2 = block("newblock1", Assign(R0, bv64(22)), @@ -249,7 +249,7 @@ class IRTest extends AnyFunSuite { val b3 = block("newblock3", Assign(R0, bv64(22)), directCall("called"), - halt + unreachable ).resolve(p) assert(b3.calls.toSet == Set(p.procs("called"))) diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 47cf65e3c..57f9739fe 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -30,7 +30,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { var IRProgram = IRTranslator.translate IRProgram = ExternalRemover(externalFunctions.map(e => e.name)).visitProgram(IRProgram) IRProgram = Renamer(Set("free")).visitProgram(IRProgram) - IRProgram.stripUnreachableFunctions() + transforms.stripUnreachableFunctions(IRProgram) val stackIdentification = StackSubstituter() stackIdentification.visitProgram(IRProgram) IRProgram.setModifies(Map()) From e2c0cd4f85a507c06934fb1acc2f80df5ef04502 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Mon, 19 Aug 2024 12:15:10 +1000 Subject: [PATCH 23/62] update docs --- docs/basil-ir.md | 45 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/docs/basil-ir.md b/docs/basil-ir.md index fe30af78d..4c152db8d 100644 --- a/docs/basil-ir.md +++ b/docs/basil-ir.md @@ -3,6 +3,7 @@ BASIL IR is the intermediate representation used during static analysis. This is on contrast to Boogie IR which is used for specification annotation, and output to textual boogie syntax that can be run through the Boogie verifier. + The grammar is described below, note that the IR is a data-structure, without a concrete textual representation so the below grammar only represents the structure. We omit the full description of the expression language because it is relatively standard. @@ -13,13 +14,16 @@ The IR has a completely standard simple type system that is enforced at construc Program ::=&~ Procedure* \\ Procedure ::=&~ (name: ProcID) (entryBlock: Block) (returnBlock: Block) (blocks: Block*) \\ &~ \text{Where }entryBlock, returnBlock \in blocks \\ -Block ::=&~ BlockID \; (Statement*)\; Jump \; (fallthrough: (GoTo | None))\\ - &~ \text{Where $fallthough$ may be $GoTo$ IF $Jump$ is $Call$} \\ +Block_1 ::=&~ BlockID \; Statement*\; Call? \; Jump \; \\ +Block_2 ::=&~ BlockID \; (Statement | Call)*\; Jump \; \\ +\\ +&~ Block = Block_1 \text{ is a structural invariant that holds during all the early analysis/transform stages} +\\ Statement ::=&~ MemoryAssign ~|~ LocalAssign ~|~ Assume ~|~ Assert ~|~ NOP \\ ProcID ::=&~ String \\ BlockID ::=&~ String \\ \\ -Jump ::=&~ Call ~|~ GoTo \\ +Jump ::=&~ GoTo ~|~ Unreachable ~|~ Return \\ GoTo ::=&~ \text{goto } BlockID* \\ Call ::=&~ DirectCall ~|~ IndirectCall \\ DirectCall ::=&~ \text{call } ProcID \\ @@ -46,6 +50,33 @@ Endian ::=&~ BigEndian ~|~ LittleEndian \\ \end{align*} ``` +- The `GoTo` jump is a multi-target jump reprsenting non-deterministic choice between its targets. + Conditional structures are represented by these with a guard (an assume statement) beginning each target. +- The `Unreachable` jump is used to signify the absence of successors, it has the semantics of `assume false`. +- The `Return` jump passes control to the calling function, often this is over-approximated to all functions which call the statement's parent procedure. + +## Translation Phases + +#### IR With Returns + +- Immediately after loading the IR return statements may appear in any block, or may be represented by indirect calls. + The transform pass below replaces all calls to the link register (R30) with return statements. + In the future, more proof is required to implement this soundly. + +``` +cilvisitor.visit_prog(transforms.ReplaceReturns(), ctx.program) +transforms.addReturnBlocks(ctx.program, true) // add return to all blocks because IDE solver expects it +cilvisitor.visit_prog(transforms.ConvertSingleReturn(), ctx.program) +``` + +This ensures that all returning, non-stub procedures have exactly one return statement residing in their `returnBlock`. + +#### Calls appear only as the last statement in a block + +- The structure of the IR allows a call may appear anywhere in the block but for all the analysis passes we hold the invariant that it + only appears as the last statement. This is checked with the function `singleCallBlockEnd(p: Program)`. + And it means for any call statement `c` we may `assert(c.parent.statements.lastOption.contains(c))`. + ## Interaction with BASIL IR ### Constructing Programs in Code @@ -62,10 +93,12 @@ var program: Program = prog( block("first_call", Assign(R0, bv64(1), None) Assign(R1, bv64(1), None) - directCall("callee1", Some("second_call")) + directCall("callee1"), + goto("second_call")) ), block("second_call", - directCall("callee2", Some("returnBlock")) + directCall("callee2"), + goto("returnBlock") ), block("returnBlock", ret @@ -82,7 +115,7 @@ program ::= prog ( procedure+ ) procedure ::= proc (procname, block+) block ::= block(blocklabel, statement+, jump) statement ::= -jump ::= call_s | goto_s | ret +jump ::= goto_s | ret | unreachable call_s ::= directCall (procedurename, None | Some(blocklabel)) // target, fallthrough goto_s ::= goto(blocklabel+) // targets procname ::= String From a3adee3341c3ec19b33fae88a2e30cb982676969 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Fri, 30 Aug 2024 10:04:52 +1000 Subject: [PATCH 24/62] fix --- .../scala/analysis/SummaryGenerator.scala | 2 +- src/main/scala/analysis/TaintAnalysis.scala | 4 +- .../analysis/VariableDependencyAnalysis.scala | 6 +- .../scala/analysis/solvers/IDESolver.scala | 4 +- src/main/scala/util/RunUtils.scala | 4 +- src/test/scala/TaintAnalysisTests.scala | 64 +++++++++++-------- 6 files changed, 47 insertions(+), 37 deletions(-) diff --git a/src/main/scala/analysis/SummaryGenerator.scala b/src/main/scala/analysis/SummaryGenerator.scala index 9d633f3a5..79833224a 100644 --- a/src/main/scala/analysis/SummaryGenerator.scala +++ b/src/main/scala/analysis/SummaryGenerator.scala @@ -142,7 +142,7 @@ class SummaryGenerator( // Use rnaResults to find stack function arguments val tainters = relevantVars.map { v => (v, Set()) - }.toMap ++ getTainters(procedure, variables ++ rnaResults(procedure.begin) + UnknownMemory()).filter { (variable, taints) => + }.toMap ++ getTainters(procedure, variables ++ rnaResults(IRWalk.firstInProc(procedure)) + UnknownMemory()).filter { (variable, taints) => relevantVars.contains(variable) } diff --git a/src/main/scala/analysis/TaintAnalysis.scala b/src/main/scala/analysis/TaintAnalysis.scala index 055188d3b..d51eda374 100644 --- a/src/main/scala/analysis/TaintAnalysis.scala +++ b/src/main/scala/analysis/TaintAnalysis.scala @@ -106,11 +106,11 @@ trait TaintAnalysisFunctions( Map(d -> IdEdge()) } - def edgesExitToAfterCall(exit: IndirectCall, aftercall: GoTo)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { + def edgesExitToAfterCall(exit: Return, aftercall: Command)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { Map(d -> IdEdge()) } - def edgesCallToAfterCall(call: DirectCall, aftercall: GoTo)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { + def edgesCallToAfterCall(call: DirectCall, aftercall: Command)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { Map(d -> IdEdge()) } diff --git a/src/main/scala/analysis/VariableDependencyAnalysis.scala b/src/main/scala/analysis/VariableDependencyAnalysis.scala index 26852f856..31fe43eb3 100644 --- a/src/main/scala/analysis/VariableDependencyAnalysis.scala +++ b/src/main/scala/analysis/VariableDependencyAnalysis.scala @@ -31,11 +31,11 @@ trait ProcVariableDependencyAnalysisFunctions( if varDepsSummaries.contains(entry) then Map() else Map(d -> IdEdge()) } - def edgesExitToAfterCall(exit: IndirectCall, aftercall: GoTo)(d: DL): Map[DL, EdgeFunction[Set[Taintable]]] = { + def edgesExitToAfterCall(exit: Return, aftercall: Command)(d: DL): Map[DL, EdgeFunction[Set[Taintable]]] = { if reachable.contains(aftercall.parent.parent) then Map(d -> IdEdge()) else Map() } - def edgesCallToAfterCall(call: DirectCall, aftercall: GoTo)(d: DL): Map[DL, EdgeFunction[Set[Taintable]]] = { + def edgesCallToAfterCall(call: DirectCall, aftercall: Command)(d: DL): Map[DL, EdgeFunction[Set[Taintable]]] = { d match { case Left(v) => varDepsSummaries.get(call.target).flatMap(_.get(v).map( _.foldLeft(Map[DL, EdgeFunction[Set[Taintable]]]()) { (m, d) => m + (Left(d) -> IdEdge()) @@ -120,7 +120,7 @@ class VariableDependencyAnalysis( procedure => { Logger.info("Generating variable dependencies for " + procedure) val varDepResults = ProcVariableDependencyAnalysis(program, varDepVariables, globals, constProp, varDepsSummariesTransposed, procedure).analyze() - val varDepMap = varDepResults.getOrElse(procedure.end, Map()) + val varDepMap = varDepResults.getOrElse(IRWalk.lastInProc(procedure), Map()) varDepsSummaries += procedure -> varDepMap varDepsSummariesTransposed += procedure -> varDepMap.foldLeft(Map[Taintable, Set[Taintable]]()) { (m, p) => { diff --git a/src/main/scala/analysis/solvers/IDESolver.scala b/src/main/scala/analysis/solvers/IDESolver.scala index eaa9a2369..fb6a0c042 100644 --- a/src/main/scala/analysis/solvers/IDESolver.scala +++ b/src/main/scala/analysis/solvers/IDESolver.scala @@ -229,7 +229,7 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) protected def isCall(call: CFGPosition): Boolean = call match - case directCall: DirectCall if (!directCall.successor.isInstanceOf[Unreachable]) => true + case directCall: DirectCall if (!directCall.successor.isInstanceOf[Unreachable] && directCall.target.returnBlock.isDefined) => true case _ => false protected def isExit(exit: CFGPosition): Boolean = @@ -238,7 +238,7 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) case command: Return => true case _ => false - protected def getAfterCalls(exit: IndirectCall): Set[Command] = + protected def getAfterCalls(exit: Return): Set[Command] = InterProcIRCursor.succ(exit).filter(_.isInstanceOf[Command]).map(_.asInstanceOf[Command]) } diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index f67087070..0fceafa1f 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -196,13 +196,13 @@ object IRTransform { } val externalRemover = ExternalRemover(externalNamesLibRemoved.toSet) val renamer = Renamer(boogieReserved) + externalRemover.visitProgram(ctx.program) + renamer.visitProgram(ctx.program) cilvisitor.visit_prog(transforms.ReplaceReturns(), ctx.program) transforms.addReturnBlocks(ctx.program, true) // add return to all blocks because IDE solver expects it cilvisitor.visit_prog(transforms.ConvertSingleReturn(), ctx.program) - externalRemover.visitProgram(ctx.program) - renamer.visitProgram(ctx.program) ctx } diff --git a/src/test/scala/TaintAnalysisTests.scala b/src/test/scala/TaintAnalysisTests.scala index 37a26f8aa..dfebf6db8 100644 --- a/src/test/scala/TaintAnalysisTests.scala +++ b/src/test/scala/TaintAnalysisTests.scala @@ -24,7 +24,8 @@ class TaintAnalysisTests extends AnyFunSuite, TestUtil { var program = prog( proc("main", block("main", - directCall("f", Some("mainRet")) + directCall("f"), + goto("mainRet") ), block("mainRet", ret) ), @@ -38,25 +39,27 @@ class TaintAnalysisTests extends AnyFunSuite, TestUtil { ), ) ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val f = program.procs("f") val taint: Map[CFGPosition, Set[Taintable]] = Map(f -> Set(R0)) val taintAnalysisResults = getTaintAnalysisResults(program, f, taint) - assert(taintAnalysisResults.get(f.end) == None) + assert(taintAnalysisResults.get(IRWalk.lastInProc(f)) == None) val varDepResults = getVarDepResults(program, f) - assert(varDepResults.get(f.end) == Some(baseRegisterMap - R0)) + assert(varDepResults.get(IRWalk.lastInProc(f)) == Some(baseRegisterMap - R0)) } test("arguments") { var program = prog( proc("main", block("main", - directCall("f", Some("mainRet")) + directCall("f"), + goto("mainRet") ), block("mainRet", ret) ), @@ -70,25 +73,27 @@ class TaintAnalysisTests extends AnyFunSuite, TestUtil { ), ), ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val f = program.procs("f") val taint: Map[CFGPosition, Set[Taintable]] = Map(f -> Set(R0)) val taintAnalysisResults = getTaintAnalysisResults(program, f, taint) - assert(taintAnalysisResults.get(f.end) == Some(Set(R0))) + assert(taintAnalysisResults.get(IRWalk.lastInProc(f)) == Some(Set(R0))) val varDepResults = getVarDepResults(program, f) - assert(varDepResults.get(f.end) == Some(baseRegisterMap + (R0 -> Set(R0, R1)))) + assert(varDepResults.get(IRWalk.lastInProc(f)) == Some(baseRegisterMap + (R0 -> Set(R0, R1)))) } test("branching") { var program = prog( proc("main", block("main", - directCall("f", Some("mainRet")) + directCall("f"), + goto("mainRet") ), block("mainRet", ret) ), @@ -109,25 +114,27 @@ class TaintAnalysisTests extends AnyFunSuite, TestUtil { ), ), ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val f = program.procs("f") val taint: Map[CFGPosition, Set[Taintable]] = Map(f -> Set(R1)) val taintAnalysisResults = getTaintAnalysisResults(program, f, taint) - assert(taintAnalysisResults.get(f.end) == Some(Set(R0, R1))) + assert(taintAnalysisResults.get(IRWalk.lastInProc(f)) == Some(Set(R0, R1))) val varDepResults = getVarDepResults(program, f) - assert(varDepResults.get(f.end) == Some(baseRegisterMap + (R0 -> Set(R1, R2)))) + assert(varDepResults.get(IRWalk.lastInProc(f)) == Some(baseRegisterMap + (R0 -> Set(R1, R2)))) } test("interproc") { var program = prog( proc("main", block("main", - directCall("f", Some("mainRet")) + directCall("f"), + goto("mainRet") ), block("mainRet", ret) ), @@ -137,12 +144,12 @@ class TaintAnalysisTests extends AnyFunSuite, TestUtil { ), block("a", Assign(R1, R1, None), - directCall("g", None), + directCall("g"), goto("returnBlock"), ), block("b", Assign(R1, R2, None), - directCall("g", None), + directCall("g"), goto("returnBlock"), ), block("returnBlock", @@ -159,25 +166,27 @@ class TaintAnalysisTests extends AnyFunSuite, TestUtil { ), ), ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val f = program.procs("f") val taint: Map[CFGPosition, Set[Taintable]] = Map(f -> Set(R1)) val taintAnalysisResults = getTaintAnalysisResults(program, f, taint) - assert(taintAnalysisResults.get(f.end) == Some(Set(R0, R1))) + assert(taintAnalysisResults.get(IRWalk.lastInProc(f)) == Some(Set(R0, R1))) val varDepResults = getVarDepResults(program, f) - assert(varDepResults.get(f.end) == Some(baseRegisterMap + (R0 -> Set(R1, R2)) + (R1 -> Set(R1, R2)))) + assert(varDepResults.get(IRWalk.lastInProc(f)) == Some(baseRegisterMap + (R0 -> Set(R1, R2)) + (R1 -> Set(R1, R2)))) } test("loop") { var program = prog( proc("main", block("main", - directCall("f", Some("mainRet")) + directCall("f"), + goto("mainRet") ), block("mainRet", ret) ), @@ -198,17 +207,18 @@ class TaintAnalysisTests extends AnyFunSuite, TestUtil { ), ), ) - val returnUnifier = ConvertToSingleProcedureReturn() - program = returnUnifier.visitProgram(program) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) val f = program.procs("f") val taint: Map[CFGPosition, Set[Taintable]] = Map(f -> Set(R1)) val taintAnalysisResults = getTaintAnalysisResults(program, f, taint) - assert(taintAnalysisResults.get(f.end) == Some(Set(R1))) + assert(taintAnalysisResults.get(IRWalk.lastInProc(f)) == Some(Set(R1))) val varDepResults = getVarDepResults(program, f) - assert(varDepResults.get(f.end) == Some(baseRegisterMap + (R0 -> Set(R2)))) + assert(varDepResults.get(IRWalk.lastInProc(f)) == Some(baseRegisterMap + (R0 -> Set(R2)))) } } From 4a92c990e9c72eab8a62471f2c3423eb00d77038 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Fri, 30 Aug 2024 10:31:40 +1000 Subject: [PATCH 25/62] fix externals --- .../scala/analysis/SummaryGenerator.scala | 2 +- .../analysis/VariableDependencyAnalysis.scala | 4 ++-- .../scala/analysis/solvers/IDESolver.scala | 6 +++--- src/main/scala/ir/IRCursor.scala | 5 +++-- src/main/scala/util/RunUtils.scala | 8 +++++--- src/test/scala/TaintAnalysisTests.scala | 20 +++++++++---------- 6 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/main/scala/analysis/SummaryGenerator.scala b/src/main/scala/analysis/SummaryGenerator.scala index 79833224a..552295f23 100644 --- a/src/main/scala/analysis/SummaryGenerator.scala +++ b/src/main/scala/analysis/SummaryGenerator.scala @@ -142,7 +142,7 @@ class SummaryGenerator( // Use rnaResults to find stack function arguments val tainters = relevantVars.map { v => (v, Set()) - }.toMap ++ getTainters(procedure, variables ++ rnaResults(IRWalk.firstInProc(procedure)) + UnknownMemory()).filter { (variable, taints) => + }.toMap ++ getTainters(procedure, variables ++ rnaResults(IRWalk.firstInProc(procedure).get) + UnknownMemory()).filter { (variable, taints) => relevantVars.contains(variable) } diff --git a/src/main/scala/analysis/VariableDependencyAnalysis.scala b/src/main/scala/analysis/VariableDependencyAnalysis.scala index 31fe43eb3..1b8258ad5 100644 --- a/src/main/scala/analysis/VariableDependencyAnalysis.scala +++ b/src/main/scala/analysis/VariableDependencyAnalysis.scala @@ -116,11 +116,11 @@ class VariableDependencyAnalysis( def analyze(): Map[Procedure, Map[Taintable, Set[Taintable]]] = { var varDepsSummaries = Map[Procedure, Map[Taintable, Set[Taintable]]]() var varDepsSummariesTransposed = Map[Procedure, Map[Taintable, Set[Taintable]]]() - scc.flatten.foreach { + scc.flatten.filter(_.blocks.nonEmpty).foreach { procedure => { Logger.info("Generating variable dependencies for " + procedure) val varDepResults = ProcVariableDependencyAnalysis(program, varDepVariables, globals, constProp, varDepsSummariesTransposed, procedure).analyze() - val varDepMap = varDepResults.getOrElse(IRWalk.lastInProc(procedure), Map()) + val varDepMap = varDepResults.getOrElse(IRWalk.lastInProc(procedure).getOrElse(procedure), Map()) varDepsSummaries += procedure -> varDepMap varDepsSummariesTransposed += procedure -> varDepMap.foldLeft(Map[Taintable, Set[Taintable]]()) { (m, p) => { diff --git a/src/main/scala/analysis/solvers/IDESolver.scala b/src/main/scala/analysis/solvers/IDESolver.scala index fb6a0c042..5e94726d9 100644 --- a/src/main/scala/analysis/solvers/IDESolver.scala +++ b/src/main/scala/analysis/solvers/IDESolver.scala @@ -211,7 +211,7 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) extends IDESolver[Procedure, Return, DirectCall, Command, D, T, L](program, program.mainProcedure), ForwardIDEAnalysis[D, T, L], IRInterproceduralForwardDependencies { - protected def entryToExit(entry: Procedure): Return = IRWalk.lastInProc(entry).asInstanceOf[Return] + protected def entryToExit(entry: Procedure): Return = IRWalk.lastInProc(entry).get.asInstanceOf[Return] protected def exitToEntry(exit: Return): Procedure = IRWalk.procedure(exit) @@ -229,7 +229,7 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) protected def isCall(call: CFGPosition): Boolean = call match - case directCall: DirectCall if (!directCall.successor.isInstanceOf[Unreachable] && directCall.target.returnBlock.isDefined) => true + case directCall: DirectCall if (!directCall.successor.isInstanceOf[Unreachable] && directCall.target.returnBlock.isDefined && directCall.target.entryBlock.isDefined) => true case _ => false protected def isExit(exit: CFGPosition): Boolean = @@ -244,7 +244,7 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) abstract class BackwardIDESolver[D, T, L <: Lattice[T]](program: Program) - extends IDESolver[Return, Procedure, Command, DirectCall, D, T, L](program, IRWalk.lastInProc(program.mainProcedure)), + extends IDESolver[Return, Procedure, Command, DirectCall, D, T, L](program, IRWalk.lastInProc(program.mainProcedure).get), BackwardIDEAnalysis[D, T, L], IRInterproceduralBackwardDependencies { protected def entryToExit(entry: Return): Procedure = IRWalk.procedure(entry) diff --git a/src/main/scala/ir/IRCursor.scala b/src/main/scala/ir/IRCursor.scala index 0e79fa422..21944d4f1 100644 --- a/src/main/scala/ir/IRCursor.scala +++ b/src/main/scala/ir/IRCursor.scala @@ -73,8 +73,9 @@ object IRWalk: def lastInBlock(p: Block): Command = p.jump def firstInBlock(p: Block): Command = p.statements.headOption.getOrElse(p.jump) - def firstInProc(p: Procedure): Command = firstInBlock(p.entryBlock.get) - def lastInProc(p: Procedure): Command = lastInBlock(p.returnBlock.get) + def firstInProc(p: Procedure): Option[Command] = p.entryBlock.map(firstInBlock) + def lastInProc(p: Procedure): Option[Command] = p.returnBlock.map(lastInBlock) + // extension (p: Block) // def isProcEntry: Boolean = p.parent.entryBlock.contains(p) diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index 0fceafa1f..dd6ba9a5a 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -194,14 +194,16 @@ object IRTransform { externalNamesLibRemoved.add(e.split('@')(0)) } } + + cilvisitor.visit_prog(transforms.ReplaceReturns(), ctx.program) + transforms.addReturnBlocks(ctx.program) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), ctx.program) + val externalRemover = ExternalRemover(externalNamesLibRemoved.toSet) val renamer = Renamer(boogieReserved) externalRemover.visitProgram(ctx.program) renamer.visitProgram(ctx.program) - cilvisitor.visit_prog(transforms.ReplaceReturns(), ctx.program) - transforms.addReturnBlocks(ctx.program, true) // add return to all blocks because IDE solver expects it - cilvisitor.visit_prog(transforms.ConvertSingleReturn(), ctx.program) ctx } diff --git a/src/test/scala/TaintAnalysisTests.scala b/src/test/scala/TaintAnalysisTests.scala index dfebf6db8..f066acb01 100644 --- a/src/test/scala/TaintAnalysisTests.scala +++ b/src/test/scala/TaintAnalysisTests.scala @@ -47,11 +47,11 @@ class TaintAnalysisTests extends AnyFunSuite, TestUtil { val taint: Map[CFGPosition, Set[Taintable]] = Map(f -> Set(R0)) val taintAnalysisResults = getTaintAnalysisResults(program, f, taint) - assert(taintAnalysisResults.get(IRWalk.lastInProc(f)) == None) + assert(taintAnalysisResults.get(IRWalk.lastInProc(f).get) == None) val varDepResults = getVarDepResults(program, f) - assert(varDepResults.get(IRWalk.lastInProc(f)) == Some(baseRegisterMap - R0)) + assert(varDepResults.get(IRWalk.lastInProc(f).get) == Some(baseRegisterMap - R0)) } test("arguments") { @@ -81,11 +81,11 @@ class TaintAnalysisTests extends AnyFunSuite, TestUtil { val taint: Map[CFGPosition, Set[Taintable]] = Map(f -> Set(R0)) val taintAnalysisResults = getTaintAnalysisResults(program, f, taint) - assert(taintAnalysisResults.get(IRWalk.lastInProc(f)) == Some(Set(R0))) + assert(taintAnalysisResults.get(IRWalk.lastInProc(f).get) == Some(Set(R0))) val varDepResults = getVarDepResults(program, f) - assert(varDepResults.get(IRWalk.lastInProc(f)) == Some(baseRegisterMap + (R0 -> Set(R0, R1)))) + assert(varDepResults.get(IRWalk.lastInProc(f).get) == Some(baseRegisterMap + (R0 -> Set(R0, R1)))) } test("branching") { @@ -122,11 +122,11 @@ class TaintAnalysisTests extends AnyFunSuite, TestUtil { val taint: Map[CFGPosition, Set[Taintable]] = Map(f -> Set(R1)) val taintAnalysisResults = getTaintAnalysisResults(program, f, taint) - assert(taintAnalysisResults.get(IRWalk.lastInProc(f)) == Some(Set(R0, R1))) + assert(taintAnalysisResults.get(IRWalk.lastInProc(f).get) == Some(Set(R0, R1))) val varDepResults = getVarDepResults(program, f) - assert(varDepResults.get(IRWalk.lastInProc(f)) == Some(baseRegisterMap + (R0 -> Set(R1, R2)))) + assert(varDepResults.get(IRWalk.lastInProc(f).get) == Some(baseRegisterMap + (R0 -> Set(R1, R2)))) } test("interproc") { @@ -174,11 +174,11 @@ class TaintAnalysisTests extends AnyFunSuite, TestUtil { val taint: Map[CFGPosition, Set[Taintable]] = Map(f -> Set(R1)) val taintAnalysisResults = getTaintAnalysisResults(program, f, taint) - assert(taintAnalysisResults.get(IRWalk.lastInProc(f)) == Some(Set(R0, R1))) + assert(taintAnalysisResults.get(IRWalk.lastInProc(f).get) == Some(Set(R0, R1))) val varDepResults = getVarDepResults(program, f) - assert(varDepResults.get(IRWalk.lastInProc(f)) == Some(baseRegisterMap + (R0 -> Set(R1, R2)) + (R1 -> Set(R1, R2)))) + assert(varDepResults.get(IRWalk.lastInProc(f).get) == Some(baseRegisterMap + (R0 -> Set(R1, R2)) + (R1 -> Set(R1, R2)))) } test("loop") { @@ -215,10 +215,10 @@ class TaintAnalysisTests extends AnyFunSuite, TestUtil { val taint: Map[CFGPosition, Set[Taintable]] = Map(f -> Set(R1)) val taintAnalysisResults = getTaintAnalysisResults(program, f, taint) - assert(taintAnalysisResults.get(IRWalk.lastInProc(f)) == Some(Set(R1))) + assert(taintAnalysisResults.get(IRWalk.lastInProc(f).get) == Some(Set(R1))) val varDepResults = getVarDepResults(program, f) - assert(varDepResults.get(IRWalk.lastInProc(f)) == Some(baseRegisterMap + (R0 -> Set(R2)))) + assert(varDepResults.get(IRWalk.lastInProc(f).get) == Some(baseRegisterMap + (R0 -> Set(R2)))) } } From b994c9990aa5a4f1ebf79f0312d80049a9bf4c60 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Fri, 30 Aug 2024 10:38:23 +1000 Subject: [PATCH 26/62] disable IDE analyses if mainproc is external --- src/main/scala/util/RunUtils.scala | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index dd6ba9a5a..7a6e08c60 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -407,13 +407,20 @@ object StaticAnalysis { val vsaSolver = ValueSetAnalysisSolver(IRProgram, globalAddresses, externalAddresses, globalOffsets, subroutines, mmm, constPropResult) val vsaResult: Map[CFGPosition, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]] = vsaSolver.analyze() - Logger.info("[!] Running Interprocedural Live Variables Analysis") - val interLiveVarsResults = InterLiveVarsAnalysis(IRProgram).analyze() - // val interLiveVarsResults = Map[CFGPosition, Map[Variable, TwoElement]]() - Logger.info("[!] Running Parameter Analysis") - val paramResults = ParamAnalysis(IRProgram).analyze() - // val paramResults = Map[Procedure, Set[Variable]]() + var paramResults: Map[Procedure, Set[Variable]] = Map.empty + var interLiveVarsResults: Map[CFGPosition, Map[Variable, TwoElement]] = Map.empty + + if (IRProgram.mainProcedure.blocks.nonEmpty) { + Logger.info("[!] Running Interprocedural Live Variables Analysis") + interLiveVarsResults = InterLiveVarsAnalysis(IRProgram).analyze() + + Logger.info("[!] Running Parameter Analysis") + paramResults = ParamAnalysis(IRProgram).analyze() + + } else { + Logger.warn(s"Disabling IDE solver tests due to external main procedure: ${IRProgram.mainProcedure.name}") + } StaticAnalysisContext( constPropResult = constPropResult, From d6c5767c2b84e69f8feeef8aac79b149cf784b13 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Fri, 30 Aug 2024 15:46:47 +1000 Subject: [PATCH 27/62] work on differential testing --- src/main/scala/ir/eval/ExprEval.scala | 5 +- src/main/scala/ir/eval/InterpretBasilIR.scala | 32 +++-- .../scala/ir/eval/InterpretBreakpoints.scala | 22 +-- src/main/scala/ir/eval/InterpretTrace.scala | 34 ++--- src/main/scala/ir/eval/Interpreter.scala | 2 +- src/main/scala/util/RunUtils.scala | 6 +- src/main/scala/util/functional.scala | 9 +- src/test/scala/IndircallDifferential.scala | 136 ++++++++++++++++++ src/test/scala/ir/InterpreterTests.scala | 49 +++++-- 9 files changed, 237 insertions(+), 58 deletions(-) create mode 100644 src/test/scala/IndircallDifferential.scala diff --git a/src/main/scala/ir/eval/ExprEval.scala b/src/main/scala/ir/eval/ExprEval.scala index 999155ae1..413b4ad5e 100644 --- a/src/main/scala/ir/eval/ExprEval.scala +++ b/src/main/scala/ir/eval/ExprEval.scala @@ -262,6 +262,9 @@ class StatelessLoader[E](getVar: Variable => Option[Literal], loadMem: (Memory, def partialEvalExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)): Expr = { val l = StatelessLoader(variableAssignment, memory) - State.evaluate((), statePartialEvalExpr(l)(exp)) + State.evaluate((), statePartialEvalExpr(l)(exp)) match { + case Right(e) => e + case Left(e) => throw Exception("Unable to evaluate expr : " + e.toString) + } } diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index 3dc394a61..18e887014 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -215,12 +215,14 @@ case object InterpFuns { val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) for { - h <- s.storeVar("funtable", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(64)))) + h <- s.storeVar("ghost-funtable", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(64)))) h <- s.storeVar("mem", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) i <- s.storeVar("stack", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) j <- s.storeVar("R31", Scope.Global, Scalar(SP)) k <- s.storeVar("R29", Scope.Global, Scalar(FP)) l <- s.storeVar("R30", Scope.Global, Scalar(LR)) + l <- s.storeVar("R0", Scope.Global, Scalar(BitVecLiteral(0, 64))) + l <- s.storeVar("R1", Scope.Global, Scalar(BitVecLiteral(0, 64))) } yield (l) } @@ -236,7 +238,7 @@ case object InterpFuns { mem, Scalar(BitVecLiteral(memory.address, 64)), memory.bytes.toList.map(Scalar(_)), - Endian.LittleEndian + Endian.BigEndian ) ) ) @@ -251,7 +253,7 @@ case object InterpFuns { .filter(p => p.blocks.nonEmpty && p.address.isDefined) .map((proc: Procedure) => Eval.storeSingle(f)( - "funtable", + "ghost-funtable", Scalar(BitVecLiteral(proc.address.get, 64)), FunPointer(BitVecLiteral(proc.address.get, 64), proc.name, Run(IRWalk.firstInBlock(proc.entryBlock.get))) ) @@ -261,7 +263,10 @@ case object InterpFuns { mem <- initMemory("stack", p.initialMemory) mem <- initMemory("mem", p.readOnlyMemory) mem <- initMemory("stack", p.readOnlyMemory) - r <- f.call(p.mainProcedure.name, Run(IRWalk.firstInBlock(p.mainProcedure.entryBlock.get)), Stopped()) + mainfun = { + p.mainProcedure + } + r <- f.call(mainfun.name, Run(IRWalk.firstInBlock(mainfun.entryBlock.get)), Stopped()) } yield (r) } @@ -334,7 +339,8 @@ case object InterpFuns { val block = dc.target.entryBlock.get f.call(dc.target.name, Run(block.statements.headOption.getOrElse(block.jump)), Run(dc.successor)) } else { - f.setNext(Run(dc.successor)) + f.setNext(EscapedControlFlow(dc)) + //f.setNext(Run(dc.successor)) } } yield (n) case ic: IndirectCall => { @@ -356,12 +362,16 @@ case object InterpFuns { } def interpret[S, E, T <: Effects[S, E]](f: T, m: S): S = { - val next = State.evaluate(m, f.getNext) - Logger.debug(s"eval $next") - next match { - case Run(c) => interpret(f, State.execute(m, f.interpretOne)) - case Stopped() => m - case errorstop => m + State.evaluate(m, f.getNext) match { + case Right(next) => { + Logger.debug(s"eval $next") + next match { + case Run(c) => interpret(f, State.execute(m, f.interpretOne)) + case Stopped() => m + case errorstop => m + } + } + case Left(err) => m } } diff --git a/src/main/scala/ir/eval/InterpretBreakpoints.scala b/src/main/scala/ir/eval/InterpretBreakpoints.scala index 4668b7075..b00622440 100644 --- a/src/main/scala/ir/eval/InterpretBreakpoints.scala +++ b/src/main/scala/ir/eval/InterpretBreakpoints.scala @@ -18,11 +18,11 @@ enum BreakPointLoc: case CMD(c: Command) case CMDCond(c: Command, condition: Expr) -case class BreakPointAction(saveState: Boolean = true, stop: Boolean = false, evalExprs: List[Expr] = List(), log: Boolean = false) +case class BreakPointAction(saveState: Boolean = true, stop: Boolean = false, evalExprs: List[(String,Expr)] = List(), log: Boolean = false) case class BreakPoint(name: String = "", location: BreakPointLoc, action: BreakPointAction) -case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, val breaks: List[BreakPoint]) extends NopEffects[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])]), InterpreterError] { +case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, val breaks: List[BreakPoint]) extends NopEffects[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), InterpreterError] { def findBreaks[R](c: Command) : State[(T,R), List[BreakPoint], InterpreterError] = { @@ -33,18 +33,18 @@ case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, v }, breaks) } - override def interpretOne : State[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])]), Unit, InterpreterError] = for { + override def interpretOne : State[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), Unit, InterpreterError] = for { v : ExecutionContinuation <- doLeft(f.getNext) n <- v match { case Run(s) => for { breaks : List[BreakPoint] <- findBreaks(s) - res <- State.sequence[(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])]), Unit, InterpreterError](State.pure(()), + res <- State.sequence[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), Unit, InterpreterError](State.pure(()), breaks.map((breakpoint: BreakPoint) => (breakpoint match { case breakpoint @ BreakPoint(name, stopcond, action) => (for { saved <- doLeft(if action.saveState then State.getS[T, InterpreterError].map(s => Some(s)) else State.pure(None)) - evals <- (State.mapM((e:Expr) => for { - ev <- doLeft(Eval.evalExpr(f)(e)) - } yield (e, ev) + evals <- (State.mapM((e:(String, Expr)) => for { + ev <- doLeft(Eval.evalExpr(f)(e._2)) + } yield (e._1, e._2, ev) , action.evalExprs)) _ <- if action.stop then doLeft(f.setNext(Errored(s"Stopped at breakpoint ${name}"))) else doLeft(State.pure(())) _ <- State.pure({ @@ -56,11 +56,11 @@ case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, v } val saving = if action.saveState then " stashing state, " else "" val stopping = if action.stop then " stopping. " else "" - val evalstr = evals.map(e => s"\n eval(${e._1}) = ${e._2}").mkString("") + val evalstr = evals.map(e => s"\n ${e._1} : eval(${e._2}) = ${e._3}").mkString("") Logger.warn(s"Breakpoint $bpn@$bpcond.$saving$stopping$evalstr") } }) - _ <- State.modify ((istate:(T, List[(BreakPoint, Option[T], List[(Expr, Expr)])])) => + _ <- State.modify ((istate:(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])])) => (istate._1, ((breakpoint, saved, evals)::istate._2))) } yield () ) @@ -73,7 +73,9 @@ case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, v } -def interpretWithBreakPoints[I](p: Program, breakpoints: List[BreakPoint], innerInterpreter: Effects[I, InterpreterError], innerInitialState: I) : (I, List[(BreakPoint, Option[I], List[(Expr, Expr)])]) = { +def interpretWithBreakPoints[I](p: Program, breakpoints: List[BreakPoint], + innerInterpreter: Effects[I, InterpreterError], + innerInitialState: I) : (I, List[(BreakPoint, Option[I], List[(String, Expr, Expr)])]) = { val interp = LayerInterpreter(innerInterpreter, RememberBreakpoints(innerInterpreter, breakpoints)) val res = InterpFuns.interpretProg(interp)(p, (innerInitialState, List())) res diff --git a/src/main/scala/ir/eval/InterpretTrace.scala b/src/main/scala/ir/eval/InterpretTrace.scala index 4a3e78891..ad1ad1200 100644 --- a/src/main/scala/ir/eval/InterpretTrace.scala +++ b/src/main/scala/ir/eval/InterpretTrace.scala @@ -31,47 +31,47 @@ case object Trace { } } -case class TraceGen[E]() extends Effects[Trace, E] { +case class TraceGen[E]() extends NopEffects[Trace, E] { /** Values are discarded by ProductInterpreter so do not matter */ - def evalBV(e: Expr) = State.pure(BitVecLiteral(0,0)) + //def evalBV(e: Expr) = State.pure(BitVecLiteral(0,0)) - def evalInt(e: Expr) = State.pure(BigInt(0)) + //def evalInt(e: Expr) = State.pure(BigInt(0)) - def evalBool(e: Expr) = State.pure(false) + //def evalBool(e: Expr) = State.pure(false) - def loadVar(v: String) = for { + override def loadVar(v: String) = for { s <- Trace.add(ExecEffect.LoadVar(v)) } yield (Scalar(FalseLiteral)) - def loadMem(v: String, addrs: List[BasilValue]) = for { + override def loadMem(v: String, addrs: List[BasilValue]) = for { s <- Trace.add(ExecEffect.LoadMem(v, addrs)) } yield (List()) - def evalAddrToProc(addr: Int) = for { - s <- Trace.add(ExecEffect.FindProc(addr)) - } yield (None) + //def evalAddrToProc(addr: Int) = for { + // s <- Trace.add(ExecEffect.FindProc(addr)) + //} yield (None) - def getNext = State.pure(Stopped()) + // def getNext = State.pure(Stopped()) - def setNext(c: ExecutionContinuation) = State.pure(()) + // def setNext(c: ExecutionContinuation) = State.pure(()) - def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = for { + override def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = for { s <- Trace.add(ExecEffect.Call(target, beginFrom, returnTo)) } yield (()) - def doReturn() = for { + override def doReturn() = for { s <- Trace.add(ExecEffect.Return) } yield (()) - def storeVar(v: String, scope: Scope, value: BasilValue) = for { + override def storeVar(v: String, scope: Scope, value: BasilValue) = for { s <- Trace.add(ExecEffect.StoreVar(v, scope, value)) } yield (()) - def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = for { - s <- Trace.add(ExecEffect.StoreMem(vname, update)) + override def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = for { + s <- if (!vname.startsWith("ghost")) Trace.add(ExecEffect.StoreMem(vname, update)) else State.pure(()) } yield (()) - def interpretOne = State.pure(()) + // def interpretOne = State.pure(()) } def tracingInterpreter = ProductInterpreter(NormalInterpreter, TraceGen()) diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 82f0aff71..51b71de3c 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -334,7 +334,7 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { Logger.debug(s" eff : FIND PROC $addr") for { res: List[BasilValue] <- getE((s: InterpreterState) => - s.memoryState.doLoad("funtable", List(Scalar(BitVecLiteral(addr, 64)))) + s.memoryState.doLoad("ghost-funtable", List(Scalar(BitVecLiteral(addr, 64)))) ) } yield { res match { diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index 13d441035..11c7558eb 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -215,13 +215,13 @@ object IRTransform { Logger.info("[!] Stripping unreachable") val before = ctx.program.procedures.size - transforms.stripUnreachableFunctions(ctx.program, config.procedureTrimDepth) + // transforms.stripUnreachableFunctions(ctx.program, config.procedureTrimDepth) Logger.info( s"[!] Removed ${before - ctx.program.procedures.size} functions (${ctx.program.procedures.size} remaining)" ) - val stackIdentification = StackSubstituter() - stackIdentification.visitProgram(ctx.program) + // val stackIdentification = StackSubstituter() + // stackIdentification.visitProgram(ctx.program) val specModifies = ctx.specification.subroutines.map(s => s.name -> s.modifies).toMap ctx.program.setModifies(specModifies) diff --git a/src/main/scala/util/functional.scala b/src/main/scala/util/functional.scala index 5f78e6d76..7be824194 100644 --- a/src/main/scala/util/functional.scala +++ b/src/main/scala/util/functional.scala @@ -50,10 +50,11 @@ object State { case Left(e) => (s, Left(e)) }) def execute[S, A, E](s: S, c: State[S,A, E]) : S = c.f(s)._1 - def evaluate[S, A, E](s: S, c: State[S,A, E]) : A = c.f(s)._2 match { - case Right(r) => r - case Left(l) => throw Exception(s"Evaluation error $l") - } + // def evaluate[S, A, E](s: S, c: State[S,A, E]) : A = c.f(s)._2 match { + // case Right(r) => r + // case Left(l) => throw Exception(s"Evaluation error $l") + // } + def evaluate[S, A, E](s: S, c: State[S,A, E]) : Either[E,A] = c.f(s)._2 def setError[S,A,E](e: E) : State[S,A,E] = State(s => (s, Left(e))) diff --git a/src/test/scala/IndircallDifferential.scala b/src/test/scala/IndircallDifferential.scala new file mode 100644 index 000000000..e71e23389 --- /dev/null +++ b/src/test/scala/IndircallDifferential.scala @@ -0,0 +1,136 @@ + +import ir.* +import java.io.{BufferedWriter, File, FileWriter} +import ir.Endian.LittleEndian +import org.scalatest.* +import org.scalatest.funsuite.* +import specification.* +import util.{BASILConfig, IRLoading, ILLoadingConfig, IRContext, RunUtils, StaticAnalysis, StaticAnalysisConfig, StaticAnalysisContext, BASILResult, Logger, LogLevel} +import ir.eval.{interpretTrace, interpret, ExecEffect, Stopped} + + +import java.io.IOException +import java.nio.file.* +import java.nio.file.attribute.BasicFileAttributes +import ir.dsl.* +import util.RunUtils.loadAndTranslate + +import scala.collection.mutable + +class DifferentialIndirectCall extends AnyFunSuite { + + Logger.setLevel(LogLevel.WARN) + + def diffTest(initial: Program, transformed: Program) = { + val (initialRes,traceInit) = interpretTrace(initial) + val (result,traceRes) = interpretTrace(transformed) + + + def filterEvents(trace: List[ExecEffect]) = { + trace.collect { + case e @ ExecEffect.StoreMem("mem", _) => e + case e @ ExecEffect.LoadMem("mem", _) => e + } + } + + // println(traceInit.t.mkString("\n ")) + assert(initialRes.nextCmd == Stopped()) + assert(result.nextCmd == Stopped()) + assert(initialRes.memoryState.getGlobalVals == result.memoryState.getGlobalVals) + assert(initialRes.memoryState.getMem("mem") == result.memoryState.getMem("mem")) + assert(filterEvents(traceInit.t) == filterEvents(traceRes.t)) + } + + def testProgram(testName: String, examplePath: String) = { + val basilConfig = BASILConfig( + loading = ILLoadingConfig(inputFile = examplePath + testName + ".adt", + relfFile = examplePath + testName + ".relf", + dumpIL = None, + ), + outputPrefix = "basil-test", + staticAnalysis = Some(StaticAnalysisConfig(None, None, None)), + ) + + val basilConfigNoAnalysis = BASILConfig( + loading = ILLoadingConfig(inputFile = examplePath + testName + ".adt", + relfFile = examplePath + testName + ".relf", + dumpIL = None, + ), + outputPrefix = "basil-test", + staticAnalysis = None, + ) + + + val program = loadAndTranslate(basilConfigNoAnalysis).ir.program + val compare = loadAndTranslate(basilConfig).ir.program + diffTest(program, compare) + + } + + test("indirect_call_example") { + val testName = "indirect_call" + val examplePath = System.getProperty("user.dir") + s"/examples/$testName/" + testProgram(testName, examplePath) + } + + test("indirect_call_gcc_example") { + val testName = "indirect_call" + val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/gcc/" + testProgram(testName, examplePath) + } + + test("indirect_call_clang_example") { + val testName = "indirect_call" + val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/clang/" + testProgram(testName, examplePath) + } + + test("jumptable2_example") { + val testName = "jumptable2" + val examplePath = System.getProperty("user.dir") + s"/examples/$testName/" + testProgram(testName, examplePath) + } + + test("jumptable2_gcc_example") { + val testName = "jumptable2" + val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/gcc/" + testProgram(testName, examplePath) + } + + test("jumptable2_clang_example") { + val testName = "jumptable2" + val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/clang/" + testProgram(testName, examplePath) + } + + test("jumptable_example") { + val testName = "jumptable" + val examplePath = System.getProperty("user.dir") + s"/examples/$testName/" + testProgram(testName, examplePath) + } + + test("functionpointer_example") { + val testName = "functionpointer" + val examplePath = System.getProperty("user.dir") + s"/examples/$testName/" + testProgram(testName, examplePath) + } + + test("functionpointer_gcc_example") { + val testName = "functionpointer" + val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/gcc/" + testProgram(testName, examplePath) + } + + test("functionpointer_clang_example") { + val testName = "functionpointer" + val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/clang/" + testProgram(testName, examplePath) + } + + + test("function_got_example") { + val testName = "function_got" + val examplePath = System.getProperty("user.dir") + s"/examples/$testName/" + testProgram(testName, examplePath) + } +} diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 0815f106d..966e37244 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -34,7 +34,13 @@ def load(s: InterpreterState, global: SpecGlobal) : Option[BitVecLiteral] = { // m.evalBV("mem", BitVecLiteral(64, global.address), Endian.LittleEndian, global.size) // i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems) try { - Some(State.evaluate(s, Eval.evalBV(f)(MemoryLoad(SharedMemory("mem", 64, 8), BitVecLiteral(global.address, 64), Endian.LittleEndian, global.size)))) + State.evaluate(s, Eval.evalBV(f)(MemoryLoad(SharedMemory("mem", 64, 8), BitVecLiteral(global.address, 64), Endian.LittleEndian, global.size))) match { + case Right(e) => Some(e) + case Left(e) => { + None + } + + } } catch { case e : InterpreterError => None } @@ -48,7 +54,8 @@ def mems[E, T <: Effects[T, E]](m: MemoryState) : Map[BigInt, BitVecLiteral] = { class InterpreterTests extends AnyFunSuite with BeforeAndAfter { // var i: Interpreter = Interpreter() - Logger.setLevel(LogLevel.DEBUG) + // Logger.setLevel(LogLevel.DEBUG) + Logger.setLevel(LogLevel.ERROR) def getProgram(name: String): (Program, Set[SpecGlobal]) = { @@ -64,11 +71,11 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { val (externalFunctions, globals, _, mainAddress) = loadReadELF(loading.relfFile, loading) val IRTranslator = BAPToIR(bapProgram, mainAddress) var IRProgram = IRTranslator.translate - IRProgram = ExternalRemover(externalFunctions.map(e => e.name)).visitProgram(IRProgram) - IRProgram = Renamer(Set("free")).visitProgram(IRProgram) + // IRProgram = ExternalRemover(externalFunctions.map(e => e.name)).visitProgram(IRProgram) + // IRProgram = Renamer(Set("free")).visitProgram(IRProgram) //IRProgram.stripUnreachableFunctions() - val stackIdentification = StackSubstituter() - stackIdentification.visitProgram(IRProgram) + // val stackIdentification = StackSubstituter() + // stackIdentification.visitProgram(IRProgram) IRProgram.setModifies(Map()) (IRProgram, globals) @@ -121,7 +128,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { } yield (v) val l = State.evaluate(InterpreterState(), s) - assert(l == Scalar(BitVecLiteral(4096 - 16, 64))) + assert(l == Right(Scalar(BitVecLiteral(4096 - 16, 64)))) } @@ -140,8 +147,8 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { r <- Eval.loadBV(NormalInterpreter)("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) } yield(r) val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) - val actual: BitVecLiteral = State.evaluate(InterpreterState(), s) - assert(actual == expected) + val actual = State.evaluate(InterpreterState(), s) + assert(actual == Right(expected)) } @@ -310,6 +317,26 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { testInterpret("cjump", expected) } + + test("initialisation") { + + // Logger.setLevel(LogLevel.WARN) + val expected = Map( + "x" -> 6, + "y" -> ('b'.toInt), + ) + + val (program, globals) = getProgram("initialisation") + + // val watch = IRWalk.firstInProc((program.mainProcedure)).get + // val globloads = globals.map(global => (global.name, MemoryLoad(SharedMemory("mem", 64, 8), BitVecLiteral(global.address, 64), Endian.LittleEndian, global.size))).toList + // val bp = BreakPoint("beginproc", BreakPointLoc.CMD(watch), BreakPointAction(false, false, globloads, true)) + // val res = interpretWithBreakPoints(program, List(bp), NormalInterpreter, InterpreterState()) + + + testInterpret("initialisation", expected) + } + test("no_interference_update_x") { val expected = Map( "x" -> 1 @@ -397,7 +424,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { Logger.setLevel(LogLevel.ERROR) var res = List[(Int, Double, Double)]() - for (i <- 0 to 25) { + for (i <- 0 to 12) { val prog = fibonacciProg(i) val t = PerformanceTimer("native") @@ -437,7 +464,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { Logger.setLevel(LogLevel.WARN) val fib = fibonacciProg(8) val watch = IRWalk.firstInProc((fib.procedures.find(_.name == "fib")).get).get - val bp = BreakPoint("Fibentry", BreakPointLoc.CMDCond(watch, BinaryExpr(BVEQ, BitVecLiteral(5, 64), Register("R0", 64))), BreakPointAction(true, true, List(Register("R0", 64)), true)) + val bp = BreakPoint("Fibentry", BreakPointLoc.CMDCond(watch, BinaryExpr(BVEQ, BitVecLiteral(5, 64), Register("R0", 64))), BreakPointAction(true, true, List(("R0", Register("R0", 64))), true)) // val interp = LayerInterpreter(NormalInterpreter, RememberBreakpoints(NormalInterpreter, List(bp))) // val res = InterpFuns.interpretProg(interp)(fib, (InterpreterState(), List())) val res = interpretWithBreakPoints(fib, List(bp), NormalInterpreter, InterpreterState()) From 2a03a237bf68781cc554d0f574deaa7e63079901 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Tue, 3 Sep 2024 08:50:18 +1000 Subject: [PATCH 28/62] hook for dynlinking --- src/main/scala/ir/eval/InterpretBasilIR.scala | 57 ++++++++++++++++++- src/main/scala/ir/eval/InterpretTrace.scala | 6 ++ src/main/scala/ir/eval/Interpreter.scala | 3 - .../scala/translating/ReadELFLoader.scala | 1 + src/test/scala/IndircallDifferential.scala | 7 ++- 5 files changed, 65 insertions(+), 9 deletions(-) diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index 18e887014..e255fc9d0 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -2,6 +2,7 @@ package ir.eval import ir._ import ir.eval.BitVectorEval.* import ir.* +import util.IRContext import util.Logger import util.functional.* import util.functional.State.* @@ -204,6 +205,42 @@ case object Eval { case object InterpFuns { + + def initRelocTable[S, E, T <: Effects[S, E]](s: T)(p: Program, reladyn: Set[(BigInt, String)]): State[S, Unit, E] = { + + val data = reladyn.toList.flatMap(r => { + val (offset, extfname) = r + p.procedures.find(proc => proc.name == extfname).map(p => (offset, p)).toList + }) + + // TODO: will have to store + // mem[rodata addr] = naddr + // ghost-funtable[naddr] = FunPointer - to intrinsic function + // We could also dynamic link against something like musl, for things like string.h + + val fptrs = data.map((p) => { + val (offset, proc) = p + val addr = proc.address match { + case Some(x) => x + case None => println(s"No address for function ${proc.name} ${"%x".format(offset)}"); 0 + } + // im guessing proc.address will be undefined and we will have to choose one for our intrinsic libc funcs + (offset, FunPointer(BitVecLiteral(addr, 64), proc.name, Run(DirectCall(proc)))) + }) + + val stores = fptrs.map((p) =>{ + val (offset, fptr) = p + Eval.storeSingle[S,E,T](s)( + "mem", + Scalar(BitVecLiteral(offset, 64)), + fptr + )}) + + for { + _ <- State.sequence[S,Unit,E](State.pure(()), stores) + } yield () + } + /** Functions which compile BASIL IR down to the minimal interpreter effects. * * Each function takes as parameter an implementation of Effects[S] @@ -339,7 +376,7 @@ case object InterpFuns { val block = dc.target.entryBlock.get f.call(dc.target.name, Run(block.statements.headOption.getOrElse(block.jump)), Run(dc.successor)) } else { - f.setNext(EscapedControlFlow(dc)) + State.setError(InterpreterError(EscapedControlFlow(dc))) //f.setNext(Run(dc.successor)) } } yield (n) @@ -352,7 +389,7 @@ case object InterpFuns { fp <- f.evalAddrToProc(addr.value.toInt) _ <- fp match { case Some(fp) => f.call(fp.name, fp.call, Run(ic.successor)) - case none => f.setNext(EscapedControlFlow(ic)) + case none => State.setError(InterpreterError(EscapedControlFlow(ic))) } } yield () } @@ -377,7 +414,16 @@ case object InterpFuns { def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: Program, is: S): S = { val begin = State.execute(is, initialiseProgram(f)(p)) - // State.execute[S,Unit](is, ) + interpret(f, begin) + } + + + def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext, is: S): S = { + val st = for { + _ <- initialiseProgram(f)(p.program) + _ <- InterpFuns.initRelocTable(f)(p.program, p.externalFunctions.map(f => (f.offset, f.name))) + } yield () + val begin = State.execute(is, st) interpret(f, begin) } } @@ -386,3 +432,8 @@ def interpret(IRProgram: Program): InterpreterState = { InterpFuns.interpretProg(NormalInterpreter)(IRProgram, InterpreterState()) } + +def interpret(IRProgram: IRContext): InterpreterState = { + InterpFuns.interpretProg(NormalInterpreter)(IRProgram, InterpreterState()) +} + diff --git a/src/main/scala/ir/eval/InterpretTrace.scala b/src/main/scala/ir/eval/InterpretTrace.scala index ad1ad1200..39ce9ac94 100644 --- a/src/main/scala/ir/eval/InterpretTrace.scala +++ b/src/main/scala/ir/eval/InterpretTrace.scala @@ -2,6 +2,7 @@ package ir.eval import ir._ import ir.eval.BitVectorEval.* import ir.* +import util.IRContext import util.Logger import util.functional.* import util.functional.State.* @@ -81,3 +82,8 @@ def interpretTrace(p: Program) : (InterpreterState, Trace) = { } +def interpretTrace(p: IRContext) : (InterpreterState, Trace) = { + InterpFuns.interpretProg(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) +} + + diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 51b71de3c..041402969 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -31,9 +31,6 @@ case class MemoryError(val message: String = "") extends ExecutionContinuation / // type InterpreterError = EscapedControlFlow | Errored | TypeError | EvalError | MemoryError -/** TODO: errors should be encapsualted in error monad, rather than mapping exceptions back into state transitions at - * State.execute() - */ case class InterpreterError(continue: ExecutionContinuation) extends Exception() /* Concrete value type of the interpreter. */ diff --git a/src/main/scala/translating/ReadELFLoader.scala b/src/main/scala/translating/ReadELFLoader.scala index d874f05fe..ed9135e1e 100644 --- a/src/main/scala/translating/ReadELFLoader.scala +++ b/src/main/scala/translating/ReadELFLoader.scala @@ -7,6 +7,7 @@ import util.ILLoadingConfig import scala.jdk.CollectionConverters.* object ReadELFLoader { + // TODO: load NOTYPE symbols, so we can get _bss_start ... _bss_end to zero-init in interpreter, as well as _end to find bottom of heap def visitSyms(ctx: SymsContext, config: ILLoadingConfig): (Set[ExternalFunction], Set[SpecGlobal], Map[BigInt, BigInt], Int) = { val externalFunctions = ctx.relocationTable.asScala.flatMap(r => visitRelocationTableExtFunc(r)).toSet val relocationOffsets = ctx.relocationTable.asScala.flatMap(r => visitRelocationTableOffsets(r)).toMap diff --git a/src/test/scala/IndircallDifferential.scala b/src/test/scala/IndircallDifferential.scala index e71e23389..1617d5479 100644 --- a/src/test/scala/IndircallDifferential.scala +++ b/src/test/scala/IndircallDifferential.scala @@ -1,5 +1,6 @@ import ir.* +import ir.eval._ import java.io.{BufferedWriter, File, FileWriter} import ir.Endian.LittleEndian import org.scalatest.* @@ -21,7 +22,7 @@ class DifferentialIndirectCall extends AnyFunSuite { Logger.setLevel(LogLevel.WARN) - def diffTest(initial: Program, transformed: Program) = { + def diffTest(initial: IRContext, transformed: IRContext) = { val (initialRes,traceInit) = interpretTrace(initial) val (result,traceRes) = interpretTrace(transformed) @@ -61,8 +62,8 @@ class DifferentialIndirectCall extends AnyFunSuite { ) - val program = loadAndTranslate(basilConfigNoAnalysis).ir.program - val compare = loadAndTranslate(basilConfig).ir.program + val program = loadAndTranslate(basilConfigNoAnalysis).ir + val compare = loadAndTranslate(basilConfig).ir diffTest(program, compare) } From 8b694c0871a672e39ea1fc0dc06af12ef20db058 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Tue, 3 Sep 2024 09:56:35 +1000 Subject: [PATCH 29/62] load full symtab --- src/main/scala/ir/eval/InterpretBasilIR.scala | 2 +- .../scala/translating/ReadELFLoader.scala | 82 +++++++++++++++---- src/main/scala/util/RunUtils.scala | 10 ++- src/test/scala/IndirectCallsTests.scala | 2 +- src/test/scala/IrreducibleLoop.scala | 2 +- src/test/scala/PointsToTest.scala | 2 +- src/test/scala/ir/InterpreterTests.scala | 2 +- 7 files changed, 76 insertions(+), 26 deletions(-) diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index e255fc9d0..4c2babfc8 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -222,7 +222,7 @@ case object InterpFuns { val (offset, proc) = p val addr = proc.address match { case Some(x) => x - case None => println(s"No address for function ${proc.name} ${"%x".format(offset)}"); 0 + case None => /* println(s"No address for function ${proc.name} ${"%x".format(offset)}"); */ 0 } // im guessing proc.address will be undefined and we will have to choose one for our intrinsic libc funcs (offset, FunPointer(BitVecLiteral(addr, 64), proc.name, Run(DirectCall(proc)))) diff --git a/src/main/scala/translating/ReadELFLoader.scala b/src/main/scala/translating/ReadELFLoader.scala index ed9135e1e..e6c6336f8 100644 --- a/src/main/scala/translating/ReadELFLoader.scala +++ b/src/main/scala/translating/ReadELFLoader.scala @@ -6,17 +6,57 @@ import util.ILLoadingConfig import scala.jdk.CollectionConverters.* +/** + * https://refspecs.linuxfoundation.org/elf/elf.pdf + */ + +enum ELFSymType: + case NOTYPE /* absolute symbol or similar */ + case SECTION /* memory section */ + case FILE + case OBJECT + case FUNC /* code function */ + + +enum ELFBind: + case LOCAL /* local to the translation unit */ + case GLOBAL /* global to the program */ + case WEAK /* multiple versions of symbol may be exposed to the linker, and the last definition is used. */ + +enum ELFVis: + case HIDDEN + case DEFAULT + +enum ELFNDX: + case Section(num: Int) /* Section containing the symbol */ + case UND /* Undefined */ + case ABS /* Absolute, unaffected by relocation */ + +case class ELFSymbol(num: Int, /* symbol number */ + value: BigInt, /* symbol address */ + size: Int, /* symbol size (bytes) */ + etype: ELFSymType, + bind: ELFBind, + vis: ELFVis, + ndx: ELFNDX, /* The section containing the symbol */ + name: String) + object ReadELFLoader { // TODO: load NOTYPE symbols, so we can get _bss_start ... _bss_end to zero-init in interpreter, as well as _end to find bottom of heap - def visitSyms(ctx: SymsContext, config: ILLoadingConfig): (Set[ExternalFunction], Set[SpecGlobal], Map[BigInt, BigInt], Int) = { + def visitSyms(ctx: SymsContext, config: ILLoadingConfig): (List[ELFSymbol], Set[ExternalFunction], Set[SpecGlobal], Map[BigInt, BigInt], Int) = { val externalFunctions = ctx.relocationTable.asScala.flatMap(r => visitRelocationTableExtFunc(r)).toSet val relocationOffsets = ctx.relocationTable.asScala.flatMap(r => visitRelocationTableOffsets(r)).toMap - val globalVariables = ctx.symbolTable.asScala.flatMap(s => visitSymbolTable(s)).toSet val mainAddress = ctx.symbolTable.asScala.flatMap(s => getFunctionAddress(s, config.mainProcedureName)) + + val symbolTable = ctx.symbolTable.asScala.flatMap(s => visitSymbolTable(s)).toList + val globalVariables = (symbolTable.collect { + case ELFSymbol(num, value, size, ELFSymType.OBJECT, ELFBind.GLOBAL, ELFVis.DEFAULT, _, name) => SpecGlobal(name, size * 8, None, value) + }).toSet + if (mainAddress.isEmpty) { throw Exception(s"no ${config.mainProcedureName} function in symbol table") } - (externalFunctions, globalVariables, relocationOffsets, mainAddress.head) + (symbolTable, externalFunctions, globalVariables, relocationOffsets, mainAddress.head) } def visitRelocationTableExtFunc(ctx: RelocationTableContext): Set[ExternalFunction] = { @@ -50,12 +90,12 @@ object ReadELFLoader { } } - def visitSymbolTable(ctx: SymbolTableContext): Set[SpecGlobal] = { + + def visitSymbolTable(ctx: SymbolTableContext): List[ELFSymbol] = { if (ctx.symbolTableHeader.tableName.STRING.getText == ".symtab") { - val rows = ctx.symbolTableRow.asScala - rows.flatMap(r => visitSymbolTableRow(r)).toSet + ctx.symbolTableRow.asScala.map(getSymbolTableRow).toList } else { - Set() + List() } } @@ -76,17 +116,25 @@ object ReadELFLoader { } } - def visitSymbolTableRow(ctx: SymbolTableRowContext): Option[SpecGlobal] = { - if (ctx.entrytype.getText == "OBJECT" && ctx.bind.getText == "GLOBAL" && ctx.vis.getText == "DEFAULT") { - val name = ctx.name.getText - if (name.forall(allowedChars.contains)) { - Some(SpecGlobal(name, ctx.size.getText.toInt * 8, None, hexToBigInt(ctx.value.getText))) - } else { - None - } - } else { - None + def getSymbolTableRow(ctx: SymbolTableRowContext): ELFSymbol = { + val bind = ELFBind.valueOf(ctx.bind.getText) + val etype = ELFSymType.valueOf(ctx.entrytype.getText) + val size = ctx.size.getText.toInt + val name = ctx.name match { + case null => "" + case x => x.getText } + val value = BigInt(ctx.value.getText, 16) + val num = ctx.num.getText.toInt + val vis = ELFVis.valueOf(ctx.vis.getText) + + val ndx = (ctx.ndx.getText match { + case "ABS" => ELFNDX.ABS + case "UND" => ELFNDX.UND + case o => ELFNDX.Section(o.toInt) + }) + + ELFSymbol(num, value, size, etype, bind, vis, ndx, name) } def hexToBigInt(hex: String): BigInt = BigInt(hex, 16) diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index 11c7558eb..45c482b03 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -39,6 +39,7 @@ import scala.collection.mutable * transformation. */ case class IRContext( + symbols: List[ELFSymbol], externalFunctions: Set[ExternalFunction], globals: Set[SpecGlobal], globalOffsets: Map[BigInt, BigInt], @@ -73,13 +74,14 @@ object IRLoading { /** Create a context from just an IR program. */ def load(p: Program): IRContext = { - IRContext(Set.empty, Set.empty, Map.empty, IRLoading.loadSpecification(None, p, Set.empty), p) + IRContext(List.empty, Set.empty, Set.empty, Map.empty, IRLoading.loadSpecification(None, p, Set.empty), p) } /** Load a program from files using the provided configuration. */ def load(q: ILLoadingConfig): IRContext = { - val (externalFunctions, globals, globalOffsets, mainAddress) = IRLoading.loadReadELF(q.relfFile, q) + // TODO: this tuple is large, should be a case class + val (symbols, externalFunctions, globals, globalOffsets, mainAddress) = IRLoading.loadReadELF(q.relfFile, q) val program: Program = if (q.inputFile.endsWith(".adt")) { val bapProgram = loadBAP(q.inputFile) @@ -93,7 +95,7 @@ object IRLoading { val specification = IRLoading.loadSpecification(q.specFile, program, globals) - IRContext(externalFunctions, globals, globalOffsets, specification, program) + IRContext(symbols, externalFunctions, globals, globalOffsets, specification, program) } def loadBAP(fileName: String): BAPProgram = { @@ -151,7 +153,7 @@ object IRLoading { def loadReadELF( fileName: String, config: ILLoadingConfig - ): (Set[ExternalFunction], Set[SpecGlobal], Map[BigInt, BigInt], Int) = { + ): (List[ELFSymbol], Set[ExternalFunction], Set[SpecGlobal], Map[BigInt, BigInt], Int) = { val lexer = ReadELFLexer(CharStreams.fromFileName(fileName)) val tokens = CommonTokenStream(lexer) val parser = ReadELFParser(tokens) diff --git a/src/test/scala/IndirectCallsTests.scala b/src/test/scala/IndirectCallsTests.scala index fbc5a4362..c208e973b 100644 --- a/src/test/scala/IndirectCallsTests.scala +++ b/src/test/scala/IndirectCallsTests.scala @@ -44,7 +44,7 @@ class IndirectCallsTests extends AnyFunSuite with OneInstancePerTest with Before globals: Set[SpecGlobal] = Set.empty, globalOffsets: Map[BigInt, BigInt] = Map.empty): StaticAnalysisContext = { - val ctx = IRContext(externalFunctions, globals, globalOffsets, Specification(Set(), Map(), List(), List(), List(), Set()), program) + val ctx = IRContext(List.empty, externalFunctions, globals, globalOffsets, Specification(Set(), Map(), List(), List(), List(), Set()), program) StaticAnalysis.analyse(ctx, StaticAnalysisConfig(), 1) } diff --git a/src/test/scala/IrreducibleLoop.scala b/src/test/scala/IrreducibleLoop.scala index c7b0a4279..7742a0ee9 100644 --- a/src/test/scala/IrreducibleLoop.scala +++ b/src/test/scala/IrreducibleLoop.scala @@ -27,7 +27,7 @@ class IrreducibleLoop extends AnyFunSuite { def load(conf: ILLoadingConfig) : Program = { val bapProgram = IRLoading.loadBAP(conf.inputFile) - val (externalFunctions, globals, globalOffsets, mainAddress) = IRLoading.loadReadELF(conf.relfFile, conf) + val (_, externalFunctions, globals, globalOffsets, mainAddress) = IRLoading.loadReadELF(conf.relfFile, conf) val IRTranslator = BAPToIR(bapProgram, mainAddress) val IRProgram = IRTranslator.translate IRProgram diff --git a/src/test/scala/PointsToTest.scala b/src/test/scala/PointsToTest.scala index 9534053a3..def6d4a54 100644 --- a/src/test/scala/PointsToTest.scala +++ b/src/test/scala/PointsToTest.scala @@ -42,7 +42,7 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft globals: Set[SpecGlobal] = Set.empty, globalOffsets: Map[BigInt, BigInt] = Map.empty): StaticAnalysisContext = { - val ctx = IRContext(externalFunctions, globals, globalOffsets, Specification(Set(), Map(), List(), List(), List(), Set()), program) + val ctx = IRContext(List.empty, externalFunctions, globals, globalOffsets, Specification(Set(), Map(), List(), List(), List(), Set()), program) StaticAnalysis.analyse(ctx, StaticAnalysisConfig(), 1) } diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 966e37244..7bf72266b 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -68,7 +68,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { ) val bapProgram = loadBAP(loading.inputFile) - val (externalFunctions, globals, _, mainAddress) = loadReadELF(loading.relfFile, loading) + val (symbols, externalFunctions, globals, _, mainAddress) = loadReadELF(loading.relfFile, loading) val IRTranslator = BAPToIR(bapProgram, mainAddress) var IRProgram = IRTranslator.translate // IRProgram = ExternalRemover(externalFunctions.map(e => e.name)).visitProgram(IRProgram) From 57ca6b387aa4dd2fe637d3d43903768c75f0f984 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Tue, 3 Sep 2024 11:14:57 +1000 Subject: [PATCH 30/62] update relf grammar --- src/main/antlr4/ReadELF.g4 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/antlr4/ReadELF.g4 b/src/main/antlr4/ReadELF.g4 index 513b6340d..1903dce7a 100644 --- a/src/main/antlr4/ReadELF.g4 +++ b/src/main/antlr4/ReadELF.g4 @@ -31,7 +31,7 @@ symbolTableHeader : 'Symbol table' tableName 'contains' HEX 'entries:' NEWLINE 'Num:' 'Value' 'Size' 'Type' 'Bind' 'Vis' 'Ndx' 'Name' // Mainly a sanity check for the column order ; -symbolTableRow : HEX ':' value=HEX size=HEX entrytype=STRING bind=STRING vis=STRING (HEX | STRING) name=(HEX | STRING)? STRING? NEWLINE; +symbolTableRow : (num=HEX) ':' value=HEX size=HEX entrytype=STRING bind=STRING vis=STRING ndx=(HEX | STRING) name=(HEX | STRING)? STRING? NEWLINE; // symbolTableRow : HEX ':' HEX HEX symbolType bind vis ndx name? ; tableName : '\'' STRING '\'' ; From b27698107e5248e440c5637966754f694f41e577 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Tue, 3 Sep 2024 13:36:54 +1000 Subject: [PATCH 31/62] init bss --- src/main/scala/ir/eval/InterpretBasilIR.scala | 104 ++++++++++++------ src/main/scala/ir/eval/InterpretTrace.scala | 15 --- src/main/scala/ir/eval/Interpreter.scala | 80 +++++--------- src/test/scala/IndircallDifferential.scala | 5 +- src/test/scala/ir/InterpreterTests.scala | 28 +++-- 5 files changed, 116 insertions(+), 116 deletions(-) diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index 4c2babfc8..655de30d4 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -13,7 +13,7 @@ import scala.annotation.tailrec import scala.collection.mutable import scala.collection.immutable import scala.util.control.Breaks.{break, breakable} - +import translating.ELFSymbol /** Abstraction for memload and variable lookup used by the expression evaluator. */ @@ -28,7 +28,12 @@ case class StVarLoader[S, F <: Effects[S, InterpreterError]](f: F) extends Loade })) } - override def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int): State[S, Option[Literal], InterpreterError] = { + override def loadMemory( + m: Memory, + addr: Expr, + endian: Endian, + size: Int + ): State[S, Option[Literal], InterpreterError] = { for { r <- addr match { case l: Literal if size == 1 => @@ -69,9 +74,9 @@ case object Eval { for { res <- evalExpr(f)(e) r <- State.pureE(res match { - case l: BitVecLiteral => Right(l) - case _ => Left(InterpreterError(Errored(s"Eval BV residual $e"))) - }) + case l: BitVecLiteral => Right(l) + case _ => Left(InterpreterError(Errored(s"Eval BV residual $e"))) + }) } yield (r) } @@ -79,9 +84,9 @@ case object Eval { for { res <- evalExpr(f)(e) r <- State.pureE(res match { - case l: IntLiteral => Right(l.value) - case _ => Left(InterpreterError(Errored(s"Eval Int residual $e"))) - }) + case l: IntLiteral => Right(l.value) + case _ => Left(InterpreterError(Errored(s"Eval Int residual $e"))) + }) } yield (r) } @@ -89,9 +94,9 @@ case object Eval { for { res <- evalExpr(f)(e) r <- State.pureE(res match { - case l: BoolLit => Right(l == TrueLiteral) - case _ => Left(InterpreterError(Errored(s"Eval Bool residual $e"))) - }) + case l: BoolLit => Right(l == TrueLiteral) + case _ => Left(InterpreterError(Errored(s"Eval Bool residual $e"))) + }) } yield (r) } @@ -103,8 +108,9 @@ case object Eval { f: T )(vname: String, addr: Scalar, endian: Endian, count: Int): State[S, List[BasilValue], InterpreterError] = { for { - _ <- if (count == 0) then State.setError(InterpreterError(Errored(s"Attempted fractional load"))) else State.pure(()) - keys <- State.mapM(((i:Int) => State.pureE(BasilValue.unsafeAdd(addr, i))), (0 until count)) + _ <- + if (count == 0) then State.setError(InterpreterError(Errored(s"Attempted fractional load"))) else State.pure(()) + keys <- State.mapM(((i: Int) => State.pureE(BasilValue.unsafeAdd(addr, i))), (0 until count)) values <- f.loadMem(vname, keys.toList) vals = endian match { case Endian.LittleEndian => values.reverse @@ -127,13 +133,24 @@ case object Eval { cells = size / valsize res <- load(f)(vname, addr, endian, cells) // actual load - bvs: List[BitVecLiteral] <- (State.mapM ((c : BasilValue) => c match { - case Scalar(bv @ BitVecLiteral(v, sz)) if sz == valsize => State.pure(bv) - case c => State.setError(InterpreterError(TypeError(s"Loaded value of type ${c.irType} did not match expected type bv$valsize"))) - },res)) + bvs: List[BitVecLiteral] <- ( + State.mapM( + (c: BasilValue) => + c match { + case Scalar(bv @ BitVecLiteral(v, sz)) if sz == valsize => State.pure(bv) + case c => + State.setError( + InterpreterError(TypeError(s"Loaded value of type ${c.irType} did not match expected type bv$valsize")) + ) + }, + res + ) + ) } yield (bvs.foldLeft(BitVecLiteral(0, 0))((acc, r) => eval.evalBVBinExpr(BVCONCAT, acc, r))) - def loadSingle[S, T <: Effects[S, InterpreterError]](f: T)(vname: String, addr: Scalar): State[S, BasilValue, InterpreterError] = { + def loadSingle[S, T <: Effects[S, InterpreterError]]( + f: T + )(vname: String, addr: Scalar): State[S, BasilValue, InterpreterError] = { for { m <- load(f)(vname, addr, Endian.LittleEndian, 1) } yield (m.head) @@ -152,7 +169,8 @@ case object Eval { ): State[S, Unit, InterpreterError] = for { mem <- f.loadVar(vname) x <- mem match { - case m @ MapValue(_, MapType(kt, vt)) if kt == addr.irType && values.forall(v => v.irType == vt) => State.pure((m, kt, vt)) + case m @ MapValue(_, MapType(kt, vt)) if kt == addr.irType && values.forall(v => v.irType == vt) => + State.pure((m, kt, vt)) case v => State.setError(InterpreterError(TypeError(s"Invalid map store operation to $vname : $v"))) } (mapval, keytype, valtype) = x @@ -198,14 +216,15 @@ case object Eval { s <- f.storeMem(vname, keys.zip(vs).toMap) } yield (s) - def storeSingle[S, E, T <: Effects[S, E]](f: T)(vname: String, addr: BasilValue, value: BasilValue): State[S, Unit, E] = { + def storeSingle[S, E, T <: Effects[S, E]]( + f: T + )(vname: String, addr: BasilValue, value: BasilValue): State[S, Unit, E] = { f.storeMem(vname, Map((addr -> value))) } } case object InterpFuns { - def initRelocTable[S, E, T <: Effects[S, E]](s: T)(p: Program, reladyn: Set[(BigInt, String)]): State[S, Unit, E] = { val data = reladyn.toList.flatMap(r => { @@ -213,7 +232,7 @@ case object InterpFuns { p.procedures.find(proc => proc.name == extfname).map(p => (offset, p)).toList }) - // TODO: will have to store + // TODO: will have to store // mem[rodata addr] = naddr // ghost-funtable[naddr] = FunPointer - to intrinsic function // We could also dynamic link against something like musl, for things like string.h @@ -221,23 +240,24 @@ case object InterpFuns { val fptrs = data.map((p) => { val (offset, proc) = p val addr = proc.address match { - case Some(x) => x - case None => /* println(s"No address for function ${proc.name} ${"%x".format(offset)}"); */ 0 + case Some(x) => x + case None => /* println(s"No address for function ${proc.name} ${"%x".format(offset)}"); */ 0 } // im guessing proc.address will be undefined and we will have to choose one for our intrinsic libc funcs - (offset, FunPointer(BitVecLiteral(addr, 64), proc.name, Run(DirectCall(proc)))) + (offset, FunPointer(BitVecLiteral(addr, 64), proc.name, Run(DirectCall(proc)))) }) - val stores = fptrs.map((p) =>{ + val stores = fptrs.map((p) => { val (offset, fptr) = p - Eval.storeSingle[S,E,T](s)( + Eval.storeSingle[S, E, T](s)( "mem", Scalar(BitVecLiteral(offset, 64)), fptr - )}) + ) + }) for { - _ <- State.sequence[S,Unit,E](State.pure(()), stores) + _ <- State.sequence[S, Unit, E](State.pure(()), stores) } yield () } @@ -403,7 +423,7 @@ case object InterpFuns { case Right(next) => { Logger.debug(s"eval $next") next match { - case Run(c) => interpret(f, State.execute(m, f.interpretOne)) + case Run(c) => interpret(f, State.execute(m, f.interpretOne)) case Stopped() => m case errorstop => m } @@ -417,10 +437,32 @@ case object InterpFuns { interpret(f, begin) } + def initBSS[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext): State[S, Unit, InterpreterError] = { + val bss = for { + first <- p.symbols.find(s => s.name == "__bss_start__").map(_.value) + last <- p.symbols.find(s => s.name == "__bss_end__").map(_.value) + r <- (if (first == last) then None else Some((first, (last - first) * 8))) + (addr, sz) = r + st = { + (rgn => Eval.storeBV(f)(rgn, Scalar(BitVecLiteral(addr, 64)), BitVecLiteral(0, sz.toInt), Endian.LittleEndian)) + } + + } yield (st) + + bss match { + case None => Logger.error("No BSS initialised"); State.pure(()) + case Some(init) => + for { + _ <- init("mem") + _ <- init("stack") + } yield () + } + } def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext, is: S): S = { val st = for { _ <- initialiseProgram(f)(p.program) + _ <- initBSS(f)(p) _ <- InterpFuns.initRelocTable(f)(p.program, p.externalFunctions.map(f => (f.offset, f.name))) } yield () val begin = State.execute(is, st) @@ -432,8 +474,6 @@ def interpret(IRProgram: Program): InterpreterState = { InterpFuns.interpretProg(NormalInterpreter)(IRProgram, InterpreterState()) } - def interpret(IRProgram: IRContext): InterpreterState = { InterpFuns.interpretProg(NormalInterpreter)(IRProgram, InterpreterState()) } - diff --git a/src/main/scala/ir/eval/InterpretTrace.scala b/src/main/scala/ir/eval/InterpretTrace.scala index 39ce9ac94..c9565d51e 100644 --- a/src/main/scala/ir/eval/InterpretTrace.scala +++ b/src/main/scala/ir/eval/InterpretTrace.scala @@ -34,11 +34,6 @@ case object Trace { case class TraceGen[E]() extends NopEffects[Trace, E] { /** Values are discarded by ProductInterpreter so do not matter */ - //def evalBV(e: Expr) = State.pure(BitVecLiteral(0,0)) - - //def evalInt(e: Expr) = State.pure(BigInt(0)) - - //def evalBool(e: Expr) = State.pure(false) override def loadVar(v: String) = for { s <- Trace.add(ExecEffect.LoadVar(v)) @@ -48,14 +43,6 @@ case class TraceGen[E]() extends NopEffects[Trace, E] { s <- Trace.add(ExecEffect.LoadMem(v, addrs)) } yield (List()) - //def evalAddrToProc(addr: Int) = for { - // s <- Trace.add(ExecEffect.FindProc(addr)) - //} yield (None) - - // def getNext = State.pure(Stopped()) - - // def setNext(c: ExecutionContinuation) = State.pure(()) - override def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = for { s <- Trace.add(ExecEffect.Call(target, beginFrom, returnTo)) } yield (()) @@ -72,7 +59,6 @@ case class TraceGen[E]() extends NopEffects[Trace, E] { s <- if (!vname.startsWith("ghost")) Trace.add(ExecEffect.StoreMem(vname, update)) else State.pure(()) } yield (()) - // def interpretOne = State.pure(()) } def tracingInterpreter = ProductInterpreter(NormalInterpreter, TraceGen()) @@ -86,4 +72,3 @@ def interpretTrace(p: IRContext) : (InterpreterState, Trace) = { InterpFuns.interpretProg(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) } - diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 041402969..dc7139d29 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -29,8 +29,6 @@ case class EvalError(val message: String = "") extends ExecutionContinuation /* failed to evaluate an expression to a concrete value */ case class MemoryError(val message: String = "") extends ExecutionContinuation /* An error to do with memory */ -// type InterpreterError = EscapedControlFlow | Errored | TypeError | EvalError | MemoryError - case class InterpreterError(continue: ExecutionContinuation) extends Exception() /* Concrete value type of the interpreter. */ @@ -42,13 +40,14 @@ case class Scalar(val value: Literal) extends BasilValue(value.getType) { } } -/** Slightly hacky way of mapping addresses to function calls within the interpreter dynamic state */ +/* Slightly hacky way of mapping addresses to function calls within the interpreter dynamic state */ case class FunPointer(val addr: BitVecLiteral, val name: String, val call: ExecutionContinuation) extends BasilValue(addr.getType) -// Erase the type of basil values and enforce the invariant that -// \exists i . \forall v \in value.keys , v.irType = i and -// \exists j . \forall v \in value.values, v.irType = j +/* Erase the type of basil values and enforce the invariant that + \exists i . \forall v \in value.keys , v.irType = i and + \exists j . \forall v \in value.values, v.irType = j + */ case class MapValue(val value: Map[BasilValue, BasilValue], override val irType: MapType) extends BasilValue(irType) { override def toString = s"MapValue : $irType" } @@ -69,19 +68,9 @@ case object BasilValue { case _ if vr == 0 => Right(l) case Scalar(IntLiteral(vl)) => Right(Scalar(IntLiteral(vl + vr))) case Scalar(b1: BitVecLiteral) => Right(Scalar(eval.evalBVBinExpr(BVADD, b1, BitVecLiteral(vr, b1.size)))) - case _ => Left(InterpreterError(TypeError(s"Operation add $vr undefined on $l"))) + case _ => Left(InterpreterError(TypeError(s"Operation add $vr undefined on $l"))) } } - - // def add(l: BasilValue, r: BasilValue): BasilValue = { - // (l, r) match { - // case (Scalar(IntLiteral(vl)), Scalar(IntLiteral(vr))) => Scalar(IntLiteral(vl + vr)) - // case (Scalar(b1: BitVecLiteral), Scalar(b2: BitVecLiteral)) => Scalar(eval.evalBVBinExpr(BVADD, b1, b2)) - // case (Scalar(b1: BoolLit), Scalar(b2: BoolLit)) => - // Scalar(if (b2.value || b2.value) then TrueLiteral else FalseLiteral) - // case _ => throw InterpreterError(TypeError(s"Operation add undefined on $l $r")) - // } - // } } /** Minimal language defining all state transitions in the interpreter, defined for the interpreter's concrete state T. @@ -133,9 +122,9 @@ trait NopEffects[T, E] extends Effects[T, E] { def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = State.pure(()) } -/** -------------------------------------------------------------------------------- Definition of concrete state - * -------------------------------------------------------------------------------- - */ +/*-------------------------------------------------------------------------------- + * Definition of concrete state + *--------------------------------------------------------------------------------*/ type StackFrameID = String val globalFrame: StackFrameID = "GLOBAL" @@ -255,7 +244,6 @@ case class MemoryState( } } - /* Map variable accessing ; load and store operations */ def doLoad(vname: String, addr: List[BasilValue]): Either[InterpreterError, List[BasilValue]] = for { v <- findVar(vname) @@ -272,16 +260,6 @@ case class MemoryState( }) } yield (xs) - // def doLoad[S](vname: String, addr: List[BasilValue]): State[S, List[BasilValue], InterpreterError] = for { - // v <- doLoadOpt(vname, addr) - // r <- v match { - // case Some(vs) => vs - // case None => { - // throw InterpreterError(MemoryError(s"Read from uninitialised ")) - // } - // } - // } - /** typecheck and some fields of a map variable */ def doStore(vname: String, values: Map[BasilValue, BasilValue]): Either[InterpreterError, MemoryState] = for { // val (frame, mem) = findVar(vname) @@ -407,10 +385,11 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { Logger.debug(s" eff : RETURN") modifyE((s: InterpreterState) => { s.callStack match { - case Nil => Right(s.copy(nextCmd = Stopped())) - case h :: tl => for { - ms <- s.memoryState.popStackFrame() - } yield (s.copy(nextCmd = h, callStack = tl, memoryState = ms)) + case Nil => Right(s.copy(nextCmd = Stopped())) + case h :: tl => + for { + ms <- s.memoryState.popStackFrame() + } yield (s.copy(nextCmd = h, callStack = tl, memoryState = ms)) } }) } @@ -420,31 +399,22 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { State.modify((s: InterpreterState) => s.copy(memoryState = s.memoryState.defVar(v, scope, value))) } - def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = + def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = State.modifyE((s: InterpreterState) => { - Logger.debug(s" eff : STORE ${formatStore(vname, update)}") - for { - ms <- s.memoryState.doStore(vname, update) - } yield(s.copy(memoryState = ms)) - }) + Logger.debug(s" eff : STORE ${formatStore(vname, update)}") + for { + ms <- s.memoryState.doStore(vname, update) + } yield (s.copy(memoryState = ms)) + }) def interpretOne: State[InterpreterState, Unit, InterpreterError] = for { next <- getNext _ <- (next match { - case Run(c: Statement) => InterpFuns.interpretStatement(this)(c) - case Run(c: Jump) => InterpFuns.interpretJump(this)(c) - case Stopped() => State.pure(()) - case errorstop => State.pure(()) - }).flatMapE((e: InterpreterError) => setNext(e.continue)) - // } catch { - // case InterpreterError(e) => setNext(e) - // case e: java.lang.IllegalArgumentException => setNext(Errored(e.getStackTrace.take(5).mkString("\n"))) - // } + case Run(c: Statement) => InterpFuns.interpretStatement(this)(c) + case Run(c: Jump) => InterpFuns.interpretJump(this)(c) + case Stopped() => State.pure(()) + case errorstop => State.pure(()) + }).flatMapE((e: InterpreterError) => setNext(e.continue)) } yield () } - -// def interpretTrace(IRProgram: Program): TracingInterpreter = { -// val s: TracingInterpreter = InterpFuns.interpretProg(IRProgram, TracingInterpreter(InterpreterState(), List())) -// s -//e diff --git a/src/test/scala/IndircallDifferential.scala b/src/test/scala/IndircallDifferential.scala index 1617d5479..03ea9dbd7 100644 --- a/src/test/scala/IndircallDifferential.scala +++ b/src/test/scala/IndircallDifferential.scala @@ -33,12 +33,13 @@ class DifferentialIndirectCall extends AnyFunSuite { case e @ ExecEffect.LoadMem("mem", _) => e } } + println(traceInit.t.mkString("\n")) // println(traceInit.t.mkString("\n ")) assert(initialRes.nextCmd == Stopped()) assert(result.nextCmd == Stopped()) - assert(initialRes.memoryState.getGlobalVals == result.memoryState.getGlobalVals) - assert(initialRes.memoryState.getMem("mem") == result.memoryState.getMem("mem")) + // assert(initialRes.memoryState.diff(result.memoryState) == Map.empty) + assert(Set.empty == initialRes.memoryState.getMem("mem").toSet.diff(result.memoryState.getMem("mem").toSet)) assert(filterEvents(traceInit.t) == filterEvents(traceRes.t)) } diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 7bf72266b..2b47f6e21 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -10,7 +10,7 @@ import specification.SpecGlobal import translating.BAPToIR import util.{LogLevel, Logger} import util.IRLoading.{loadBAP, loadReadELF} -import util.ILLoadingConfig +import util.{ILLoadingConfig, IRContext, IRLoading, IRTransform} // def initialMem(): MemoryState = { // val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) @@ -58,8 +58,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { Logger.setLevel(LogLevel.ERROR) - def getProgram(name: String): (Program, Set[SpecGlobal]) = { - + def getProgram(name: String): IRContext = { val loading = ILLoadingConfig( inputFile = s"examples/$name/$name.adt", relfFile = s"examples/$name/$name.relf", @@ -67,24 +66,29 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { dumpIL = None ) - val bapProgram = loadBAP(loading.inputFile) - val (symbols, externalFunctions, globals, _, mainAddress) = loadReadELF(loading.relfFile, loading) - val IRTranslator = BAPToIR(bapProgram, mainAddress) - var IRProgram = IRTranslator.translate + val p = IRLoading.load(loading) + val ctx = IRTransform.doCleanup(p) + // val bapProgram = loadBAP(loading.inputFile) + // val (symbols, externalFunctions, globals, _, mainAddress) = loadReadELF(loading.relfFile, loading) + // val IRTranslator = BAPToIR(bapProgram, mainAddress) + // var IRProgram = IRTranslator.translate // IRProgram = ExternalRemover(externalFunctions.map(e => e.name)).visitProgram(IRProgram) // IRProgram = Renamer(Set("free")).visitProgram(IRProgram) //IRProgram.stripUnreachableFunctions() // val stackIdentification = StackSubstituter() // stackIdentification.visitProgram(IRProgram) - IRProgram.setModifies(Map()) + ctx.program.setModifies(Map()) + - (IRProgram, globals) + // (IRProgram, globals) + ctx } def testInterpret(name: String, expected: Map[String, Int]): Unit = { - val (program, globals) = getProgram(name) - val fstate = interpret(program) + val ctx = getProgram(name) + val fstate = interpret(ctx) val regs = fstate.memoryState.getGlobalVals + val globals = ctx.globals // Show interpreted result Logger.info("Registers:") @@ -326,7 +330,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { "y" -> ('b'.toInt), ) - val (program, globals) = getProgram("initialisation") + // val (program, globals) = getProgram("initialisation") // val watch = IRWalk.firstInProc((program.mainProcedure)).get // val globloads = globals.map(global => (global.name, MemoryLoad(SharedMemory("mem", 64, 8), BitVecLiteral(global.address, 64), Endian.LittleEndian, global.size))).toList From 888d360e870b41f3a803ce1321061784c4146879 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Tue, 3 Sep 2024 14:50:10 +1000 Subject: [PATCH 32/62] cleanup --- src/main/scala/ir/eval/ExprEval.scala | 337 +++++++++--------- src/main/scala/ir/eval/InterpretBasilIR.scala | 38 +- src/main/scala/ir/eval/Interpreter.scala | 5 + src/main/scala/util/functional.scala | 10 +- ...ntial.scala => DifferentialAnalysis.scala} | 7 +- 5 files changed, 199 insertions(+), 198 deletions(-) rename src/test/scala/{IndircallDifferential.scala => DifferentialAnalysis.scala} (94%) diff --git a/src/main/scala/ir/eval/ExprEval.scala b/src/main/scala/ir/eval/ExprEval.scala index 413b4ad5e..e24355254 100644 --- a/src/main/scala/ir/eval/ExprEval.scala +++ b/src/main/scala/ir/eval/ExprEval.scala @@ -3,268 +3,259 @@ import ir.eval.BitVectorEval import util.functional.State import ir._ -/** - * We generalise the expression evaluator to a partial evaluator to simplify evaluating casts. - * - * - Program state is taken via a function from var -> value and for loads a function from (mem,addr,endian,size) -> value. - * - For conrete evaluators we prefer low-level representations (bool vs BoolLit) and wrap them at the expression eval level - * - Avoid using any default cases so we have some idea of complete coverage - * - */ +/** We generalise the expression evaluator to a partial evaluator to simplify evaluating casts. + * + * - Program state is taken via a function from var -> value and for loads a function from (mem,addr,endian,size) -> + * value. + * - For conrete evaluators we prefer low-level representations (bool vs BoolLit) and wrap them at the expression + * eval level + * - Avoid using any default cases so we have some idea of complete coverage + */ - -def evalBVBinExpr(b: BVBinOp, l:BitVecLiteral, r:BitVecLiteral): BitVecLiteral = { +def evalBVBinExpr(b: BVBinOp, l: BitVecLiteral, r: BitVecLiteral): BitVecLiteral = { b match { - case BVADD => BitVectorEval.smt_bvadd(l, r) - case BVSUB => BitVectorEval.smt_bvsub(l, r) - case BVMUL => BitVectorEval.smt_bvmul(l, r) - case BVUDIV => BitVectorEval.smt_bvudiv(l, r) - case BVSDIV => BitVectorEval.smt_bvsdiv(l, r) - case BVSREM => BitVectorEval.smt_bvsrem(l, r) - case BVUREM => BitVectorEval.smt_bvurem(l, r) - case BVSMOD => BitVectorEval.smt_bvsmod(l, r) - case BVAND => BitVectorEval.smt_bvand(l, r) - case BVOR => BitVectorEval.smt_bvxor(l, r) - case BVXOR => BitVectorEval.smt_bvxor(l, r) - case BVNAND => BitVectorEval.smt_bvnand(l, r) - case BVNOR => BitVectorEval.smt_bvnor(l, r) - case BVXNOR => BitVectorEval.smt_bvxnor(l, r) - case BVSHL => BitVectorEval.smt_bvshl(l, r) - case BVLSHR => BitVectorEval.smt_bvlshr(l, r) - case BVASHR => BitVectorEval.smt_bvashr(l, r) - case BVCOMP => BitVectorEval.smt_bvcomp(l, r) + case BVADD => BitVectorEval.smt_bvadd(l, r) + case BVSUB => BitVectorEval.smt_bvsub(l, r) + case BVMUL => BitVectorEval.smt_bvmul(l, r) + case BVUDIV => BitVectorEval.smt_bvudiv(l, r) + case BVSDIV => BitVectorEval.smt_bvsdiv(l, r) + case BVSREM => BitVectorEval.smt_bvsrem(l, r) + case BVUREM => BitVectorEval.smt_bvurem(l, r) + case BVSMOD => BitVectorEval.smt_bvsmod(l, r) + case BVAND => BitVectorEval.smt_bvand(l, r) + case BVOR => BitVectorEval.smt_bvxor(l, r) + case BVXOR => BitVectorEval.smt_bvxor(l, r) + case BVNAND => BitVectorEval.smt_bvnand(l, r) + case BVNOR => BitVectorEval.smt_bvnor(l, r) + case BVXNOR => BitVectorEval.smt_bvxnor(l, r) + case BVSHL => BitVectorEval.smt_bvshl(l, r) + case BVLSHR => BitVectorEval.smt_bvlshr(l, r) + case BVASHR => BitVectorEval.smt_bvashr(l, r) + case BVCOMP => BitVectorEval.smt_bvcomp(l, r) case BVCONCAT => BitVectorEval.smt_concat(l, r) - case BVULE => throw Exception("Did not expect logical op") - case BVULT => throw Exception("Did not expect logical op") - case BVUGT => throw Exception("Did not expect logical op") - case BVUGE => throw Exception("Did not expect logical op") - case BVSLT => throw Exception("Did not expect logical op") - case BVSLE => throw Exception("Did not expect logical op") - case BVSGT => throw Exception("Did not expect logical op") - case BVSGE => throw Exception("Did not expect logical op") - case BVEQ => throw Exception("Did not expect logical op") - case BVNEQ => throw Exception("Did not expect logical op") + case BVULE | BVULT | BVUGT | BVUGE | BVSLT | BVSLE | BVSGT | BVSGE | BVEQ | BVNEQ => + throw Exception("Did not expect logical op") } } -def evalBVLogBinExpr(b: BVBinOp, l: BitVecLiteral, r:BitVecLiteral) : Boolean = b match { +def evalBVLogBinExpr(b: BVBinOp, l: BitVecLiteral, r: BitVecLiteral): Boolean = b match { case BVULE => BitVectorEval.smt_bvule(l, r) case BVUGT => BitVectorEval.smt_bvult(l, r) case BVUGE => BitVectorEval.smt_bvuge(l, r) - case BVULT => BitVectorEval.smt_bvult(l, r) - case BVSLT => BitVectorEval.smt_bvslt(l, r) - case BVSLE => BitVectorEval.smt_bvsle(l, r) - case BVSGT => BitVectorEval.smt_bvsgt(l, r) - case BVSGE => BitVectorEval.smt_bvsge(l, r) + case BVULT => BitVectorEval.smt_bvult(l, r) + case BVSLT => BitVectorEval.smt_bvslt(l, r) + case BVSLE => BitVectorEval.smt_bvsle(l, r) + case BVSGT => BitVectorEval.smt_bvsgt(l, r) + case BVSGE => BitVectorEval.smt_bvsge(l, r) case BVEQ => BitVectorEval.smt_bveq(l, r) - case BVNEQ => BitVectorEval.smt_bvneq(l, r) - case BVADD => throw Exception("Did not expect non-logical op") - case BVSUB => throw Exception("Did not expect non-logical op") - case BVMUL => throw Exception("Did not expect non-logical op") - case BVUDIV => throw Exception("Did not expect non-logical op") - case BVSDIV => throw Exception("Did not expect non-logical op") - case BVSREM => throw Exception("Did not expect non-logical op") - case BVUREM => throw Exception("Did not expect non-logical op") - case BVSMOD => throw Exception("Did not expect non-logical op") - case BVAND => throw Exception("Did not expect non-logical op") - case BVOR => throw Exception("Did not expect non-logical op") - case BVXOR => throw Exception("Did not expect non-logical op") - case BVNAND => throw Exception("Did not expect non-logical op") - case BVNOR => throw Exception("Did not expect non-logical op") - case BVXNOR => throw Exception("Did not expect non-logical op") - case BVSHL => throw Exception("Did not expect non-logical op") - case BVLSHR => throw Exception("Did not expect non-logical op") - case BVASHR => throw Exception("Did not expect non-logical op") - case BVCOMP => throw Exception("Did not expect non-logical op") - case BVCONCAT => throw Exception("Did not expect non-logical op") + case BVNEQ => BitVectorEval.smt_bvneq(l, r) + case BVADD | BVSUB | BVMUL | BVUDIV | BVSDIV | BVSREM | BVUREM | BVSMOD | BVAND | BVOR | BVXOR | BVNAND | BVNOR | + BVXNOR | BVSHL | BVLSHR | BVASHR | BVCOMP | BVCONCAT => + throw Exception("Did not expect non-logical op") } -def evalIntLogBinExpr(b: IntBinOp, l:BigInt, r:BigInt) : Boolean = b match { - case IntEQ => l == r - case IntNEQ => l != r - case IntLT => l < r - case IntLE => l <= r - case IntGT => l > r - case IntGE => l >= r - case IntADD => throw Exception("Did not expect non-logical op") - case IntSUB => throw Exception("Did not expect non-logical op") - case IntMUL => throw Exception("Did not expect non-logical op") - case IntDIV => throw Exception("Did not expect non-logical op") - case IntMOD => throw Exception("Did not expect non-logical op") +def evalIntLogBinExpr(b: IntBinOp, l: BigInt, r: BigInt): Boolean = b match { + case IntEQ => l == r + case IntNEQ => l != r + case IntLT => l < r + case IntLE => l <= r + case IntGT => l > r + case IntGE => l >= r + case IntADD | IntSUB | IntMUL | IntDIV | IntMOD => throw Exception("Did not expect non-logical op") } -def evalIntBinExpr(b: IntBinOp, l:BigInt, r: BigInt): BigInt = b match { +def evalIntBinExpr(b: IntBinOp, l: BigInt, r: BigInt): BigInt = b match { case IntADD => l + r - case IntSUB => l - r + case IntSUB => l - r case IntMUL => l * r - case IntDIV => l / r - case IntMOD => l % r - case IntEQ => throw Exception("Did not expect logical op") + case IntDIV => l / r + case IntMOD => l % r + case IntEQ => throw Exception("Did not expect logical op") case IntNEQ => throw Exception("Did not expect logical op") - case IntLT => throw Exception("Did not expect logical op") - case IntLE => throw Exception("Did not expect logical op") - case IntGT => throw Exception("Did not expect logical op") - case IntGE => throw Exception("Did not expect logical op") + case IntLT => throw Exception("Did not expect logical op") + case IntLE => throw Exception("Did not expect logical op") + case IntGT => throw Exception("Did not expect logical op") + case IntGE => throw Exception("Did not expect logical op") } - -def evalBoolLogBinExpr(b: BoolBinOp, l:Boolean, r:Boolean) : Boolean = b match { - case BoolEQ => l == r - case BoolEQUIV => l == r - case BoolNEQ => l != r - case BoolAND => l && r - case BoolOR => l || r +def evalBoolLogBinExpr(b: BoolBinOp, l: Boolean, r: Boolean): Boolean = b match { + case BoolEQ => l == r + case BoolEQUIV => l == r + case BoolNEQ => l != r + case BoolAND => l && r + case BoolOR => l || r case BoolIMPLIES => l || (!r) } - -def evalUnOp(op: UnOp, body: Literal) : Expr = { +def evalUnOp(op: UnOp, body: Literal): Expr = { (body, op) match { - case (b: BitVecLiteral, BVNOT) => BitVectorEval.smt_bvnot(b) - case (b: BitVecLiteral, BVNEG) => BitVectorEval.smt_bvneg(b) - case (i: IntLiteral, IntNEG) => IntLiteral(-i.value) - case (FalseLiteral, BoolNOT) => TrueLiteral - case (TrueLiteral, BoolNOT) => FalseLiteral + case (b: BitVecLiteral, BVNOT) => BitVectorEval.smt_bvnot(b) + case (b: BitVecLiteral, BVNEG) => BitVectorEval.smt_bvneg(b) + case (i: IntLiteral, IntNEG) => IntLiteral(-i.value) + case (FalseLiteral, BoolNOT) => TrueLiteral + case (TrueLiteral, BoolNOT) => FalseLiteral } } - - -def evalIntExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)): Either[Expr, BigInt] = { +def evalIntExpr( + exp: Expr, + variableAssignment: Variable => Option[Literal], + memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a, b, c, d) => None) +): Either[Expr, BigInt] = { partialEvalExpr(exp, variableAssignment, memory) match { case i: IntLiteral => Right(i.value) - case o => Left(o) + case o => Left(o) } } -def evalBVExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)): Either[Expr, BitVecLiteral] = { +def evalBVExpr( + exp: Expr, + variableAssignment: Variable => Option[Literal], + memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a, b, c, d) => None) +): Either[Expr, BitVecLiteral] = { partialEvalExpr(exp, variableAssignment, memory) match { case b: BitVecLiteral => Right(b) - case o => Left(o) + case o => Left(o) } } -def evalLogExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c, d) => None)): Either[Expr, Boolean] = { +def evalLogExpr( + exp: Expr, + variableAssignment: Variable => Option[Literal], + memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a, b, c, d) => None) +): Either[Expr, Boolean] = { partialEvalExpr(exp, variableAssignment, memory) match { - case TrueLiteral => Right(true) + case TrueLiteral => Right(true) case FalseLiteral => Right(false) - case o => Left(o) + case o => Left(o) } } -def evalExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((d, a,b,c) => None)): Option[Literal] = { +def evalExpr( + exp: Expr, + variableAssignment: Variable => Option[Literal], + memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((d, a, b, c) => None) +): Option[Literal] = { partialEvalExpr match { case l: Literal => Some(l) - case _ => None + case _ => None } } - -/** - * typeclass defining variable and memory laoding from state S - */ +/** typeclass defining variable and memory laoding from state S + */ trait Loader[S, E] { - def getVariable(v: Variable) : State[S, Option[Literal], E] - def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int) : State[S, Option[Literal], E] = { + def getVariable(v: Variable): State[S, Option[Literal], E] + def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int): State[S, Option[Literal], E] = { State.pure(None) } } - def statePartialEvalExpr[S, E](l: Loader[S, E])(exp: Expr): State[S, Expr, E] = { val eval = statePartialEvalExpr(l) exp match { case f: UninterpretedFunction => State.pure(f) - case unOp: UnaryExpr => for { - body <- eval(unOp.arg) - } yield ( - body match { + case unOp: UnaryExpr => + for { + body <- eval(unOp.arg) + } yield (body match { case l: Literal => evalUnOp(unOp.op, l) - case o => UnaryExpr(unOp.op, body) + case o => UnaryExpr(unOp.op, body) }) - case binOp: BinaryExpr => for { - lhs <- eval(binOp.arg1) - rhs <- eval(binOp.arg2) - } yield ( - binOp.getType match { + case binOp: BinaryExpr => + for { + lhs <- eval(binOp.arg1) + rhs <- eval(binOp.arg2) + } yield (binOp.getType match { case m: MapType => binOp case b: BitVecType => { (binOp.op, lhs, rhs) match { case (o: BVBinOp, l: BitVecLiteral, r: BitVecLiteral) => evalBVBinExpr(o, l, r) - case _ => BinaryExpr(binOp.op, lhs, rhs) + case _ => BinaryExpr(binOp.op, lhs, rhs) } } case BoolType => { def bool2lit(b: Boolean) = if b then TrueLiteral else FalseLiteral (binOp.op, lhs, rhs) match { case (o: BVBinOp, l: BitVecLiteral, r: BitVecLiteral) => bool2lit(evalBVLogBinExpr(o, l, r)) - case (o: IntBinOp, l: IntLiteral , r: IntLiteral) => bool2lit(evalIntLogBinExpr(o, l.value, r.value)) - case (o: BoolBinOp, l: BoolLit, r: BoolLit) => bool2lit(evalBoolLogBinExpr(o, l.value, r.value)) - case _ => BinaryExpr(binOp.op, lhs, rhs) + case (o: IntBinOp, l: IntLiteral, r: IntLiteral) => bool2lit(evalIntLogBinExpr(o, l.value, r.value)) + case (o: BoolBinOp, l: BoolLit, r: BoolLit) => bool2lit(evalBoolLogBinExpr(o, l.value, r.value)) + case _ => BinaryExpr(binOp.op, lhs, rhs) } } case IntType => { (binOp.op, lhs, rhs) match { - case (o: IntBinOp, l: IntLiteral , r: IntLiteral) => IntLiteral(evalIntBinExpr(o, l.value, r.value)) - case _ => BinaryExpr(binOp.op, lhs, rhs) + case (o: IntBinOp, l: IntLiteral, r: IntLiteral) => IntLiteral(evalIntBinExpr(o, l.value, r.value)) + case _ => BinaryExpr(binOp.op, lhs, rhs) } } }) - case extend: ZeroExtend => for { - body <- eval(extend.body) - } yield (body match { - case b : BitVecLiteral => BitVectorEval.smt_zero_extend(extend.extension, b) - case o => extend.copy(body=o) - }) - case extend: SignExtend => for { - body <- eval(extend.body) - } yield (body match { - case b: BitVecLiteral => BitVectorEval.smt_sign_extend(extend.extension, b) - case o => extend.copy(body=o) - }) - case e: Extract => for { - body <- eval(e.body) - } yield (body match { - case b: BitVecLiteral => BitVectorEval.boogie_extract(e.end, e.start, b) - case o => e.copy(body=o) - }) - case r: Repeat => for { - body <- eval(r.body) + case extend: ZeroExtend => + for { + body <- eval(extend.body) + } yield (body match { + case b: BitVecLiteral => BitVectorEval.smt_zero_extend(extend.extension, b) + case o => extend.copy(body = o) + }) + case extend: SignExtend => + for { + body <- eval(extend.body) + } yield (body match { + case b: BitVecLiteral => BitVectorEval.smt_sign_extend(extend.extension, b) + case o => extend.copy(body = o) + }) + case e: Extract => + for { + body <- eval(e.body) + } yield (body match { + case b: BitVecLiteral => BitVectorEval.boogie_extract(e.end, e.start, b) + case o => e.copy(body = o) + }) + case r: Repeat => + for { + body <- eval(r.body) } yield (body match { case b: BitVecLiteral => { assert(r.repeats > 0) - if (r.repeats == 1) b + if (r.repeats == 1) b else { (2 to r.repeats).foldLeft(b)((acc, r) => BitVectorEval.smt_concat(acc, b)) } } - case o => r.copy(body=o) - }) - case variable: Variable => for { - v : Option[Literal] <- l.getVariable(variable) - } yield (v.getOrElse(variable)) - case ml: MemoryLoad => for { - addr <- eval(ml.index) - mem <- l.loadMemory(ml.mem, addr, ml.endian, ml.size) - } yield (mem.getOrElse(ml)) + case o => r.copy(body = o) + }) + case variable: Variable => + for { + v: Option[Literal] <- l.getVariable(variable) + } yield (v.getOrElse(variable)) + case ml: MemoryLoad => + for { + addr <- eval(ml.index) + mem <- l.loadMemory(ml.mem, addr, ml.endian, ml.size) + } yield (mem.getOrElse(ml)) case b: BitVecLiteral => State.pure(b) - case b: IntLiteral => State.pure(b) - case b: BoolLit => State.pure(b) + case b: IntLiteral => State.pure(b) + case b: BoolLit => State.pure(b) } } - -class StatelessLoader[E](getVar: Variable => Option[Literal], loadMem: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)) extends Loader[Unit, E] { - def getVariable(v: Variable) : State[Unit, Option[Literal], E] = State.pure(getVar(v)) - override def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int) : State[Unit, Option[Literal], E] = State.pure(loadMem(m, addr, endian, size)) +class StatelessLoader[E]( + getVar: Variable => Option[Literal], + loadMem: (Memory, Expr, Endian, Int) => Option[Literal] = ((a, b, c, d) => None) +) extends Loader[Unit, E] { + def getVariable(v: Variable): State[Unit, Option[Literal], E] = State.pure(getVar(v)) + override def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int): State[Unit, Option[Literal], E] = + State.pure(loadMem(m, addr, endian, size)) } - -def partialEvalExpr(exp: Expr, variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a,b,c,d) => None)): Expr = { +def partialEvalExpr( + exp: Expr, + variableAssignment: Variable => Option[Literal], + memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a, b, c, d) => None) +): Expr = { val l = StatelessLoader(variableAssignment, memory) State.evaluate((), statePartialEvalExpr(l)(exp)) match { case Right(e) => e - case Left(e) => throw Exception("Unable to evaluate expr : " + e.toString) + case Left(e) => throw Exception("Unable to evaluate expr : " + e.toString) } } - diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index 655de30d4..b764b4746 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -223,7 +223,22 @@ case object Eval { } } -case object InterpFuns { +object LibcIntrinsic { + + def putc[S, E, T <: Effects[S, E]](s: T)(f: BitVecLiteral, c: BitVecLiteral): State[S, Unit, E] = { + State.pure(()) + } + + def puts[S, E, T <: Effects[S, E]](s: T)(f: BitVecLiteral, str: BitVecLiteral): State[S, Unit, E] = { + State.pure(()) + } + + def printf[S, E, T <: Effects[S, E]](s: T)(f: BitVecLiteral, str: BitVecLiteral): State[S, Unit, E] = { + State.pure(()) + } +} + +object InterpFuns { def initRelocTable[S, E, T <: Effects[S, E]](s: T)(p: Program, reladyn: Set[(BigInt, String)]): State[S, Unit, E] = { @@ -243,7 +258,7 @@ case object InterpFuns { case Some(x) => x case None => /* println(s"No address for function ${proc.name} ${"%x".format(offset)}"); */ 0 } - // im guessing proc.address will be undefined and we will have to choose one for our intrinsic libc funcs + // proc.address will be undefined and we will have to choose one for our intrinsic libc funcs (offset, FunPointer(BitVecLiteral(addr, 64), proc.name, Run(DirectCall(proc)))) }) @@ -256,9 +271,7 @@ case object InterpFuns { ) }) - for { - _ <- State.sequence[S, Unit, E](State.pure(()), stores) - } yield () + State.sequence[S, Unit, E](State.pure(()), stores) } /** Functions which compile BASIL IR down to the minimal interpreter effects. @@ -450,21 +463,14 @@ case object InterpFuns { } yield (st) bss match { - case None => Logger.error("No BSS initialised"); State.pure(()) - case Some(init) => - for { - _ <- init("mem") - _ <- init("stack") - } yield () + case None => Logger.error("No BSS initialised"); State.pure(()) + case Some(init) => init("mem") >> init("stack") } } def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext, is: S): S = { - val st = for { - _ <- initialiseProgram(f)(p.program) - _ <- initBSS(f)(p) - _ <- InterpFuns.initRelocTable(f)(p.program, p.externalFunctions.map(f => (f.offset, f.name))) - } yield () + val st = (initialiseProgram(f)(p.program) >> initBSS(f)(p)) + >> InterpFuns.initRelocTable(f)(p.program, p.externalFunctions.map(f => (f.offset, f.name))) val begin = State.execute(is, st) interpret(f, begin) } diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index dc7139d29..ab716d101 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -418,3 +418,8 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { } yield () } + +// def interpretTrace(IRProgram: Program): TracingInterpreter = { +// val s: TracingInterpreter = InterpFuns.interpretProg(IRProgram, TracingInterpreter(InterpreterState(), List())) +// s +//e diff --git a/src/main/scala/util/functional.scala b/src/main/scala/util/functional.scala index 7be824194..a9c9c73d1 100644 --- a/src/main/scala/util/functional.scala +++ b/src/main/scala/util/functional.scala @@ -7,6 +7,12 @@ case class State[S, A, E](f: S => (S, Either[E, A])) { def unit[A](a: A): State[S, A, E] = State(s => (s, Right(a))) + def >>(o: State[S,A,E]) = for { + _ <- this + _ <- o + } yield (()) + + def flatMap[B](f: A => State[S, B, E]): State[S, B, E] = State(s => { val (s2, a) = this.f(s) val r = a match { @@ -50,10 +56,6 @@ object State { case Left(e) => (s, Left(e)) }) def execute[S, A, E](s: S, c: State[S,A, E]) : S = c.f(s)._1 - // def evaluate[S, A, E](s: S, c: State[S,A, E]) : A = c.f(s)._2 match { - // case Right(r) => r - // case Left(l) => throw Exception(s"Evaluation error $l") - // } def evaluate[S, A, E](s: S, c: State[S,A, E]) : Either[E,A] = c.f(s)._2 def setError[S,A,E](e: E) : State[S,A,E] = State(s => (s, Left(e))) diff --git a/src/test/scala/IndircallDifferential.scala b/src/test/scala/DifferentialAnalysis.scala similarity index 94% rename from src/test/scala/IndircallDifferential.scala rename to src/test/scala/DifferentialAnalysis.scala index 03ea9dbd7..3de9450b5 100644 --- a/src/test/scala/IndircallDifferential.scala +++ b/src/test/scala/DifferentialAnalysis.scala @@ -18,7 +18,7 @@ import util.RunUtils.loadAndTranslate import scala.collection.mutable -class DifferentialIndirectCall extends AnyFunSuite { +class DifferentialAnalysis extends AnyFunSuite { Logger.setLevel(LogLevel.WARN) @@ -33,14 +33,11 @@ class DifferentialIndirectCall extends AnyFunSuite { case e @ ExecEffect.LoadMem("mem", _) => e } } - println(traceInit.t.mkString("\n")) - // println(traceInit.t.mkString("\n ")) assert(initialRes.nextCmd == Stopped()) assert(result.nextCmd == Stopped()) - // assert(initialRes.memoryState.diff(result.memoryState) == Map.empty) assert(Set.empty == initialRes.memoryState.getMem("mem").toSet.diff(result.memoryState.getMem("mem").toSet)) - assert(filterEvents(traceInit.t) == filterEvents(traceRes.t)) + assert(filterEvents(traceInit.t).mkString("\n") == filterEvents(traceRes.t).mkString("\n")) } def testProgram(testName: String, examplePath: String) = { From 0abf6542bb6ecfd611a97f3558a68e855ba8780e Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Tue, 3 Sep 2024 15:06:36 +1000 Subject: [PATCH 33/62] cleanup init trace --- src/main/scala/ir/eval/ExprEval.scala | 23 ++++++++----------- src/main/scala/ir/eval/InterpretBasilIR.scala | 15 ++++++++---- src/main/scala/ir/eval/InterpretTrace.scala | 4 +++- src/main/scala/ir/eval/Interpreter.scala | 5 ---- 4 files changed, 22 insertions(+), 25 deletions(-) diff --git a/src/main/scala/ir/eval/ExprEval.scala b/src/main/scala/ir/eval/ExprEval.scala index e24355254..65218caf2 100644 --- a/src/main/scala/ir/eval/ExprEval.scala +++ b/src/main/scala/ir/eval/ExprEval.scala @@ -34,7 +34,7 @@ def evalBVBinExpr(b: BVBinOp, l: BitVecLiteral, r: BitVecLiteral): BitVecLiteral case BVCOMP => BitVectorEval.smt_bvcomp(l, r) case BVCONCAT => BitVectorEval.smt_concat(l, r) case BVULE | BVULT | BVUGT | BVUGE | BVSLT | BVSLE | BVSGT | BVSGE | BVEQ | BVNEQ => - throw Exception("Did not expect logical op") + throw IllegalArgumentException("Did not expect logical op") } } @@ -51,7 +51,7 @@ def evalBVLogBinExpr(b: BVBinOp, l: BitVecLiteral, r: BitVecLiteral): Boolean = case BVNEQ => BitVectorEval.smt_bvneq(l, r) case BVADD | BVSUB | BVMUL | BVUDIV | BVSDIV | BVSREM | BVUREM | BVSMOD | BVAND | BVOR | BVXOR | BVNAND | BVNOR | BVXNOR | BVSHL | BVLSHR | BVASHR | BVCOMP | BVCONCAT => - throw Exception("Did not expect non-logical op") + throw IllegalArgumentException("Did not expect non-logical op") } def evalIntLogBinExpr(b: IntBinOp, l: BigInt, r: BigInt): Boolean = b match { @@ -61,21 +61,16 @@ def evalIntLogBinExpr(b: IntBinOp, l: BigInt, r: BigInt): Boolean = b match { case IntLE => l <= r case IntGT => l > r case IntGE => l >= r - case IntADD | IntSUB | IntMUL | IntDIV | IntMOD => throw Exception("Did not expect non-logical op") + case IntADD | IntSUB | IntMUL | IntDIV | IntMOD => throw IllegalArgumentException("Did not expect non-logical op") } def evalIntBinExpr(b: IntBinOp, l: BigInt, r: BigInt): BigInt = b match { - case IntADD => l + r - case IntSUB => l - r - case IntMUL => l * r - case IntDIV => l / r - case IntMOD => l % r - case IntEQ => throw Exception("Did not expect logical op") - case IntNEQ => throw Exception("Did not expect logical op") - case IntLT => throw Exception("Did not expect logical op") - case IntLE => throw Exception("Did not expect logical op") - case IntGT => throw Exception("Did not expect logical op") - case IntGE => throw Exception("Did not expect logical op") + case IntADD => l + r + case IntSUB => l - r + case IntMUL => l * r + case IntDIV => l / r + case IntMOD => l % r + case IntEQ | IntNEQ | IntLT | IntLE | IntGT | IntGE => throw IllegalArgumentException("Did not expect logical op") } def evalBoolLogBinExpr(b: BoolBinOp, l: Boolean, r: Boolean): Boolean = b match { diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index b764b4746..7750036ba 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -7,7 +7,7 @@ import util.Logger import util.functional.* import util.functional.State.* import boogie.Scope -import scala.collection.WithFilter +import collection.mutable.ArrayBuffer import scala.annotation.tailrec import scala.collection.mutable @@ -262,7 +262,8 @@ object InterpFuns { (offset, FunPointer(BitVecLiteral(addr, 64), proc.name, Run(DirectCall(proc)))) }) - val stores = fptrs.map((p) => { + // sort for deterministic trace + val stores = fptrs.sortBy(f => f._1).map((p) => { val (offset, fptr) = p Eval.storeSingle[S, E, T](s)( "mem", @@ -297,7 +298,7 @@ object InterpFuns { } def initialiseProgram[S, T <: Effects[S, InterpreterError]](f: T)(p: Program): State[S, Unit, InterpreterError] = { - def initMemory(mem: String, mems: Iterable[MemorySection]) = { + def initMemory(mem: String, mems: ArrayBuffer[MemorySection]) = { for { m <- State.sequence( State.pure(()), @@ -468,10 +469,14 @@ object InterpFuns { } } - def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext, is: S): S = { + def initProgState[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext, is: S): S = { val st = (initialiseProgram(f)(p.program) >> initBSS(f)(p)) >> InterpFuns.initRelocTable(f)(p.program, p.externalFunctions.map(f => (f.offset, f.name))) - val begin = State.execute(is, st) + State.execute(is, st) + } + + def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext, is: S): S = { + val begin = initProgState(f)(p, is) interpret(f, begin) } } diff --git a/src/main/scala/ir/eval/InterpretTrace.scala b/src/main/scala/ir/eval/InterpretTrace.scala index c9565d51e..413222753 100644 --- a/src/main/scala/ir/eval/InterpretTrace.scala +++ b/src/main/scala/ir/eval/InterpretTrace.scala @@ -69,6 +69,8 @@ def interpretTrace(p: Program) : (InterpreterState, Trace) = { def interpretTrace(p: IRContext) : (InterpreterState, Trace) = { - InterpFuns.interpretProg(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) + val b = InterpFuns.initProgState(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) + // Throw away the trace of program initialisation + InterpFuns.interpret(tracingInterpreter, (b._1, Trace(List()))) } diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index ab716d101..dc7139d29 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -418,8 +418,3 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { } yield () } - -// def interpretTrace(IRProgram: Program): TracingInterpreter = { -// val s: TracingInterpreter = InterpFuns.interpretProg(IRProgram, TracingInterpreter(InterpreterState(), List())) -// s -//e From 9192ce37ee71e600c310baf294be7c897be04836 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Tue, 3 Sep 2024 15:58:08 +1000 Subject: [PATCH 34/62] intrinsic stub and cleanup errors --- src/main/scala/ir/eval/InterpretBasilIR.scala | 99 +++++++++---------- .../scala/ir/eval/InterpretBreakpoints.scala | 2 +- src/main/scala/ir/eval/Interpreter.scala | 75 ++++++++------ 3 files changed, 90 insertions(+), 86 deletions(-) diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index 7750036ba..c31ac1c53 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -75,7 +75,7 @@ case object Eval { res <- evalExpr(f)(e) r <- State.pureE(res match { case l: BitVecLiteral => Right(l) - case _ => Left(InterpreterError(Errored(s"Eval BV residual $e"))) + case _ => Left((Errored(s"Eval BV residual $e"))) }) } yield (r) } @@ -85,7 +85,7 @@ case object Eval { res <- evalExpr(f)(e) r <- State.pureE(res match { case l: IntLiteral => Right(l.value) - case _ => Left(InterpreterError(Errored(s"Eval Int residual $e"))) + case _ => Left((Errored(s"Eval Int residual $e"))) }) } yield (r) } @@ -95,7 +95,7 @@ case object Eval { res <- evalExpr(f)(e) r <- State.pureE(res match { case l: BoolLit => Right(l == TrueLiteral) - case _ => Left(InterpreterError(Errored(s"Eval Bool residual $e"))) + case _ => Left((Errored(s"Eval Bool residual $e"))) }) } yield (r) } @@ -109,7 +109,7 @@ case object Eval { )(vname: String, addr: Scalar, endian: Endian, count: Int): State[S, List[BasilValue], InterpreterError] = { for { _ <- - if (count == 0) then State.setError(InterpreterError(Errored(s"Attempted fractional load"))) else State.pure(()) + if (count == 0) then State.setError((Errored(s"Attempted fractional load"))) else State.pure(()) keys <- State.mapM(((i: Int) => State.pureE(BasilValue.unsafeAdd(addr, i))), (0 until count)) values <- f.loadMem(vname, keys.toList) vals = endian match { @@ -126,7 +126,7 @@ case object Eval { mem <- f.loadVar(vname) x <- mem match { case mapv @ MapValue(_, MapType(_, BitVecType(sz))) => State.pure((sz, mapv)) - case _ => State.setError(InterpreterError(Errored("Trued to load-concat non bv"))) + case _ => State.setError((Errored("Trued to load-concat non bv"))) } (valsize, mapv) = x @@ -140,7 +140,7 @@ case object Eval { case Scalar(bv @ BitVecLiteral(v, sz)) if sz == valsize => State.pure(bv) case c => State.setError( - InterpreterError(TypeError(s"Loaded value of type ${c.irType} did not match expected type bv$valsize")) + TypeError(s"Loaded value of type ${c.irType} did not match expected type bv$valsize") ) }, res @@ -171,7 +171,7 @@ case object Eval { x <- mem match { case m @ MapValue(_, MapType(kt, vt)) if kt == addr.irType && values.forall(v => v.irType == vt) => State.pure((m, kt, vt)) - case v => State.setError(InterpreterError(TypeError(s"Invalid map store operation to $vname : $v"))) + case v => State.setError((TypeError(s"Invalid map store operation to $vname : $v"))) } (mapval, keytype, valtype) = x keys <- State.mapM((i: Int) => State.pureE(BasilValue.unsafeAdd(addr, i)), (0 until values.size)) @@ -190,19 +190,22 @@ case object Eval { endian: Endian ): State[S, Unit, InterpreterError] = for { mem <- f.loadVar(vname) - (mapval, vsize) = mem match { - case m @ MapValue(_, MapType(kt, BitVecType(size))) if kt == addr.irType => (m, size) + mr <- mem match { + case m @ MapValue(_, MapType(kt, BitVecType(size))) if kt == addr.irType => State.pure((m, size)) case v => - throw InterpreterError( + State.setError( TypeError( s"Invalid map store operation to $vname : ${v.irType} (expect [${addr.irType}] <- ${value.getType})" ) ) } + (mapval, vsize) = mr cells = value.size / vsize _ = { if (cells < 1) { - throw InterpreterError(MemoryError("Tried to execute fractional store")) + State.setError((MemoryError("Tried to execute fractional store"))) + } else { + State.pure(()) } } @@ -223,21 +226,6 @@ case object Eval { } } -object LibcIntrinsic { - - def putc[S, E, T <: Effects[S, E]](s: T)(f: BitVecLiteral, c: BitVecLiteral): State[S, Unit, E] = { - State.pure(()) - } - - def puts[S, E, T <: Effects[S, E]](s: T)(f: BitVecLiteral, str: BitVecLiteral): State[S, Unit, E] = { - State.pure(()) - } - - def printf[S, E, T <: Effects[S, E]](s: T)(f: BitVecLiteral, str: BitVecLiteral): State[S, Unit, E] = { - State.pure(()) - } -} - object InterpFuns { def initRelocTable[S, E, T <: Effects[S, E]](s: T)(p: Program, reladyn: Set[(BigInt, String)]): State[S, Unit, E] = { @@ -263,14 +251,16 @@ object InterpFuns { }) // sort for deterministic trace - val stores = fptrs.sortBy(f => f._1).map((p) => { - val (offset, fptr) = p - Eval.storeSingle[S, E, T](s)( - "mem", - Scalar(BitVecLiteral(offset, 64)), - fptr - ) - }) + val stores = fptrs + .sortBy(f => f._1) + .map((p) => { + val (offset, fptr) = p + Eval.storeSingle[S, E, T](s)( + "mem", + Scalar(BitVecLiteral(offset, 64)), + fptr + ) + }) State.sequence[S, Unit, E](State.pure(()), stores) } @@ -350,20 +340,23 @@ object InterpFuns { val assumes = gt.targets.flatMap(_.statements.headOption).collect { case a: Assume => a } - if (assumes.size != gt.targets.size) { - throw InterpreterError(Errored(s"Some goto target missing guard $gt")) - } for { + _ <- + if (assumes.size != gt.targets.size) { + State.setError((Errored(s"Some goto target missing guard $gt"))) + } else { + State.pure(()) + } chosen: List[Assume] <- filterM((a: Assume) => Eval.evalBool(f)(a.body), assumes) res <- chosen match { - case Nil => f.setNext(Errored(s"No jump target satisfied $gt")) + case Nil => State.setError(Errored(s"No jump target satisfied $gt")) case h :: Nil => f.setNext(Run(h)) - case h :: tl => f.setNext(Errored(s"More than one jump guard satisfied $gt")) + case h :: tl => State.setError(Errored(s"More than one jump guard satisfied $gt")) } } yield (res) case r: Return => f.doReturn() - case h: Unreachable => f.setNext(EscapedControlFlow(h)) + case h: Unreachable => State.setError(EscapedControlFlow(h)) } } @@ -398,7 +391,7 @@ object InterpFuns { b <- Eval.evalBool(f)(assume.body) n <- (if (!b) { - f.setNext(Errored(s"Assumption not satisfied: $assume")) + State.setError(Errored(s"Assumption not satisfied: $assume")) } else { f.setNext(Run(s.successor)) }) @@ -409,9 +402,10 @@ object InterpFuns { if (dc.target.entryBlock.isDefined) { val block = dc.target.entryBlock.get f.call(dc.target.name, Run(block.statements.headOption.getOrElse(block.jump)), Run(dc.successor)) + } else if (LibcIntrinsic.intrinsics.contains(dc.target.name)) { + f.call(dc.target.name, CallIntrinsic(dc.target.name), Run(dc.successor)) } else { - State.setError(InterpreterError(EscapedControlFlow(dc))) - //f.setNext(Run(dc.successor)) + State.setError(EscapedControlFlow(dc)) } } yield (n) case ic: IndirectCall => { @@ -423,7 +417,7 @@ object InterpFuns { fp <- f.evalAddrToProc(addr.value.toInt) _ <- fp match { case Some(fp) => f.call(fp.name, fp.call, Run(ic.successor)) - case none => State.setError(InterpreterError(EscapedControlFlow(ic))) + case none => State.setError(EscapedControlFlow(ic)) } } yield () } @@ -433,17 +427,14 @@ object InterpFuns { } def interpret[S, E, T <: Effects[S, E]](f: T, m: S): S = { - State.evaluate(m, f.getNext) match { - case Right(next) => { - Logger.debug(s"eval $next") - next match { - case Run(c) => interpret(f, State.execute(m, f.interpretOne)) - case Stopped() => m - case errorstop => m - } - } - case Left(err) => m + // run to fixed point + var o = m + var n = State.execute(o, f.interpretOne) + while (o != n) { + o = n + n = State.execute(o, f.interpretOne) } + n } def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: Program, is: S): S = { diff --git a/src/main/scala/ir/eval/InterpretBreakpoints.scala b/src/main/scala/ir/eval/InterpretBreakpoints.scala index b00622440..8121d139d 100644 --- a/src/main/scala/ir/eval/InterpretBreakpoints.scala +++ b/src/main/scala/ir/eval/InterpretBreakpoints.scala @@ -46,7 +46,7 @@ case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, v ev <- doLeft(Eval.evalExpr(f)(e._2)) } yield (e._1, e._2, ev) , action.evalExprs)) - _ <- if action.stop then doLeft(f.setNext(Errored(s"Stopped at breakpoint ${name}"))) else doLeft(State.pure(())) + _ <- if action.stop then doLeft(State.setError(Errored(s"Stopped at breakpoint ${name}"))) else doLeft(State.pure(())) _ <- State.pure({ if (action.log) { val bpn = breakpoint.name diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index dc7139d29..04ced6d96 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -18,18 +18,18 @@ sealed trait ExecutionContinuation case class FailedAssertion(a: Assert) extends ExecutionContinuation case class Stopped() extends ExecutionContinuation /* normal program stop */ +case class ErrorStop(error: InterpreterError) extends ExecutionContinuation /* normal program stop */ case class Run(val next: Command) extends ExecutionContinuation /* continue by executing next command */ +case class CallIntrinsic(val name: String) extends ExecutionContinuation /* continue by executing next command */ +sealed trait InterpreterError case class EscapedControlFlow(val call: Jump | Call) - extends ExecutionContinuation /* controlflow has reached somewhere eunrecoverable */ - -case class Errored(val message: String = "") extends ExecutionContinuation -case class TypeError(val message: String = "") extends ExecutionContinuation /* type mismatch appeared */ + extends InterpreterError /* controlflow has reached somewhere eunrecoverable */ +case class Errored(val message: String = "") extends InterpreterError +case class TypeError(val message: String = "") extends InterpreterError /* type mismatch appeared */ case class EvalError(val message: String = "") - extends ExecutionContinuation /* failed to evaluate an expression to a concrete value */ -case class MemoryError(val message: String = "") extends ExecutionContinuation /* An error to do with memory */ - -case class InterpreterError(continue: ExecutionContinuation) extends Exception() + extends InterpreterError /* failed to evaluate an expression to a concrete value */ +case class MemoryError(val message: String = "") extends InterpreterError /* An error to do with memory */ /* Concrete value type of the interpreter. */ sealed trait BasilValue(val irType: IRType) @@ -68,7 +68,7 @@ case object BasilValue { case _ if vr == 0 => Right(l) case Scalar(IntLiteral(vl)) => Right(Scalar(IntLiteral(vl + vr))) case Scalar(b1: BitVecLiteral) => Right(Scalar(eval.evalBVBinExpr(BVADD, b1, BitVecLiteral(vr, b1.size)))) - case _ => Left(InterpreterError(TypeError(s"Operation add $vr undefined on $l"))) + case _ => Left((TypeError(s"Operation add $vr undefined on $l"))) } } } @@ -171,8 +171,8 @@ case class MemoryState( def popStackFrame(): Either[InterpreterError, MemoryState] = { val hv = activations match { - case Nil => Left(InterpreterError(Errored("No stack frame to pop"))) - case h :: Nil if h == globalFrame => Left(InterpreterError(Errored("tried to pop global scope"))) + case Nil => Left((Errored("No stack frame to pop"))) + case h :: Nil if h == globalFrame => Left((Errored("tried to pop global scope"))) case h :: tl => Right((h, tl)) } hv.map((hv) => { @@ -221,24 +221,20 @@ case class MemoryState( def findVar(name: String): Either[InterpreterError, (StackFrameID, BasilValue)] = { findVarOpt(name: String) .map(Right(_)) - .getOrElse(Left(InterpreterError(Errored(s"Access to undefined variable $name")))) + .getOrElse(Left((Errored(s"Access to undefined variable $name")))) } def getVarOpt(name: String): Option[BasilValue] = findVarOpt(name).map(_._2) def getVar(name: String): Either[InterpreterError, BasilValue] = { - getVarOpt(name).map(Right(_)).getOrElse(Left(InterpreterError(Errored(s"Access undefined variable $name")))) + getVarOpt(name).map(Right(_)).getOrElse(Left((Errored(s"Access undefined variable $name")))) } def getVar(v: Variable): Either[InterpreterError, BasilValue] = { val value = getVar(v.name) value match { case Right(dv: BasilValue) if v.getType != dv.irType => - Left( - InterpreterError( - Errored(s"Type mismatch on variable definition and load: defined ${dv.irType}, variable ${v.getType}") - ) - ) + Left(Errored(s"Type mismatch on variable definition and load: defined ${dv.irType}, variable ${v.getType}")) case Right(o) => Right(o) case o => o } @@ -249,14 +245,14 @@ case class MemoryState( v <- findVar(vname) mapv: MapValue <- v._2 match { case m @ MapValue(innerMap, ty) => Right(m) - case m => Left(InterpreterError(TypeError(s"Load from nonmap ${m.irType}"))) + case m => Left((TypeError(s"Load from nonmap ${m.irType}"))) } rs: List[Option[BasilValue]] = addr.map(k => mapv.value.get(k)) xs <- (if (rs.forall(_.isDefined)) { Right(rs.map(_.get)) } else { - Left(InterpreterError(MemoryError(s"Read from uninitialised $vname[${addr.head} .. ${addr.last}]"))) + Left((MemoryError(s"Read from uninitialised $vname[${addr.head} .. ${addr.last}]"))) }) } yield (xs) @@ -268,17 +264,15 @@ case class MemoryState( // val (mapval, keytype, valtype) = mapi <- mem match { case m @ MapValue(_, MapType(kt, vt)) => Right((m, kt, vt)) - case v => Left(InterpreterError(TypeError(s"Invalid map store operation to $vname : ${v.irType}"))) + case v => Left((TypeError(s"Invalid map store operation to $vname : ${v.irType}"))) } (mapval, keytype, valtype) = mapi checkTypes <- (values.find((k, v) => k.irType != keytype || v.irType != valtype)) match { case Some(v) => Left( - InterpreterError( - TypeError( - s"Invalid addr or value type (${v._1.irType}, ${v._2.irType}) does not match map type $vname : ($keytype, $valtype)" - ) + TypeError( + s"Invalid addr or value type (${v._1.irType}, ${v._2.irType}) does not match map type $vname : ($keytype, $valtype)" ) ) case None => Right(()) @@ -289,6 +283,25 @@ case class MemoryState( } yield (ms) } +object LibcIntrinsic { + + def putc[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = { + s.doReturn() + } + + def puts[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = { + s.doReturn() + } + + def printf[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = { + s.doReturn() + } + + def intrinsics[S, E, T <: Effects[S, E]] = + Map[String, T => State[S, Unit, E]]("putc" -> putc, "puts" -> puts, "printf" -> printf) + +} + case class InterpreterState( val nextCmd: ExecutionContinuation = Stopped(), val callStack: List[ExecutionContinuation] = List.empty, @@ -367,7 +380,6 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { /** effects * */ def setNext(c: ExecutionContinuation) = State.modify((s: InterpreterState) => { - // Logger.debug(s" eff : setNext $c") s.copy(nextCmd = c) }) @@ -410,11 +422,12 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { def interpretOne: State[InterpreterState, Unit, InterpreterError] = for { next <- getNext _ <- (next match { - case Run(c: Statement) => InterpFuns.interpretStatement(this)(c) - case Run(c: Jump) => InterpFuns.interpretJump(this)(c) - case Stopped() => State.pure(()) - case errorstop => State.pure(()) - }).flatMapE((e: InterpreterError) => setNext(e.continue)) + case CallIntrinsic(tgt) => LibcIntrinsic.intrinsics(tgt)(this) + case Run(c: Statement) => InterpFuns.interpretStatement(this)(c) + case Run(c: Jump) => InterpFuns.interpretJump(this)(c) + case Stopped() => State.pure(()) + case ErrorStop(e) => State.pure(()) + }).flatMapE((e: InterpreterError) => setNext(ErrorStop(e))) } yield () } From 534235893e13e9e2921106b9ed77372032cf6aaf Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Tue, 3 Sep 2024 17:24:28 +1000 Subject: [PATCH 35/62] init relocation table --- src/main/scala/ir/eval/InterpretBasilIR.scala | 53 +++++++++++-------- src/main/scala/ir/eval/InterpretTrace.scala | 13 ++--- src/test/scala/DifferentialAnalysis.scala | 5 +- 3 files changed, 39 insertions(+), 32 deletions(-) diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index c31ac1c53..19485c816 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -228,26 +228,33 @@ case object Eval { object InterpFuns { - def initRelocTable[S, E, T <: Effects[S, E]](s: T)(p: Program, reladyn: Set[(BigInt, String)]): State[S, Unit, E] = { + def initRelocTable[S, T <: Effects[S, InterpreterError]](s: T)(ctx: IRContext): State[S, Unit, InterpreterError] = { - val data = reladyn.toList.flatMap(r => { - val (offset, extfname) = r - p.procedures.find(proc => proc.name == extfname).map(p => (offset, p)).toList - }) + val p = ctx.program - // TODO: will have to store - // mem[rodata addr] = naddr - // ghost-funtable[naddr] = FunPointer - to intrinsic function - // We could also dynamic link against something like musl, for things like string.h + val base = ctx.symbols.find(_.name == "__end__").get + var addr = base.value + var done = false + var x = List[(String, FunPointer)]() - val fptrs = data.map((p) => { - val (offset, proc) = p - val addr = proc.address match { - case Some(x) => x - case None => /* println(s"No address for function ${proc.name} ${"%x".format(offset)}"); */ 0 - } - // proc.address will be undefined and we will have to choose one for our intrinsic libc funcs - (offset, FunPointer(BitVecLiteral(addr, 64), proc.name, Run(DirectCall(proc)))) + def newAddr() : BigInt = { + addr += 8 + addr + } + + for ((fname, fun) <- LibcIntrinsic.intrinsics) { + val name = fname.takeWhile(c => c != '@') + println(name) + x = (name, FunPointer(BitVecLiteral(newAddr(), 64), name, CallIntrinsic(name))) :: x + } + + val intrinsics = x.toMap + + val procs = p.procedures.filter(proc => proc.address.isDefined) + + val fptrs = ctx.externalFunctions.toList.sortBy(_.name).flatMap(f => { + intrinsics.get(f.name).map(fp => (f.offset, fp)) + .orElse(procs.find(p => p.name == f.name).map(proc => (f.offset, FunPointer(BitVecLiteral(proc.address.getOrElse(newAddr().toInt), 64), proc.name, Run(DirectCall(proc)))))) }) // sort for deterministic trace @@ -255,14 +262,16 @@ object InterpFuns { .sortBy(f => f._1) .map((p) => { val (offset, fptr) = p - Eval.storeSingle[S, E, T](s)( + Eval.storeSingle(s)("ghost-funtable", Scalar(fptr.addr), fptr) + >> (Eval.storeBV(s)( "mem", Scalar(BitVecLiteral(offset, 64)), - fptr - ) + fptr.addr, + Endian.LittleEndian + )) }) - State.sequence[S, Unit, E](State.pure(()), stores) + State.sequence(State.pure(()), stores) } /** Functions which compile BASIL IR down to the minimal interpreter effects. @@ -462,7 +471,7 @@ object InterpFuns { def initProgState[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext, is: S): S = { val st = (initialiseProgram(f)(p.program) >> initBSS(f)(p)) - >> InterpFuns.initRelocTable(f)(p.program, p.externalFunctions.map(f => (f.offset, f.name))) + >> InterpFuns.initRelocTable(f)(p) State.execute(is, st) } diff --git a/src/main/scala/ir/eval/InterpretTrace.scala b/src/main/scala/ir/eval/InterpretTrace.scala index 413222753..5c568561a 100644 --- a/src/main/scala/ir/eval/InterpretTrace.scala +++ b/src/main/scala/ir/eval/InterpretTrace.scala @@ -35,9 +35,9 @@ case object Trace { case class TraceGen[E]() extends NopEffects[Trace, E] { /** Values are discarded by ProductInterpreter so do not matter */ - override def loadVar(v: String) = for { - s <- Trace.add(ExecEffect.LoadVar(v)) - } yield (Scalar(FalseLiteral)) + // override def loadVar(v: String) = for { + // s <- Trace.add(ExecEffect.LoadVar(v)) + // } yield (Scalar(FalseLiteral)) override def loadMem(v: String, addrs: List[BasilValue]) = for { s <- Trace.add(ExecEffect.LoadMem(v, addrs)) @@ -52,7 +52,7 @@ case class TraceGen[E]() extends NopEffects[Trace, E] { } yield (()) override def storeVar(v: String, scope: Scope, value: BasilValue) = for { - s <- Trace.add(ExecEffect.StoreVar(v, scope, value)) + s <- if (!v.startsWith("ghost")) Trace.add(ExecEffect.StoreVar(v, scope, value)) else State.pure(()) } yield (()) override def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = for { @@ -67,10 +67,7 @@ def interpretTrace(p: Program) : (InterpreterState, Trace) = { InterpFuns.interpretProg(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) } - def interpretTrace(p: IRContext) : (InterpreterState, Trace) = { - val b = InterpFuns.initProgState(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) - // Throw away the trace of program initialisation - InterpFuns.interpret(tracingInterpreter, (b._1, Trace(List()))) + InterpFuns.interpretProg(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) } diff --git a/src/test/scala/DifferentialAnalysis.scala b/src/test/scala/DifferentialAnalysis.scala index 3de9450b5..85e0bfe74 100644 --- a/src/test/scala/DifferentialAnalysis.scala +++ b/src/test/scala/DifferentialAnalysis.scala @@ -20,7 +20,7 @@ import scala.collection.mutable class DifferentialAnalysis extends AnyFunSuite { - Logger.setLevel(LogLevel.WARN) + Logger.setLevel(LogLevel.DEBUG) def diffTest(initial: IRContext, transformed: IRContext) = { val (initialRes,traceInit) = interpretTrace(initial) @@ -29,11 +29,12 @@ class DifferentialAnalysis extends AnyFunSuite { def filterEvents(trace: List[ExecEffect]) = { trace.collect { + case e @ ExecEffect.Call(_, _, _) => e case e @ ExecEffect.StoreMem("mem", _) => e case e @ ExecEffect.LoadMem("mem", _) => e } } - // println(traceInit.t.mkString("\n ")) + println((traceInit.t).mkString("\n")) assert(initialRes.nextCmd == Stopped()) assert(result.nextCmd == Stopped()) assert(Set.empty == initialRes.memoryState.getMem("mem").toSet.diff(result.memoryState.getMem("mem").toSet)) From 87b0f86d3632290c00a5a66a8df61a6b085ca18b Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 4 Sep 2024 11:32:00 +1000 Subject: [PATCH 36/62] pull stepper outside effects to fix interpreter composition again --- src/main/scala/ir/eval/InterpretBasilIR.scala | 233 +++++++++--------- .../scala/ir/eval/InterpretBreakpoints.scala | 4 +- src/main/scala/ir/eval/Interpreter.scala | 38 +-- .../scala/ir/eval/InterpreterProduct.scala | 17 +- src/main/scala/util/RunUtils.scala | 15 +- src/test/scala/DifferentialAnalysis.scala | 15 +- 6 files changed, 166 insertions(+), 156 deletions(-) diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index 19485c816..1c8c5a0c9 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -226,6 +226,116 @@ case object Eval { } } + +class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S, InterpreterError](f) { + + def interpretOne: State[S, Unit, InterpreterError] = for { + next <- f.getNext + _ <- (next match { + case CallIntrinsic(tgt) => LibcIntrinsic.intrinsics(tgt)(f) + case Run(c: Statement) => interpretStatement(f)(c) + case Run(c: Jump) => interpretJump(f)(c) + case Stopped() => State.pure(()) + case ErrorStop(e) => State.pure(()) + }).flatMapE((e: InterpreterError) => f.setNext(ErrorStop(e))) + } yield () + + def interpretJump[S, T <: Effects[S, InterpreterError]](f: T)(j: Jump): State[S, Unit, InterpreterError] = { + j match { + case gt: GoTo if gt.targets.size == 1 => { + f.setNext(Run(IRWalk.firstInBlock(gt.targets.head))) + } + case gt: GoTo => + val assumes = gt.targets.flatMap(_.statements.headOption).collect { case a: Assume => + a + } + for { + _ <- + if (assumes.size != gt.targets.size) { + State.setError((Errored(s"Some goto target missing guard $gt"))) + } else { + State.pure(()) + } + chosen: List[Assume] <- filterM((a: Assume) => Eval.evalBool(f)(a.body), assumes) + + res <- chosen match { + case Nil => State.setError(Errored(s"No jump target satisfied $gt")) + case h :: Nil => f.setNext(Run(h)) + case h :: tl => State.setError(Errored(s"More than one jump guard satisfied $gt")) + } + } yield (res) + case r: Return => f.doReturn() + case h: Unreachable => State.setError(EscapedControlFlow(h)) + } + } + + def interpretStatement[S, T <: Effects[S, InterpreterError]](f: T)(s: Statement): State[S, Unit, InterpreterError] = { + s match { + case assign: Assign => { + for { + rhs <- Eval.evalBV(f)(assign.rhs) + st <- f.storeVar(assign.lhs.name, assign.lhs.toBoogie.scope, Scalar(rhs)) + n <- f.setNext(Run(s.successor)) + } yield (st) + } + case assign: MemoryAssign => + for { + index: BitVecLiteral <- Eval.evalBV(f)(assign.index) + value: BitVecLiteral <- Eval.evalBV(f)(assign.value) + _ <- Eval.storeBV(f)(assign.mem.name, Scalar(index), value, assign.endian) + n <- f.setNext(Run(s.successor)) + } yield (n) + case assert: Assert => + for { + b <- Eval.evalBool(f)(assert.body) + n <- + (if (!b) then { + f.setNext(FailedAssertion(assert)) + } else { + f.setNext(Run(s.successor)) + }) + } yield (n) + case assume: Assume => + for { + b <- Eval.evalBool(f)(assume.body) + n <- + (if (!b) { + State.setError(Errored(s"Assumption not satisfied: $assume")) + } else { + f.setNext(Run(s.successor)) + }) + } yield (n) + case dc: DirectCall => + for { + n <- + if (dc.target.entryBlock.isDefined) { + val block = dc.target.entryBlock.get + f.call(dc.target.name, Run(block.statements.headOption.getOrElse(block.jump)), Run(dc.successor)) + } else if (LibcIntrinsic.intrinsics.contains(dc.target.name)) { + f.call(dc.target.name, CallIntrinsic(dc.target.name), Run(dc.successor)) + } else { + State.setError(EscapedControlFlow(dc)) + } + } yield (n) + case ic: IndirectCall => { + if (ic.target == Register("R30", 64)) { + f.doReturn() + } else { + for { + addr <- Eval.evalBV(f)(ic.target) + fp <- f.evalAddrToProc(addr.value.toInt) + _ <- fp match { + case Some(fp) => f.call(fp.name, fp.call, Run(ic.successor)) + case none => State.setError(EscapedControlFlow(ic)) + } + } yield () + } + } + case _: NOP => f.setNext(Run(s.successor)) + } + } +} + object InterpFuns { def initRelocTable[S, T <: Effects[S, InterpreterError]](s: T)(ctx: IRContext): State[S, Unit, InterpreterError] = { @@ -244,7 +354,6 @@ object InterpFuns { for ((fname, fun) <- LibcIntrinsic.intrinsics) { val name = fname.takeWhile(c => c != '@') - println(name) x = (name, FunPointer(BitVecLiteral(newAddr(), 64), name, CallIntrinsic(name))) :: x } @@ -340,117 +449,6 @@ object InterpFuns { } yield (r) } - def interpretJump[S, T <: Effects[S, InterpreterError]](f: T)(j: Jump): State[S, Unit, InterpreterError] = { - j match { - case gt: GoTo if gt.targets.size == 1 => { - f.setNext(Run(IRWalk.firstInBlock(gt.targets.head))) - } - case gt: GoTo => - val assumes = gt.targets.flatMap(_.statements.headOption).collect { case a: Assume => - a - } - for { - _ <- - if (assumes.size != gt.targets.size) { - State.setError((Errored(s"Some goto target missing guard $gt"))) - } else { - State.pure(()) - } - chosen: List[Assume] <- filterM((a: Assume) => Eval.evalBool(f)(a.body), assumes) - - res <- chosen match { - case Nil => State.setError(Errored(s"No jump target satisfied $gt")) - case h :: Nil => f.setNext(Run(h)) - case h :: tl => State.setError(Errored(s"More than one jump guard satisfied $gt")) - } - } yield (res) - case r: Return => f.doReturn() - case h: Unreachable => State.setError(EscapedControlFlow(h)) - } - } - - def interpretStatement[S, T <: Effects[S, InterpreterError]](f: T)(s: Statement): State[S, Unit, InterpreterError] = { - s match { - case assign: Assign => { - for { - rhs <- Eval.evalBV(f)(assign.rhs) - st <- f.storeVar(assign.lhs.name, assign.lhs.toBoogie.scope, Scalar(rhs)) - n <- f.setNext(Run(s.successor)) - } yield (st) - } - case assign: MemoryAssign => - for { - index: BitVecLiteral <- Eval.evalBV(f)(assign.index) - value: BitVecLiteral <- Eval.evalBV(f)(assign.value) - _ <- Eval.storeBV(f)(assign.mem.name, Scalar(index), value, assign.endian) - n <- f.setNext(Run(s.successor)) - } yield (n) - case assert: Assert => - for { - b <- Eval.evalBool(f)(assert.body) - n <- - (if (!b) then { - f.setNext(FailedAssertion(assert)) - } else { - f.setNext(Run(s.successor)) - }) - } yield (n) - case assume: Assume => - for { - b <- Eval.evalBool(f)(assume.body) - n <- - (if (!b) { - State.setError(Errored(s"Assumption not satisfied: $assume")) - } else { - f.setNext(Run(s.successor)) - }) - } yield (n) - case dc: DirectCall => - for { - n <- - if (dc.target.entryBlock.isDefined) { - val block = dc.target.entryBlock.get - f.call(dc.target.name, Run(block.statements.headOption.getOrElse(block.jump)), Run(dc.successor)) - } else if (LibcIntrinsic.intrinsics.contains(dc.target.name)) { - f.call(dc.target.name, CallIntrinsic(dc.target.name), Run(dc.successor)) - } else { - State.setError(EscapedControlFlow(dc)) - } - } yield (n) - case ic: IndirectCall => { - if (ic.target == Register("R30", 64)) { - f.doReturn() - } else { - for { - addr <- Eval.evalBV(f)(ic.target) - fp <- f.evalAddrToProc(addr.value.toInt) - _ <- fp match { - case Some(fp) => f.call(fp.name, fp.call, Run(ic.successor)) - case none => State.setError(EscapedControlFlow(ic)) - } - } yield () - } - } - case _: NOP => f.setNext(Run(s.successor)) - } - } - - def interpret[S, E, T <: Effects[S, E]](f: T, m: S): S = { - // run to fixed point - var o = m - var n = State.execute(o, f.interpretOne) - while (o != n) { - o = n - n = State.execute(o, f.interpretOne) - } - n - } - - def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: Program, is: S): S = { - val begin = State.execute(is, initialiseProgram(f)(p)) - interpret(f, begin) - } - def initBSS[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext): State[S, Unit, InterpreterError] = { val bss = for { first <- p.symbols.find(s => s.name == "__bss_start__").map(_.value) @@ -475,9 +473,18 @@ object InterpFuns { State.execute(is, st) } + /* Intialise from ELF and Interpret program */ def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext, is: S): S = { val begin = initProgState(f)(p, is) - interpret(f, begin) + val interp = BASILInterpreter(f) + interp.run(begin) + } + + /* Interpret IR program */ + def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: Program, is: S): S = { + val begin = State.execute(is, initialiseProgram(f)(p)) + val interp = BASILInterpreter(f) + interp.run(begin) } } diff --git a/src/main/scala/ir/eval/InterpretBreakpoints.scala b/src/main/scala/ir/eval/InterpretBreakpoints.scala index 8121d139d..38584add2 100644 --- a/src/main/scala/ir/eval/InterpretBreakpoints.scala +++ b/src/main/scala/ir/eval/InterpretBreakpoints.scala @@ -33,7 +33,7 @@ case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, v }, breaks) } - override def interpretOne : State[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), Unit, InterpreterError] = for { + override def getNext: State[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), ExecutionContinuation, InterpreterError] = for { v : ExecutionContinuation <- doLeft(f.getNext) n <- v match { case Run(s) => for { @@ -68,7 +68,7 @@ case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, v } yield () case _ => State.pure(()) } - } yield () + } yield (v) } diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 04ced6d96..21ebb047a 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -76,9 +76,7 @@ case object BasilValue { /** Minimal language defining all state transitions in the interpreter, defined for the interpreter's concrete state T. */ trait Effects[T, E] { - - // perform an execution step - def interpretOne: State[T, Unit, E] + /* expression eval */ def loadVar(v: String): State[T, BasilValue, E] @@ -108,7 +106,6 @@ trait Effects[T, E] { } trait NopEffects[T, E] extends Effects[T, E] { - def interpretOne = State.pure(()) def loadVar(v: String) = State.pure(Scalar(FalseLiteral)) def loadMem(v: String, addrs: List[BasilValue]) = State.pure(List()) def evalAddrToProc(addr: Int) = State.pure(None) @@ -418,16 +415,29 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { ms <- s.memoryState.doStore(vname, update) } yield (s.copy(memoryState = ms)) }) +} + +trait Interpreter[S, E](val f: Effects[S, E]) { + + def interpretOne: State[S, Unit, E] - def interpretOne: State[InterpreterState, Unit, InterpreterError] = for { - next <- getNext - _ <- (next match { - case CallIntrinsic(tgt) => LibcIntrinsic.intrinsics(tgt)(this) - case Run(c: Statement) => InterpFuns.interpretStatement(this)(c) - case Run(c: Jump) => InterpFuns.interpretJump(this)(c) - case Stopped() => State.pure(()) - case ErrorStop(e) => State.pure(()) - }).flatMapE((e: InterpreterError) => setNext(ErrorStop(e))) - } yield () + @tailrec + final def run(begin: S): S = { + val c = for { + _ <- interpretOne + x <- f.getNext + continue = x match { + case Stopped() | ErrorStop(_) => false + case _ => true + } + } yield (continue) + + val (fs,cont) = c.f(begin) + if (cont.contains(true)) then { + run(fs) + } else { + fs + } + } } diff --git a/src/main/scala/ir/eval/InterpreterProduct.scala b/src/main/scala/ir/eval/InterpreterProduct.scala index 7383e46d5..929c69425 100644 --- a/src/main/scala/ir/eval/InterpreterProduct.scala +++ b/src/main/scala/ir/eval/InterpreterProduct.scala @@ -16,27 +16,23 @@ import scala.util.control.Breaks.{break, breakable} def doLeft[L, T, V, E](f: State[L, V, E]) : State[(L, T), V, E] = for { - f <- State[(L, T), V, E]((s: (L, T)) => { + n <- State[(L, T), V, E]((s: (L, T)) => { val r = f.f(s._1) ((r._1, s._2), r._2) }) -} yield (f) +} yield (n) def doRight[L, T, V, E](f: State[T, V, E]) : State[(L, T), V, E] = for { - f <- State[(L, T), V, E]((s: (L, T)) => { + n <- State[(L, T), V, E]((s: (L, T)) => { val r = f.f(s._2) ((s._1, r._1), r._2) }) -} yield (f) +} yield (n) /** * Runs two interpreters "inner" and "before" simultaneously, returning the value from inner, and ignoring before */ case class ProductInterpreter[L, T, E](val inner: Effects[L, E], val before: Effects[T, E]) extends Effects[(L, T), E] { - def interpretOne = for { - n <- doRight(before.interpretOne) - f <- doLeft(inner.interpretOne) - } yield () def loadVar(v: String) = for { n <- doRight(before.loadVar(v)) @@ -88,11 +84,6 @@ case class ProductInterpreter[L, T, E](val inner: Effects[L, E], val before: Eff case class LayerInterpreter[L, T, E](val inner: Effects[L, E], val before: Effects[(L, T), E]) extends Effects[(L, T), E] { - def interpretOne = for { - n <- (before.interpretOne) - f <- doLeft(inner.interpretOne) - } yield () - def loadVar(v: String) = for { n <- (before.loadVar(v)) f <- doLeft(inner.loadVar(v)) diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index 45c482b03..03f208d98 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -213,20 +213,20 @@ object IRTransform { * add in modifies from the spec. */ def prepareForTranslation(config: ILLoadingConfig, ctx: IRContext): Unit = { - ctx.program.determineRelevantMemory(ctx.globalOffsets) + // ctx.program.determineRelevantMemory(ctx.globalOffsets) Logger.info("[!] Stripping unreachable") val before = ctx.program.procedures.size - // transforms.stripUnreachableFunctions(ctx.program, config.procedureTrimDepth) + //transforms.stripUnreachableFunctions(ctx.program, config.procedureTrimDepth) Logger.info( s"[!] Removed ${before - ctx.program.procedures.size} functions (${ctx.program.procedures.size} remaining)" ) - // val stackIdentification = StackSubstituter() - // stackIdentification.visitProgram(ctx.program) + //val stackIdentification = StackSubstituter() + //stackIdentification.visitProgram(ctx.program) val specModifies = ctx.specification.subroutines.map(s => s.name -> s.modifies).toMap - ctx.program.setModifies(specModifies) + // ctx.program.setModifies(specModifies) assert(invariant.singleCallBlockEnd(ctx.program)) } @@ -497,9 +497,8 @@ object RunUtils { def loadAndTranslate(q: BASILConfig): BASILResult = { Logger.info("[!] Loading Program") - val ctx = IRLoading.load(q.loading) - - IRTransform.doCleanup(ctx) + var ctx = IRLoading.load(q.loading) + ctx = IRTransform.doCleanup(ctx) q.loading.dumpIL.foreach(s => writeToFile(serialiseIL(ctx.program), s"$s-before-analysis.il")) val analysis = q.staticAnalysis.map(conf => staticAnalysis(conf, ctx)) diff --git a/src/test/scala/DifferentialAnalysis.scala b/src/test/scala/DifferentialAnalysis.scala index 85e0bfe74..b350e3f64 100644 --- a/src/test/scala/DifferentialAnalysis.scala +++ b/src/test/scala/DifferentialAnalysis.scala @@ -6,7 +6,7 @@ import ir.Endian.LittleEndian import org.scalatest.* import org.scalatest.funsuite.* import specification.* -import util.{BASILConfig, IRLoading, ILLoadingConfig, IRContext, RunUtils, StaticAnalysis, StaticAnalysisConfig, StaticAnalysisContext, BASILResult, Logger, LogLevel} +import util.{BASILConfig, IRLoading, ILLoadingConfig, IRContext, RunUtils, StaticAnalysis, StaticAnalysisConfig, StaticAnalysisContext, BASILResult, Logger, LogLevel, IRTransform} import ir.eval.{interpretTrace, interpret, ExecEffect, Stopped} @@ -20,7 +20,7 @@ import scala.collection.mutable class DifferentialAnalysis extends AnyFunSuite { - Logger.setLevel(LogLevel.DEBUG) + Logger.setLevel(LogLevel.ERROR) def diffTest(initial: IRContext, transformed: IRContext) = { val (initialRes,traceInit) = interpretTrace(initial) @@ -34,10 +34,11 @@ class DifferentialAnalysis extends AnyFunSuite { case e @ ExecEffect.LoadMem("mem", _) => e } } - println((traceInit.t).mkString("\n")) - assert(initialRes.nextCmd == Stopped()) assert(result.nextCmd == Stopped()) + assert(initialRes.nextCmd == Stopped()) assert(Set.empty == initialRes.memoryState.getMem("mem").toSet.diff(result.memoryState.getMem("mem").toSet)) + assert(traceInit.t.nonEmpty) + assert(traceRes.t.nonEmpty) assert(filterEvents(traceInit.t).mkString("\n") == filterEvents(traceRes.t).mkString("\n")) } @@ -61,9 +62,11 @@ class DifferentialAnalysis extends AnyFunSuite { ) - val program = loadAndTranslate(basilConfigNoAnalysis).ir + var ictx = IRLoading.load(basilConfigNoAnalysis.loading) + ictx = IRTransform.doCleanup(ictx) + val compare = loadAndTranslate(basilConfig).ir - diffTest(program, compare) + diffTest(ictx, compare) } From f5a05907d776477572daf4c34c8f5fda60828fe5 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 4 Sep 2024 11:41:01 +1000 Subject: [PATCH 37/62] cleanup --- src/main/scala/ir/eval/InterpretBasilIR.scala | 7 +- .../scala/ir/eval/InterpretBreakpoints.scala | 128 +++++++++++------- src/main/scala/ir/eval/InterpretTrace.scala | 13 +- src/main/scala/ir/eval/Interpreter.scala | 3 +- .../scala/ir/eval/InterpreterProduct.scala | 36 +++-- 5 files changed, 104 insertions(+), 83 deletions(-) diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index 1c8c5a0c9..38b78c689 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -288,13 +288,12 @@ class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S case assert: Assert => for { b <- Eval.evalBool(f)(assert.body) - n <- - (if (!b) then { - f.setNext(FailedAssertion(assert)) + _ <- (if (!b) then { + State.setError(FailedAssertion(assert)) } else { f.setNext(Run(s.successor)) }) - } yield (n) + } yield () case assume: Assume => for { b <- Eval.evalBool(f)(assume.body) diff --git a/src/main/scala/ir/eval/InterpretBreakpoints.scala b/src/main/scala/ir/eval/InterpretBreakpoints.scala index 38584add2..9eb0e7b79 100644 --- a/src/main/scala/ir/eval/InterpretBreakpoints.scala +++ b/src/main/scala/ir/eval/InterpretBreakpoints.scala @@ -13,70 +13,98 @@ import scala.collection.mutable import scala.collection.immutable import scala.util.control.Breaks.{break, breakable} - enum BreakPointLoc: case CMD(c: Command) case CMDCond(c: Command, condition: Expr) -case class BreakPointAction(saveState: Boolean = true, stop: Boolean = false, evalExprs: List[(String,Expr)] = List(), log: Boolean = false) +case class BreakPointAction( + saveState: Boolean = true, + stop: Boolean = false, + evalExprs: List[(String, Expr)] = List(), + log: Boolean = false +) case class BreakPoint(name: String = "", location: BreakPointLoc, action: BreakPointAction) -case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, val breaks: List[BreakPoint]) extends NopEffects[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), InterpreterError] { - +case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, val breaks: List[BreakPoint]) + extends NopEffects[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), InterpreterError] { - def findBreaks[R](c: Command) : State[(T,R), List[BreakPoint], InterpreterError] = { - State.filterM(b => b.location match { - case BreakPointLoc.CMD(bc) if (bc == c) => State.pure(true) - case BreakPointLoc.CMDCond(bc, e) if bc == c => doLeft(Eval.evalBool(f)(e)) - case _ => State.pure(false) - }, breaks) + def findBreaks[R](c: Command): State[(T, R), List[BreakPoint], InterpreterError] = { + State.filterM( + b => + b.location match { + case BreakPointLoc.CMD(bc) if (bc == c) => State.pure(true) + case BreakPointLoc.CMDCond(bc, e) if bc == c => doLeft(Eval.evalBool(f)(e)) + case _ => State.pure(false) + }, + breaks + ) } - override def getNext: State[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), ExecutionContinuation, InterpreterError] = for { - v : ExecutionContinuation <- doLeft(f.getNext) - n <- v match { - case Run(s) => for { - breaks : List[BreakPoint] <- findBreaks(s) - res <- State.sequence[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), Unit, InterpreterError](State.pure(()), - breaks.map((breakpoint: BreakPoint) => (breakpoint match { - case breakpoint @ BreakPoint(name, stopcond, action) => (for { - saved <- doLeft(if action.saveState then State.getS[T, InterpreterError].map(s => Some(s)) else State.pure(None)) - evals <- (State.mapM((e:(String, Expr)) => for { - ev <- doLeft(Eval.evalExpr(f)(e._2)) - } yield (e._1, e._2, ev) - , action.evalExprs)) - _ <- if action.stop then doLeft(State.setError(Errored(s"Stopped at breakpoint ${name}"))) else doLeft(State.pure(())) - _ <- State.pure({ - if (action.log) { - val bpn = breakpoint.name - val bpcond = breakpoint.location match { - case BreakPointLoc.CMD(c) => c.toString - case BreakPointLoc.CMDCond(c, e) => s"$c when $e" - } - val saving = if action.saveState then " stashing state, " else "" - val stopping = if action.stop then " stopping. " else "" - val evalstr = evals.map(e => s"\n ${e._1} : eval(${e._2}) = ${e._3}").mkString("") - Logger.warn(s"Breakpoint $bpn@$bpcond.$saving$stopping$evalstr") - } - }) - _ <- State.modify ((istate:(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])])) => - (istate._1, ((breakpoint, saved, evals)::istate._2))) - } yield () + override def getNext + : State[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), ExecutionContinuation, InterpreterError] = + for { + v: ExecutionContinuation <- doLeft(f.getNext) + n <- v match { + case Run(s) => + for { + breaks: List[BreakPoint] <- findBreaks(s) + res <- State + .sequence[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), Unit, InterpreterError]( + State.pure(()), + breaks.map((breakpoint: BreakPoint) => + (breakpoint match { + case breakpoint @ BreakPoint(name, stopcond, action) => ( + for { + saved <- doLeft( + if action.saveState then State.getS[T, InterpreterError].map(s => Some(s)) + else State.pure(None) + ) + evals <- (State.mapM( + (e: (String, Expr)) => + for { + ev <- doLeft(Eval.evalExpr(f)(e._2)) + } yield (e._1, e._2, ev), + action.evalExprs + )) + _ <- + if action.stop then doLeft(State.setError(Errored(s"Stopped at breakpoint ${name}"))) + else doLeft(State.pure(())) + _ <- State.pure({ + if (action.log) { + val bpn = breakpoint.name + val bpcond = breakpoint.location match { + case BreakPointLoc.CMD(c) => c.toString + case BreakPointLoc.CMDCond(c, e) => s"$c when $e" + } + val saving = if action.saveState then " stashing state, " else "" + val stopping = if action.stop then " stopping. " else "" + val evalstr = evals.map(e => s"\n ${e._1} : eval(${e._2}) = ${e._3}").mkString("") + Logger.warn(s"Breakpoint $bpn@$bpcond.$saving$stopping$evalstr") + } + }) + _ <- State.modify((istate: (T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])])) => + (istate._1, ((breakpoint, saved, evals) :: istate._2)) + ) + } yield () + ) + }) + ) ) - }))) - } yield () - case _ => State.pure(()) + } yield () + case _ => State.pure(()) } } yield (v) } - -def interpretWithBreakPoints[I](p: Program, breakpoints: List[BreakPoint], - innerInterpreter: Effects[I, InterpreterError], - innerInitialState: I) : (I, List[(BreakPoint, Option[I], List[(String, Expr, Expr)])]) = { - val interp = LayerInterpreter(innerInterpreter, RememberBreakpoints(innerInterpreter, breakpoints)) - val res = InterpFuns.interpretProg(interp)(p, (innerInitialState, List())) - res +def interpretWithBreakPoints[I]( + p: Program, + breakpoints: List[BreakPoint], + innerInterpreter: Effects[I, InterpreterError], + innerInitialState: I +): (I, List[(BreakPoint, Option[I], List[(String, Expr, Expr)])]) = { + val interp = LayerInterpreter(innerInterpreter, RememberBreakpoints(innerInterpreter, breakpoints)) + val res = InterpFuns.interpretProg(interp)(p, (innerInitialState, List())) + res } diff --git a/src/main/scala/ir/eval/InterpretTrace.scala b/src/main/scala/ir/eval/InterpretTrace.scala index 5c568561a..92af057b6 100644 --- a/src/main/scala/ir/eval/InterpretTrace.scala +++ b/src/main/scala/ir/eval/InterpretTrace.scala @@ -14,7 +14,6 @@ import scala.collection.mutable import scala.collection.immutable import scala.util.control.Breaks.{break, breakable} - enum ExecEffect: case Call(target: String, begin: ExecutionContinuation, returnTo: ExecutionContinuation) case Return @@ -27,19 +26,20 @@ enum ExecEffect: case class Trace(val t: List[ExecEffect]) case object Trace { - def add[E](e: ExecEffect) : State[Trace, Unit, E] = { - State.modify ((t: Trace) => Trace(t.t.appended(e))) + def add[E](e: ExecEffect): State[Trace, Unit, E] = { + State.modify((t: Trace) => Trace(t.t.appended(e))) } } case class TraceGen[E]() extends NopEffects[Trace, E] { + /** Values are discarded by ProductInterpreter so do not matter */ // override def loadVar(v: String) = for { // s <- Trace.add(ExecEffect.LoadVar(v)) // } yield (Scalar(FalseLiteral)) - override def loadMem(v: String, addrs: List[BasilValue]) = for { + override def loadMem(v: String, addrs: List[BasilValue]) = for { s <- Trace.add(ExecEffect.LoadMem(v, addrs)) } yield (List()) @@ -63,11 +63,10 @@ case class TraceGen[E]() extends NopEffects[Trace, E] { def tracingInterpreter = ProductInterpreter(NormalInterpreter, TraceGen()) -def interpretTrace(p: Program) : (InterpreterState, Trace) = { +def interpretTrace(p: Program): (InterpreterState, Trace) = { InterpFuns.interpretProg(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) } -def interpretTrace(p: IRContext) : (InterpreterState, Trace) = { +def interpretTrace(p: IRContext): (InterpreterState, Trace) = { InterpFuns.interpretProg(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) } - diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 21ebb047a..9d3246fcc 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -15,14 +15,13 @@ import scala.util.control.Breaks.{break, breakable} /** Interpreter status type, either stopped, run next command or error */ sealed trait ExecutionContinuation -case class FailedAssertion(a: Assert) extends ExecutionContinuation - case class Stopped() extends ExecutionContinuation /* normal program stop */ case class ErrorStop(error: InterpreterError) extends ExecutionContinuation /* normal program stop */ case class Run(val next: Command) extends ExecutionContinuation /* continue by executing next command */ case class CallIntrinsic(val name: String) extends ExecutionContinuation /* continue by executing next command */ sealed trait InterpreterError +case class FailedAssertion(a: Assert) extends InterpreterError case class EscapedControlFlow(val call: Jump | Call) extends InterpreterError /* controlflow has reached somewhere eunrecoverable */ case class Errored(val message: String = "") extends InterpreterError diff --git a/src/main/scala/ir/eval/InterpreterProduct.scala b/src/main/scala/ir/eval/InterpreterProduct.scala index 929c69425..ae8e51252 100644 --- a/src/main/scala/ir/eval/InterpreterProduct.scala +++ b/src/main/scala/ir/eval/InterpreterProduct.scala @@ -1,4 +1,3 @@ - package ir.eval import ir._ import ir.eval.BitVectorEval.* @@ -14,24 +13,22 @@ import scala.collection.mutable import scala.collection.immutable import scala.util.control.Breaks.{break, breakable} - -def doLeft[L, T, V, E](f: State[L, V, E]) : State[(L, T), V, E] = for { +def doLeft[L, T, V, E](f: State[L, V, E]): State[(L, T), V, E] = for { n <- State[(L, T), V, E]((s: (L, T)) => { val r = f.f(s._1) ((r._1, s._2), r._2) }) } yield (n) -def doRight[L, T, V, E](f: State[T, V, E]) : State[(L, T), V, E] = for { +def doRight[L, T, V, E](f: State[T, V, E]): State[(L, T), V, E] = for { n <- State[(L, T), V, E]((s: (L, T)) => { val r = f.f(s._2) ((s._1, r._1), r._2) }) } yield (n) -/** - * Runs two interpreters "inner" and "before" simultaneously, returning the value from inner, and ignoring before - */ +/** Runs two interpreters "inner" and "before" simultaneously, returning the value from inner, and ignoring before + */ case class ProductInterpreter[L, T, E](val inner: Effects[L, E], val before: Effects[T, E]) extends Effects[(L, T), E] { def loadVar(v: String) = for { @@ -47,12 +44,12 @@ case class ProductInterpreter[L, T, E](val inner: Effects[L, E], val before: Eff def evalAddrToProc(addr: Int) = for { n <- doRight(before.evalAddrToProc(addr: Int)) f <- doLeft(inner.evalAddrToProc(addr)) - } yield(f) + } yield (f) def getNext = for { n <- doRight(before.getNext) f <- doLeft(inner.getNext) - } yield(f) + } yield (f) /** state effects */ def setNext(c: ExecutionContinuation) = for { @@ -73,16 +70,16 @@ case class ProductInterpreter[L, T, E](val inner: Effects[L, E], val before: Eff def storeVar(v: String, scope: Scope, value: BasilValue) = for { n <- doRight(before.storeVar(v, scope, value)) f <- doLeft(inner.storeVar(v, scope, value)) - } yield(f) + } yield (f) def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = for { - n <- doRight(before.storeMem(vname,update)) + n <- doRight(before.storeMem(vname, update)) f <- doLeft(inner.storeMem(vname, update)) - } yield(f) + } yield (f) } - -case class LayerInterpreter[L, T, E](val inner: Effects[L, E], val before: Effects[(L, T), E]) extends Effects[(L, T), E] { +case class LayerInterpreter[L, T, E](val inner: Effects[L, E], val before: Effects[(L, T), E]) + extends Effects[(L, T), E] { def loadVar(v: String) = for { n <- (before.loadVar(v)) @@ -97,12 +94,12 @@ case class LayerInterpreter[L, T, E](val inner: Effects[L, E], val before: Effec def evalAddrToProc(addr: Int) = for { n <- (before.evalAddrToProc(addr: Int)) f <- doLeft(inner.evalAddrToProc(addr)) - } yield(f) + } yield (f) def getNext = for { n <- (before.getNext) f <- doLeft(inner.getNext) - } yield(f) + } yield (f) /** state effects */ def setNext(c: ExecutionContinuation) = for { @@ -123,11 +120,10 @@ case class LayerInterpreter[L, T, E](val inner: Effects[L, E], val before: Effec def storeVar(v: String, scope: Scope, value: BasilValue) = for { n <- (before.storeVar(v, scope, value)) f <- doLeft(inner.storeVar(v, scope, value)) - } yield(f) + } yield (f) def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = for { - n <- (before.storeMem(vname,update)) + n <- (before.storeMem(vname, update)) f <- doLeft(inner.storeMem(vname, update)) - } yield(f) + } yield (f) } - From a6b58e8f1a05d87224b8ee46b79fd11e055bc3d5 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 4 Sep 2024 12:16:30 +1000 Subject: [PATCH 38/62] cleanup --- .../scala/ir/eval/InterpretBreakpoints.scala | 17 +++++++++-------- src/main/scala/ir/eval/InterpretTrace.scala | 9 +++------ src/main/scala/util/RunUtils.scala | 2 +- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/main/scala/ir/eval/InterpretBreakpoints.scala b/src/main/scala/ir/eval/InterpretBreakpoints.scala index 9eb0e7b79..5fb49c3e2 100644 --- a/src/main/scala/ir/eval/InterpretBreakpoints.scala +++ b/src/main/scala/ir/eval/InterpretBreakpoints.scala @@ -42,8 +42,8 @@ case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, v } override def getNext - : State[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), ExecutionContinuation, InterpreterError] = - for { + : State[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), ExecutionContinuation, InterpreterError] = { + for { v: ExecutionContinuation <- doLeft(f.getNext) n <- v match { case Run(s) => @@ -67,22 +67,23 @@ case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, v } yield (e._1, e._2, ev), action.evalExprs )) - _ <- - if action.stop then doLeft(State.setError(Errored(s"Stopped at breakpoint ${name}"))) - else doLeft(State.pure(())) _ <- State.pure({ if (action.log) { val bpn = breakpoint.name val bpcond = breakpoint.location match { - case BreakPointLoc.CMD(c) => c.toString - case BreakPointLoc.CMDCond(c, e) => s"$c when $e" + case BreakPointLoc.CMD(c) => s"${c.parent.label}:$c" + case BreakPointLoc.CMDCond(c, e) => s"${c.parent.label}:$c when $e" } val saving = if action.saveState then " stashing state, " else "" val stopping = if action.stop then " stopping. " else "" val evalstr = evals.map(e => s"\n ${e._1} : eval(${e._2}) = ${e._3}").mkString("") Logger.warn(s"Breakpoint $bpn@$bpcond.$saving$stopping$evalstr") + //println(s"Breakpoint $bpn@$bpcond.$saving$stopping$evalstr") } }) + _ <- + if action.stop then doLeft(State.setError(Errored(s"Stopped at breakpoint ${name}"))) + else doLeft(State.pure(())) _ <- State.modify((istate: (T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])])) => (istate._1, ((breakpoint, saved, evals) :: istate._2)) ) @@ -95,7 +96,7 @@ case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, v case _ => State.pure(()) } } yield (v) - + } } def interpretWithBreakPoints[I]( diff --git a/src/main/scala/ir/eval/InterpretTrace.scala b/src/main/scala/ir/eval/InterpretTrace.scala index 92af057b6..5c1c963fd 100644 --- a/src/main/scala/ir/eval/InterpretTrace.scala +++ b/src/main/scala/ir/eval/InterpretTrace.scala @@ -34,11 +34,6 @@ case object Trace { case class TraceGen[E]() extends NopEffects[Trace, E] { /** Values are discarded by ProductInterpreter so do not matter */ - - // override def loadVar(v: String) = for { - // s <- Trace.add(ExecEffect.LoadVar(v)) - // } yield (Scalar(FalseLiteral)) - override def loadMem(v: String, addrs: List[BasilValue]) = for { s <- Trace.add(ExecEffect.LoadMem(v, addrs)) } yield (List()) @@ -68,5 +63,7 @@ def interpretTrace(p: Program): (InterpreterState, Trace) = { } def interpretTrace(p: IRContext): (InterpreterState, Trace) = { - InterpFuns.interpretProg(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) + val begin = InterpFuns.initProgState(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) + // throw away initialisation trace + BASILInterpreter(tracingInterpreter).run((begin._1, Trace(List()))) } diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index 03f208d98..d58d5b7cf 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -217,7 +217,7 @@ object IRTransform { Logger.info("[!] Stripping unreachable") val before = ctx.program.procedures.size - //transforms.stripUnreachableFunctions(ctx.program, config.procedureTrimDepth) + transforms.stripUnreachableFunctions(ctx.program, config.procedureTrimDepth) Logger.info( s"[!] Removed ${before - ctx.program.procedures.size} functions (${ctx.program.procedures.size} remaining)" ) From dd64c149ecb7e77e80d0ea98d153493718622a45 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 4 Sep 2024 14:23:09 +1000 Subject: [PATCH 39/62] constprop test with interpreter --- .../{UtilMethods.scala => ExprSSAEval.scala} | 0 src/main/scala/ir/dsl/DSL.scala | 9 ++ .../scala/ir/eval/InterpretBreakpoints.scala | 21 ++- src/test/scala/DifferentialAnalysis.scala | 31 ++-- src/test/scala/InterpretTestConstProp.scala | 137 ++++++++++++++++++ src/test/scala/ir/InterpreterTests.scala | 35 +---- 6 files changed, 177 insertions(+), 56 deletions(-) rename src/main/scala/analysis/{UtilMethods.scala => ExprSSAEval.scala} (100%) create mode 100644 src/test/scala/InterpretTestConstProp.scala diff --git a/src/main/scala/analysis/UtilMethods.scala b/src/main/scala/analysis/ExprSSAEval.scala similarity index 100% rename from src/main/scala/analysis/UtilMethods.scala rename to src/main/scala/analysis/ExprSSAEval.scala diff --git a/src/main/scala/ir/dsl/DSL.scala b/src/main/scala/ir/dsl/DSL.scala index 2d892a444..c55e48f88 100644 --- a/src/main/scala/ir/dsl/DSL.scala +++ b/src/main/scala/ir/dsl/DSL.scala @@ -17,6 +17,13 @@ val R30: Register = Register("R30", 64) val R31: Register = Register("R31", 64) +def exprEq(l: Expr, r: Expr) : Expr = (l, r) match { + case (l, r) if l.getType != r.getType => FalseLiteral + case (l, r) if l.getType == BoolType => BinaryExpr(BoolEQ, l, r) + case (l, r) if l.getType.isInstanceOf[BitVecType] => BinaryExpr(BVEQ, l, r) + case (l, r) if l.getType == IntType => BinaryExpr(IntEQ, l, r) + case _ => FalseLiteral +} def bv32(i: Int): BitVecLiteral = BitVecLiteral(i, 32) @@ -26,6 +33,8 @@ def bv8(i: Int): BitVecLiteral = BitVecLiteral(i, 8) def bv16(i: Int): BitVecLiteral = BitVecLiteral(i, 16) +def R(i: Int): Register = Register(s"R$i", 64) + case class DelayNameResolve(ident: String) { def resolveProc(prog: Program): Option[Procedure] = prog.collectFirst { case b: Procedure if b.name == ident => b diff --git a/src/main/scala/ir/eval/InterpretBreakpoints.scala b/src/main/scala/ir/eval/InterpretBreakpoints.scala index 5fb49c3e2..831c85c3d 100644 --- a/src/main/scala/ir/eval/InterpretBreakpoints.scala +++ b/src/main/scala/ir/eval/InterpretBreakpoints.scala @@ -3,6 +3,7 @@ import ir._ import ir.eval.BitVectorEval.* import ir.* import util.Logger +import util.IRContext import util.functional.* import util.functional.State.* import boogie.Scope @@ -41,9 +42,12 @@ case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, v ) } - override def getNext - : State[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), ExecutionContinuation, InterpreterError] = { - for { + override def getNext: State[ + (T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), + ExecutionContinuation, + InterpreterError + ] = { + for { v: ExecutionContinuation <- doLeft(f.getNext) n <- v match { case Run(s) => @@ -99,6 +103,17 @@ case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, v } } +def interpretWithBreakPoints[I]( + p: IRContext, + breakpoints: List[BreakPoint], + innerInterpreter: Effects[I, InterpreterError], + innerInitialState: I +): (I, List[(BreakPoint, Option[I], List[(String, Expr, Expr)])]) = { + val interp = LayerInterpreter(innerInterpreter, RememberBreakpoints(innerInterpreter, breakpoints)) + val res = InterpFuns.interpretProg(interp)(p, (innerInitialState, List())) + res +} + def interpretWithBreakPoints[I]( p: Program, breakpoints: List[BreakPoint], diff --git a/src/test/scala/DifferentialAnalysis.scala b/src/test/scala/DifferentialAnalysis.scala index b350e3f64..8c40fac56 100644 --- a/src/test/scala/DifferentialAnalysis.scala +++ b/src/test/scala/DifferentialAnalysis.scala @@ -34,8 +34,10 @@ class DifferentialAnalysis extends AnyFunSuite { case e @ ExecEffect.LoadMem("mem", _) => e } } - assert(result.nextCmd == Stopped()) + + Logger.info(traceInit.t.map(_.toString.take(80)).mkString("\n")) assert(initialRes.nextCmd == Stopped()) + assert(result.nextCmd == Stopped()) assert(Set.empty == initialRes.memoryState.getMem("mem").toSet.diff(result.memoryState.getMem("mem").toSet)) assert(traceInit.t.nonEmpty) assert(traceRes.t.nonEmpty) @@ -43,31 +45,20 @@ class DifferentialAnalysis extends AnyFunSuite { } def testProgram(testName: String, examplePath: String) = { - val basilConfig = BASILConfig( - loading = ILLoadingConfig(inputFile = examplePath + testName + ".adt", - relfFile = examplePath + testName + ".relf", - dumpIL = None, - ), - outputPrefix = "basil-test", - staticAnalysis = Some(StaticAnalysisConfig(None, None, None)), - ) - val basilConfigNoAnalysis = BASILConfig( - loading = ILLoadingConfig(inputFile = examplePath + testName + ".adt", - relfFile = examplePath + testName + ".relf", - dumpIL = None, - ), - outputPrefix = "basil-test", - staticAnalysis = None, + val loading = ILLoadingConfig(inputFile = examplePath + testName + ".adt", + relfFile = examplePath + testName + ".relf", + dumpIL = None, ) - - var ictx = IRLoading.load(basilConfigNoAnalysis.loading) + var ictx = IRLoading.load(loading) ictx = IRTransform.doCleanup(ictx) - val compare = loadAndTranslate(basilConfig).ir - diffTest(ictx, compare) + var comparectx = IRLoading.load(loading) + comparectx = IRTransform.doCleanup(ictx) + val analysisres = RunUtils.staticAnalysis(StaticAnalysisConfig(None, None, None), comparectx) + diffTest(ictx, comparectx) } test("indirect_call_example") { diff --git a/src/test/scala/InterpretTestConstProp.scala b/src/test/scala/InterpretTestConstProp.scala new file mode 100644 index 000000000..c1fc83e5c --- /dev/null +++ b/src/test/scala/InterpretTestConstProp.scala @@ -0,0 +1,137 @@ +import ir.* +import ir.eval.* +import analysis.* +import java.io.{BufferedWriter, File, FileWriter} +import ir.Endian.LittleEndian +import org.scalatest.* +import org.scalatest.funsuite.* +import specification.* +import util.{BASILConfig, IRLoading, ILLoadingConfig, IRContext, RunUtils, StaticAnalysis, StaticAnalysisConfig, StaticAnalysisContext, BASILResult, Logger, LogLevel, IRTransform} +import ir.eval.{interpretTrace, interpret, ExecEffect, Stopped} +import ir.dsl + + +import java.io.IOException +import java.nio.file.* +import java.nio.file.attribute.BasicFileAttributes +import ir.dsl.* +import util.RunUtils.loadAndTranslate + +import scala.collection.mutable + +class ConstPropInterpreterValidate extends AnyFunSuite { + + Logger.setLevel(LogLevel.ERROR) + + def testInterpretConstProp(testName: String, examplePath: String) = { + val loading = ILLoadingConfig(inputFile = examplePath + testName + ".adt", + relfFile = examplePath + testName + ".relf", + dumpIL = None, + ) + + var ictx = IRLoading.load(loading) + ictx = IRTransform.doCleanup(ictx) + val analysisres = RunUtils.staticAnalysis(StaticAnalysisConfig(None, None, None), ictx) + + val breaks : List[BreakPoint] = analysisres.constPropResult.collect { + // convert analysis result to a list of breakpoints, each which evaluates an expression describing + // the invariant inferred by the analysis (the assignment of registers) at a corresponding program point + + case (command: Command, v) => { + val expectedPredicates : List[(String, Expr)] = v.toList.map(r => { + val (variable, value) = r + val assertion = value match { + case Top => TrueLiteral + case Bottom => FalseLiteral /* unreachable */ + case FlatEl(value) => BinaryExpr(BVEQ, variable, value) + } + (variable.name, assertion) + }) + BreakPoint(location=BreakPointLoc.CMD(command), BreakPointAction(saveState=false,evalExprs=expectedPredicates)) + } + }.toList + + assert(breaks.nonEmpty) + + // run the interpreter evaluating the analysis result at each command with a breakpoint + val interpretResult = interpretWithBreakPoints(ictx, breaks.toList, NormalInterpreter, InterpreterState()) + val breakres : List[(BreakPoint, _, List[(String, Expr, Expr)])] = interpretResult._2 + assert(interpretResult._1.nextCmd == Stopped()) + assert(breakres.nonEmpty) + + // assert all the collected breakpoint watches have evaluated to true + for (b <- breakres) { + val (_, _, evaluatedexprs) = b + evaluatedexprs.forall(c => { + val (n, before, evaled) = c + evaled == TrueLiteral + }) + } + } + + test("indirect_call_example") { + val testName = "indirect_call" + val examplePath = System.getProperty("user.dir") + s"/examples/$testName/" + testInterpretConstProp(testName, examplePath) + } + + test("indirect_call_gcc_example") { + val testName = "indirect_call" + val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/gcc/" + testInterpretConstProp(testName, examplePath) + } + + test("indirect_call_clang_example") { + val testName = "indirect_call" + val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/clang/" + testInterpretConstProp(testName, examplePath) + } + + test("jumptable2_example") { + val testName = "jumptable2" + val examplePath = System.getProperty("user.dir") + s"/examples/$testName/" + testInterpretConstProp(testName, examplePath) + } + + test("jumptable2_gcc_example") { + val testName = "jumptable2" + val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/gcc/" + testInterpretConstProp(testName, examplePath) + } + + test("jumptable2_clang_example") { + val testName = "jumptable2" + val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/clang/" + testInterpretConstProp(testName, examplePath) + } + + test("functionpointer_example") { + val testName = "functionpointer" + val examplePath = System.getProperty("user.dir") + s"/examples/$testName/" + testInterpretConstProp(testName, examplePath) + } + + test("functionpointer_gcc_example") { + val testName = "functionpointer" + val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/gcc/" + testInterpretConstProp(testName, examplePath) + } + + test("functionpointer_clang_example") { + val testName = "functionpointer" + val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/clang/" + testInterpretConstProp(testName, examplePath) + } + + test("secret_write_clang") { + val testName = "secret_write" + val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/clang/" + testInterpretConstProp(testName, examplePath) + } + + test("secret_write_gcc") { + val testName = "secret_write" + val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/gcc/" + testInterpretConstProp(testName, examplePath) + } +} diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 2b47f6e21..934f17e04 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -12,26 +12,9 @@ import util.{LogLevel, Logger} import util.IRLoading.{loadBAP, loadReadELF} import util.{ILLoadingConfig, IRContext, IRLoading, IRTransform} -// def initialMem(): MemoryState = { -// val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) -// val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) -// val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) -// -// MemoryState() -// .setVar(globalFrame, "mem", MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) -// .setVar(globalFrame, "stack", MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) -// .setVar(globalFrame, "R31", Scalar(SP)) -// .setVar(globalFrame, "R29", Scalar(FP)) -// .setVar(globalFrame, "R30", Scalar(LR)) -// } - - -// def initialMem() = InterpFuns.initialState(InterpreterState(), List()) def load(s: InterpreterState, global: SpecGlobal) : Option[BitVecLiteral] = { val f = NormalInterpreter - // i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems) - // m.evalBV("mem", BitVecLiteral(64, global.address), Endian.LittleEndian, global.size) // i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems) try { State.evaluate(s, Eval.evalBV(f)(MemoryLoad(SharedMemory("mem", 64, 8), BitVecLiteral(global.address, 64), Endian.LittleEndian, global.size))) match { @@ -53,9 +36,7 @@ def mems[E, T <: Effects[T, E]](m: MemoryState) : Map[BigInt, BitVecLiteral] = { class InterpreterTests extends AnyFunSuite with BeforeAndAfter { - // var i: Interpreter = Interpreter() - // Logger.setLevel(LogLevel.DEBUG) - Logger.setLevel(LogLevel.ERROR) + Logger.setLevel(LogLevel.INFO) def getProgram(name: String): IRContext = { @@ -78,9 +59,6 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { // val stackIdentification = StackSubstituter() // stackIdentification.visitProgram(IRProgram) ctx.program.setModifies(Map()) - - - // (IRProgram, globals) ctx } @@ -330,13 +308,6 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { "y" -> ('b'.toInt), ) - // val (program, globals) = getProgram("initialisation") - - // val watch = IRWalk.firstInProc((program.mainProcedure)).get - // val globloads = globals.map(global => (global.name, MemoryLoad(SharedMemory("mem", 64, 8), BitVecLiteral(global.address, 64), Endian.LittleEndian, global.size))).toList - // val bp = BreakPoint("beginproc", BreakPointLoc.CMD(watch), BreakPointAction(false, false, globloads, true)) - // val res = interpretWithBreakPoints(program, List(bp), NormalInterpreter, InterpreterState()) - testInterpret("initialisation", expected) } @@ -457,7 +428,6 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { val r = interpretTrace(fib) assert(r._1.nextCmd == Stopped()) - info(r._2.t.mkString("\n")) // Show interpreted result // @@ -465,14 +435,13 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { test("fib breakpoints") { - Logger.setLevel(LogLevel.WARN) val fib = fibonacciProg(8) val watch = IRWalk.firstInProc((fib.procedures.find(_.name == "fib")).get).get val bp = BreakPoint("Fibentry", BreakPointLoc.CMDCond(watch, BinaryExpr(BVEQ, BitVecLiteral(5, 64), Register("R0", 64))), BreakPointAction(true, true, List(("R0", Register("R0", 64))), true)) + // val bp2 = BreakPoint("Fibentry", BreakPointLoc.CMD(watch), BreakPointAction(true, false, List(("R0", Register("R0", 64))), true)) // val interp = LayerInterpreter(NormalInterpreter, RememberBreakpoints(NormalInterpreter, List(bp))) // val res = InterpFuns.interpretProg(interp)(fib, (InterpreterState(), List())) val res = interpretWithBreakPoints(fib, List(bp), NormalInterpreter, InterpreterState()) - println(res) } From 8be5cb6492165ec51fffde6a33a12ee625e507bc Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 4 Sep 2024 15:23:16 +1000 Subject: [PATCH 40/62] interpreter docs --- docs/development/interpreter.md | 132 +++++++++++++++++++++++++++++ src/main/scala/util/RunUtils.scala | 10 ++- 2 files changed, 141 insertions(+), 1 deletion(-) create mode 100644 docs/development/interpreter.md diff --git a/docs/development/interpreter.md b/docs/development/interpreter.md new file mode 100644 index 000000000..dcb561a00 --- /dev/null +++ b/docs/development/interpreter.md @@ -0,0 +1,132 @@ +# BASIL IR Interpreter + +The interpreter is designed for testing, debugging, and validation of static analyses and code transforms. +This page describes first how it can be used for this purpose, and secondly its design. + +## Basic Usage + +The interpreter can be invoked from the command line, via the interpret flag, by default this prints a trace and checks that the interpreter +exited on a non-error stop state. + +```shell +./mill run -i src/test/correct/indirect_call/gcc/indirect_call.adt -r src/test/correct/indirect_call/gcc/indirect_call.relf --interpret +[INFO] Interpreter Trace: + StoreVar(#5,Local,0xfd0:bv64) + StoreMem(mem,HashMap(0xfd0:bv64 -> 0xf0:bv8, 0xfd6:bv64 -> 0x0:bv8, 0xfd2:bv64 -> 0x0:bv8, 0xfd3:bv64 -> 0x0:bv8, 0xfd4:bv64 -> 0x0:bv8, 0xfd7:bv64 -> 0x0:bv8, 0xfd5:bv64 -> 0x0:bv8, 0xfd1 +... +``` + +The `--verbose` flag can also be used, which may print interpreter trace events as they are executed, but not this may not correspond to the actual +execution trace, and contain additional events not corresponding to the program. +E.g. this shows the memory intialisation events that precede the program execution. This is mainly useful for debugging the interpreter. + +### Testing with Interpreter + +The interpreter is invoked with `interpret(p: IRContext)` to interpret normally and return an `InterpreterState` object +containing the final state. + +#### Traces + +There is also, `interpretTrace(p: IRContext)` which returns a tuple of `(InterpreterState, Trace(t: List[ExecEffect]))`, +where the second argument contains a list of all the events generated by the interpreter in order. +This is useful for asserting a stronger equivalence between program executions, but in most cases events describing "unobservable" +behaviour, such as register accesses should be filtered out from this list before comparison. + +To see an example of this used to validate the constant prop analysis see [/src/test/scala/DifferentialAnalysis.scala](../../src/test/scala/DifferentialAnalysis.scala). + +#### BreakPoints + +Finally `interpretWithBreakPoints(p: IRContext, List[BreakPoint], interpreter, initialState)` is used to +run an interpreter, and perform additional work at specific code points. For example, this may be invoked such as: + +```scala +val watch = IRWalk.firstInProc((program.procedures.find(_.name == "main")).get).get +val bp = BreakPoint("entrypoint", BreakPointLoc.CMD(watch), BreakPointAction(saveState=true, stop=true, evalExprs=List(("R0", Register("R0", 64))), log=true)) +val res = interpretWithBreakPoints(program, List(bp), NormalInterpreter, InterpreterState()) +``` + +The structure of a breakpoint is as follows: + +```scala +case class BreakPoint(name: String = "", location: BreakPointLoc, action: BreakPointAction) + +// the place to perform the breakpoint action +enum BreakPointLoc: + case CMD(c: Command) // at a command c + case CMDCond(c: Command, condition: Expr) // at a command c, when condition evaluates to TrueLiteral + +// describes what to do when the breakpoint is triggered +case class BreakPointAction( + saveState: Boolean = true, // stash the state of the interpreter + stop: Boolean = false, // stop the interpreter with an error state + evalExprs: List[(String, Expr)] = List(), // Evaluate the rhs of the list of expressions, and stash them (lhs is an arbitrary human-readable name) + log: Boolean = false // Print a log message about passing the breakpoint describing the results of this action +) +``` + +To see an example of this used to validate the constant prop analysis see [/src/test/scala/InterpretTestConstProp.scala](../../src/test/scala/InterpretTestConstProp.scala). + +## Implementation / Code Structure + +### Summary + +- [Bitvector.scala](../../src/main/scala/ir/eval/Bitvector.scala) + - Evaluation of bitvector operations, throws `IllegalArgumentException` on violation of contract + (e.g negative divisor, type mismatch) +- [ExprEval.scala] (../../src/main/scala/ir/eval/ExprEval.scala) + - Evaluation of expressions, defined in terms of partial evaluation down to a Literal +- [Interpreter.scala](../../src/main/scala/ir/eval/Interpreter.scala) + - Definition of core `Effects[S, E]` and `Interpreter[S, E]` types describing state transitions in + the interpreter + - Instantiation/definition of `Effects` for concrete state `InterpreterState` +- [InterpreterProduct.scala](../../src/main/scala/ir/eval/InterpreterProduct.scala) + - Definition of product and layering composition of generic `Effects[S, E]`s interpreters +- [InterpretBasilIR.scala](../../src/main/scala/ir/eval/InterpretBasilIR.scala) + - Definition of ELF initialisation, and the interpreter for BASIL IR, using a generic + `Effects` instance and concrete state. +- [InterpretBreakpoints.scala](../../src/main/scala/ir/eval/InterpretBreakpoints.scala) + - Definition of a generic interpreter with a breakpoint checker layered on top +- [InterpretTrace.scala](../../src/main/scala/ir/eval/InterpretTrace.scala) + - Definition of a generic interpreter which records a trace of calls to the `Effects[]` instance. + +### Explanation + +The interpreter is structured for compositionality, at its core is the `Effects[S, E]` type, defined in [Interpreter.scala](../../src/main/scala/ir/eval/Interpreter.scala). +This type defines a small set of functions which describe all the possible state transformations, over a concrete state `S`, and error type E (always `InterpreterError` in practice). + +This is implemented using the state Monad, `State[S,V,E]` where `S` is the state, `V` the value, and `E` the error type. +This is a flattened `State[S, Either[E]]`, defined in [util/functional.scala](../../src/main/scala/util/functional.scala). +`Effects` methods return delayed computations, functions from an input state (`S`) to a resulting state and a value (`(S, Either[E, V])`). +These are sequenced using `flatMap` (monad bind), or the `for{} yield()` syntax sugar for flatMap. + +This `Effects[S, E]` is instantiated for a given concrete state, the main example of which is `NormalInterpreter <: Effects[InterpreterState, InterpreterError]`, +also defined in `Interpreter.scala`. The memory component of the state is abstracted further into the `MemoryState` object. + +The actual execution of code is defined on top of this, in the `Interpreter[S, E]` type, which takes an instance of the `Effects` by parameter, +and defines both the small step (`interpretOne`) over on instruction, and the fixed point to termination from some in initial state in `run()`. +The fact that the stepping is defined outside the effects is important, as it allows concrete states, and state transitions over them to be +composed somewhat arbitrarily, and the interpretatation of the language compiled down to calls to resulting instance of `Effects`. + +This is defined in [InterpretBasilIR.scala](../../src/main/scala/ir/eval/InterpretBasilIR.scala). `BASILInterpreter` defines an +`Interpreter` over an arbitrary instance of `Effects[S, InterpreterError]`, encoding BASIL IR commands as effects. +This file also contains definitions of the initial memory state setup of the interpreter, based on the ELF sections and symbol table. + +### Composition of interpreters + +There are two ways to compose `Effects`, product and layer. Both produce an instance of `Effects[(L, R), E]`, +where `L` and `R` are the concrete state types of the two Effects being composed. + +Product runs the two effects, over two different concrete state types, simultaneously without interaction. + +Layer runs the `before` effect first, and passes its state to the `inner` effect whose value is returned. + +```scala +case class ProductInterpreter[L, T, E](val inner: Effects[L, E], val before: Effects[T, E]) extends Effects[(L, T), E] { +case class LayerInterpreter[L, T, E](val inner: Effects[L, E], val before: Effects[(L, T), E]) +``` + +Examples of using these are in the `interpretTrace` and `interpretWithBreakPoints` interpreters respectively. + + + + diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index d58d5b7cf..11bc4f016 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -5,6 +5,7 @@ import com.grammatech.gtirb.proto.IR.IR import com.grammatech.gtirb.proto.Module.Module import com.grammatech.gtirb.proto.Section.Section import spray.json.* +import ir.eval import gtirb.* import scala.collection.mutable.ListBuffer import scala.collection.mutable.ArrayBuffer @@ -506,7 +507,14 @@ object RunUtils { if (q.runInterpret) { // val interpreter = eval.Interpreter() - eval.interpret(ctx.program) + val fs = eval.interpretTrace(ctx) + Logger.info("Interpreter Trace:\n" + fs._2.t.mkString("\n")) + val stopState = fs._1.nextCmd + if (stopState != eval.Stopped()) { + Logger.error(s"Interpreter exited with $stopState") + } else { + Logger.info("Interpreter stopped normally.") + } } IRTransform.prepareForTranslation(q.loading, ctx) From d2a248638c5471bb5b565824f7c3d2263860f573 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 4 Sep 2024 15:38:22 +1000 Subject: [PATCH 41/62] fix list --- docs/development/interpreter.md | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/docs/development/interpreter.md b/docs/development/interpreter.md index dcb561a00..7c0897bf1 100644 --- a/docs/development/interpreter.md +++ b/docs/development/interpreter.md @@ -71,28 +71,28 @@ To see an example of this used to validate the constant prop analysis see [/src/ ### Summary - [Bitvector.scala](../../src/main/scala/ir/eval/Bitvector.scala) - - Evaluation of bitvector operations, throws `IllegalArgumentException` on violation of contract - (e.g negative divisor, type mismatch) + - Evaluation of bitvector operations, throws `IllegalArgumentException` on violation of contract + (e.g negative divisor, type mismatch) - [ExprEval.scala] (../../src/main/scala/ir/eval/ExprEval.scala) - - Evaluation of expressions, defined in terms of partial evaluation down to a Literal + - Evaluation of expressions, defined in terms of partial evaluation down to a Literal - [Interpreter.scala](../../src/main/scala/ir/eval/Interpreter.scala) - - Definition of core `Effects[S, E]` and `Interpreter[S, E]` types describing state transitions in - the interpreter - - Instantiation/definition of `Effects` for concrete state `InterpreterState` + - Definition of core `Effects[S, E]` and `Interpreter[S, E]` types describing state transitions in + the interpreter + - Instantiation/definition of `Effects` for concrete state `InterpreterState` - [InterpreterProduct.scala](../../src/main/scala/ir/eval/InterpreterProduct.scala) - - Definition of product and layering composition of generic `Effects[S, E]`s interpreters + - Definition of product and layering composition of generic `Effects[S, E]`s interpreters - [InterpretBasilIR.scala](../../src/main/scala/ir/eval/InterpretBasilIR.scala) - - Definition of ELF initialisation, and the interpreter for BASIL IR, using a generic + - Definition of ELF initialisation, and the interpreter for BASIL IR, using a generic `Effects` instance and concrete state. - [InterpretBreakpoints.scala](../../src/main/scala/ir/eval/InterpretBreakpoints.scala) - - Definition of a generic interpreter with a breakpoint checker layered on top + - Definition of a generic interpreter with a breakpoint checker layered on top - [InterpretTrace.scala](../../src/main/scala/ir/eval/InterpretTrace.scala) - - Definition of a generic interpreter which records a trace of calls to the `Effects[]` instance. + - Definition of a generic interpreter which records a trace of calls to the `Effects[]` instance. ### Explanation The interpreter is structured for compositionality, at its core is the `Effects[S, E]` type, defined in [Interpreter.scala](../../src/main/scala/ir/eval/Interpreter.scala). -This type defines a small set of functions which describe all the possible state transformations, over a concrete state `S`, and error type E (always `InterpreterError` in practice). +This type defines a small set of functions which describe all the possible state transformations, over a concrete state `S`, and error type `E` (always `InterpreterError` in practice). This is implemented using the state Monad, `State[S,V,E]` where `S` is the state, `V` the value, and `E` the error type. This is a flattened `State[S, Either[E]]`, defined in [util/functional.scala](../../src/main/scala/util/functional.scala). @@ -127,6 +127,12 @@ case class LayerInterpreter[L, T, E](val inner: Effects[L, E], val before: Effec Examples of using these are in the `interpretTrace` and `interpretWithBreakPoints` interpreters respectively. +Note, this only works by the aforementioned requirement that all effect calls come from outside the `Effects[]` +instance itself. In the simple case, the `Interpreter` instance is the only object calling `Effects`. +This means, `Effects` triggered by an inner `Effects[]` instance do not flow back to the `ProductInterpreter`, +but only appear from when `Interpreter` above the `ProductInterpreter` interprets the program via effect calls. +For this reason if, for example, `NormalInterpreter` makes effect calls they will not appear in a trace emitted by `interptretTrace`. + From d06e268fa57ca14d5618abdec71007e6f3eed189 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 4 Sep 2024 15:51:02 +1000 Subject: [PATCH 42/62] paragraph --- docs/development/interpreter.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/development/interpreter.md b/docs/development/interpreter.md index 7c0897bf1..a24738d41 100644 --- a/docs/development/interpreter.md +++ b/docs/development/interpreter.md @@ -73,7 +73,7 @@ To see an example of this used to validate the constant prop analysis see [/src/ - [Bitvector.scala](../../src/main/scala/ir/eval/Bitvector.scala) - Evaluation of bitvector operations, throws `IllegalArgumentException` on violation of contract (e.g negative divisor, type mismatch) -- [ExprEval.scala] (../../src/main/scala/ir/eval/ExprEval.scala) +- [ExprEval.scala](../../src/main/scala/ir/eval/ExprEval.scala) - Evaluation of expressions, defined in terms of partial evaluation down to a Literal - [Interpreter.scala](../../src/main/scala/ir/eval/Interpreter.scala) - Definition of core `Effects[S, E]` and `Interpreter[S, E]` types describing state transitions in @@ -82,8 +82,9 @@ To see an example of this used to validate the constant prop analysis see [/src/ - [InterpreterProduct.scala](../../src/main/scala/ir/eval/InterpreterProduct.scala) - Definition of product and layering composition of generic `Effects[S, E]`s interpreters - [InterpretBasilIR.scala](../../src/main/scala/ir/eval/InterpretBasilIR.scala) - - Definition of ELF initialisation, and the interpreter for BASIL IR, using a generic - `Effects` instance and concrete state. + - Definition of `Eval` object defining expression evaluation in terms of `Effects[S, InterpreterError]` + - Definition of `Interpreter` instance for BASIL IR, using a generic `Effects` instance and concrete state. + - Definition of ELF initialisation in terms of generic `Effects[S, InterpreterError]` - [InterpretBreakpoints.scala](../../src/main/scala/ir/eval/InterpretBreakpoints.scala) - Definition of a generic interpreter with a breakpoint checker layered on top - [InterpretTrace.scala](../../src/main/scala/ir/eval/InterpretTrace.scala) From 4ae39976ebdec854851dca6264e190c9b546cca7 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 4 Sep 2024 16:33:02 +1000 Subject: [PATCH 43/62] trap eval exceptions to monadic errors --- src/main/scala/ir/eval/ExprEval.scala | 13 +- src/main/scala/ir/eval/InterpretBasilIR.scala | 2 +- src/main/scala/util/functional.scala | 16 ++ src/test/scala/ir/InterpreterTests.scala | 218 +++++++++--------- 4 files changed, 133 insertions(+), 116 deletions(-) diff --git a/src/main/scala/ir/eval/ExprEval.scala b/src/main/scala/ir/eval/ExprEval.scala index 65218caf2..53e79d11b 100644 --- a/src/main/scala/ir/eval/ExprEval.scala +++ b/src/main/scala/ir/eval/ExprEval.scala @@ -146,9 +146,9 @@ trait Loader[S, E] { } } -def statePartialEvalExpr[S, E](l: Loader[S, E])(exp: Expr): State[S, Expr, E] = { +def statePartialEvalExpr[S](l: Loader[S, InterpreterError])(exp: Expr): State[S, Expr, InterpreterError] = { val eval = statePartialEvalExpr(l) - exp match { + val ns = exp match { case f: UninterpretedFunction => State.pure(f) case unOp: UnaryExpr => for { @@ -232,6 +232,13 @@ def statePartialEvalExpr[S, E](l: Loader[S, E])(exp: Expr): State[S, Expr, E] = case b: IntLiteral => State.pure(b) case b: BoolLit => State.pure(b) } + State.protect( + () => ns, + { case e => + Errored(e.toString) + }: PartialFunction[Exception, InterpreterError] + ) + } class StatelessLoader[E]( @@ -248,7 +255,7 @@ def partialEvalExpr( variableAssignment: Variable => Option[Literal], memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a, b, c, d) => None) ): Expr = { - val l = StatelessLoader(variableAssignment, memory) + val l = StatelessLoader[InterpreterError](variableAssignment, memory) State.evaluate((), statePartialEvalExpr(l)(exp)) match { case Right(e) => e case Left(e) => throw Exception("Unable to evaluate expr : " + e.toString) diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index 38b78c689..7ffd4875a 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -66,7 +66,7 @@ case object Eval { def evalExpr[S, T <: Effects[S, InterpreterError]](f: T)(e: Expr): State[S, Expr, InterpreterError] = { val ldr = StVarLoader[S, T](f) for { - res <- ir.eval.statePartialEvalExpr[S, InterpreterError](ldr)(e) + res <- ir.eval.statePartialEvalExpr[S](ldr)(e) } yield (res) } diff --git a/src/main/scala/util/functional.scala b/src/main/scala/util/functional.scala index a9c9c73d1..4db438646 100644 --- a/src/main/scala/util/functional.scala +++ b/src/main/scala/util/functional.scala @@ -78,6 +78,22 @@ object State { xs.foldRight(pure(List[B]()))((b,acc) => acc.flatMap(c => m(b).map(v => v::c))) } + def protect[S, V, E](f : () => State[S, V, E], fnly: PartialFunction[Exception, E]) : State[S, V, E] = { + State((s: S) => try { + f().f(s) + } catch { + case e: Exception if fnly.isDefinedAt(e) => (s, Left(fnly(e))) + }) + } + + def protectPure[S,V,E](f : () => V, fnly : PartialFunction[Exception, E]) : State[S, V, E] = { + State((s: S) => try { + (s, Right(f())) + } catch { + case e: Exception if fnly.isDefinedAt(e) => (s, Left(fnly(e))) + }) + } + } def protect[T](x: () => T, fnly: PartialFunction[Exception, T]): T = { diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 934f17e04..89ac86984 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -10,35 +10,32 @@ import specification.SpecGlobal import translating.BAPToIR import util.{LogLevel, Logger} import util.IRLoading.{loadBAP, loadReadELF} -import util.{ILLoadingConfig, IRContext, IRLoading, IRTransform} +import util.{ILLoadingConfig, IRContext, IRLoading, IRTransform} - -def load(s: InterpreterState, global: SpecGlobal) : Option[BitVecLiteral] = { +def load(s: InterpreterState, global: SpecGlobal): Option[BitVecLiteral] = { val f = NormalInterpreter - - try { - State.evaluate(s, Eval.evalBV(f)(MemoryLoad(SharedMemory("mem", 64, 8), BitVecLiteral(global.address, 64), Endian.LittleEndian, global.size))) match { - case Right(e) => Some(e) - case Left(e) => { - None - } + State.evaluate( + s, + Eval.evalBV(f)( + MemoryLoad(SharedMemory("mem", 64, 8), BitVecLiteral(global.address, 64), Endian.LittleEndian, global.size) + ) + ) match { + case Right(e) => Some(e) + case Left(e) => { + None } - } catch { - case e : InterpreterError => None } } - -def mems[E, T <: Effects[T, E]](m: MemoryState) : Map[BigInt, BitVecLiteral] = { - m.getMem("mem").map((k,v) => k.value -> v) +def mems[E, T <: Effects[T, E]](m: MemoryState): Map[BigInt, BitVecLiteral] = { + m.getMem("mem").map((k, v) => k.value -> v) } class InterpreterTests extends AnyFunSuite with BeforeAndAfter { Logger.setLevel(LogLevel.INFO) - def getProgram(name: String): IRContext = { val loading = ILLoadingConfig( inputFile = s"examples/$name/$name.adt", @@ -82,10 +79,8 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { } // Test expected value - val actual : Map[String, Int] = expected.flatMap ( (name, expected) => - globals.find(_.name == name).flatMap(global => - load(fstate, global).map(gv => name -> gv.value.toInt) - ) + val actual: Map[String, Int] = expected.flatMap((name, expected) => + globals.find(_.name == name).flatMap(global => load(fstate, global).map(gv => name -> gv.value.toInt)) ) assert(fstate.nextCmd == Stopped()) assert(expected == actual) @@ -101,63 +96,62 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { assert(s.memoryState.getVarOpt("R31").isDefined) assert(s.memoryState.getVarOpt("R29").isDefined) - } test("var load store") { - val s = for { - s <- InterpFuns.initialState(NormalInterpreter) - v <- NormalInterpreter.loadVar("R31") - } yield (v) - val l = State.evaluate(InterpreterState(), s) + val s = for { + s <- InterpFuns.initialState(NormalInterpreter) + v <- NormalInterpreter.loadVar("R31") + } yield (v) + val l = State.evaluate(InterpreterState(), s) - assert(l == Right(Scalar(BitVecLiteral(4096 - 16, 64)))) + assert(l == Right(Scalar(BitVecLiteral(4096 - 16, 64)))) + + } + + test("Store = Load LittleEndian") { + val ts = List( + BitVecLiteral(BigInt("0D", 16), 8), + BitVecLiteral(BigInt("0C", 16), 8), + BitVecLiteral(BigInt("0B", 16), 8), + BitVecLiteral(BigInt("0A", 16), 8) + ) + + val loader = StVarLoader(NormalInterpreter) + + val s = for { + _ <- InterpFuns.initialState(NormalInterpreter) + _ <- Eval.store(NormalInterpreter)("mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) + r <- Eval.loadBV(NormalInterpreter)("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) + } yield (r) + val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) + val actual = State.evaluate(InterpreterState(), s) + assert(actual == Right(expected)) } - test("Store = Load LittleEndian") { - val ts = List( - BitVecLiteral(BigInt("0D", 16), 8), - BitVecLiteral(BigInt("0C", 16), 8), - BitVecLiteral(BigInt("0B", 16), 8), - BitVecLiteral(BigInt("0A", 16), 8)) - - val loader = StVarLoader(NormalInterpreter) - - val s = for { - _ <- InterpFuns.initialState(NormalInterpreter) - _ <- Eval.store(NormalInterpreter)("mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) - r <- Eval.loadBV(NormalInterpreter)("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) - } yield(r) - val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) - val actual = State.evaluate(InterpreterState(), s) - assert(actual == Right(expected)) - - - } - // test("store bv = loadbv le") { // val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) // val s2 = Eval.storeBV(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) // val actual2: BitVecLiteral = Eval.loadBV(s2, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) // assert(actual2 == expected) // } -// -// +// +// // test("Store = Load BigEndian") { // val ts = List( // BitVecLiteral(BigInt("0D", 16), 8), // BitVecLiteral(BigInt("0C", 16), 8), // BitVecLiteral(BigInt("0B", 16), 8), // BitVecLiteral(BigInt("0A", 16), 8)) -// +// // val s = Eval.store(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) // val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) // val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.BigEndian , 32) // assert(actual == expected) -// -// +// +// // } -// +// // test("getMemory in LittleEndian") { // val ts = List((BitVecLiteral(0, 64), BitVecLiteral(BigInt("0D", 16), 8)), // (BitVecLiteral(1, 64) , BitVecLiteral(BigInt("0C", 16), 8)), @@ -166,24 +160,24 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { // val s = ts.foldLeft(initialMem())((m, v) => Eval.storeSingle(m, "mem", Scalar(v._1), Scalar(v._2))) // // val s = initialMem().store("mem") // // val r = s.loadBV("mem", BitVecLiteral(0, 64)) -// +// // val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) -// +// // // def loadBV(vname: String, addr: Scalar, endian: Endian, size: Int): BitVecLiteral = { // val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) // assert(actual == expected) // } -// -// +// +// // test("StoreBV = LoadBV LE ") { // val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) -// +// // val s = Eval.storeBV(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) // val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) // println(s"${actual.value.toInt.toHexString} == ${expected.value.toInt.toHexString}") // assert(actual == expected) // } -// +// // // test("getMemory in BigEndian") { // // i.mems(0) = BitVecLiteral(BigInt("0A", 16), 8) // // i.mems(1) = BitVecLiteral(BigInt("0B", 16), 8) @@ -193,7 +187,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { // // val actual: BitVecLiteral = i.getMemory(0, 32, Endian.BigEndian, i.mems) // // assert(actual == expected) // // } -// +// // // test("setMemory in LittleEndian") { // // i.mems(0) = BitVecLiteral(BigInt("FF", 16), 8) // // i.mems(1) = BitVecLiteral(BigInt("FF", 16), 8) @@ -204,7 +198,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { // // val actual: BitVecLiteral = i.getMemory(0, 32, Endian.LittleEndian, i.mems) // // assert(actual == expected) // // } -// +// // // test("setMemory in BigEndian") { // // i.mems(0) = BitVecLiteral(BigInt("FF", 16), 8) // // i.mems(1) = BitVecLiteral(BigInt("FF", 16), 8) @@ -215,7 +209,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { // // val actual: BitVecLiteral = i.getMemory(0, 32, Endian.BigEndian, i.mems) // // assert(actual == expected) // // } -// +// test("basic_arrays_read") { val expected = Map( "arr" -> 0 @@ -283,7 +277,6 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { testInterpret("indirect_call_outparam", expected) } - test("ifglobal") { val expected = Map( "x" -> 1 @@ -299,16 +292,14 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { testInterpret("cjump", expected) } - test("initialisation") { // Logger.setLevel(LogLevel.WARN) val expected = Map( "x" -> 6, - "y" -> ('b'.toInt), + "y" -> ('b'.toInt) ) - testInterpret("initialisation", expected) } @@ -326,8 +317,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { testInterpret("no_interference_update_y", expected) } - - def fib(n: Int) : Int = { + def fib(n: Int): Int = { n match { case 0 => 0 case 1 => 1 @@ -337,37 +327,29 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { def fibonacciProg(n: Int) = { prog( - proc("begin", - block("entry", - Assign(R8, Register("R31", 64)), - Assign(R0, bv64(n)), - directCall("fib"), - goto("done") - ), - block("done", - Assert(BinaryExpr(BVEQ, R0, bv64(fib(n)))), - ret - )), - proc("fib", + proc( + "begin", + block("entry", Assign(R8, Register("R31", 64)), Assign(R0, bv64(n)), directCall("fib"), goto("done")), + block("done", Assert(BinaryExpr(BVEQ, R0, bv64(fib(n)))), ret) + ), + proc( + "fib", block("base", goto("base1", "base2", "dofib")), - block("base1", - Assume(BinaryExpr(BVEQ, R0, bv64(0))), - ret), - block("base2", - Assume(BinaryExpr(BVEQ, R0, bv64(1))), - ret), - block("dofib", + block("base1", Assume(BinaryExpr(BVEQ, R0, bv64(0))), ret), + block("base2", Assume(BinaryExpr(BVEQ, R0, bv64(1))), ret), + block( + "dofib", Assume(BinaryExpr(BoolAND, BinaryExpr(BVNEQ, R0, bv64(0)), BinaryExpr(BVNEQ, R0, bv64(1)))), // R8 stack pointer preserved across calls - Assign(R7, BinaryExpr(BVADD, R8, bv64(8))), + Assign(R7, BinaryExpr(BVADD, R8, bv64(8))), MemoryAssign(stack, R7, R8, Endian.LittleEndian, 64), // sp Assign(R8, R7), - Assign(R8, BinaryExpr(BVADD, R8, bv64(8))), // sp + 8 + Assign(R8, BinaryExpr(BVADD, R8, bv64(8))), // sp + 8 MemoryAssign(stack, R8, R0, Endian.LittleEndian, 64), // [sp + 8] = arg0 Assign(R0, BinaryExpr(BVSUB, R0, bv64(1))), directCall("fib"), Assign(R2, R8), // sp + 8 - Assign(R8, BinaryExpr(BVADD, R8, bv64(8))), // sp + 16 + Assign(R8, BinaryExpr(BVADD, R8, bv64(8))), // sp + 16 MemoryAssign(stack, R8, R0, Endian.LittleEndian, 64), // [sp + 16] = r1 Assign(R0, MemoryLoad(stack, R2, Endian.LittleEndian, 64)), // [sp + 8] Assign(R0, BinaryExpr(BVSUB, R0, bv64(2))), @@ -376,9 +358,9 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { Assign(R0, BinaryExpr(BVADD, R0, R2)), Assign(R8, MemoryLoad(stack, BinaryExpr(BVSUB, R8, bv64(16)), Endian.LittleEndian, 64)), ret - ) ) ) + ) } test("fibonacci") { @@ -410,40 +392,52 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { val ir = interpret(prog) val it = intt.elapsed() - res = (i,native,it)::res + res = (i, native, it) :: res println(s"${res.head}") } - println(("fib number,native time,interp time"::(res.map(x => s"${x._1},${x._2},${x._3}"))).mkString("\n")) + println(("fib number,native time,interp time" :: (res.map(x => s"${x._1},${x._2},${x._3}"))).mkString("\n")) } + test("fibonacci Trace") { + + val fib = fibonacciProg(8) + + val r = interpretTrace(fib) + assert(r._1.nextCmd == Stopped()) + // Show interpreted result + // - test("fibonacci Trace") { + } - val fib = fibonacciProg(8) + test("fib breakpoints") { - val r = interpretTrace(fib) - - assert(r._1.nextCmd == Stopped()) - // Show interpreted result - // - - } + val fib = fibonacciProg(8) + val watch = IRWalk.firstInProc((fib.procedures.find(_.name == "fib")).get).get + val bp = BreakPoint( + "Fibentry", + BreakPointLoc.CMDCond(watch, BinaryExpr(BVEQ, BitVecLiteral(5, 64), Register("R0", 64))), + BreakPointAction(true, true, List(("R0", Register("R0", 64))), true) + ) + // val bp2 = BreakPoint("Fibentry", BreakPointLoc.CMD(watch), BreakPointAction(true, false, List(("R0", Register("R0", 64))), true)) + // val interp = LayerInterpreter(NormalInterpreter, RememberBreakpoints(NormalInterpreter, List(bp))) + // val res = InterpFuns.interpretProg(interp)(fib, (InterpreterState(), List())) + val res = interpretWithBreakPoints(fib, List(bp), NormalInterpreter, InterpreterState()) + } - test("fib breakpoints") { + test("Capture IllegalArg") { - val fib = fibonacciProg(8) - val watch = IRWalk.firstInProc((fib.procedures.find(_.name == "fib")).get).get - val bp = BreakPoint("Fibentry", BreakPointLoc.CMDCond(watch, BinaryExpr(BVEQ, BitVecLiteral(5, 64), Register("R0", 64))), BreakPointAction(true, true, List(("R0", Register("R0", 64))), true)) - // val bp2 = BreakPoint("Fibentry", BreakPointLoc.CMD(watch), BreakPointAction(true, false, List(("R0", Register("R0", 64))), true)) - // val interp = LayerInterpreter(NormalInterpreter, RememberBreakpoints(NormalInterpreter, List(bp))) - // val res = InterpFuns.interpretProg(interp)(fib, (InterpreterState(), List())) - val res = interpretWithBreakPoints(fib, List(bp), NormalInterpreter, InterpreterState()) + val tp = prog( + proc("begin", block("shouldfail", Assign(R0, ZeroExtend(-1, BitVecLiteral(0, 64))), ret)) + ) + val ir = interpret(tp) + println(ir) + assert(ir.nextCmd.isInstanceOf[ErrorStop]) - } + } } From 4fc8ed2720b9d6aae9d2e06c957af201d63951d1 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 4 Sep 2024 17:44:50 +1000 Subject: [PATCH 44/62] improve interpretOne --- src/main/scala/ir/eval/InterpretBasilIR.scala | 77 ++++++++++++------- .../scala/ir/eval/InterpretBreakpoints.scala | 3 +- src/main/scala/ir/eval/Interpreter.scala | 20 ++--- src/test/scala/ir/InterpreterTests.scala | 7 +- 4 files changed, 64 insertions(+), 43 deletions(-) diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index 7ffd4875a..b483431a0 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -226,19 +226,22 @@ case object Eval { } } - class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S, InterpreterError](f) { - def interpretOne: State[S, Unit, InterpreterError] = for { + def interpretOne: State[S, Boolean, InterpreterError] = for { next <- f.getNext - _ <- (next match { - case CallIntrinsic(tgt) => LibcIntrinsic.intrinsics(tgt)(f) - case Run(c: Statement) => interpretStatement(f)(c) - case Run(c: Jump) => interpretJump(f)(c) - case Stopped() => State.pure(()) - case ErrorStop(e) => State.pure(()) - }).flatMapE((e: InterpreterError) => f.setNext(ErrorStop(e))) - } yield () + _ <- State.pure(Logger.debug(s"$next")) + r: Boolean <- (next match { + case CallIntrinsic(tgt) => LibcIntrinsic.intrinsics(tgt)(f).map(_ => true) + case Run(c: Statement) => interpretStatement(f)(c).map(_ => true) + case Run(c: Jump) => interpretJump(f)(c).map(_ => true) + case Stopped() => State.pure(false) + case ErrorStop(e) => State.pure(false) + }) + .flatMapE((e: InterpreterError) => { + f.setNext(ErrorStop(e)).map(_ => false) + }) + } yield (r) def interpretJump[S, T <: Effects[S, InterpreterError]](f: T)(j: Jump): State[S, Unit, InterpreterError] = { j match { @@ -288,7 +291,8 @@ class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S case assert: Assert => for { b <- Eval.evalBool(f)(assert.body) - _ <- (if (!b) then { + _ <- + (if (!b) then { State.setError(FailedAssertion(assert)) } else { f.setNext(Run(s.successor)) @@ -346,7 +350,7 @@ object InterpFuns { var done = false var x = List[(String, FunPointer)]() - def newAddr() : BigInt = { + def newAddr(): BigInt = { addr += 8 addr } @@ -360,10 +364,27 @@ object InterpFuns { val procs = p.procedures.filter(proc => proc.address.isDefined) - val fptrs = ctx.externalFunctions.toList.sortBy(_.name).flatMap(f => { - intrinsics.get(f.name).map(fp => (f.offset, fp)) - .orElse(procs.find(p => p.name == f.name).map(proc => (f.offset, FunPointer(BitVecLiteral(proc.address.getOrElse(newAddr().toInt), 64), proc.name, Run(DirectCall(proc)))))) - }) + val fptrs = ctx.externalFunctions.toList + .sortBy(_.name) + .flatMap(f => { + intrinsics + .get(f.name) + .map(fp => (f.offset, fp)) + .orElse( + procs + .find(p => p.name == f.name) + .map(proc => + ( + f.offset, + FunPointer( + BitVecLiteral(proc.address.getOrElse(newAddr().toInt), 64), + proc.name, + Run(DirectCall(proc)) + ) + ) + ) + ) + }) // sort for deterministic trace val stores = fptrs @@ -371,12 +392,12 @@ object InterpFuns { .map((p) => { val (offset, fptr) = p Eval.storeSingle(s)("ghost-funtable", Scalar(fptr.addr), fptr) - >> (Eval.storeBV(s)( - "mem", - Scalar(BitVecLiteral(offset, 64)), - fptr.addr, - Endian.LittleEndian - )) + >> (Eval.storeBV(s)( + "mem", + Scalar(BitVecLiteral(offset, 64)), + fptr.addr, + Endian.LittleEndian + )) }) State.sequence(State.pure(()), stores) @@ -388,11 +409,12 @@ object InterpFuns { */ def initialState[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = { - val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) - val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) - val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) + val SP: BitVecLiteral = BitVecLiteral(0x78000000, 64) + val FP: BitVecLiteral = SP + val LR: BitVecLiteral = BitVecLiteral(BigInt("78000000", 16), 64) for { + h <- State.pure(Logger.debug("DEFINE MEMORY REGIONS")) h <- s.storeVar("ghost-funtable", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(64)))) h <- s.storeVar("mem", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) i <- s.storeVar("stack", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) @@ -410,7 +432,7 @@ object InterpFuns { m <- State.sequence( State.pure(()), mems - .filter(m => m.address != 0) + .filter(m => m.address != 0 && m.bytes.size != 0) .map(memory => Eval.store(f)( mem, @@ -426,7 +448,7 @@ object InterpFuns { for { d <- initialState(f) funs <- State.sequence( - State.pure(()), + State.pure(Logger.debug("INITIALISE FUNCTION ADDRESSES")), p.procedures .filter(p => p.blocks.nonEmpty && p.address.isDefined) .map((proc: Procedure) => @@ -437,6 +459,7 @@ object InterpFuns { ) ) ) + _ <- State.pure(Logger.debug("INITIALISE MEMORY SECTIONS")) mem <- initMemory("mem", p.initialMemory) mem <- initMemory("stack", p.initialMemory) mem <- initMemory("mem", p.readOnlyMemory) diff --git a/src/main/scala/ir/eval/InterpretBreakpoints.scala b/src/main/scala/ir/eval/InterpretBreakpoints.scala index 831c85c3d..e64d65506 100644 --- a/src/main/scala/ir/eval/InterpretBreakpoints.scala +++ b/src/main/scala/ir/eval/InterpretBreakpoints.scala @@ -82,11 +82,10 @@ case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, v val stopping = if action.stop then " stopping. " else "" val evalstr = evals.map(e => s"\n ${e._1} : eval(${e._2}) = ${e._3}").mkString("") Logger.warn(s"Breakpoint $bpn@$bpcond.$saving$stopping$evalstr") - //println(s"Breakpoint $bpn@$bpcond.$saving$stopping$evalstr") } }) _ <- - if action.stop then doLeft(State.setError(Errored(s"Stopped at breakpoint ${name}"))) + if action.stop then doLeft(f.setNext(ErrorStop(Errored(s"Stopped at breakpoint ${name}")))) else doLeft(State.pure(())) _ <- State.modify((istate: (T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])])) => (istate._1, ((breakpoint, saved, evals) :: istate._2)) diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 9d3246fcc..70ba4d8ee 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -255,6 +255,8 @@ case class MemoryState( /** typecheck and some fields of a map variable */ def doStore(vname: String, values: Map[BasilValue, BasilValue]): Either[InterpreterError, MemoryState] = for { // val (frame, mem) = findVar(vname) + + _ <- if (values.size == 0) then Left(MemoryError("Tried to store size 0")) else Right(()) v <- findVar(vname) (frame, mem) = v // val (mapval, keytype, valtype) = @@ -356,7 +358,7 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { rs match { case Some(_, l) => { val vs = Scalar(l.foldLeft(BitVecLiteral(0, 0))((acc, r) => eval.evalBVBinExpr(BVCONCAT, acc, r))).toString - s"$varname[${ks.head._1}] := $vs" + s"$varname[${ks.headOption.map(_._1).getOrElse("null")}] := $vs" } case None if ks.length < 8 => s"$varname[${ks.map(_._1).mkString(",")}] := ${ks.map(_._2).mkString(",")}" case None => s"$varname[${ks.map(_._1).take(8).mkString(",")}...] := ${ks.map(_._2).take(8).mkString(", ")}... " @@ -418,20 +420,14 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { trait Interpreter[S, E](val f: Effects[S, E]) { - def interpretOne: State[S, Unit, E] + /* + * Returns value deciding whether to continue. + */ + def interpretOne: State[S, Boolean, E] @tailrec final def run(begin: S): S = { - val c = for { - _ <- interpretOne - x <- f.getNext - continue = x match { - case Stopped() | ErrorStop(_) => false - case _ => true - } - } yield (continue) - - val (fs,cont) = c.f(begin) + val (fs,cont) = interpretOne.f(begin) if (cont.contains(true)) then { run(fs) diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 89ac86984..bd12b5818 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -34,7 +34,7 @@ def mems[E, T <: Effects[T, E]](m: MemoryState): Map[BigInt, BitVecLiteral] = { class InterpreterTests extends AnyFunSuite with BeforeAndAfter { - Logger.setLevel(LogLevel.INFO) + Logger.setLevel(LogLevel.WARN) def getProgram(name: String): IRContext = { val loading = ILLoadingConfig( @@ -415,6 +415,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { test("fib breakpoints") { + Logger.setLevel(LogLevel.INFO) val fib = fibonacciProg(8) val watch = IRWalk.firstInProc((fib.procedures.find(_.name == "fib")).get).get val bp = BreakPoint( @@ -422,10 +423,12 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { BreakPointLoc.CMDCond(watch, BinaryExpr(BVEQ, BitVecLiteral(5, 64), Register("R0", 64))), BreakPointAction(true, true, List(("R0", Register("R0", 64))), true) ) - // val bp2 = BreakPoint("Fibentry", BreakPointLoc.CMD(watch), BreakPointAction(true, false, List(("R0", Register("R0", 64))), true)) + val bp2 = BreakPoint("Fibentry", BreakPointLoc.CMD(watch), BreakPointAction(true, true , List(("R0", Register("R0", 64))), true)) // val interp = LayerInterpreter(NormalInterpreter, RememberBreakpoints(NormalInterpreter, List(bp))) // val res = InterpFuns.interpretProg(interp)(fib, (InterpreterState(), List())) val res = interpretWithBreakPoints(fib, List(bp), NormalInterpreter, InterpreterState()) + assert(res._1.nextCmd.isInstanceOf[ErrorStop]) + assert(res._2.nonEmpty) } test("Capture IllegalArg") { From 8f966cedb52704094973bed02c7f6a6165886f07 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 4 Sep 2024 17:45:01 +1000 Subject: [PATCH 45/62] notes on initialisation --- docs/development/interpreter.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/docs/development/interpreter.md b/docs/development/interpreter.md index a24738d41..0fb9713f5 100644 --- a/docs/development/interpreter.md +++ b/docs/development/interpreter.md @@ -14,6 +14,7 @@ exited on a non-error stop state. StoreVar(#5,Local,0xfd0:bv64) StoreMem(mem,HashMap(0xfd0:bv64 -> 0xf0:bv8, 0xfd6:bv64 -> 0x0:bv8, 0xfd2:bv64 -> 0x0:bv8, 0xfd3:bv64 -> 0x0:bv8, 0xfd4:bv64 -> 0x0:bv8, 0xfd7:bv64 -> 0x0:bv8, 0xfd5:bv64 -> 0x0:bv8, 0xfd1 ... +[INFO] Interpreter stopped normally. ``` The `--verbose` flag can also be used, which may print interpreter trace events as they are executed, but not this may not correspond to the actual @@ -134,6 +135,27 @@ This means, `Effects` triggered by an inner `Effects[]` instance do not flow bac but only appear from when `Interpreter` above the `ProductInterpreter` interprets the program via effect calls. For this reason if, for example, `NormalInterpreter` makes effect calls they will not appear in a trace emitted by `interptretTrace`. +### Note on memory space initialisation +Most of the interpret functions are overloaded such that there is a version taking a program `interpret(p: Program)`, +and a version taking `IRContext`. The variant taking IRContext uses the ELF symbol information to initialise the +memory before interpretation. If you are interpreting a real program (i.e. not a synthetic example created through +the DSL), this is most likely required. + +We initialise: + +- The general interpreter state, stack and memory regions, stack pointer, a symbolic mapping from addresses functions +- The initial and readonly memory sections stored in Program +- The `.bss` section to zero +- The relocation table. Each listed offset is stored an address to either a real procedure in the program, or a + location storing a symbolic function pointer to an intrinsic function. + +- `.bss` is generally the top of the initialised data, the ELF symbol `__bss_end__` being equal to the symbol `__end__`. + Above this we can somewhat choose arbitrarily where to put things, usually the heap is above, followed by + dynamically linked symbols, then the stack. There is currently no stack overflow checking, or heap implemented in the + interpreter. +- Unfortunately these details are defined by the load-time linker and the system's linker script, and it is hard to find a good description + of their behaviour. Some details are described here https://refspecs.linuxfoundation.org/elf/elf.pdf, and here + https://dl.acm.org/doi/abs/10.1145/2983990.2983996. From f00b44b806ce385468d714653ac78f3ea77cc1c6 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 4 Sep 2024 18:00:53 +1000 Subject: [PATCH 46/62] simplify invoc funcs --- docs/development/interpreter.md | 6 +++--- .../scala/ir/eval/InterpretBreakpoints.scala | 9 +++++++++ src/main/scala/ir/eval/InterpretTrace.scala | 20 +++++++++++++------ src/main/scala/ir/eval/Interpreter.scala | 4 ++-- src/test/scala/ir/InterpreterTests.scala | 2 -- 5 files changed, 28 insertions(+), 13 deletions(-) diff --git a/docs/development/interpreter.md b/docs/development/interpreter.md index 0fb9713f5..6e0764457 100644 --- a/docs/development/interpreter.md +++ b/docs/development/interpreter.md @@ -37,13 +37,13 @@ To see an example of this used to validate the constant prop analysis see [/src/ #### BreakPoints -Finally `interpretWithBreakPoints(p: IRContext, List[BreakPoint], interpreter, initialState)` is used to -run an interpreter, and perform additional work at specific code points. For example, this may be invoked such as: +Finally `interpretBreakPoints(p: IRContext, breakpoints: List[BreakPoint])` is used to +run an interpreter and perform additional actions at specified code points. For example, this may be invoked such as: ```scala val watch = IRWalk.firstInProc((program.procedures.find(_.name == "main")).get).get val bp = BreakPoint("entrypoint", BreakPointLoc.CMD(watch), BreakPointAction(saveState=true, stop=true, evalExprs=List(("R0", Register("R0", 64))), log=true)) -val res = interpretWithBreakPoints(program, List(bp), NormalInterpreter, InterpreterState()) +val res = interpretBreakPoints(program, List(bp)) ``` The structure of a breakpoint is as follows: diff --git a/src/main/scala/ir/eval/InterpretBreakpoints.scala b/src/main/scala/ir/eval/InterpretBreakpoints.scala index e64d65506..c8613dada 100644 --- a/src/main/scala/ir/eval/InterpretBreakpoints.scala +++ b/src/main/scala/ir/eval/InterpretBreakpoints.scala @@ -123,3 +123,12 @@ def interpretWithBreakPoints[I]( val res = InterpFuns.interpretProg(interp)(p, (innerInitialState, List())) res } + +def interpretBreakPoints(p: IRContext, breakpoints: List[BreakPoint]) = { + interpretWithBreakPoints(p, breakpoints, NormalInterpreter, InterpreterState()) +} + + +def interpretBreakPoints(p: Program, breakpoints: List[BreakPoint]) = { + interpretWithBreakPoints(p, breakpoints, NormalInterpreter, InterpreterState()) +} diff --git a/src/main/scala/ir/eval/InterpretTrace.scala b/src/main/scala/ir/eval/InterpretTrace.scala index 5c1c963fd..dbdd6e0b8 100644 --- a/src/main/scala/ir/eval/InterpretTrace.scala +++ b/src/main/scala/ir/eval/InterpretTrace.scala @@ -56,14 +56,22 @@ case class TraceGen[E]() extends NopEffects[Trace, E] { } -def tracingInterpreter = ProductInterpreter(NormalInterpreter, TraceGen()) - -def interpretTrace(p: Program): (InterpreterState, Trace) = { - InterpFuns.interpretProg(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) +def interpretWithTrace[I](p: Program, innerInterpreter: Effects[I, InterpreterError], innerInitialState: I): (I, Trace) = { + val tracingInterpreter = ProductInterpreter(innerInterpreter, TraceGen()) + InterpFuns.interpretProg(tracingInterpreter)(p, (innerInitialState, Trace(List()))) } -def interpretTrace(p: IRContext): (InterpreterState, Trace) = { - val begin = InterpFuns.initProgState(tracingInterpreter)(p, (InterpreterState(), Trace(List()))) +def interpretWithTrace[I](p: IRContext, innerInterpreter: Effects[I, InterpreterError], innerInitialState: I): (I, Trace) = { + val tracingInterpreter = ProductInterpreter(innerInterpreter, TraceGen()) + val begin = InterpFuns.initProgState(tracingInterpreter)(p, (innerInitialState, Trace(List()))) // throw away initialisation trace BASILInterpreter(tracingInterpreter).run((begin._1, Trace(List()))) } + +def interpretTrace(p: Program) = { + interpretWithTrace(p, NormalInterpreter, InterpreterState()) +} + +def interpretTrace(p: IRContext) = { + interpretWithTrace(p, NormalInterpreter, InterpreterState()) +} diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 70ba4d8ee..108eb34e2 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -39,11 +39,11 @@ case class Scalar(val value: Literal) extends BasilValue(value.getType) { } } -/* Slightly hacky way of mapping addresses to function calls within the interpreter dynamic state */ +/* Abstract callable function address */ case class FunPointer(val addr: BitVecLiteral, val name: String, val call: ExecutionContinuation) extends BasilValue(addr.getType) -/* Erase the type of basil values and enforce the invariant that +/* We erase the type of basil values and enforce the invariant that \exists i . \forall v \in value.keys , v.irType = i and \exists j . \forall v \in value.values, v.irType = j */ diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index bd12b5818..f128702f6 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -424,8 +424,6 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { BreakPointAction(true, true, List(("R0", Register("R0", 64))), true) ) val bp2 = BreakPoint("Fibentry", BreakPointLoc.CMD(watch), BreakPointAction(true, true , List(("R0", Register("R0", 64))), true)) - // val interp = LayerInterpreter(NormalInterpreter, RememberBreakpoints(NormalInterpreter, List(bp))) - // val res = InterpFuns.interpretProg(interp)(fib, (InterpreterState(), List())) val res = interpretWithBreakPoints(fib, List(bp), NormalInterpreter, InterpreterState()) assert(res._1.nextCmd.isInstanceOf[ErrorStop]) assert(res._2.nonEmpty) From ceef23351e1b751c736a7ddd1425350331494bb4 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Thu, 5 Sep 2024 10:18:15 +1000 Subject: [PATCH 47/62] note missing features --- docs/development/interpreter.md | 39 +++++++++++++++++++++++++++------ docs/development/readme.md | 1 + docs/readme.md | 1 + 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/docs/development/interpreter.md b/docs/development/interpreter.md index 6e0764457..2564d7f89 100644 --- a/docs/development/interpreter.md +++ b/docs/development/interpreter.md @@ -76,6 +76,7 @@ To see an example of this used to validate the constant prop analysis see [/src/ (e.g negative divisor, type mismatch) - [ExprEval.scala](../../src/main/scala/ir/eval/ExprEval.scala) - Evaluation of expressions, defined in terms of partial evaluation down to a Literal + - This can also be used to evaluate expressions in static analyses, by passing a function to query variable assignments and memory state from the value domain. - [Interpreter.scala](../../src/main/scala/ir/eval/Interpreter.scala) - Definition of core `Effects[S, E]` and `Interpreter[S, E]` types describing state transitions in the interpreter @@ -150,12 +151,36 @@ We initialise: - The relocation table. Each listed offset is stored an address to either a real procedure in the program, or a location storing a symbolic function pointer to an intrinsic function. -- `.bss` is generally the top of the initialised data, the ELF symbol `__bss_end__` being equal to the symbol `__end__`. - Above this we can somewhat choose arbitrarily where to put things, usually the heap is above, followed by - dynamically linked symbols, then the stack. There is currently no stack overflow checking, or heap implemented in the - interpreter. -- Unfortunately these details are defined by the load-time linker and the system's linker script, and it is hard to find a good description - of their behaviour. Some details are described here https://refspecs.linuxfoundation.org/elf/elf.pdf, and here - https://dl.acm.org/doi/abs/10.1145/2983990.2983996. +`.bss` is generally the top of the initialised data, the ELF symbol `__bss_end__` being equal to the symbol `__end__`. +Above this we can somewhat choose arbitrarily where to put things, usually the heap is above, followed by +dynamically linked symbols, then the stack. There is currently no stack overflow checking, or heap implemented in the +interpreter. + +Unfortunately these details are defined by the load-time linker and the system's linker script, and it is hard to find a good description +of their behaviour. Some details are described here https://refspecs.linuxfoundation.org/elf/elf.pdf, and here +https://dl.acm.org/doi/abs/10.1145/2983990.2983996. + +### Missing features + +There is functionality to implement external function calls via intrinsics written in Scala code, but currently only +basic printf style functions are implemented as no-ops. These can be extended to use a file IO abstraction, where +a memory region is created for each file (e.g. stdout), with a variable to keep track of the current write-point +such that a file write operation stores to the write-point address, and increments it by the size of the store. + +Importantly, an implementation of malloc() and free() is needed, which can implement a simple greedy allocation +algorithm. + +Despite the presence of procedure parameters in the current IR, they are not used for by the boogie translation and +are hence similarly ignored in the interpreter. + +The interpreter's immutable state representation is motivated by the ability to easily implement a sound approach +to non-determinism, e.g. to implement GoTos with guessing and rollback rather than look-ahead. This is more +useful for checking specification constructs than executing real programs. + +Finally, the trace does not clearly distinguish internal vs external calls, or observable +and non-observable behaviour. + +While the interpreter semantics supports memory regions, we do not initialise the memory regions (or the initial memory state) +based on those present in the program, we assume a flat `mem` memory model, possibly with `stack` as well. diff --git a/docs/development/readme.md b/docs/development/readme.md index 9b7de0be3..d0b7fbb75 100644 --- a/docs/development/readme.md +++ b/docs/development/readme.md @@ -5,6 +5,7 @@ - [tool-installation](tool-installation.md) Guide to lifter, etc. tool installation - [scala](scala.md) Advice on Scala programming. - [cfg](cfg.md) Explanation of the old CFG datastructure +- [interpreter](interpreter.md) Explanation of IR interpreter ## Scala diff --git a/docs/readme.md b/docs/readme.md index f7b331b23..3bb6905aa 100644 --- a/docs/readme.md +++ b/docs/readme.md @@ -12,6 +12,7 @@ To get started on development, see [development](development). - [editor-setup](development/editor-setup.md) Guide to basil development in IDEs - [tool-installation](development/tool-installation.md) Guide to lifter, etc. tool installation - [cfg](development/cfg.md) Explanation of the old CFG datastructure + - [interpreter](development/interpreter.md) Explanation of IR interpreter - [basil-ir](basil-ir.md) explanation of BASIL's intermediate representation - [compiler-explorer](compiler-explorer.md) guide to the compiler explorer basil interface - [il-cfg](il-cfg.md) explanation of the IL cfg iterator design From cff6c970584815b036bd132cb148ac959aea2cc9 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Thu, 5 Sep 2024 10:37:04 +1000 Subject: [PATCH 48/62] run through all system tests --- src/main/scala/util/functional.scala | 4 +- src/test/scala/DifferentialAnalysis.scala | 64 ++++++++++------------- src/test/scala/SystemTests.scala | 22 +------- src/test/scala/test_util/TestUtil.scala | 22 ++++++++ 4 files changed, 55 insertions(+), 57 deletions(-) diff --git a/src/main/scala/util/functional.scala b/src/main/scala/util/functional.scala index 4db438646..975220686 100644 --- a/src/main/scala/util/functional.scala +++ b/src/main/scala/util/functional.scala @@ -9,8 +9,8 @@ case class State[S, A, E](f: S => (S, Either[E, A])) { def >>(o: State[S,A,E]) = for { _ <- this - _ <- o - } yield (()) + x <- o + } yield (x) def flatMap[B](f: A => State[S, B, E]): State[S, B, E] = State(s => { diff --git a/src/test/scala/DifferentialAnalysis.scala b/src/test/scala/DifferentialAnalysis.scala index 8c40fac56..758dc2221 100644 --- a/src/test/scala/DifferentialAnalysis.scala +++ b/src/test/scala/DifferentialAnalysis.scala @@ -8,6 +8,7 @@ import org.scalatest.funsuite.* import specification.* import util.{BASILConfig, IRLoading, ILLoadingConfig, IRContext, RunUtils, StaticAnalysis, StaticAnalysisConfig, StaticAnalysisContext, BASILResult, Logger, LogLevel, IRTransform} import ir.eval.{interpretTrace, interpret, ExecEffect, Stopped} +import test_util.* import java.io.IOException @@ -44,9 +45,9 @@ class DifferentialAnalysis extends AnyFunSuite { assert(filterEvents(traceInit.t).mkString("\n") == filterEvents(traceRes.t).mkString("\n")) } - def testProgram(testName: String, examplePath: String) = { + def testProgram(testName: String, examplePath: String, suffix: String =".adt") = { - val loading = ILLoadingConfig(inputFile = examplePath + testName + ".adt", + val loading = ILLoadingConfig(inputFile = examplePath + testName + suffix, relfFile = examplePath + testName + ".relf", dumpIL = None, ) @@ -67,17 +68,6 @@ class DifferentialAnalysis extends AnyFunSuite { testProgram(testName, examplePath) } - test("indirect_call_gcc_example") { - val testName = "indirect_call" - val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/gcc/" - testProgram(testName, examplePath) - } - - test("indirect_call_clang_example") { - val testName = "indirect_call" - val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/clang/" - testProgram(testName, examplePath) - } test("jumptable2_example") { val testName = "jumptable2" @@ -85,17 +75,6 @@ class DifferentialAnalysis extends AnyFunSuite { testProgram(testName, examplePath) } - test("jumptable2_gcc_example") { - val testName = "jumptable2" - val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/gcc/" - testProgram(testName, examplePath) - } - - test("jumptable2_clang_example") { - val testName = "jumptable2" - val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/clang/" - testProgram(testName, examplePath) - } test("jumptable_example") { val testName = "jumptable" @@ -109,17 +88,6 @@ class DifferentialAnalysis extends AnyFunSuite { testProgram(testName, examplePath) } - test("functionpointer_gcc_example") { - val testName = "functionpointer" - val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/gcc/" - testProgram(testName, examplePath) - } - - test("functionpointer_clang_example") { - val testName = "functionpointer" - val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/clang/" - testProgram(testName, examplePath) - } test("function_got_example") { @@ -127,4 +95,30 @@ class DifferentialAnalysis extends AnyFunSuite { val examplePath = System.getProperty("user.dir") + s"/examples/$testName/" testProgram(testName, examplePath) } + + + def runSystemTests(): Unit = { + + val path = System.getProperty("user.dir") + s"/src/test/correct/" + val programs: Array[String] = getSubdirectories(path) + + // get all variations of each program + for (p <- programs) { + val programPath = path + "/" + p + val variations = getSubdirectories(programPath) + println(variations.mkString("\n")) + variations.foreach(variation => { + test("analysis_differential:" + p + "/" + variation + ":BAP") { + testProgram(p, path + "/" + p + "/" + variation + "/", suffix=".adt") + } + test("analysis_differential:" + p + "/" + variation + ":GTIRB") { + testProgram(p, path + "/" + p + "/" + variation + "/", suffix=".gts") + } + } + ) + } + } + + + runSystemTests() } diff --git a/src/test/scala/SystemTests.scala b/src/test/scala/SystemTests.scala index 80a0ce4fc..c18d5e40a 100644 --- a/src/test/scala/SystemTests.scala +++ b/src/test/scala/SystemTests.scala @@ -1,5 +1,6 @@ import org.scalatest.funsuite.AnyFunSuite import util.{Logger, PerformanceTimer} +import test_util.* import Numeric.Implicits.* import java.io.{BufferedWriter, File, FileWriter} @@ -215,25 +216,6 @@ trait SystemTests extends AnyFunSuite { true } - /** @param directoryName - * of the parent directory - * @return - * the names all subdirectories of the given parent directory - */ - def getSubdirectories(directoryName: String): Array[String] = { - Option(File(directoryName).listFiles(_.isDirectory)) match { - case None => throw java.io.IOException(s"failed to read directory '$directoryName'") - case Some(subdirs) => subdirs.map(_.getName) - } - } - - def log(text: String, path: String): Unit = { - val writer = BufferedWriter(FileWriter(path, false)) - writer.write(text) - writer.flush() - writer.close() - } - } class SystemTestsBAP extends SystemTests { @@ -343,4 +325,4 @@ def loadHisto() = { val timeValues = res("verifyTime").map(_.toDouble) val histo = histogram(50, Some(800.0, 1000.0))(timeValues.toSeq) println(histoToSvg("test histogram", 500, 300, histo, 800.0, 1000.0)) -} \ No newline at end of file +} diff --git a/src/test/scala/test_util/TestUtil.scala b/src/test/scala/test_util/TestUtil.scala index da2eb7c56..abb243ab0 100644 --- a/src/test/scala/test_util/TestUtil.scala +++ b/src/test/scala/test_util/TestUtil.scala @@ -1,4 +1,5 @@ package test_util +import java.io.{BufferedWriter, File, FileWriter} import ir.{Block, Procedure, Program} import util.{BASILConfig, BASILResult, BoogieGeneratorConfig, ILLoadingConfig, RunUtils, StaticAnalysisConfig} @@ -40,3 +41,24 @@ trait TestUtil { ) } } + + +/** @param directoryName + * of the parent directory + * @return + * the names all subdirectories of the given parent directory + */ +def getSubdirectories(directoryName: String): Array[String] = { + Option(File(directoryName).listFiles(_.isDirectory)) match { + case None => throw java.io.IOException(s"failed to read directory '$directoryName'") + case Some(subdirs) => subdirs.map(_.getName) + } +} + +def log(text: String, path: String): Unit = { + val writer = BufferedWriter(FileWriter(path, false)) + writer.write(text) + writer.flush() + writer.close() +} + From 822c12778e511b62eebae961b20d4972376cffd6 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Thu, 5 Sep 2024 11:24:00 +1000 Subject: [PATCH 49/62] add resource limit --- src/main/scala/ir/eval/InterpretBasilIR.scala | 29 ++++++----- src/main/scala/ir/eval/InterpretRLimit.scala | 52 +++++++++++++++++++ src/main/scala/ir/eval/InterpretTrace.scala | 5 +- src/test/scala/DifferentialAnalysis.scala | 14 +++-- 4 files changed, 82 insertions(+), 18 deletions(-) create mode 100644 src/main/scala/ir/eval/InterpretRLimit.scala diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index b483431a0..2878a3608 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -228,20 +228,23 @@ case object Eval { class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S, InterpreterError](f) { - def interpretOne: State[S, Boolean, InterpreterError] = for { - next <- f.getNext - _ <- State.pure(Logger.debug(s"$next")) - r: Boolean <- (next match { - case CallIntrinsic(tgt) => LibcIntrinsic.intrinsics(tgt)(f).map(_ => true) - case Run(c: Statement) => interpretStatement(f)(c).map(_ => true) - case Run(c: Jump) => interpretJump(f)(c).map(_ => true) - case Stopped() => State.pure(false) - case ErrorStop(e) => State.pure(false) - }) - .flatMapE((e: InterpreterError) => { - f.setNext(ErrorStop(e)).map(_ => false) + def interpretOne: State[S, Boolean, InterpreterError] = { + val next = for { + next <- f.getNext + _ <- State.pure(Logger.debug(s"$next")) + r: Boolean <- (next match { + case CallIntrinsic(tgt) => LibcIntrinsic.intrinsics(tgt)(f).map(_ => true) + case Run(c: Statement) => interpretStatement(f)(c).map(_ => true) + case Run(c: Jump) => interpretJump(f)(c).map(_ => true) + case Stopped() => State.pure(false) + case ErrorStop(e) => State.pure(false) }) - } yield (r) + } yield (r) + + next.flatMapE((e: InterpreterError) => { + f.setNext(ErrorStop(e)).map(_ => false) + }) + } def interpretJump[S, T <: Effects[S, InterpreterError]](f: T)(j: Jump): State[S, Unit, InterpreterError] = { j match { diff --git a/src/main/scala/ir/eval/InterpretRLimit.scala b/src/main/scala/ir/eval/InterpretRLimit.scala new file mode 100644 index 000000000..cdaecccc1 --- /dev/null +++ b/src/main/scala/ir/eval/InterpretRLimit.scala @@ -0,0 +1,52 @@ + +package ir.eval +import ir._ +import ir.eval.BitVectorEval.* +import ir.* +import util.IRContext +import util.Logger +import util.functional.* +import util.functional.State.* +import boogie.Scope +import scala.collection.WithFilter + +import scala.annotation.tailrec +import scala.collection.mutable +import scala.collection.immutable +import scala.util.control.Breaks.{break, breakable} + + +case class EffectsRLimit[T, E, I <: Effects[T, InterpreterError]](val limit: Int) extends NopEffects[(T, Int), InterpreterError] { + + override def getNext :State[(T, Int), ExecutionContinuation, InterpreterError] = { + for { + c : (T, Int) <- State.getS + (is, resources) = c + _ <- if (resources >= limit && limit >= 0) { + State.setError(Errored(s"Resource limit $limit reached")) + } else { + State.modify ((s : (T, Int)) => (s._1, s._2 + 1)) + } + } yield (Stopped()) // thrown away by LayerInterpreter + } +} + +def interpretWithRLimit[I](p: Program, instructionLimit: Int, innerInterpreter: Effects[I, InterpreterError], innerInitialState: I): I = { + val rlimitInterpreter = LayerInterpreter(innerInterpreter, EffectsRLimit(instructionLimit)) + InterpFuns.interpretProg(rlimitInterpreter)(p, (innerInitialState, 0))._1 +} + +def interpretWithRLimit[I](p: IRContext, instructionLimit: Int, innerInterpreter: Effects[I, InterpreterError], innerInitialState: I): I = { + val rlimitInterpreter = LayerInterpreter(innerInterpreter, EffectsRLimit(instructionLimit)) + val begin = InterpFuns.initProgState(rlimitInterpreter)(p, (innerInitialState, 0)) + // throw away initialisation trace + BASILInterpreter(rlimitInterpreter).run((begin._1, 0))._1 +} + +def interpretRLimit(p: Program, instructionLimit: Int) : InterpreterState = { + interpretWithRLimit(p, instructionLimit, NormalInterpreter, InterpreterState()) +} + +def interpretRLimit(p: IRContext, instructionLimit: Int) : InterpreterState = { + interpretWithRLimit(p, instructionLimit, NormalInterpreter, InterpreterState()) +} diff --git a/src/main/scala/ir/eval/InterpretTrace.scala b/src/main/scala/ir/eval/InterpretTrace.scala index dbdd6e0b8..89b0fc976 100644 --- a/src/main/scala/ir/eval/InterpretTrace.scala +++ b/src/main/scala/ir/eval/InterpretTrace.scala @@ -56,9 +56,10 @@ case class TraceGen[E]() extends NopEffects[Trace, E] { } +def tracingInterpreter[I, E](innerInterpreter: Effects[I, E]) = ProductInterpreter(innerInterpreter, TraceGen()) + def interpretWithTrace[I](p: Program, innerInterpreter: Effects[I, InterpreterError], innerInitialState: I): (I, Trace) = { - val tracingInterpreter = ProductInterpreter(innerInterpreter, TraceGen()) - InterpFuns.interpretProg(tracingInterpreter)(p, (innerInitialState, Trace(List()))) + InterpFuns.interpretProg(tracingInterpreter(innerInterpreter))(p, (innerInitialState, Trace(List()))) } def interpretWithTrace[I](p: IRContext, innerInterpreter: Effects[I, InterpreterError], innerInitialState: I): (I, Trace) = { diff --git a/src/test/scala/DifferentialAnalysis.scala b/src/test/scala/DifferentialAnalysis.scala index 758dc2221..966073ab4 100644 --- a/src/test/scala/DifferentialAnalysis.scala +++ b/src/test/scala/DifferentialAnalysis.scala @@ -24,8 +24,17 @@ class DifferentialAnalysis extends AnyFunSuite { Logger.setLevel(LogLevel.ERROR) def diffTest(initial: IRContext, transformed: IRContext) = { - val (initialRes,traceInit) = interpretTrace(initial) - val (result,traceRes) = interpretTrace(transformed) + + val instructionLimit = 100000 + + def interp(p: IRContext) : (InterpreterState, Trace) = { + val interpreter = LayerInterpreter(tracingInterpreter(NormalInterpreter), EffectsRLimit(instructionLimit)) + val initialState = InterpFuns.initProgState(NormalInterpreter)(p, InterpreterState()) + BASILInterpreter(interpreter).run((initialState, Trace(List())), 0)._1 + } + + val (initialRes,traceInit) = interp(initial) + val (result,traceRes) = interp(transformed) def filterEvents(trace: List[ExecEffect]) = { @@ -106,7 +115,6 @@ class DifferentialAnalysis extends AnyFunSuite { for (p <- programs) { val programPath = path + "/" + p val variations = getSubdirectories(programPath) - println(variations.mkString("\n")) variations.foreach(variation => { test("analysis_differential:" + p + "/" + variation + ":BAP") { testProgram(p, path + "/" + p + "/" + variation + "/", suffix=".adt") From 9b8f71563ca0d4c41acbd7f922108b24a9138254 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Thu, 5 Sep 2024 11:27:30 +1000 Subject: [PATCH 50/62] doc resource limit --- docs/development/interpreter.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/docs/development/interpreter.md b/docs/development/interpreter.md index 2564d7f89..8d6b19169 100644 --- a/docs/development/interpreter.md +++ b/docs/development/interpreter.md @@ -67,6 +67,26 @@ case class BreakPointAction( To see an example of this used to validate the constant prop analysis see [/src/test/scala/InterpretTestConstProp.scala](../../src/test/scala/InterpretTestConstProp.scala). +### Resource Limit + +This kills the interpreter in an error state once a specified instruction count is reached. + +It can be used simply with the function `interptretRLimit`, this automatically ignores the initialisation instructions. + +```scala +def interpretRLimit(p: IRContext, instructionLimit: Int) : InterpreterState +``` + +It can also be combined with other interpreters as shown: + +```scala +def interp(p: IRContext, instructionLimit: Int) : (InterpreterState, Trace) = { + val interpreter = LayerInterpreter(tracingInterpreter(NormalInterpreter), EffectsRLimit(instructionLimit)) + val initialState = InterpFuns.initProgState(NormalInterpreter)(p, InterpreterState()) + BASILInterpreter(interpreter).run((initialState, Trace(List())), 0)._1 +} +``` + ## Implementation / Code Structure ### Summary From 648a41d8c12fae6d8b9c7825006d0233bea81eee Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Thu, 5 Sep 2024 11:29:46 +1000 Subject: [PATCH 51/62] tweak doc --- docs/development/interpreter.md | 39 +++++++++++++++------------------ 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/docs/development/interpreter.md b/docs/development/interpreter.md index 8d6b19169..a2854bb73 100644 --- a/docs/development/interpreter.md +++ b/docs/development/interpreter.md @@ -69,7 +69,7 @@ To see an example of this used to validate the constant prop analysis see [/src/ ### Resource Limit -This kills the interpreter in an error state once a specified instruction count is reached. +This kills the interpreter in an error state once a specified instruction count is reached, to avoid the interpreter running forever on infinite loops. It can be used simply with the function `interptretRLimit`, this automatically ignores the initialisation instructions. @@ -109,6 +109,8 @@ def interp(p: IRContext, instructionLimit: Int) : (InterpreterState, Trace) = { - Definition of ELF initialisation in terms of generic `Effects[S, InterpreterError]` - [InterpretBreakpoints.scala](../../src/main/scala/ir/eval/InterpretBreakpoints.scala) - Definition of a generic interpreter with a breakpoint checker layered on top +- [interpretRLimit.scala](../../src/main/scala/ir/eval/InterpretRLimit.scala) + - Definition of layered interpreter which terminates after a specified cycle count - [InterpretTrace.scala](../../src/main/scala/ir/eval/InterpretTrace.scala) - Definition of a generic interpreter which records a trace of calls to the `Effects[]` instance. @@ -182,25 +184,20 @@ https://dl.acm.org/doi/abs/10.1145/2983990.2983996. ### Missing features -There is functionality to implement external function calls via intrinsics written in Scala code, but currently only -basic printf style functions are implemented as no-ops. These can be extended to use a file IO abstraction, where -a memory region is created for each file (e.g. stdout), with a variable to keep track of the current write-point -such that a file write operation stores to the write-point address, and increments it by the size of the store. - -Importantly, an implementation of malloc() and free() is needed, which can implement a simple greedy allocation -algorithm. - -Despite the presence of procedure parameters in the current IR, they are not used for by the boogie translation and -are hence similarly ignored in the interpreter. - -The interpreter's immutable state representation is motivated by the ability to easily implement a sound approach -to non-determinism, e.g. to implement GoTos with guessing and rollback rather than look-ahead. This is more -useful for checking specification constructs than executing real programs. - -Finally, the trace does not clearly distinguish internal vs external calls, or observable -and non-observable behaviour. - -While the interpreter semantics supports memory regions, we do not initialise the memory regions (or the initial memory state) -based on those present in the program, we assume a flat `mem` memory model, possibly with `stack` as well. +- There is functionality to implement external function calls via intrinsics written in Scala code, but currently only + basic printf style functions are implemented as no-ops. These can be extended to use a file IO abstraction, where + a memory region is created for each file (e.g. stdout), with a variable to keep track of the current write-point + such that a file write operation stores to the write-point address, and increments it by the size of the store. + Importantly, an implementation of malloc() and free() is needed, which can implement a simple greedy allocation + algorithm. +- Despite the presence of procedure parameters in the current IR, they are not used for by the boogie translation and + are hence similarly ignored in the interpreter. +- The interpreter's immutable state representation is motivated by the ability to easily implement a sound approach + to non-determinism, e.g. to implement GoTos with guessing and rollback rather than look-ahead. This is more + useful for checking specification constructs than executing real programs, so is not yet implemented. +- The trace does not clearly distinguish internal vs external calls, or observable + and non-observable behaviour. +- While the interpreter semantics supports memory regions, we do not initialise the memory regions (or the initial memory state) + based on those present in the program, we simply assume a flat `mem` and `stack` memory partitioning. From dfa7209d4408e9cf90cfc7bec977f9091faefb60 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Thu, 5 Sep 2024 13:34:15 +1000 Subject: [PATCH 52/62] fix --- src/main/scala/util/RunUtils.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index 11bc4f016..a4f7f06ea 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -214,7 +214,7 @@ object IRTransform { * add in modifies from the spec. */ def prepareForTranslation(config: ILLoadingConfig, ctx: IRContext): Unit = { - // ctx.program.determineRelevantMemory(ctx.globalOffsets) + ctx.program.determineRelevantMemory(ctx.globalOffsets) Logger.info("[!] Stripping unreachable") val before = ctx.program.procedures.size @@ -223,11 +223,11 @@ object IRTransform { s"[!] Removed ${before - ctx.program.procedures.size} functions (${ctx.program.procedures.size} remaining)" ) - //val stackIdentification = StackSubstituter() - //stackIdentification.visitProgram(ctx.program) + val stackIdentification = StackSubstituter() + stackIdentification.visitProgram(ctx.program) val specModifies = ctx.specification.subroutines.map(s => s.name -> s.modifies).toMap - // ctx.program.setModifies(specModifies) + ctx.program.setModifies(specModifies) assert(invariant.singleCallBlockEnd(ctx.program)) } @@ -506,7 +506,6 @@ object RunUtils { q.loading.dumpIL.foreach(s => writeToFile(serialiseIL(ctx.program), s"$s-after-analysis.il")) if (q.runInterpret) { - // val interpreter = eval.Interpreter() val fs = eval.interpretTrace(ctx) Logger.info("Interpreter Trace:\n" + fs._2.t.mkString("\n")) val stopState = fs._1.nextCmd From 38f871c9ef6e56890c2bb0afdb85789115723056 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Thu, 5 Sep 2024 15:38:59 +1000 Subject: [PATCH 53/62] tweak interpretrlimit --- src/main/scala/ir/eval/InterpretRLimit.scala | 12 ++++++------ src/test/scala/ir/InterpreterTests.scala | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/main/scala/ir/eval/InterpretRLimit.scala b/src/main/scala/ir/eval/InterpretRLimit.scala index cdaecccc1..a381e5d8f 100644 --- a/src/main/scala/ir/eval/InterpretRLimit.scala +++ b/src/main/scala/ir/eval/InterpretRLimit.scala @@ -31,22 +31,22 @@ case class EffectsRLimit[T, E, I <: Effects[T, InterpreterError]](val limit: Int } } -def interpretWithRLimit[I](p: Program, instructionLimit: Int, innerInterpreter: Effects[I, InterpreterError], innerInitialState: I): I = { +def interpretWithRLimit[I](p: Program, instructionLimit: Int, innerInterpreter: Effects[I, InterpreterError], innerInitialState: I): (I, Int) = { val rlimitInterpreter = LayerInterpreter(innerInterpreter, EffectsRLimit(instructionLimit)) - InterpFuns.interpretProg(rlimitInterpreter)(p, (innerInitialState, 0))._1 + InterpFuns.interpretProg(rlimitInterpreter)(p, (innerInitialState, 0)) } -def interpretWithRLimit[I](p: IRContext, instructionLimit: Int, innerInterpreter: Effects[I, InterpreterError], innerInitialState: I): I = { +def interpretWithRLimit[I](p: IRContext, instructionLimit: Int, innerInterpreter: Effects[I, InterpreterError], innerInitialState: I): (I, Int) = { val rlimitInterpreter = LayerInterpreter(innerInterpreter, EffectsRLimit(instructionLimit)) val begin = InterpFuns.initProgState(rlimitInterpreter)(p, (innerInitialState, 0)) // throw away initialisation trace - BASILInterpreter(rlimitInterpreter).run((begin._1, 0))._1 + BASILInterpreter(rlimitInterpreter).run((begin._1, 0)) } -def interpretRLimit(p: Program, instructionLimit: Int) : InterpreterState = { +def interpretRLimit(p: Program, instructionLimit: Int) : (InterpreterState, Int) = { interpretWithRLimit(p, instructionLimit, NormalInterpreter, InterpreterState()) } -def interpretRLimit(p: IRContext, instructionLimit: Int) : InterpreterState = { +def interpretRLimit(p: IRContext, instructionLimit: Int) : (InterpreterState, Int) = { interpretWithRLimit(p, instructionLimit, NormalInterpreter, InterpreterState()) } diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index f128702f6..e4cc8dbf7 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -379,9 +379,9 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { test("fibonaccistress") { Logger.setLevel(LogLevel.ERROR) - var res = List[(Int, Double, Double)]() + var res = List[(Int, Double, Double, Int)]() - for (i <- 0 to 12) { + for (i <- 0 to 30) { val prog = fibonacciProg(i) val t = PerformanceTimer("native") @@ -389,15 +389,15 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { val native = t.elapsed() val intt = PerformanceTimer("interp") - val ir = interpret(prog) + val ir = interpretRLimit(prog, 100000000) val it = intt.elapsed() - res = (i, native, it) :: res + res = (i, native, it, ir._2) :: res println(s"${res.head}") } - println(("fib number,native time,interp time" :: (res.map(x => s"${x._1},${x._2},${x._3}"))).mkString("\n")) + println(("fib number,native time,interp time,interp cycles" :: (res.map(x => s"${x._1},${x._2},${x._3},${x._4}"))).mkString("\n")) } From 6559180436452efdc560f99d70d58a8cb2d2bd86 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Tue, 10 Sep 2024 09:37:44 +1000 Subject: [PATCH 54/62] basic malloc implementation --- src/main/scala/ir/eval/InterpretBasilIR.scala | 7 +- src/main/scala/ir/eval/Interpreter.scala | 64 ++++++++++++++++++- .../scala/ir/eval/InterpreterProduct.scala | 10 +++ src/test/scala/DifferentialAnalysis.scala | 2 +- 4 files changed, 79 insertions(+), 4 deletions(-) diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index 2878a3608..9e84335e9 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -403,7 +403,12 @@ object InterpFuns { )) }) - State.sequence(State.pure(()), stores) + + for { + _ <- State.sequence(State.pure(()), stores) + malloc_top = BitVecLiteral(newAddr() + 1024, 64) + _ <- s.storeVar("ghost_malloc_top", Scope.Global, Scalar(malloc_top)) + } yield (()) } /** Functions which compile BASIL IR down to the minimal interpreter effects. diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 108eb34e2..4389dac84 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -70,6 +70,15 @@ case object BasilValue { case _ => Left((TypeError(s"Operation add $vr undefined on $l"))) } } + + def add[S, E](l: BasilValue, r: BasilValue): Either[InterpreterError, BasilValue] = { + (l,r) match { + case (Scalar(IntLiteral(vl)), Scalar(IntLiteral(vr))) => Right(Scalar(IntLiteral(vl + vr))) + case (Scalar(b1: BitVecLiteral), Scalar(b2: BitVecLiteral)) => Right(Scalar(eval.evalBVBinExpr(BVADD, b1, b2))) + case _ => Left((TypeError(s"Operation add undefined $l + $r"))) + } + } + } /** Minimal language defining all state transitions in the interpreter, defined for the interpreter's concrete state T. @@ -97,6 +106,8 @@ trait Effects[T, E] { */ def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): State[T, Unit, E] + def callIntrinsic(name: String, args: List[BasilValue]) : State[T, Unit, E] + def doReturn(): State[T, Unit, E] def storeVar(v: String, scope: Scope, value: BasilValue): State[T, Unit, E] @@ -112,6 +123,7 @@ trait NopEffects[T, E] extends Effects[T, E] { def setNext(c: ExecutionContinuation) = State.pure(()) def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = State.pure(()) + def callIntrinsic(name: String, args: List[BasilValue]) = State.pure(()) def doReturn() = State.pure(()) def storeVar(v: String, scope: Scope, value: BasilValue) = State.pure(()) @@ -295,8 +307,33 @@ object LibcIntrinsic { s.doReturn() } - def intrinsics[S, E, T <: Effects[S, E]] = - Map[String, T => State[S, Unit, E]]("putc" -> putc, "puts" -> puts, "printf" -> printf) + def malloc[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = for { + size <- s.loadVar("R0") + res <- s.callIntrinsic("malloc", List(size)) + _ <- s.doReturn() + } yield (()) + + def free[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = for { + ptr <- s.loadVar("R0") + res <- s.callIntrinsic("free", List(ptr)) + _ <- s.doReturn() + } yield (()) + + + def calloc[S, T <: Effects[S, InterpreterError]](s: T): State[S, Unit, InterpreterError] = for { + size <- s.loadVar("R0") + res <- s.callIntrinsic("malloc", List(size)) + ptr <- s.loadVar("R0") + isize <- size match { + case Scalar(b: BitVecLiteral) => State.pure(b.value * 8) + case _ => State.setError(Errored("programmer error")) + } + cl <- Eval.storeBV(s)("mem", ptr, BitVecLiteral(0, isize.toInt), Endian.LittleEndian) + _ <- s.doReturn() + } yield (()) + + def intrinsics[S, T <: Effects[S, InterpreterError]] = + Map[String, T => State[S, Unit, InterpreterError]]("putc" -> putc, "puts" -> puts, "printf" -> printf, "malloc" -> malloc, "free" -> free, "#free" -> free, "calloc" -> calloc) } @@ -310,6 +347,29 @@ case class InterpreterState( */ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { + def callIntrinsic(name: String, args: List[BasilValue]) = { + name match { + case "malloc" => { + for { + size <- (args.headOption match { + case Some(x @ Scalar(_: BitVecLiteral)) => State.pure(x) + case Some(Scalar(x: IntLiteral)) => State.pure(Scalar(BitVecLiteral(x.value, 64))) + case _ => State.setError(Errored("illegal prim arg")) + }) + x <- loadVar("ghost_malloc_top") + x_gap <- State.pureE(BasilValue.unsafeAdd(x, 128)) // put a gap around allocations to catch buffer overflows + x_end <- State.pureE(BasilValue.add(x_gap, size)) + _ <- storeVar("ghost_malloc_top", Scope.Global, x_end) + _ <- storeVar("R0", Scope.Global, x_gap) + } yield (()) + } + case "free" => { + State.pure(()) + } + case _ => State.setError(Errored(s"Call undefined intrinsic $name")) + } + } + def loadVar(v: String) = { State.getE((s: InterpreterState) => { s.memoryState.getVar(v) diff --git a/src/main/scala/ir/eval/InterpreterProduct.scala b/src/main/scala/ir/eval/InterpreterProduct.scala index ae8e51252..37e1820a4 100644 --- a/src/main/scala/ir/eval/InterpreterProduct.scala +++ b/src/main/scala/ir/eval/InterpreterProduct.scala @@ -62,6 +62,11 @@ case class ProductInterpreter[L, T, E](val inner: Effects[L, E], val before: Eff f <- doLeft(inner.call(target, beginFrom, returnTo)) } yield (f) + def callIntrinsic(name: String, args: List[BasilValue]) = for { + n <- doRight(before.callIntrinsic(name, args)) + f <- doLeft(inner.callIntrinsic(name, args)) + } yield (f) + def doReturn() = for { n <- doRight(before.doReturn()) f <- doLeft(inner.doReturn()) @@ -112,6 +117,11 @@ case class LayerInterpreter[L, T, E](val inner: Effects[L, E], val before: Effec f <- doLeft(inner.call(target, beginFrom, returnTo)) } yield (f) + def callIntrinsic(name: String, args: List[BasilValue]) = for { + n <- before.callIntrinsic(name, args) + f <- doLeft(inner.callIntrinsic(name, args)) + } yield (f) + def doReturn() = for { n <- (before.doReturn()) f <- doLeft(inner.doReturn()) diff --git a/src/test/scala/DifferentialAnalysis.scala b/src/test/scala/DifferentialAnalysis.scala index 966073ab4..7fd865957 100644 --- a/src/test/scala/DifferentialAnalysis.scala +++ b/src/test/scala/DifferentialAnalysis.scala @@ -25,7 +25,7 @@ class DifferentialAnalysis extends AnyFunSuite { def diffTest(initial: IRContext, transformed: IRContext) = { - val instructionLimit = 100000 + val instructionLimit = 1000000 def interp(p: IRContext) : (InterpreterState, Trace) = { val interpreter = LayerInterpreter(tracingInterpreter(NormalInterpreter), EffectsRLimit(instructionLimit)) From c6e938d34396a32ab341014865cce171903058b7 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Tue, 10 Sep 2024 17:35:01 +1000 Subject: [PATCH 55/62] implement printf --- src/main/scala/ir/eval/InterpretBasilIR.scala | 93 +++++--- src/main/scala/ir/eval/Interpreter.scala | 201 +++++++++++++----- src/test/scala/DifferentialAnalysis.scala | 19 +- 3 files changed, 220 insertions(+), 93 deletions(-) diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index 9e84335e9..62b626f5d 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -125,8 +125,8 @@ case object Eval { )(vname: String, addr: Scalar, endian: Endian, size: Int): State[S, BitVecLiteral, InterpreterError] = for { mem <- f.loadVar(vname) x <- mem match { - case mapv @ MapValue(_, MapType(_, BitVecType(sz))) => State.pure((sz, mapv)) - case _ => State.setError((Errored("Trued to load-concat non bv"))) + case mapv @ BasilMapValue(_, MapType(_, BitVecType(sz))) => State.pure((sz, mapv)) + case _ => State.setError((Errored("Trued to load-concat non bv"))) } (valsize, mapv) = x @@ -169,7 +169,8 @@ case object Eval { ): State[S, Unit, InterpreterError] = for { mem <- f.loadVar(vname) x <- mem match { - case m @ MapValue(_, MapType(kt, vt)) if kt == addr.irType && values.forall(v => v.irType == vt) => + case m @ BasilMapValue(_, MapType(kt, vt)) + if Some(kt) == addr.irType && values.forall(v => v.irType == Some(vt)) => State.pure((m, kt, vt)) case v => State.setError((TypeError(s"Invalid map store operation to $vname : $v"))) } @@ -191,7 +192,7 @@ case object Eval { ): State[S, Unit, InterpreterError] = for { mem <- f.loadVar(vname) mr <- mem match { - case m @ MapValue(_, MapType(kt, BitVecType(size))) if kt == addr.irType => State.pure((m, size)) + case m @ BasilMapValue(_, MapType(kt, BitVecType(size))) if Some(kt) == addr.irType => State.pure((m, size)) case v => State.setError( TypeError( @@ -224,6 +225,28 @@ case object Eval { )(vname: String, addr: BasilValue, value: BasilValue): State[S, Unit, E] = { f.storeMem(vname, Map((addr -> value))) } + + /** Helper functions * */ + + def getNullTerminatedString[S, T <: Effects[S, InterpreterError]](f: T) + (rgn: String, src: BasilValue, acc: List[BitVecLiteral] = List()): State[S, List[BitVecLiteral], InterpreterError] = + for { + srv: BitVecLiteral <- src match { + case Scalar(b: BitVecLiteral) => State.pure(b) + case _ => State.setError(Errored(s"Not pointer : $src")) + } + c <- f.loadMem(rgn, List(src)) + res <- c.head match { + case Scalar(BitVecLiteral(0, 8)) => State.pure(acc) + case Scalar(b: BitVecLiteral) => { + for { + nsrc <- State.pureE(BasilValue.unsafeAdd(src, 1)) + r <- getNullTerminatedString(f)(rgn, nsrc, acc.appended(b)) + } yield (r) + } + case _ => State.setError(Errored(s"not byte $c")) + } + } yield (res) } class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S, InterpreterError](f) { @@ -233,11 +256,11 @@ class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S next <- f.getNext _ <- State.pure(Logger.debug(s"$next")) r: Boolean <- (next match { - case CallIntrinsic(tgt) => LibcIntrinsic.intrinsics(tgt)(f).map(_ => true) - case Run(c: Statement) => interpretStatement(f)(c).map(_ => true) - case Run(c: Jump) => interpretJump(f)(c).map(_ => true) - case Stopped() => State.pure(false) - case ErrorStop(e) => State.pure(false) + case Intrinsic(tgt) => LibcIntrinsic.intrinsics(tgt)(f).map(_ => true) + case Run(c: Statement) => interpretStatement(f)(c).map(_ => true) + case Run(c: Jump) => interpretJump(f)(c).map(_ => true) + case Stopped() => State.pure(false) + case ErrorStop(e) => State.pure(false) }) } yield (r) @@ -311,18 +334,16 @@ class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S f.setNext(Run(s.successor)) }) } yield (n) - case dc: DirectCall => - for { - n <- - if (dc.target.entryBlock.isDefined) { - val block = dc.target.entryBlock.get - f.call(dc.target.name, Run(block.statements.headOption.getOrElse(block.jump)), Run(dc.successor)) - } else if (LibcIntrinsic.intrinsics.contains(dc.target.name)) { - f.call(dc.target.name, CallIntrinsic(dc.target.name), Run(dc.successor)) - } else { - State.setError(EscapedControlFlow(dc)) - } - } yield (n) + case dc: DirectCall => { + if (dc.target.entryBlock.isDefined) { + val block = dc.target.entryBlock.get + f.call(dc.target.name, Run(block.statements.headOption.getOrElse(block.jump)), Run(dc.successor)) + } else if (LibcIntrinsic.intrinsics.contains(dc.target.name)) { + f.call(dc.target.name, Intrinsic(dc.target.name), Run(dc.successor)) + } else { + State.setError(EscapedControlFlow(dc)) + } + } case ic: IndirectCall => { if (ic.target == Register("R30", 64)) { f.doReturn() @@ -360,7 +381,7 @@ object InterpFuns { for ((fname, fun) <- LibcIntrinsic.intrinsics) { val name = fname.takeWhile(c => c != '@') - x = (name, FunPointer(BitVecLiteral(newAddr(), 64), name, CallIntrinsic(name))) :: x + x = (name, FunPointer(BitVecLiteral(newAddr(), 64), name, Intrinsic(name))) :: x } val intrinsics = x.toMap @@ -403,12 +424,11 @@ object InterpFuns { )) }) - - for { - _ <- State.sequence(State.pure(()), stores) - malloc_top = BitVecLiteral(newAddr() + 1024, 64) - _ <- s.storeVar("ghost_malloc_top", Scope.Global, Scalar(malloc_top)) - } yield (()) + for { + _ <- State.sequence(State.pure(()), stores) + malloc_top = BitVecLiteral(newAddr() + 1024, 64) + _ <- s.storeVar("ghost_malloc_top", Scope.Global, Scalar(malloc_top)) + } yield (()) } /** Functions which compile BASIL IR down to the minimal interpreter effects. @@ -423,14 +443,19 @@ object InterpFuns { for { h <- State.pure(Logger.debug("DEFINE MEMORY REGIONS")) - h <- s.storeVar("ghost-funtable", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(64)))) - h <- s.storeVar("mem", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) - i <- s.storeVar("stack", Scope.Global, MapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + h <- s.storeVar("mem", Scope.Global, BasilMapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + i <- s.storeVar("stack", Scope.Global, BasilMapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) j <- s.storeVar("R31", Scope.Global, Scalar(SP)) k <- s.storeVar("R29", Scope.Global, Scalar(FP)) l <- s.storeVar("R30", Scope.Global, Scalar(LR)) l <- s.storeVar("R0", Scope.Global, Scalar(BitVecLiteral(0, 64))) l <- s.storeVar("R1", Scope.Global, Scalar(BitVecLiteral(0, 64))) + _ <- s.storeVar("ghost-funtable", Scope.Global, BasilMapValue(Map.empty, MapType(BitVecType(64), BitVecType(64)))) + _ <- s.storeVar("ghost-file-bookkeeping", Scope.Global, GenMapValue(Map.empty)) + _ <- s.storeVar("ghost-fd-mapping", Scope.Global, GenMapValue(Map.empty)) + _ <- s.storeMem("ghost-file-bookkeeping", Map(Symbol("$$filecount") -> Scalar(BitVecLiteral(0, 64)))) + _ <- s.callIntrinsic("fopen", List(Symbol("stderr"))) + _ <- s.callIntrinsic("fopen", List(Symbol("stdout"))) } yield (l) } @@ -500,7 +525,11 @@ object InterpFuns { def initProgState[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext, is: S): S = { val st = (initialiseProgram(f)(p.program) >> initBSS(f)(p)) >> InterpFuns.initRelocTable(f)(p) - State.execute(is, st) + val (fs, v) = st.f(is) + v match { + case Right(r) => fs + case Left(e) => throw Exception(s"Init failed $e") + } } /* Intialise from ELF and Interpret program */ diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 4389dac84..964b6d3f8 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -16,9 +16,9 @@ import scala.util.control.Breaks.{break, breakable} */ sealed trait ExecutionContinuation case class Stopped() extends ExecutionContinuation /* normal program stop */ -case class ErrorStop(error: InterpreterError) extends ExecutionContinuation /* normal program stop */ +case class ErrorStop(error: InterpreterError) extends ExecutionContinuation /* program stop in error state */ case class Run(val next: Command) extends ExecutionContinuation /* continue by executing next command */ -case class CallIntrinsic(val name: String) extends ExecutionContinuation /* continue by executing next command */ +case class Intrinsic(val name: String) extends ExecutionContinuation /* a named intrinsic instruction */ sealed trait InterpreterError case class FailedAssertion(a: Assert) extends InterpreterError @@ -31,8 +31,8 @@ case class EvalError(val message: String = "") case class MemoryError(val message: String = "") extends InterpreterError /* An error to do with memory */ /* Concrete value type of the interpreter. */ -sealed trait BasilValue(val irType: IRType) -case class Scalar(val value: Literal) extends BasilValue(value.getType) { +sealed trait BasilValue(val irType: Option[IRType]) +case class Scalar(val value: Literal) extends BasilValue(Some(value.getType)) { override def toString = value match { case b: BitVecLiteral => "0x%x:bv%d".format(b.value, b.size) case c => c.toString @@ -41,16 +41,28 @@ case class Scalar(val value: Literal) extends BasilValue(value.getType) { /* Abstract callable function address */ case class FunPointer(val addr: BitVecLiteral, val name: String, val call: ExecutionContinuation) - extends BasilValue(addr.getType) + extends BasilValue(Some(addr.getType)) + +sealed trait MapValue { + def value: Map[BasilValue, BasilValue] +} /* We erase the type of basil values and enforce the invariant that \exists i . \forall v \in value.keys , v.irType = i and \exists j . \forall v \in value.values, v.irType = j */ -case class MapValue(val value: Map[BasilValue, BasilValue], override val irType: MapType) extends BasilValue(irType) { +case class BasilMapValue(val value: Map[BasilValue, BasilValue], val mapType: MapType) + extends MapValue + with BasilValue(Some(mapType)) { override def toString = s"MapValue : $irType" } +case class GenMapValue(val value: Map[BasilValue, BasilValue]) extends BasilValue(None) with MapValue { + override def toString = s"GenMapValue : $irType" +} + +case class Symbol(val value: String) extends BasilValue(None) + case object BasilValue { def size(v: IRType): Int = { @@ -60,7 +72,12 @@ case object BasilValue { } } - def size(v: BasilValue): Int = size(v.irType) + def toBV[S, E](l: BasilValue): Either[InterpreterError, BitVecLiteral] = { + l match { + case Scalar(b1: BitVecLiteral) => Right(b1) + case _ => Left((TypeError(s"Not a bitvector add $l"))) + } + } def unsafeAdd[S, E](l: BasilValue, vr: Int): Either[InterpreterError, BasilValue] = { l match { @@ -72,10 +89,10 @@ case object BasilValue { } def add[S, E](l: BasilValue, r: BasilValue): Either[InterpreterError, BasilValue] = { - (l,r) match { - case (Scalar(IntLiteral(vl)), Scalar(IntLiteral(vr))) => Right(Scalar(IntLiteral(vl + vr))) + (l, r) match { + case (Scalar(IntLiteral(vl)), Scalar(IntLiteral(vr))) => Right(Scalar(IntLiteral(vl + vr))) case (Scalar(b1: BitVecLiteral), Scalar(b2: BitVecLiteral)) => Right(Scalar(eval.evalBVBinExpr(BVADD, b1, b2))) - case _ => Left((TypeError(s"Operation add undefined $l + $r"))) + case _ => Left((TypeError(s"Operation add undefined $l + $r"))) } } @@ -106,7 +123,7 @@ trait Effects[T, E] { */ def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): State[T, Unit, E] - def callIntrinsic(name: String, args: List[BasilValue]) : State[T, Unit, E] + def callIntrinsic(name: String, args: List[BasilValue]): State[T, Option[BasilValue], E] def doReturn(): State[T, Unit, E] @@ -123,7 +140,7 @@ trait NopEffects[T, E] extends Effects[T, E] { def setNext(c: ExecutionContinuation) = State.pure(()) def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = State.pure(()) - def callIntrinsic(name: String, args: List[BasilValue]) = State.pure(()) + def callIntrinsic(name: String, args: List[BasilValue]) = State.pure(None) def doReturn() = State.pure(()) def storeVar(v: String, scope: Scope, value: BasilValue) = State.pure(()) @@ -157,7 +174,7 @@ case class MemoryState( def getMem(name: String): Map[BitVecLiteral, BitVecLiteral] = { stackFrames(globalFrame)(name) match { - case MapValue(innerMap, MapType(BitVecType(ks), BitVecType(vs))) => { + case BasilMapValue(innerMap, MapType(BitVecType(ks), BitVecType(vs))) => { def unwrap(v: BasilValue): BitVecLiteral = v match { case Scalar(b: BitVecLiteral) => b case v => throw Exception(s"Failed to convert map value to bitvector: $v (interpreter type error somewhere)") @@ -241,7 +258,7 @@ case class MemoryState( def getVar(v: Variable): Either[InterpreterError, BasilValue] = { val value = getVar(v.name) value match { - case Right(dv: BasilValue) if v.getType != dv.irType => + case Right(dv: BasilValue) if Some(v.getType) != dv.irType => Left(Errored(s"Type mismatch on variable definition and load: defined ${dv.irType}, variable ${v.getType}")) case Right(o) => Right(o) case o => o @@ -252,8 +269,8 @@ case class MemoryState( def doLoad(vname: String, addr: List[BasilValue]): Either[InterpreterError, List[BasilValue]] = for { v <- findVar(vname) mapv: MapValue <- v._2 match { - case m @ MapValue(innerMap, ty) => Right(m) - case m => Left((TypeError(s"Load from nonmap ${m.irType}"))) + case m: MapValue => Right(m) + case m => Left((TypeError(s"Load from nonmap ${m.irType}"))) } rs: List[Option[BasilValue]] = addr.map(k => mapv.value.get(k)) xs <- @@ -266,46 +283,59 @@ case class MemoryState( /** typecheck and some fields of a map variable */ def doStore(vname: String, values: Map[BasilValue, BasilValue]): Either[InterpreterError, MemoryState] = for { - // val (frame, mem) = findVar(vname) - + _ <- if (values.size == 0) then Left(MemoryError("Tried to store size 0")) else Right(()) v <- findVar(vname) (frame, mem) = v - // val (mapval, keytype, valtype) = - mapi <- mem match { - case m @ MapValue(_, MapType(kt, vt)) => Right((m, kt, vt)) - case v => Left((TypeError(s"Invalid map store operation to $vname : ${v.irType}"))) - } - (mapval, keytype, valtype) = mapi - - checkTypes <- (values.find((k, v) => k.irType != keytype || v.irType != valtype)) match { - case Some(v) => - Left( - TypeError( - s"Invalid addr or value type (${v._1.irType}, ${v._2.irType}) does not match map type $vname : ($keytype, $valtype)" - ) - ) - case None => Right(()) + mapval <- mem match { + case m @ BasilMapValue(_, MapType(kt, vt)) => + for { + m <- (values.find((k, v) => k.irType != Some(kt) || v.irType != Some(vt))) match { + case Some(v) => + Left( + TypeError( + s"Invalid addr or value type (${v._1.irType}, ${v._2.irType}) does not match map type $vname : ($kt, $vt)" + ) + ) + case None => Right(m) + } + nm = BasilMapValue(m.value ++ values, m.mapType) + } yield (nm) + case m @ GenMapValue(_) => { + Right(GenMapValue(m.value ++ values)) + } + case v => Left((TypeError(s"Invalid map store operation to $vname : ${v.irType}"))) } - nmap = MapValue(mapval.value ++ values, mapval.irType) - ms <- Right(setVar(frame, vname, nmap)) + ms <- Right(setVar(frame, vname, mapval)) } yield (ms) } object LibcIntrinsic { - def putc[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = { - s.doReturn() - } + def putc[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = for { + c <- s.loadVar("R0") + _ <- s.callIntrinsic("putc", List(c)) + _ <- s.doReturn() + } yield (()) - def puts[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = { - s.doReturn() - } + def puts[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = for { + dstr <- s.loadVar("R0") + _ <- s.callIntrinsic("puts", List(dstr)) + _ <- s.doReturn() + } yield (()) - def printf[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = { - s.doReturn() - } + def strlen[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = for { + dstr <- s.loadVar("R0") + _ <- s.callIntrinsic("strlen", List(dstr)) + _ <- s.doReturn() + } yield (()) + + def printf[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = for { + dstr <- s.loadVar("R0") + _ <- s.callIntrinsic("print", List(dstr)) + _ <- s.doReturn() + } yield (()) def malloc[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = for { size <- s.loadVar("R0") @@ -319,21 +349,28 @@ object LibcIntrinsic { _ <- s.doReturn() } yield (()) - def calloc[S, T <: Effects[S, InterpreterError]](s: T): State[S, Unit, InterpreterError] = for { size <- s.loadVar("R0") res <- s.callIntrinsic("malloc", List(size)) ptr <- s.loadVar("R0") isize <- size match { case Scalar(b: BitVecLiteral) => State.pure(b.value * 8) - case _ => State.setError(Errored("programmer error")) + case _ => State.setError(Errored("programmer error")) } cl <- Eval.storeBV(s)("mem", ptr, BitVecLiteral(0, isize.toInt), Endian.LittleEndian) _ <- s.doReturn() } yield (()) def intrinsics[S, T <: Effects[S, InterpreterError]] = - Map[String, T => State[S, Unit, InterpreterError]]("putc" -> putc, "puts" -> puts, "printf" -> printf, "malloc" -> malloc, "free" -> free, "#free" -> free, "calloc" -> calloc) + Map[String, T => State[S, Unit, InterpreterError]]( + "putc" -> putc, + "puts" -> puts, + "printf" -> printf, + "malloc" -> malloc, + "free" -> free, + "#free" -> free, + "calloc" -> calloc + ) } @@ -347,25 +384,79 @@ case class InterpreterState( */ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { - def callIntrinsic(name: String, args: List[BasilValue]) = { + def putc(arg: BasilValue): State[InterpreterState, Option[BasilValue], InterpreterError] = { + for { + addr <- loadMem("ghost-file-bookkeeping", List(Symbol("stdout-ptr"))) + byte <- State.pureE(BasilValue.toBV(arg)) + c <- Eval.evalBV(this)(Extract(8, 0, byte)) + _ <- storeMem("stdout", Map(addr.head -> Scalar(c))) + naddr <- State.pureE(BasilValue.unsafeAdd(addr.head, 1)) + _ <- storeMem("ghost-file-bookkeeping", Map(Symbol("stdout-ptr") -> naddr)) + } yield (None) + } + + def fopen(file: BasilValue): State[InterpreterState, Option[BasilValue], InterpreterError] = { + for { + fname <- file match { + case Symbol(name) => State.pure(name) + case _ => State.setError(Errored("Intrinsic fopen open not given filename")) + } + _ <- storeVar(fname, Scope.Global, BasilMapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + filecount <- loadMem("ghost-file-bookkeeping", List(Symbol("$$filecount"))) + nfilecount <- State.pureE(BasilValue.unsafeAdd(filecount.head, 1)) + _ <- storeMem("ghost-file-bookkeeping", Map(Symbol("$$filecount") -> nfilecount)) + _ <- storeMem("ghost-file-bookkeeping", Map(Symbol(fname + "-ptr") -> Scalar(BitVecLiteral(0, 64)))) + _ <- storeMem("ghost-fd-mapping", Map(nfilecount -> Symbol(fname + "-ptr"))) + _ <- storeVar("R0", Scope.Global, nfilecount) + } yield (Some(nfilecount)) + } + + def print(strptr: BasilValue): State[InterpreterState, Option[BasilValue], InterpreterError] = { + for { + str <- Eval.getNullTerminatedString(this)("mem", strptr) + baseptr: List[BasilValue] <- loadMem("ghost-file-bookkeeping", List(Symbol("stdout-ptr"))) + offs: List[BasilValue] <- State.mapM( + ((i: Int) => State.pureE(BasilValue.unsafeAdd(baseptr.head, i))), + (0 until (str.size + 1)) + ) + _ <- storeMem("stdout", offs.zip(str.map(Scalar(_))).toMap) + naddr <- State.pureE(BasilValue.unsafeAdd(baseptr.head, str.size)) + _ <- storeMem("ghost-file-bookkeeping", Map(Symbol("stdout-ptr") -> naddr)) + } yield (None) + } + + def callIntrinsic( + name: String, + args: List[BasilValue] + ): State[InterpreterState, Option[BasilValue], InterpreterError] = { name match { + case "free" => { + State.pure(None) + } case "malloc" => { for { size <- (args.headOption match { case Some(x @ Scalar(_: BitVecLiteral)) => State.pure(x) - case Some(Scalar(x: IntLiteral)) => State.pure(Scalar(BitVecLiteral(x.value, 64))) - case _ => State.setError(Errored("illegal prim arg")) + case Some(Scalar(x: IntLiteral)) => State.pure(Scalar(BitVecLiteral(x.value, 64))) + case _ => State.setError(Errored("illegal prim arg")) }) x <- loadVar("ghost_malloc_top") x_gap <- State.pureE(BasilValue.unsafeAdd(x, 128)) // put a gap around allocations to catch buffer overflows x_end <- State.pureE(BasilValue.add(x_gap, size)) _ <- storeVar("ghost_malloc_top", Scope.Global, x_end) _ <- storeVar("R0", Scope.Global, x_gap) - } yield (()) - } - case "free" => { - State.pure(()) + } yield (Some(x_gap)) } + case "fopen" => fopen(args.head) + case "putc" => putc(args.head) + case "strlen" => + for { + str <- Eval.getNullTerminatedString(this)("mem", args.head) + r = Scalar(BitVecLiteral(str.length, 64)) + _ <- storeVar("R0", Scope.Global, r) + } yield (Some(r)) + case "print" => print(args.head) + case "puts" => print(args.head) >> putc(Scalar(BitVecLiteral('\n'.toInt, 64))) case _ => State.setError(Errored(s"Call undefined intrinsic $name")) } } @@ -487,7 +578,7 @@ trait Interpreter[S, E](val f: Effects[S, E]) { @tailrec final def run(begin: S): S = { - val (fs,cont) = interpretOne.f(begin) + val (fs, cont) = interpretOne.f(begin) if (cont.contains(true)) then { run(fs) diff --git a/src/test/scala/DifferentialAnalysis.scala b/src/test/scala/DifferentialAnalysis.scala index 7fd865957..3ef3c7db1 100644 --- a/src/test/scala/DifferentialAnalysis.scala +++ b/src/test/scala/DifferentialAnalysis.scala @@ -21,7 +21,7 @@ import scala.collection.mutable class DifferentialAnalysis extends AnyFunSuite { - Logger.setLevel(LogLevel.ERROR) + Logger.setLevel(LogLevel.WARN) def diffTest(initial: IRContext, transformed: IRContext) = { @@ -30,13 +30,15 @@ class DifferentialAnalysis extends AnyFunSuite { def interp(p: IRContext) : (InterpreterState, Trace) = { val interpreter = LayerInterpreter(tracingInterpreter(NormalInterpreter), EffectsRLimit(instructionLimit)) val initialState = InterpFuns.initProgState(NormalInterpreter)(p, InterpreterState()) - BASILInterpreter(interpreter).run((initialState, Trace(List())), 0)._1 + //Logger.setLevel(LogLevel.DEBUG) + val r = BASILInterpreter(interpreter).run((initialState, Trace(List())), 0)._1 + //Logger.setLevel(LogLevel.WARN) + r } val (initialRes,traceInit) = interp(initial) val (result,traceRes) = interp(transformed) - def filterEvents(trace: List[ExecEffect]) = { trace.collect { case e @ ExecEffect.Call(_, _, _) => e @@ -46,6 +48,11 @@ class DifferentialAnalysis extends AnyFunSuite { } Logger.info(traceInit.t.map(_.toString.take(80)).mkString("\n")) + val initstdout = initialRes.memoryState.getMem("stdout").toList.sortBy(_._1.value).map(_._2.value.toChar).mkString("") + val comparstdout = result.memoryState.getMem("stdout").toList.sortBy(_._1.value).map(_._2.value.toChar).mkString("") + info("STDOUT: \"" + initstdout + "\"") + // Logger.info(initialRes.memoryState.getMem("stderr").toList.sortBy(_._1.value).map(_._2).mkString("")) + assert(initstdout == comparstdout) assert(initialRes.nextCmd == Stopped()) assert(result.nextCmd == Stopped()) assert(Set.empty == initialRes.memoryState.getMem("mem").toSet.diff(result.memoryState.getMem("mem").toSet)) @@ -119,9 +126,9 @@ class DifferentialAnalysis extends AnyFunSuite { test("analysis_differential:" + p + "/" + variation + ":BAP") { testProgram(p, path + "/" + p + "/" + variation + "/", suffix=".adt") } - test("analysis_differential:" + p + "/" + variation + ":GTIRB") { - testProgram(p, path + "/" + p + "/" + variation + "/", suffix=".gts") - } + //test("analysis_differential:" + p + "/" + variation + ":GTIRB") { + // testProgram(p, path + "/" + p + "/" + variation + "/", suffix=".gts") + //} } ) } From 285099b31fea5592194dfcf45f38a7b0faf742c6 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Tue, 10 Sep 2024 18:05:51 +1000 Subject: [PATCH 56/62] cleanup intrins --- src/main/scala/ir/eval/InterpretBasilIR.scala | 10 +- src/main/scala/ir/eval/Interpreter.scala | 167 +++++++++--------- 2 files changed, 85 insertions(+), 92 deletions(-) diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index 62b626f5d..98f1a9f25 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -228,6 +228,10 @@ case object Eval { /** Helper functions * */ + /** + * Load all memory cells from pointer until reaching cell containing 0. + * Ptr -> List[Bitvector] + */ def getNullTerminatedString[S, T <: Effects[S, InterpreterError]](f: T) (rgn: String, src: BasilValue, acc: List[BitVecLiteral] = List()): State[S, List[BitVecLiteral], InterpreterError] = for { @@ -451,11 +455,7 @@ object InterpFuns { l <- s.storeVar("R0", Scope.Global, Scalar(BitVecLiteral(0, 64))) l <- s.storeVar("R1", Scope.Global, Scalar(BitVecLiteral(0, 64))) _ <- s.storeVar("ghost-funtable", Scope.Global, BasilMapValue(Map.empty, MapType(BitVecType(64), BitVecType(64)))) - _ <- s.storeVar("ghost-file-bookkeeping", Scope.Global, GenMapValue(Map.empty)) - _ <- s.storeVar("ghost-fd-mapping", Scope.Global, GenMapValue(Map.empty)) - _ <- s.storeMem("ghost-file-bookkeeping", Map(Symbol("$$filecount") -> Scalar(BitVecLiteral(0, 64)))) - _ <- s.callIntrinsic("fopen", List(Symbol("stderr"))) - _ <- s.callIntrinsic("fopen", List(Symbol("stdout"))) + _ <- IntrinsicImpl.initFileGhostRegions(s) } yield (l) } diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 964b6d3f8..b57bb5fc1 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -313,46 +313,22 @@ case class MemoryState( object LibcIntrinsic { - def putc[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = for { - c <- s.loadVar("R0") - _ <- s.callIntrinsic("putc", List(c)) - _ <- s.doReturn() - } yield (()) - - def puts[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = for { - dstr <- s.loadVar("R0") - _ <- s.callIntrinsic("puts", List(dstr)) - _ <- s.doReturn() - } yield (()) - - def strlen[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = for { - dstr <- s.loadVar("R0") - _ <- s.callIntrinsic("strlen", List(dstr)) - _ <- s.doReturn() - } yield (()) - - def printf[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = for { - dstr <- s.loadVar("R0") - _ <- s.callIntrinsic("print", List(dstr)) - _ <- s.doReturn() - } yield (()) - - def malloc[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = for { - size <- s.loadVar("R0") - res <- s.callIntrinsic("malloc", List(size)) - _ <- s.doReturn() - } yield (()) + /** + * Part of the intrinsics implementation that lives above the Effects interface + * (i.e. we are defining the observable part of the intrinsics behaviour) + */ - def free[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = for { - ptr <- s.loadVar("R0") - res <- s.callIntrinsic("free", List(ptr)) + def singleArg[S, E, T <: Effects[S, E]](name: String)(s: T): State[S, Unit, E] = for { + c <- s.loadVar("R0") + res <- s.callIntrinsic(name, List(c)) + _ <- if res.isDefined then s.storeVar("R0", Scope.Global, res.get) else State.pure(()) _ <- s.doReturn() } yield (()) def calloc[S, T <: Effects[S, InterpreterError]](s: T): State[S, Unit, InterpreterError] = for { size <- s.loadVar("R0") res <- s.callIntrinsic("malloc", List(size)) - ptr <- s.loadVar("R0") + ptr = res.get isize <- size match { case Scalar(b: BitVecLiteral) => State.pure(b.value * 8) case _ => State.setError(Errored("programmer error")) @@ -363,101 +339,118 @@ object LibcIntrinsic { def intrinsics[S, T <: Effects[S, InterpreterError]] = Map[String, T => State[S, Unit, InterpreterError]]( - "putc" -> putc, - "puts" -> puts, - "printf" -> printf, - "malloc" -> malloc, - "free" -> free, - "#free" -> free, + "putc" -> singleArg("putc"), + "puts" -> singleArg("puts"), + "printf" -> singleArg("print"), + "malloc" -> singleArg("malloc"), + "free" -> singleArg("free"), + "#free" -> singleArg("free"), "calloc" -> calloc ) } -case class InterpreterState( - val nextCmd: ExecutionContinuation = Stopped(), - val callStack: List[ExecutionContinuation] = List.empty, - val memoryState: MemoryState = MemoryState() -) +object IntrinsicImpl { -/** Implementation of Effects for InterpreterState concrete state representation. - */ -object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { + /** state initialisation for file modelling */ + def initFileGhostRegions[S, E, T <: Effects[S, E]](f: T): State[S, Unit, E] = for { + _ <- f.storeVar("ghost-file-bookkeeping", Scope.Global, GenMapValue(Map.empty)) + _ <- f.storeVar("ghost-fd-mapping", Scope.Global, GenMapValue(Map.empty)) + _ <- f.storeMem("ghost-file-bookkeeping", Map(Symbol("$$filecount") -> Scalar(BitVecLiteral(0, 64)))) + _ <- f.callIntrinsic("fopen", List(Symbol("stderr"))) + _ <- f.callIntrinsic("fopen", List(Symbol("stdout"))) + } yield (()) - def putc(arg: BasilValue): State[InterpreterState, Option[BasilValue], InterpreterError] = { + /** Intrinsics defined over arbitrary effects + * + * We call these from Effects[T, E] rather than the Interpreter so their implementation does not appear in the trace. + */ + def putc[S, T <: Effects[S, InterpreterError]](f: T)(arg: BasilValue): State[S, Option[BasilValue], InterpreterError] = { for { - addr <- loadMem("ghost-file-bookkeeping", List(Symbol("stdout-ptr"))) + addr <- f.loadMem("ghost-file-bookkeeping", List(Symbol("stdout-ptr"))) byte <- State.pureE(BasilValue.toBV(arg)) - c <- Eval.evalBV(this)(Extract(8, 0, byte)) - _ <- storeMem("stdout", Map(addr.head -> Scalar(c))) + c <- Eval.evalBV(f)(Extract(8, 0, byte)) + _ <- f.storeMem("stdout", Map(addr.head -> Scalar(c))) naddr <- State.pureE(BasilValue.unsafeAdd(addr.head, 1)) - _ <- storeMem("ghost-file-bookkeeping", Map(Symbol("stdout-ptr") -> naddr)) + _ <- f.storeMem("ghost-file-bookkeeping", Map(Symbol("stdout-ptr") -> naddr)) } yield (None) } - def fopen(file: BasilValue): State[InterpreterState, Option[BasilValue], InterpreterError] = { + def fopen[S, T <: Effects[S, InterpreterError]](f: T)(file: BasilValue): State[S, Option[BasilValue], InterpreterError] = { for { fname <- file match { case Symbol(name) => State.pure(name) case _ => State.setError(Errored("Intrinsic fopen open not given filename")) } - _ <- storeVar(fname, Scope.Global, BasilMapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) - filecount <- loadMem("ghost-file-bookkeeping", List(Symbol("$$filecount"))) + _ <- f.storeVar(fname, Scope.Global, BasilMapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + filecount <- f.loadMem("ghost-file-bookkeeping", List(Symbol("$$filecount"))) + _ <- f.storeMem("ghost-file-bookkeeping", Map(Symbol(fname + "-ptr") -> Scalar(BitVecLiteral(0, 64)))) + _ <- f.storeMem("ghost-fd-mapping", Map(filecount.head -> Symbol(fname + "-ptr"))) + _ <- f.storeVar("R0", Scope.Global, filecount.head) nfilecount <- State.pureE(BasilValue.unsafeAdd(filecount.head, 1)) - _ <- storeMem("ghost-file-bookkeeping", Map(Symbol("$$filecount") -> nfilecount)) - _ <- storeMem("ghost-file-bookkeeping", Map(Symbol(fname + "-ptr") -> Scalar(BitVecLiteral(0, 64)))) - _ <- storeMem("ghost-fd-mapping", Map(nfilecount -> Symbol(fname + "-ptr"))) - _ <- storeVar("R0", Scope.Global, nfilecount) - } yield (Some(nfilecount)) + _ <- f.storeMem("ghost-file-bookkeeping", Map(Symbol("$$filecount") -> nfilecount)) + } yield (Some(filecount.head)) } - def print(strptr: BasilValue): State[InterpreterState, Option[BasilValue], InterpreterError] = { + def print[S, T <: Effects[S, InterpreterError]](f: T)(strptr: BasilValue): State[S, Option[BasilValue], InterpreterError] = { for { - str <- Eval.getNullTerminatedString(this)("mem", strptr) - baseptr: List[BasilValue] <- loadMem("ghost-file-bookkeeping", List(Symbol("stdout-ptr"))) + str <- Eval.getNullTerminatedString(f)("mem", strptr) + baseptr: List[BasilValue] <- f.loadMem("ghost-file-bookkeeping", List(Symbol("stdout-ptr"))) offs: List[BasilValue] <- State.mapM( ((i: Int) => State.pureE(BasilValue.unsafeAdd(baseptr.head, i))), (0 until (str.size + 1)) ) - _ <- storeMem("stdout", offs.zip(str.map(Scalar(_))).toMap) + _ <- f.storeMem("stdout", offs.zip(str.map(Scalar(_))).toMap) naddr <- State.pureE(BasilValue.unsafeAdd(baseptr.head, str.size)) - _ <- storeMem("ghost-file-bookkeeping", Map(Symbol("stdout-ptr") -> naddr)) + _ <- f.storeMem("ghost-file-bookkeeping", Map(Symbol("stdout-ptr") -> naddr)) } yield (None) } + def malloc[S, T <: Effects[S, InterpreterError]](f: T)(size: BasilValue): State[S, Option[BasilValue], InterpreterError] = { + for { + size <- (size match { + case (x @ Scalar(_: BitVecLiteral)) => State.pure(x) + case (Scalar(x: IntLiteral)) => State.pure(Scalar(BitVecLiteral(x.value, 64))) + case _ => State.setError(Errored("illegal prim arg")) + }) + x <- f.loadVar("ghost_malloc_top") + x_gap <- State.pureE(BasilValue.unsafeAdd(x, 128)) // put a gap around allocations to catch buffer overflows + x_end <- State.pureE(BasilValue.add(x_gap, size)) + _ <- f.storeVar("ghost_malloc_top", Scope.Global, x_end) + _ <- f.storeVar("R0", Scope.Global, x_gap) + } yield (Some(x_gap)) + } +} + +case class InterpreterState( + val nextCmd: ExecutionContinuation = Stopped(), + val callStack: List[ExecutionContinuation] = List.empty, + val memoryState: MemoryState = MemoryState() +) + +/** Implementation of Effects for InterpreterState concrete state representation. + */ +object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { + + def callIntrinsic( name: String, args: List[BasilValue] ): State[InterpreterState, Option[BasilValue], InterpreterError] = { name match { - case "free" => { - State.pure(None) - } - case "malloc" => { - for { - size <- (args.headOption match { - case Some(x @ Scalar(_: BitVecLiteral)) => State.pure(x) - case Some(Scalar(x: IntLiteral)) => State.pure(Scalar(BitVecLiteral(x.value, 64))) - case _ => State.setError(Errored("illegal prim arg")) - }) - x <- loadVar("ghost_malloc_top") - x_gap <- State.pureE(BasilValue.unsafeAdd(x, 128)) // put a gap around allocations to catch buffer overflows - x_end <- State.pureE(BasilValue.add(x_gap, size)) - _ <- storeVar("ghost_malloc_top", Scope.Global, x_end) - _ <- storeVar("R0", Scope.Global, x_gap) - } yield (Some(x_gap)) - } - case "fopen" => fopen(args.head) - case "putc" => putc(args.head) + case "free" => State.pure(None) + case "malloc" => IntrinsicImpl.malloc(this)(args.head) + case "fopen" => IntrinsicImpl.fopen(this)(args.head) + case "putc" => IntrinsicImpl.putc(this)(args.head) case "strlen" => for { str <- Eval.getNullTerminatedString(this)("mem", args.head) r = Scalar(BitVecLiteral(str.length, 64)) _ <- storeVar("R0", Scope.Global, r) } yield (Some(r)) - case "print" => print(args.head) - case "puts" => print(args.head) >> putc(Scalar(BitVecLiteral('\n'.toInt, 64))) - case _ => State.setError(Errored(s"Call undefined intrinsic $name")) + case "print" => IntrinsicImpl.print(this)(args.head) + case "puts" => IntrinsicImpl.print(this)(args.head) >> IntrinsicImpl.putc(this)(Scalar(BitVecLiteral('\n'.toInt, 64))) + case _ => State.setError(Errored(s"Call undefined intrinsic $name")) } } From aac7f5ce7f0d88634f8e16e3c6d63390352c9c12 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Mon, 23 Sep 2024 16:04:32 +1000 Subject: [PATCH 57/62] cleanup --- build.sc | 1 + src/main/scala/ir/eval/ExprEval.scala | 1 + .../scala/ir/eval/InterpretBreakpoints.scala | 2 +- src/main/scala/ir/eval/Interpreter.scala | 20 +-- .../scala/ir/eval/InterpreterProduct.scala | 4 +- src/test/scala/ir/InterpreterTests.scala | 116 ++---------------- 6 files changed, 26 insertions(+), 118 deletions(-) diff --git a/build.sc b/build.sc index 59223e66d..1298673a8 100644 --- a/build.sc +++ b/build.sc @@ -22,6 +22,7 @@ object basil extends RootModule with ScalaModule with antlr.AntlrModule with Sca def scalaPBVersion = "0.11.15" + def scalacOptions = Seq("-deprecation", "-unchecked", "-feature") def mainClass = Some("Main") diff --git a/src/main/scala/ir/eval/ExprEval.scala b/src/main/scala/ir/eval/ExprEval.scala index 53e79d11b..34b9e3b40 100644 --- a/src/main/scala/ir/eval/ExprEval.scala +++ b/src/main/scala/ir/eval/ExprEval.scala @@ -89,6 +89,7 @@ def evalUnOp(op: UnOp, body: Literal): Expr = { case (i: IntLiteral, IntNEG) => IntLiteral(-i.value) case (FalseLiteral, BoolNOT) => TrueLiteral case (TrueLiteral, BoolNOT) => FalseLiteral + case (_, _) => throw Exception(s"Unreachable ${(body, op)}") } } diff --git a/src/main/scala/ir/eval/InterpretBreakpoints.scala b/src/main/scala/ir/eval/InterpretBreakpoints.scala index c8613dada..ced1eecb8 100644 --- a/src/main/scala/ir/eval/InterpretBreakpoints.scala +++ b/src/main/scala/ir/eval/InterpretBreakpoints.scala @@ -27,7 +27,7 @@ case class BreakPointAction( case class BreakPoint(name: String = "", location: BreakPointLoc, action: BreakPointAction) -case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](val f: I, val breaks: List[BreakPoint]) +case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](f: I, breaks: List[BreakPoint]) extends NopEffects[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), InterpreterError] { def findBreaks[R](c: Command): State[(T, R), List[BreakPoint], InterpreterError] = { diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index b57bb5fc1..6f948269f 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -17,22 +17,22 @@ import scala.util.control.Breaks.{break, breakable} sealed trait ExecutionContinuation case class Stopped() extends ExecutionContinuation /* normal program stop */ case class ErrorStop(error: InterpreterError) extends ExecutionContinuation /* program stop in error state */ -case class Run(val next: Command) extends ExecutionContinuation /* continue by executing next command */ -case class Intrinsic(val name: String) extends ExecutionContinuation /* a named intrinsic instruction */ +case class Run(next: Command) extends ExecutionContinuation /* continue by executing next command */ +case class Intrinsic(name: String) extends ExecutionContinuation /* a named intrinsic instruction */ sealed trait InterpreterError case class FailedAssertion(a: Assert) extends InterpreterError -case class EscapedControlFlow(val call: Jump | Call) +case class EscapedControlFlow(call: Jump | Call) extends InterpreterError /* controlflow has reached somewhere eunrecoverable */ -case class Errored(val message: String = "") extends InterpreterError -case class TypeError(val message: String = "") extends InterpreterError /* type mismatch appeared */ -case class EvalError(val message: String = "") +case class Errored(message: String = "") extends InterpreterError +case class TypeError(message: String = "") extends InterpreterError /* type mismatch appeared */ +case class EvalError(message: String = "") extends InterpreterError /* failed to evaluate an expression to a concrete value */ -case class MemoryError(val message: String = "") extends InterpreterError /* An error to do with memory */ +case class MemoryError(message: String = "") extends InterpreterError /* An error to do with memory */ /* Concrete value type of the interpreter. */ sealed trait BasilValue(val irType: Option[IRType]) -case class Scalar(val value: Literal) extends BasilValue(Some(value.getType)) { +case class Scalar(value: Literal) extends BasilValue(Some(value.getType)) { override def toString = value match { case b: BitVecLiteral => "0x%x:bv%d".format(b.value, b.size) case c => c.toString @@ -40,7 +40,7 @@ case class Scalar(val value: Literal) extends BasilValue(Some(value.getType)) { } /* Abstract callable function address */ -case class FunPointer(val addr: BitVecLiteral, val name: String, val call: ExecutionContinuation) +case class FunPointer(addr: BitVecLiteral, name: String, call: ExecutionContinuation) extends BasilValue(Some(addr.getType)) sealed trait MapValue { @@ -51,7 +51,7 @@ sealed trait MapValue { \exists i . \forall v \in value.keys , v.irType = i and \exists j . \forall v \in value.values, v.irType = j */ -case class BasilMapValue(val value: Map[BasilValue, BasilValue], val mapType: MapType) +case class BasilMapValue(value: Map[BasilValue, BasilValue], mapType: MapType) extends MapValue with BasilValue(Some(mapType)) { override def toString = s"MapValue : $irType" diff --git a/src/main/scala/ir/eval/InterpreterProduct.scala b/src/main/scala/ir/eval/InterpreterProduct.scala index 37e1820a4..ccf4fb682 100644 --- a/src/main/scala/ir/eval/InterpreterProduct.scala +++ b/src/main/scala/ir/eval/InterpreterProduct.scala @@ -29,7 +29,7 @@ def doRight[L, T, V, E](f: State[T, V, E]): State[(L, T), V, E] = for { /** Runs two interpreters "inner" and "before" simultaneously, returning the value from inner, and ignoring before */ -case class ProductInterpreter[L, T, E](val inner: Effects[L, E], val before: Effects[T, E]) extends Effects[(L, T), E] { +case class ProductInterpreter[L, T, E](inner: Effects[L, E], before: Effects[T, E]) extends Effects[(L, T), E] { def loadVar(v: String) = for { n <- doRight(before.loadVar(v)) @@ -83,7 +83,7 @@ case class ProductInterpreter[L, T, E](val inner: Effects[L, E], val before: Eff } yield (f) } -case class LayerInterpreter[L, T, E](val inner: Effects[L, E], val before: Effects[(L, T), E]) +case class LayerInterpreter[L, T, E](inner: Effects[L, E], before: Effects[(L, T), E]) extends Effects[(L, T), E] { def loadVar(v: String) = for { diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index e4cc8dbf7..9fa6fbafe 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -3,6 +3,7 @@ package ir import util.PerformanceTimer import util.functional._ import ir.eval._ +import boogie.Scope import ir.dsl._ import org.scalatest.funsuite.AnyFunSuite import org.scalatest.BeforeAndAfter @@ -37,9 +38,10 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { Logger.setLevel(LogLevel.WARN) def getProgram(name: String): IRContext = { + val compiler = "gcc" val loading = ILLoadingConfig( - inputFile = s"examples/$name/$name.adt", - relfFile = s"examples/$name/$name.relf", + inputFile = s"src/test/correct/$name/$compiler/$name.adt", + relfFile = s"src/test/correct/$name/$compiler/$name.relf", specFile = None, dumpIL = None ) @@ -97,15 +99,16 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { assert(s.memoryState.getVarOpt("R29").isDefined) } + test("var load store") { val s = for { s <- InterpFuns.initialState(NormalInterpreter) - v <- NormalInterpreter.loadVar("R31") + v <- NormalInterpreter.storeVar("R1", Scope.Global, Scalar(BitVecLiteral(1024, 64))) + v <- NormalInterpreter.loadVar("R1") } yield (v) val l = State.evaluate(InterpreterState(), s) - assert(l == Right(Scalar(BitVecLiteral(4096 - 16, 64)))) - + assert(l == Right(Scalar(BitVecLiteral(1024, 64)))) } test("Store = Load LittleEndian") { @@ -129,87 +132,6 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { } -// test("store bv = loadbv le") { -// val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) -// val s2 = Eval.storeBV(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) -// val actual2: BitVecLiteral = Eval.loadBV(s2, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) -// assert(actual2 == expected) -// } -// -// -// test("Store = Load BigEndian") { -// val ts = List( -// BitVecLiteral(BigInt("0D", 16), 8), -// BitVecLiteral(BigInt("0C", 16), 8), -// BitVecLiteral(BigInt("0B", 16), 8), -// BitVecLiteral(BigInt("0A", 16), 8)) -// -// val s = Eval.store(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) -// val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) -// val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.BigEndian , 32) -// assert(actual == expected) -// -// -// } -// -// test("getMemory in LittleEndian") { -// val ts = List((BitVecLiteral(0, 64), BitVecLiteral(BigInt("0D", 16), 8)), -// (BitVecLiteral(1, 64) , BitVecLiteral(BigInt("0C", 16), 8)), -// (BitVecLiteral(2, 64) , BitVecLiteral(BigInt("0B", 16), 8)), -// (BitVecLiteral(3, 64) , BitVecLiteral(BigInt("0A", 16), 8))) -// val s = ts.foldLeft(initialMem())((m, v) => Eval.storeSingle(m, "mem", Scalar(v._1), Scalar(v._2))) -// // val s = initialMem().store("mem") -// // val r = s.loadBV("mem", BitVecLiteral(0, 64)) -// -// val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) -// -// // def loadBV(vname: String, addr: Scalar, endian: Endian, size: Int): BitVecLiteral = { -// val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) -// assert(actual == expected) -// } -// -// -// test("StoreBV = LoadBV LE ") { -// val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) -// -// val s = Eval.storeBV(initialMem(), "mem", Scalar(BitVecLiteral(0, 64)), expected, Endian.LittleEndian) -// val actual: BitVecLiteral = Eval.loadBV(s, "mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) -// println(s"${actual.value.toInt.toHexString} == ${expected.value.toInt.toHexString}") -// assert(actual == expected) -// } -// -// // test("getMemory in BigEndian") { -// // i.mems(0) = BitVecLiteral(BigInt("0A", 16), 8) -// // i.mems(1) = BitVecLiteral(BigInt("0B", 16), 8) -// // i.mems(2) = BitVecLiteral(BigInt("0C", 16), 8) -// // i.mems(3) = BitVecLiteral(BigInt("0D", 16), 8) -// // val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) -// // val actual: BitVecLiteral = i.getMemory(0, 32, Endian.BigEndian, i.mems) -// // assert(actual == expected) -// // } -// -// // test("setMemory in LittleEndian") { -// // i.mems(0) = BitVecLiteral(BigInt("FF", 16), 8) -// // i.mems(1) = BitVecLiteral(BigInt("FF", 16), 8) -// // i.mems(2) = BitVecLiteral(BigInt("FF", 16), 8) -// // i.mems(3) = BitVecLiteral(BigInt("FF", 16), 8) -// // val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) -// // i.setMemory(0, 32, Endian.LittleEndian, expected, i.mems) -// // val actual: BitVecLiteral = i.getMemory(0, 32, Endian.LittleEndian, i.mems) -// // assert(actual == expected) -// // } -// -// // test("setMemory in BigEndian") { -// // i.mems(0) = BitVecLiteral(BigInt("FF", 16), 8) -// // i.mems(1) = BitVecLiteral(BigInt("FF", 16), 8) -// // i.mems(2) = BitVecLiteral(BigInt("FF", 16), 8) -// // i.mems(3) = BitVecLiteral(BigInt("FF", 16), 8) -// // val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) -// // i.setMemory(0, 32, Endian.BigEndian, expected, i.mems) -// // val actual: BitVecLiteral = i.getMemory(0, 32, Endian.BigEndian, i.mems) -// // assert(actual == expected) -// // } -// test("basic_arrays_read") { val expected = Map( "arr" -> 0 @@ -231,21 +153,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { testInterpret("basic_assign_increment", expected) } - test("basic_loop_loop") { - val expected = Map( - "x" -> 10 - ) - testInterpret("basic_loop_loop", expected) - } - test("basicassign") { - val expected = Map( - "x" -> 0, - "z" -> 0, - "secret" -> 0 - ) - testInterpret("basicassign", expected) - } test("function") { val expected = Map( @@ -274,7 +182,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { test("indirect_call") { val expected = Map[String, Int]() - testInterpret("indirect_call_outparam", expected) + testInterpret("indirect_call", expected) } test("ifglobal") { @@ -381,7 +289,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { Logger.setLevel(LogLevel.ERROR) var res = List[(Int, Double, Double, Int)]() - for (i <- 0 to 30) { + for (i <- 0 to 20) { val prog = fibonacciProg(i) val t = PerformanceTimer("native") @@ -394,10 +302,9 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { res = (i, native, it, ir._2) :: res - println(s"${res.head}") } - println(("fib number,native time,interp time,interp cycles" :: (res.map(x => s"${x._1},${x._2},${x._3},${x._4}"))).mkString("\n")) + info(("fibonacci runtime table:\nFibNumber,ScalaRunTime,interpreterRunTime,instructionCycleCount" :: (res.map(x => s"${x._1},${x._2},${x._3},${x._4}"))).mkString("\n")) } @@ -436,7 +343,6 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { ) val ir = interpret(tp) - println(ir) assert(ir.nextCmd.isInstanceOf[ErrorStop]) } From 46ff66a12686ded86aa208b38cffa84af97d2660 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Mon, 23 Sep 2024 17:24:34 +1000 Subject: [PATCH 58/62] cleanup --- src/main/scala/ir/eval/{Bitvector.scala => BitVectorEval.scala} | 0 src/test/scala/ir/IRTest.scala | 2 +- src/test/scala/ir/InterpreterTests.scala | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename src/main/scala/ir/eval/{Bitvector.scala => BitVectorEval.scala} (100%) diff --git a/src/main/scala/ir/eval/Bitvector.scala b/src/main/scala/ir/eval/BitVectorEval.scala similarity index 100% rename from src/main/scala/ir/eval/Bitvector.scala rename to src/main/scala/ir/eval/BitVectorEval.scala diff --git a/src/test/scala/ir/IRTest.scala b/src/test/scala/ir/IRTest.scala index 855c86d98..854389d90 100644 --- a/src/test/scala/ir/IRTest.scala +++ b/src/test/scala/ir/IRTest.scala @@ -135,8 +135,8 @@ class IRTest extends AnyFunSuite { assert(1 == aftercallGotos.count(b => IntraProcIRCursor.pred(b).contains(blocks("l_main_1").jump))) assert(1 == aftercallGotos.count(b => IntraProcIRCursor.succ(b).contains(blocks("l_main_1").jump match { case GoTo(targets, _) => targets.head + case _ => throw Exception("unreachable") }))) - } test("addblocks") { diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index ee1328b51..912fc6c37 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -299,7 +299,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { val ir = interpretRLimit(prog, 100000000) val it = intt.elapsed() - res = (i, native, it, ir._2) :: res + res = (i, native.toDouble, it.toDouble, ir._2) :: res } From 1a8067c689d49730fe2278b721abc9f0012e904d Mon Sep 17 00:00:00 2001 From: l-kent Date: Fri, 8 Nov 2024 14:10:48 +1000 Subject: [PATCH 59/62] make compiler options consistent --- build.sbt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.sbt b/build.sbt index f1e6339ff..bef063b83 100644 --- a/build.sbt +++ b/build.sbt @@ -29,7 +29,7 @@ lazy val root = project libraryDependencies += "org.scalameta" %% "munit" % "0.7.29" % Test ) -scalacOptions ++= Seq("-deprecation", "-feature") +scalacOptions ++= Seq("-deprecation", "-unchecked", "-feature") Compile / PB.targets := Seq( scalapb.gen() -> (Compile / sourceManaged).value / "scalapb" From 32bf4395975f73b67476d591a9d9837443bb9ce3 Mon Sep 17 00:00:00 2001 From: l-kent Date: Fri, 8 Nov 2024 14:27:51 +1000 Subject: [PATCH 60/62] fix use of deprecated examples --- src/test/scala/InterpretTestConstProp.scala | 95 +++++++++------------ 1 file changed, 39 insertions(+), 56 deletions(-) diff --git a/src/test/scala/InterpretTestConstProp.scala b/src/test/scala/InterpretTestConstProp.scala index d3a13afae..0abc7dcd6 100644 --- a/src/test/scala/InterpretTestConstProp.scala +++ b/src/test/scala/InterpretTestConstProp.scala @@ -23,9 +23,12 @@ class ConstPropInterpreterValidate extends AnyFunSuite { Logger.setLevel(LogLevel.ERROR) - def testInterpretConstProp(testName: String, examplePath: String) = { - val loading = ILLoadingConfig(inputFile = examplePath + testName + ".adt", - relfFile = examplePath + testName + ".relf", + def testInterpretConstProp(name: String, variation: String, path: String): Unit = { + val directoryPath = path + name + "/" + val variationPath = directoryPath + variation + "/" + name + val loading = ILLoadingConfig( + inputFile = variationPath + ".adt", + relfFile = variationPath + ".relf", dumpIL = None, ) @@ -33,12 +36,12 @@ class ConstPropInterpreterValidate extends AnyFunSuite { ictx = IRTransform.doCleanup(ictx) val analysisres = RunUtils.staticAnalysis(StaticAnalysisConfig(None, None, None), ictx) - val breaks : List[BreakPoint] = analysisres.constPropResult.collect { + val breaks: List[BreakPoint] = analysisres.constPropResult.collect { // convert analysis result to a list of breakpoints, each which evaluates an expression describing // the invariant inferred by the analysis (the assignment of registers) at a corresponding program point - case (command: Command, v) => { - val expectedPredicates : List[(String, Expr)] = v.toList.map(r => { + case (command: Command, v) => + val expectedPredicates: List[(String, Expr)] = v.toList.map { r => val (variable, value) = r val assertion = value match { case Top => TrueLiteral @@ -46,92 +49,72 @@ class ConstPropInterpreterValidate extends AnyFunSuite { case FlatEl(value) => BinaryExpr(BVEQ, variable, value) } (variable.name, assertion) - }) - BreakPoint(location=BreakPointLoc.CMD(command), BreakPointAction(saveState=false,evalExprs=expectedPredicates)) - } + } + BreakPoint( + location = BreakPointLoc.CMD(command), + BreakPointAction(saveState = false, evalExprs = expectedPredicates) + ) }.toList assert(breaks.nonEmpty) // run the interpreter evaluating the analysis result at each command with a breakpoint val interpretResult = interpretWithBreakPoints(ictx, breaks.toList, NormalInterpreter, InterpreterState()) - val breakres : List[(BreakPoint, _, List[(String, Expr, Expr)])] = interpretResult._2 + val breakres: List[(BreakPoint, _, List[(String, Expr, Expr)])] = interpretResult._2 assert(interpretResult._1.nextCmd == Stopped()) assert(breakres.nonEmpty) // assert all the collected breakpoint watches have evaluated to true for (b <- breakres) { val (_, _, evaluatedexprs) = b - evaluatedexprs.forall(c => { + evaluatedexprs.forall { c => val (n, before, evaled) = c evaled == TrueLiteral - }) + } } } - test("indirect_call_example") { - val testName = "indirect_call" - val examplePath = System.getProperty("user.dir") + s"/examples/$testName/" - testInterpretConstProp(testName, examplePath) + test("indirect_call/gcc_pic:BAP") { + testInterpretConstProp("indirect_call", "gcc_pic", "./src/test/indirect_calls/") } - test("indirect_call_gcc_example") { - val testName = "indirect_call" - val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/gcc/" - testInterpretConstProp(testName, examplePath) + test("indirect_call/gcc:BAP") { + testInterpretConstProp("indirect_call", "gcc", "./src/test/indirect_calls/") } - test("indirect_call_clang_example") { - val testName = "indirect_call" - val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/clang/" - testInterpretConstProp(testName, examplePath) + test("indirect_call/clang:BAP") { + testInterpretConstProp("indirect_call", "clang", "./src/test/indirect_calls/") } - test("jumptable2_example") { - val testName = "jumptable2" - val examplePath = System.getProperty("user.dir") + s"/examples/$testName/" - testInterpretConstProp(testName, examplePath) + test("jumptable2/gcc_pic:BAP") { + testInterpretConstProp("jumptable2", "gcc_pic", "./src/test/indirect_calls/") } - test("jumptable2_gcc_example") { - val testName = "jumptable2" - val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/gcc/" - testInterpretConstProp(testName, examplePath) + test("jumptable2/gcc:BAP") { + testInterpretConstProp("jumptable2", "gcc", "./src/test/indirect_calls/") } - test("jumptable2_clang_example") { - val testName = "jumptable2" - val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/clang/" - testInterpretConstProp(testName, examplePath) + test("jumptable2/clang:BAP") { + testInterpretConstProp("jumptable2", "clang", "./src/test/indirect_calls/") } - test("functionpointer_example") { - val testName = "functionpointer" - val examplePath = System.getProperty("user.dir") + s"/examples/$testName/" - testInterpretConstProp(testName, examplePath) + test("functionpointer/gcc_pic:BAP") { + testInterpretConstProp("functionpointer", "gcc_pic", "./src/test/indirect_calls/") } - test("functionpointer_gcc_example") { - val testName = "functionpointer" - val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/gcc/" - testInterpretConstProp(testName, examplePath) + test("functionpointer/gcc:BAP") { + testInterpretConstProp("functionpointer", "gcc", "./src/test/indirect_calls/") } - test("functionpointer_clang_example") { - val testName = "functionpointer" - val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/clang/" - testInterpretConstProp(testName, examplePath) + test("functionpointer/clang:BAP") { + testInterpretConstProp("functionpointer", "clang", "./src/test/indirect_calls/") } - test("secret_write_clang") { - val testName = "secret_write" - val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/clang/" - testInterpretConstProp(testName, examplePath) + test("secret_write/clang:BAP") { + testInterpretConstProp("secret_write", "clang", "./src/test/correct/") } - test("secret_write_gcc") { - val testName = "secret_write" - val examplePath = System.getProperty("user.dir") + s"/src/test/correct/$testName/gcc/" - testInterpretConstProp(testName, examplePath) + test("secret_write/gcc:BAP") { + testInterpretConstProp("secret_write", "gcc", "./src/test/correct/") } } From 9cc3c31b0139cea288105e79e9f61bbbfd0e7792 Mon Sep 17 00:00:00 2001 From: l-kent Date: Fri, 8 Nov 2024 16:03:33 +1000 Subject: [PATCH 61/62] fix use of deprecated examples --- src/test/scala/DifferentialAnalysis.scala | 72 +++++++++------------ src/test/scala/InterpretTestConstProp.scala | 3 +- src/test/scala/ir/InterpreterTests.scala | 34 +++++----- 3 files changed, 48 insertions(+), 61 deletions(-) diff --git a/src/test/scala/DifferentialAnalysis.scala b/src/test/scala/DifferentialAnalysis.scala index 99b270aae..6062d2db5 100644 --- a/src/test/scala/DifferentialAnalysis.scala +++ b/src/test/scala/DifferentialAnalysis.scala @@ -7,7 +7,7 @@ import org.scalatest.funsuite.* import specification.* import util.{BASILConfig, IRLoading, ILLoadingConfig, IRContext, RunUtils, StaticAnalysis, StaticAnalysisConfig, StaticAnalysisContext, BASILResult, Logger, LogLevel, IRTransform} import ir.eval.* -import test_util.* +import test_util.BASILTest.getSubdirectories import java.io.IOException import java.nio.file.* @@ -29,7 +29,7 @@ class DifferentialAnalysis extends AnyFunSuite { val interpreter = LayerInterpreter(tracingInterpreter(NormalInterpreter), EffectsRLimit(instructionLimit)) val initialState = InterpFuns.initProgState(NormalInterpreter)(p, InterpreterState()) //Logger.setLevel(LogLevel.DEBUG) - val r = BASILInterpreter(interpreter).run((initialState, Trace(List())), 0)._1 + val (r, _) = BASILInterpreter(interpreter).run((initialState, Trace(List())), 0) //Logger.setLevel(LogLevel.WARN) r } @@ -46,22 +46,23 @@ class DifferentialAnalysis extends AnyFunSuite { } Logger.info(traceInit.t.map(_.toString.take(80)).mkString("\n")) - val initstdout = initialRes.memoryState.getMem("stdout").toList.sortBy(_._1.value).map(_._2.value.toChar).mkString("") - val comparstdout = result.memoryState.getMem("stdout").toList.sortBy(_._1.value).map(_._2.value.toChar).mkString("") - info("STDOUT: \"" + initstdout + "\"") - // Logger.info(initialRes.memoryState.getMem("stderr").toList.sortBy(_._1.value).map(_._2).mkString("")) + val initstdout = initialRes.memoryState.getMem("stdout") + val comparstdout = result.memoryState.getMem("stdout") + val text = initstdout.toList.sortBy(_._1.value).map(_._2.value.toChar).mkString("") + info("STDOUT: \"" + text + "\"") assert(initstdout == comparstdout) assert(initialRes.nextCmd == Stopped()) assert(result.nextCmd == Stopped()) - assert(Set.empty == initialRes.memoryState.getMem("mem").toSet.diff(result.memoryState.getMem("mem").toSet)) assert(traceInit.t.nonEmpty) assert(traceRes.t.nonEmpty) assert(filterEvents(traceInit.t).mkString("\n") == filterEvents(traceRes.t).mkString("\n")) } - def testProgram(testName: String, examplePath: String, suffix: String = ".adt"): Unit = { - val loading = ILLoadingConfig(inputFile = examplePath + testName + suffix, - relfFile = examplePath + testName + ".relf", + def testProgram(name: String, variation: String, path: String): Unit = { + val variationPath = path + name + "/" + variation + "/" + name + val loading = ILLoadingConfig( + inputFile = variationPath + ".adt", + relfFile = variationPath + ".relf", dumpIL = None, ) @@ -75,54 +76,41 @@ class DifferentialAnalysis extends AnyFunSuite { diffTest(ictx, comparectx) } - test("indirect_call_example") { - val testName = "indirect_call" - val examplePath = System.getProperty("user.dir") + s"/examples/$testName/" - testProgram(testName, examplePath) + test("indirect_calls/indirect_call/gcc_pic:BAP") { + testProgram("indirect_call", "gcc_pic", "./src/test/indirect_calls/") } - test("jumptable2_example") { - val testName = "jumptable2" - val examplePath = System.getProperty("user.dir") + s"/examples/$testName/" - testProgram(testName, examplePath) + test("indirect_calls/jumptable2/gcc_pic:BAP") { + testProgram("jumptable2", "gcc_pic", "./src/test/indirect_calls/") } - test("jumptable_example") { - val testName = "jumptable" - val examplePath = System.getProperty("user.dir") + s"/examples/$testName/" - testProgram(testName, examplePath) + test("indirect_calls/jumptable/gcc:BAP") { + testProgram("jumptable", "gcc", "./src/test/indirect_calls/") } - test("functionpointer_example") { - val testName = "functionpointer" - val examplePath = System.getProperty("user.dir") + s"/examples/$testName/" - testProgram(testName, examplePath) + test("functionpointer/gcc_pic:BAP") { + testProgram("functionpointer", "gcc_pic", "./src/test/indirect_calls/") } - test("function_got_example") { - val testName = "function_got" - val examplePath = System.getProperty("user.dir") + s"/examples/$testName/" - testProgram(testName, examplePath) - } - - def runSystemTests(): Unit = { + def runTests(): Unit = { val path = System.getProperty("user.dir") + s"/src/test/correct/" - val programs: Array[String] = BASILTest.getSubdirectories(path) + val programs = getSubdirectories(path) // get all variations of each program for (p <- programs) { val programPath = path + "/" + p - val variations = BASILTest.getSubdirectories(programPath) - variations.foreach { variation => - test("analysis_differential:" + p + "/" + variation + ":BAP") { - testProgram(p, path + "/" + p + "/" + variation + "/", suffix = ".adt") + val variations = getSubdirectories(programPath) + variations.foreach { t => + val variationPath = programPath + "/" + t + "/" + p + val inputPath = variationPath + ".adt" + if (File(inputPath).exists) { + test("correct" + "/" + p + "/" + t + ":BAP") { + testProgram(p, t, path) + } } - //test("analysis_differential:" + p + "/" + variation + ":GTIRB") { - // testProgram(p, path + "/" + p + "/" + variation + "/", suffix=".gts") - //} } } } - runSystemTests() + runTests() } diff --git a/src/test/scala/InterpretTestConstProp.scala b/src/test/scala/InterpretTestConstProp.scala index 0abc7dcd6..49f4fcd48 100644 --- a/src/test/scala/InterpretTestConstProp.scala +++ b/src/test/scala/InterpretTestConstProp.scala @@ -24,8 +24,7 @@ class ConstPropInterpreterValidate extends AnyFunSuite { Logger.setLevel(LogLevel.ERROR) def testInterpretConstProp(name: String, variation: String, path: String): Unit = { - val directoryPath = path + name + "/" - val variationPath = directoryPath + variation + "/" + name + val variationPath = path + name + "/" + variation + "/" + name val loading = ILLoadingConfig( inputFile = variationPath + ".adt", relfFile = variationPath + ".relf", diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index c8f7a6b20..6d206bfaf 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -37,11 +37,11 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { Logger.setLevel(LogLevel.WARN) - def getProgram(name: String): IRContext = { + def getProgram(name: String, folder: String): IRContext = { val compiler = "gcc" val loading = ILLoadingConfig( - inputFile = s"src/test/correct/$name/$compiler/$name.adt", - relfFile = s"src/test/correct/$name/$compiler/$name.relf", + inputFile = s"src/test/$folder/$name/$compiler/$name.adt", + relfFile = s"src/test/$folder/$name/$compiler/$name.relf", specFile = None, dumpIL = None ) @@ -61,8 +61,8 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { ctx } - def testInterpret(name: String, expected: Map[String, Int]): Unit = { - val ctx = getProgram(name) + def testInterpret(name: String, folder: String, expected: Map[String, Int]): Unit = { + val ctx = getProgram(name, folder) val fstate = interpret(ctx) val regs = fstate.memoryState.getGlobalVals val globals = ctx.globals @@ -136,21 +136,21 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { val expected = Map( "arr" -> 0 ) - testInterpret("basic_arrays_read", expected) + testInterpret("basic_arrays_read", "correct", expected) } test("basic_assign_assign") { val expected = Map( "x" -> 5 ) - testInterpret("basic_assign_assign", expected) + testInterpret("basic_assign_assign", "correct", expected) } test("basic_assign_increment") { val expected = Map( "x" -> 1 ) - testInterpret("basic_assign_increment", expected) + testInterpret("basic_assign_increment", "correct", expected) } @@ -159,7 +159,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { "x" -> 1, "y" -> 2 ) - testInterpret("function", expected) + testInterpret("function", "correct", expected) } test("function1") { @@ -167,7 +167,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { "x" -> 1, "y" -> 1410065515 // 10000000107 % 2147483648 = 1410065515 ) - testInterpret("function1", expected) + testInterpret("function1", "correct", expected) } test("secret_write") { @@ -176,19 +176,19 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { "x" -> 0, "secret" -> 0 ) - testInterpret("secret_write", expected) + testInterpret("secret_write", "correct", expected) } test("indirect_call") { val expected = Map[String, Int]() - testInterpret("indirect_call", expected) + testInterpret("indirect_call", "indirect_calls", expected) } test("ifglobal") { val expected = Map( "x" -> 1 ) - testInterpret("ifglobal", expected) + testInterpret("ifglobal", "correct", expected) } test("cjump") { @@ -196,7 +196,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { "x" -> 1, "y" -> 3 ) - testInterpret("cjump", expected) + testInterpret("cjump", "correct", expected) } test("initialisation") { @@ -207,21 +207,21 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { "y" -> ('b'.toInt) ) - testInterpret("initialisation", expected) + testInterpret("initialisation", "correct", expected) } test("no_interference_update_x") { val expected = Map( "x" -> 1 ) - testInterpret("no_interference_update_x", expected) + testInterpret("no_interference_update_x", "correct", expected) } test("no_interference_update_y") { val expected = Map( "y" -> 1 ) - testInterpret("no_interference_update_y", expected) + testInterpret("no_interference_update_y", "correct", expected) } def fib(n: Int): Int = { From fe3678facf0395022fb005811195bd2a4f7bd22c Mon Sep 17 00:00:00 2001 From: l-kent Date: Mon, 11 Nov 2024 09:54:33 +1000 Subject: [PATCH 62/62] give State filename consistent with Package, formatting, clean up imports --- src/main/scala/ir/eval/ExprEval.scala | 2 +- src/main/scala/ir/eval/InterpretBasilIR.scala | 128 +++++------ .../scala/ir/eval/InterpretBreakpoints.scala | 93 ++++---- src/main/scala/ir/eval/InterpretRLimit.scala | 29 +-- src/main/scala/ir/eval/InterpretTrace.scala | 39 ++-- src/main/scala/ir/eval/Interpreter.scala | 213 ++++++++---------- .../scala/ir/eval/InterpreterProduct.scala | 119 +++++----- .../State.scala} | 44 ++-- 8 files changed, 302 insertions(+), 365 deletions(-) rename src/main/scala/util/{functional.scala => functional/State.scala} (59%) diff --git a/src/main/scala/ir/eval/ExprEval.scala b/src/main/scala/ir/eval/ExprEval.scala index 34b9e3b40..6a9fb8829 100644 --- a/src/main/scala/ir/eval/ExprEval.scala +++ b/src/main/scala/ir/eval/ExprEval.scala @@ -1,7 +1,7 @@ package ir.eval import ir.eval.BitVectorEval import util.functional.State -import ir._ +import ir.* /** We generalise the expression evaluator to a partial evaluator to simplify evaluating casts. * diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala index a544e127a..6554b71c9 100644 --- a/src/main/scala/ir/eval/InterpretBasilIR.scala +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -1,19 +1,9 @@ package ir.eval -import ir._ -import ir.eval.BitVectorEval.* import ir.* import util.IRContext import util.Logger import util.functional.* -import util.functional.State.* import boogie.Scope -import collection.mutable.ArrayBuffer - -import scala.annotation.tailrec -import scala.collection.mutable -import scala.collection.immutable -import scala.util.control.Breaks.{break, breakable} -import translating.ELFSymbol /** Abstraction for memload and variable lookup used by the expression evaluator. */ @@ -22,10 +12,12 @@ case class StVarLoader[S, F <: Effects[S, InterpreterError]](f: F) extends Loade def getVariable(v: Variable): State[S, Option[Literal], InterpreterError] = { for { v <- f.loadVar(v.name) - } yield ((v match { - case Scalar(l) => Some(l) - case _ => None - })) + } yield { + v match { + case Scalar(l) => Some(l) + case _ => None + } + } } override def loadMemory( @@ -46,9 +38,9 @@ case class StVarLoader[S, F <: Effects[S, InterpreterError]](f: F) extends Loade } ) case l: Literal => Eval.loadBV(f)(m.name, Scalar(l), endian, size).map(Some(_)) - case _ => get((s: S) => None) + case _ => State.get((s: S) => None) } - } yield (r) + } yield r } } @@ -67,7 +59,7 @@ case object Eval { val ldr = StVarLoader[S, T](f) for { res <- ir.eval.statePartialEvalExpr[S](ldr)(e) - } yield (res) + } yield res } def evalBV[S, T <: Effects[S, InterpreterError]](f: T)(e: Expr): State[S, BitVecLiteral, InterpreterError] = { @@ -77,7 +69,7 @@ case object Eval { case l: BitVecLiteral => Right(l) case _ => Left((Errored(s"Eval BV residual $e"))) }) - } yield (r) + } yield r } def evalInt[S, T <: Effects[S, InterpreterError]](f: T)(e: Expr): State[S, BigInt, InterpreterError] = { @@ -87,7 +79,7 @@ case object Eval { case l: IntLiteral => Right(l.value) case _ => Left((Errored(s"Eval Int residual $e"))) }) - } yield (r) + } yield r } def evalBool[S, T <: Effects[S, InterpreterError]](f: T)(e: Expr): State[S, Boolean, InterpreterError] = { @@ -97,7 +89,7 @@ case object Eval { case l: BoolLit => Right(l == TrueLiteral) case _ => Left((Errored(s"Eval Bool residual $e"))) }) - } yield (r) + } yield r } /*--------------------------------------------------------------------------------*/ @@ -109,8 +101,8 @@ case object Eval { )(vname: String, addr: Scalar, endian: Endian, count: Int): State[S, List[BasilValue], InterpreterError] = { for { _ <- - if (count == 0) then State.setError((Errored(s"Attempted fractional load"))) else State.pure(()) - keys <- State.mapM(((i: Int) => State.pureE(BasilValue.unsafeAdd(addr, i))), (0 until count)) + if count == 0 then State.setError((Errored(s"Attempted fractional load"))) else State.pure(()) + keys <- State.mapM((i: Int) => State.pureE(BasilValue.unsafeAdd(addr, i)), 0 until count) values <- f.loadMem(vname, keys.toList) vals = endian match { case Endian.LittleEndian => values.reverse @@ -133,7 +125,7 @@ case object Eval { cells = size / valsize res <- load(f)(vname, addr, endian, cells) // actual load - bvs: List[BitVecLiteral] <- ( + bvs: List[BitVecLiteral] <- State.mapM( (c: BasilValue) => c match { @@ -145,15 +137,18 @@ case object Eval { }, res ) - ) - } yield (bvs.foldLeft(BitVecLiteral(0, 0))((acc, r) => eval.evalBVBinExpr(BVCONCAT, acc, r))) + } yield { + bvs.foldLeft(BitVecLiteral(0, 0))((acc, r) => eval.evalBVBinExpr(BVCONCAT, acc, r)) + } def loadSingle[S, T <: Effects[S, InterpreterError]]( f: T )(vname: String, addr: Scalar): State[S, BasilValue, InterpreterError] = { for { m <- load(f)(vname, addr, Endian.LittleEndian, 1) - } yield (m.head) + } yield { + m.head + } } /*--------------------------------------------------------------------------------*/ @@ -181,7 +176,7 @@ case object Eval { case Endian.BigEndian => values } x <- f.storeMem(vname, keys.zip(vals).toMap) - } yield (x) + } yield x /** Extract bitvec to bytes and store bytes */ def storeBV[S, T <: Effects[S, InterpreterError]](f: T)( @@ -218,7 +213,7 @@ case object Eval { keys <- State.mapM((i: Int) => State.pureE(BasilValue.unsafeAdd(addr, i)), (0 until cells)) s <- f.storeMem(vname, keys.zip(vs).toMap) - } yield (s) + } yield s def storeSingle[S, E, T <: Effects[S, E]]( f: T @@ -246,11 +241,11 @@ case object Eval { for { nsrc <- State.pureE(BasilValue.unsafeAdd(src, 1)) r <- getNullTerminatedString(f)(rgn, nsrc, acc.appended(b)) - } yield (r) + } yield r } case _ => State.setError(Errored(s"not byte $c")) } - } yield (res) + } yield res } class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S, InterpreterError](f) { @@ -259,18 +254,18 @@ class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S val next = for { next <- f.getNext _ <- State.pure(Logger.debug(s"$next")) - r: Boolean <- (next match { + r: Boolean <- next match { case Intrinsic(tgt) => LibcIntrinsic.intrinsics(tgt)(f).map(_ => true) case Run(c: Statement) => interpretStatement(f)(c).map(_ => true) case Run(c: Jump) => interpretJump(f)(c).map(_ => true) case Stopped() => State.pure(false) case ErrorStop(e) => State.pure(false) - }) - } yield (r) + } + } yield r - next.flatMapE((e: InterpreterError) => { + next.flatMapE { (e: InterpreterError) => f.setNext(ErrorStop(e)).map(_ => false) - }) + } } def interpretJump[S, T <: Effects[S, InterpreterError]](f: T)(j: Jump): State[S, Unit, InterpreterError] = { @@ -289,14 +284,14 @@ class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S } else { State.pure(()) } - chosen: List[Assume] <- filterM((a: Assume) => Eval.evalBool(f)(a.body), assumes) + chosen: List[Assume] <- State.filterM((a: Assume) => Eval.evalBool(f)(a.body), assumes) res <- chosen match { case Nil => State.setError(Errored(s"No jump target satisfied $gt")) case h :: Nil => f.setNext(Run(h)) case h :: tl => State.setError(Errored(s"More than one jump guard satisfied $gt")) } - } yield (res) + } yield res case r: Return => f.doReturn() case h: Unreachable => State.setError(EscapedControlFlow(h)) } @@ -309,7 +304,7 @@ class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S rhs <- Eval.evalBV(f)(assign.rhs) st <- f.storeVar(assign.lhs.name, assign.lhs.toBoogie.scope, Scalar(rhs)) n <- f.setNext(Run(s.successor)) - } yield (st) + } yield st } case assign: MemoryAssign => for { @@ -317,28 +312,28 @@ class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S value: BitVecLiteral <- Eval.evalBV(f)(assign.value) _ <- Eval.storeBV(f)(assign.mem.name, Scalar(index), value, assign.endian) n <- f.setNext(Run(s.successor)) - } yield (n) + } yield n case assert: Assert => for { b <- Eval.evalBool(f)(assert.body) _ <- - (if (!b) then { - State.setError(FailedAssertion(assert)) - } else { - f.setNext(Run(s.successor)) - }) + if (!b) { + State.setError(FailedAssertion(assert)) + } else { + f.setNext(Run(s.successor)) + } } yield () case assume: Assume => for { b <- Eval.evalBool(f)(assume.body) n <- - (if (!b) { + if (!b) { State.setError(Errored(s"Assumption not satisfied: $assume")) } else { f.setNext(Run(s.successor)) - }) - } yield (n) - case dc: DirectCall => { + } + } yield n + case dc: DirectCall => if (dc.target.entryBlock.isDefined) { val block = dc.target.entryBlock.get f.call(dc.target.name, Run(block.statements.headOption.getOrElse(block.jump)), Run(dc.successor)) @@ -347,8 +342,7 @@ class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S } else { State.setError(EscapedControlFlow(dc)) } - } - case ic: IndirectCall => { + case ic: IndirectCall => if (ic.target == Register("R30", 64)) { f.doReturn() } else { @@ -357,11 +351,10 @@ class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S fp <- f.evalAddrToProc(addr.value.toInt) _ <- fp match { case Some(fp) => f.call(fp.name, fp.call, Run(ic.successor)) - case none => State.setError(EscapedControlFlow(ic)) + case None => State.setError(EscapedControlFlow(ic)) } } yield () } - } case _: NOP => f.setNext(Run(s.successor)) } } @@ -416,23 +409,23 @@ object InterpFuns { // sort for deterministic trace val stores = fptrs - .sortBy(f => f._1) - .map((p) => { + .sortBy(f => f(0)) + .map { p => val (offset, fptr) = p Eval.storeSingle(s)("ghost-funtable", Scalar(fptr.addr), fptr) - >> (Eval.storeBV(s)( + >> Eval.storeBV(s)( "mem", Scalar(BitVecLiteral(offset, 64)), fptr.addr, Endian.LittleEndian - )) - }) + ) + } for { _ <- State.sequence(State.pure(()), stores) malloc_top = BitVecLiteral(newAddr() + 1024, 64) _ <- s.storeVar("ghost_malloc_top", Scope.Global, Scalar(malloc_top)) - } yield (()) + } yield () } /** Functions which compile BASIL IR down to the minimal interpreter effects. @@ -456,7 +449,7 @@ object InterpFuns { l <- s.storeVar("R1", Scope.Global, Scalar(BitVecLiteral(0, 64))) _ <- s.storeVar("ghost-funtable", Scope.Global, BasilMapValue(Map.empty, MapType(BitVecType(64), BitVecType(64)))) _ <- IntrinsicImpl.initFileGhostRegions(s) - } yield (l) + } yield l } def initialiseProgram[S, T <: Effects[S, InterpreterError]](f: T)(p: Program): State[S, Unit, InterpreterError] = { @@ -465,15 +458,15 @@ object InterpFuns { m <- State.sequence( State.pure(()), mems - .filter(m => m.address != 0 && m.bytes.size != 0) - .map(memory => + .filter(m => m.address != 0 && m.bytes.nonEmpty) + .map { memory => Eval.store(f)( mem, Scalar(BitVecLiteral(memory.address, 64)), memory.bytes.toList.map(Scalar(_)), Endian.BigEndian ) - ) + } ) } yield () } @@ -484,13 +477,13 @@ object InterpFuns { State.pure(Logger.debug("INITIALISE FUNCTION ADDRESSES")), p.procedures .filter(p => p.blocks.nonEmpty && p.address.isDefined) - .map((proc: Procedure) => + .map { (proc: Procedure) => Eval.storeSingle(f)( "ghost-funtable", Scalar(BitVecLiteral(proc.address.get, 64)), FunPointer(BitVecLiteral(proc.address.get, 64), proc.name, Run(IRWalk.firstInBlock(proc.entryBlock.get))) ) - ) + } ) _ <- State.pure(Logger.debug("INITIALISE MEMORY SECTIONS")) mem <- initMemory("mem", p.initialMemory.values) @@ -499,20 +492,19 @@ object InterpFuns { p.mainProcedure } r <- f.call(mainfun.name, Run(IRWalk.firstInBlock(mainfun.entryBlock.get)), Stopped()) - } yield (r) + } yield r } def initBSS[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext): State[S, Unit, InterpreterError] = { val bss = for { first <- p.symbols.find(s => s.name == "__bss_start__").map(_.value) last <- p.symbols.find(s => s.name == "__bss_end__").map(_.value) - r <- (if (first == last) then None else Some((first, (last - first) * 8))) + r <- if first == last then None else Some((first, (last - first) * 8)) (addr, sz) = r st = { (rgn => Eval.storeBV(f)(rgn, Scalar(BitVecLiteral(addr, 64)), BitVecLiteral(0, sz.toInt), Endian.LittleEndian)) } - - } yield (st) + } yield st bss match { case None => Logger.error("No BSS initialised"); State.pure(()) diff --git a/src/main/scala/ir/eval/InterpretBreakpoints.scala b/src/main/scala/ir/eval/InterpretBreakpoints.scala index ced1eecb8..1620e1b77 100644 --- a/src/main/scala/ir/eval/InterpretBreakpoints.scala +++ b/src/main/scala/ir/eval/InterpretBreakpoints.scala @@ -1,28 +1,18 @@ package ir.eval -import ir._ -import ir.eval.BitVectorEval.* import ir.* import util.Logger import util.IRContext import util.functional.* -import util.functional.State.* -import boogie.Scope -import scala.collection.WithFilter - -import scala.annotation.tailrec -import scala.collection.mutable -import scala.collection.immutable -import scala.util.control.Breaks.{break, breakable} enum BreakPointLoc: case CMD(c: Command) case CMDCond(c: Command, condition: Expr) case class BreakPointAction( - saveState: Boolean = true, - stop: Boolean = false, - evalExprs: List[(String, Expr)] = List(), - log: Boolean = false + saveState: Boolean = true, + stop: Boolean = false, + evalExprs: List[(String, Expr)] = List(), + log: Boolean = false ) case class BreakPoint(name: String = "", location: BreakPointLoc, action: BreakPointAction) @@ -34,7 +24,7 @@ case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](f: I, break State.filterM( b => b.location match { - case BreakPointLoc.CMD(bc) if (bc == c) => State.pure(true) + case BreakPointLoc.CMD(bc) if bc == c => State.pure(true) case BreakPointLoc.CMDCond(bc, e) if bc == c => doLeft(Eval.evalBool(f)(e)) case _ => State.pure(false) }, @@ -56,49 +46,46 @@ case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](f: I, break res <- State .sequence[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), Unit, InterpreterError]( State.pure(()), - breaks.map((breakpoint: BreakPoint) => - (breakpoint match { - case breakpoint @ BreakPoint(name, stopcond, action) => ( - for { - saved <- doLeft( - if action.saveState then State.getS[T, InterpreterError].map(s => Some(s)) - else State.pure(None) - ) - evals <- (State.mapM( - (e: (String, Expr)) => - for { - ev <- doLeft(Eval.evalExpr(f)(e._2)) - } yield (e._1, e._2, ev), - action.evalExprs - )) - _ <- State.pure({ - if (action.log) { - val bpn = breakpoint.name - val bpcond = breakpoint.location match { - case BreakPointLoc.CMD(c) => s"${c.parent.label}:$c" - case BreakPointLoc.CMDCond(c, e) => s"${c.parent.label}:$c when $e" - } - val saving = if action.saveState then " stashing state, " else "" - val stopping = if action.stop then " stopping. " else "" - val evalstr = evals.map(e => s"\n ${e._1} : eval(${e._2}) = ${e._3}").mkString("") - Logger.warn(s"Breakpoint $bpn@$bpcond.$saving$stopping$evalstr") + breaks.map { + case breakpoint @ BreakPoint(name, stopcond, action) => + for { + saved <- doLeft( + if action.saveState then State.getS[T, InterpreterError].map(s => Some(s)) + else State.pure(None) + ) + evals <- State.mapM( + (e: (String, Expr)) => + for { + ev <- doLeft(Eval.evalExpr(f)(e(1))) + } yield (e(0), e(1), ev), + action.evalExprs + ) + _ <- State.pure({ + if (action.log) { + val bpn = breakpoint.name + val bpcond = breakpoint.location match { + case BreakPointLoc.CMD(c) => s"${c.parent.label}:$c" + case BreakPointLoc.CMDCond(c, e) => s"${c.parent.label}:$c when $e" } - }) - _ <- - if action.stop then doLeft(f.setNext(ErrorStop(Errored(s"Stopped at breakpoint ${name}")))) - else doLeft(State.pure(())) - _ <- State.modify((istate: (T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])])) => - (istate._1, ((breakpoint, saved, evals) :: istate._2)) - ) - } yield () - ) - }) - ) + val saving = if action.saveState then " stashing state, " else "" + val stopping = if action.stop then " stopping. " else "" + val evalstr = evals.map(e => s"\n ${e(0)} : eval(${e(1)}) = ${e(2)}").mkString("") + Logger.warn(s"Breakpoint $bpn@$bpcond.$saving$stopping$evalstr") + } + }) + _ <- + if action.stop then doLeft(f.setNext(ErrorStop(Errored(s"Stopped at breakpoint ${name}")))) + else doLeft(State.pure(())) + _ <- State.modify((istate: (T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])])) => + (istate(0), ((breakpoint, saved, evals) :: istate(1))) + ) + } yield () + } ) } yield () case _ => State.pure(()) } - } yield (v) + } yield v } } diff --git a/src/main/scala/ir/eval/InterpretRLimit.scala b/src/main/scala/ir/eval/InterpretRLimit.scala index a381e5d8f..8bc8e5568 100644 --- a/src/main/scala/ir/eval/InterpretRLimit.scala +++ b/src/main/scala/ir/eval/InterpretRLimit.scala @@ -1,33 +1,22 @@ package ir.eval -import ir._ -import ir.eval.BitVectorEval.* import ir.* import util.IRContext import util.Logger import util.functional.* -import util.functional.State.* -import boogie.Scope -import scala.collection.WithFilter -import scala.annotation.tailrec -import scala.collection.mutable -import scala.collection.immutable -import scala.util.control.Breaks.{break, breakable} +case class EffectsRLimit[T, E, I <: Effects[T, InterpreterError]](limit: Int) extends NopEffects[(T, Int), InterpreterError] { - -case class EffectsRLimit[T, E, I <: Effects[T, InterpreterError]](val limit: Int) extends NopEffects[(T, Int), InterpreterError] { - - override def getNext :State[(T, Int), ExecutionContinuation, InterpreterError] = { + override def getNext: State[(T, Int), ExecutionContinuation, InterpreterError] = { for { - c : (T, Int) <- State.getS + c: (T, Int) <- State.getS (is, resources) = c _ <- if (resources >= limit && limit >= 0) { State.setError(Errored(s"Resource limit $limit reached")) } else { - State.modify ((s : (T, Int)) => (s._1, s._2 + 1)) + State.modify((s: (T, Int)) => (s(0), s(1) + 1)) } - } yield (Stopped()) // thrown away by LayerInterpreter + } yield Stopped() // thrown away by LayerInterpreter } } @@ -38,15 +27,15 @@ def interpretWithRLimit[I](p: Program, instructionLimit: Int, innerInterpreter: def interpretWithRLimit[I](p: IRContext, instructionLimit: Int, innerInterpreter: Effects[I, InterpreterError], innerInitialState: I): (I, Int) = { val rlimitInterpreter = LayerInterpreter(innerInterpreter, EffectsRLimit(instructionLimit)) - val begin = InterpFuns.initProgState(rlimitInterpreter)(p, (innerInitialState, 0)) + val (begin, _) = InterpFuns.initProgState(rlimitInterpreter)(p, (innerInitialState, 0)) // throw away initialisation trace - BASILInterpreter(rlimitInterpreter).run((begin._1, 0)) + BASILInterpreter(rlimitInterpreter).run((begin, 0)) } -def interpretRLimit(p: Program, instructionLimit: Int) : (InterpreterState, Int) = { +def interpretRLimit(p: Program, instructionLimit: Int): (InterpreterState, Int) = { interpretWithRLimit(p, instructionLimit, NormalInterpreter, InterpreterState()) } -def interpretRLimit(p: IRContext, instructionLimit: Int) : (InterpreterState, Int) = { +def interpretRLimit(p: IRContext, instructionLimit: Int): (InterpreterState, Int) = { interpretWithRLimit(p, instructionLimit, NormalInterpreter, InterpreterState()) } diff --git a/src/main/scala/ir/eval/InterpretTrace.scala b/src/main/scala/ir/eval/InterpretTrace.scala index 89b0fc976..7087463c3 100644 --- a/src/main/scala/ir/eval/InterpretTrace.scala +++ b/src/main/scala/ir/eval/InterpretTrace.scala @@ -1,18 +1,9 @@ package ir.eval -import ir._ -import ir.eval.BitVectorEval.* import ir.* import util.IRContext import util.Logger import util.functional.* -import util.functional.State.* import boogie.Scope -import scala.collection.WithFilter - -import scala.annotation.tailrec -import scala.collection.mutable -import scala.collection.immutable -import scala.util.control.Breaks.{break, breakable} enum ExecEffect: case Call(target: String, begin: ExecutionContinuation, returnTo: ExecutionContinuation) @@ -23,7 +14,7 @@ enum ExecEffect: case LoadMem(vname: String, addrs: List[BasilValue]) case FindProc(addr: Int) -case class Trace(val t: List[ExecEffect]) +case class Trace(t: List[ExecEffect]) case object Trace { def add[E](e: ExecEffect): State[Trace, Unit, E] = { @@ -34,25 +25,25 @@ case object Trace { case class TraceGen[E]() extends NopEffects[Trace, E] { /** Values are discarded by ProductInterpreter so do not matter */ - override def loadMem(v: String, addrs: List[BasilValue]) = for { + override def loadMem(v: String, addrs: List[BasilValue]): State[Trace, List[BasilValue], E] = for { s <- Trace.add(ExecEffect.LoadMem(v, addrs)) - } yield (List()) + } yield List() - override def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = for { + override def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): State[Trace, Unit, E] = for { s <- Trace.add(ExecEffect.Call(target, beginFrom, returnTo)) - } yield (()) + } yield () - override def doReturn() = for { + override def doReturn(): State[Trace, Unit, E] = for { s <- Trace.add(ExecEffect.Return) - } yield (()) + } yield () - override def storeVar(v: String, scope: Scope, value: BasilValue) = for { - s <- if (!v.startsWith("ghost")) Trace.add(ExecEffect.StoreVar(v, scope, value)) else State.pure(()) - } yield (()) + override def storeVar(v: String, scope: Scope, value: BasilValue): State[Trace, Unit, E] = for { + s <- if !v.startsWith("ghost") then Trace.add(ExecEffect.StoreVar(v, scope, value)) else State.pure(()) + } yield () - override def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = for { - s <- if (!vname.startsWith("ghost")) Trace.add(ExecEffect.StoreMem(vname, update)) else State.pure(()) - } yield (()) + override def storeMem(vname: String, update: Map[BasilValue, BasilValue]): State[Trace, Unit, E] = for { + s <- if !vname.startsWith("ghost") then Trace.add(ExecEffect.StoreMem(vname, update)) else State.pure(()) + } yield () } @@ -64,9 +55,9 @@ def interpretWithTrace[I](p: Program, innerInterpreter: Effects[I, InterpreterEr def interpretWithTrace[I](p: IRContext, innerInterpreter: Effects[I, InterpreterError], innerInitialState: I): (I, Trace) = { val tracingInterpreter = ProductInterpreter(innerInterpreter, TraceGen()) - val begin = InterpFuns.initProgState(tracingInterpreter)(p, (innerInitialState, Trace(List()))) + val (begin, _) = InterpFuns.initProgState(tracingInterpreter)(p, (innerInitialState, Trace(List()))) // throw away initialisation trace - BASILInterpreter(tracingInterpreter).run((begin._1, Trace(List()))) + BASILInterpreter(tracingInterpreter).run((begin, Trace(List()))) } def interpretTrace(p: Program) = { diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala index 6f948269f..6d42b5cc3 100644 --- a/src/main/scala/ir/eval/Interpreter.scala +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -1,16 +1,10 @@ package ir.eval -import ir.eval.BitVectorEval.* import ir.* import util.Logger import util.functional.* -import util.functional.State.* import boogie.Scope -import scala.collection.WithFilter import scala.annotation.tailrec -import scala.collection.mutable -import scala.collection.immutable -import scala.util.control.Breaks.{break, breakable} /** Interpreter status type, either stopped, run next command or error */ @@ -33,7 +27,7 @@ case class MemoryError(message: String = "") extends InterpreterError /* An erro /* Concrete value type of the interpreter. */ sealed trait BasilValue(val irType: Option[IRType]) case class Scalar(value: Literal) extends BasilValue(Some(value.getType)) { - override def toString = value match { + override def toString: _root_.java.lang.String = value match { case b: BitVecLiteral => "0x%x:bv%d".format(b.value, b.size) case c => c.toString } @@ -57,11 +51,11 @@ case class BasilMapValue(value: Map[BasilValue, BasilValue], mapType: MapType) override def toString = s"MapValue : $irType" } -case class GenMapValue(val value: Map[BasilValue, BasilValue]) extends BasilValue(None) with MapValue { +case class GenMapValue(value: Map[BasilValue, BasilValue]) extends BasilValue(None) with MapValue { override def toString = s"GenMapValue : $irType" } -case class Symbol(val value: String) extends BasilValue(None) +case class Symbol(value: String) extends BasilValue(None) case object BasilValue { @@ -75,7 +69,7 @@ case object BasilValue { def toBV[S, E](l: BasilValue): Either[InterpreterError, BitVecLiteral] = { l match { case Scalar(b1: BitVecLiteral) => Right(b1) - case _ => Left((TypeError(s"Not a bitvector add $l"))) + case _ => Left(TypeError(s"Not a bitvector add $l")) } } @@ -84,7 +78,7 @@ case object BasilValue { case _ if vr == 0 => Right(l) case Scalar(IntLiteral(vl)) => Right(Scalar(IntLiteral(vl + vr))) case Scalar(b1: BitVecLiteral) => Right(Scalar(eval.evalBVBinExpr(BVADD, b1, BitVecLiteral(vr, b1.size)))) - case _ => Left((TypeError(s"Operation add $vr undefined on $l"))) + case _ => Left(TypeError(s"Operation add $vr undefined on $l")) } } @@ -92,7 +86,7 @@ case object BasilValue { (l, r) match { case (Scalar(IntLiteral(vl)), Scalar(IntLiteral(vr))) => Right(Scalar(IntLiteral(vl + vr))) case (Scalar(b1: BitVecLiteral), Scalar(b2: BitVecLiteral)) => Right(Scalar(eval.evalBVBinExpr(BVADD, b1, b2))) - case _ => Left((TypeError(s"Operation add undefined $l + $r"))) + case _ => Left(TypeError(s"Operation add undefined $l + $r")) } } @@ -133,18 +127,18 @@ trait Effects[T, E] { } trait NopEffects[T, E] extends Effects[T, E] { - def loadVar(v: String) = State.pure(Scalar(FalseLiteral)) - def loadMem(v: String, addrs: List[BasilValue]) = State.pure(List()) - def evalAddrToProc(addr: Int) = State.pure(None) - def getNext = State.pure(Stopped()) - def setNext(c: ExecutionContinuation) = State.pure(()) - - def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = State.pure(()) - def callIntrinsic(name: String, args: List[BasilValue]) = State.pure(None) - def doReturn() = State.pure(()) - - def storeVar(v: String, scope: Scope, value: BasilValue) = State.pure(()) - def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = State.pure(()) + def loadVar(v: String): State[T, BasilValue, E] = State.pure(Scalar(FalseLiteral)) + def loadMem(v: String, addrs: List[BasilValue]): State[T, List[BasilValue], E] = State.pure(List()) + def evalAddrToProc(addr: Int): State[T, Option[FunPointer], E] = State.pure(None) + def getNext: State[T, ExecutionContinuation, E] = State.pure(Stopped()) + def setNext(c: ExecutionContinuation): State[T, Unit, E] = State.pure(()) + + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): State[T, Unit, E] = State.pure(()) + def callIntrinsic(name: String, args: List[BasilValue]): State[T, Option[BasilValue], E] = State.pure(None) + def doReturn(): State[T, Unit, E] = State.pure(()) + + def storeVar(v: String, scope: Scope, value: BasilValue): State[T, Unit, E] = State.pure(()) + def storeMem(vname: String, update: Map[BasilValue, BasilValue]): State[T, Unit, E] = State.pure(()) } /*-------------------------------------------------------------------------------- @@ -159,28 +153,27 @@ case class MemoryState( * - activations is the call stack, the top of which indicates the current stackFrame. * - activationCount: (procedurename -> int) is used to create uniquely-named stackframes. */ - val stackFrames: Map[StackFrameID, Map[String, BasilValue]] = Map((globalFrame -> Map.empty)), - val activations: List[StackFrameID] = List.empty, - val activationCount: Map[String, Int] = Map.empty.withDefault(_ => 0) + stackFrames: Map[StackFrameID, Map[String, BasilValue]] = Map((globalFrame -> Map.empty)), + activations: List[StackFrameID] = List.empty, + activationCount: Map[String, Int] = Map.empty.withDefault(_ => 0) ) { /** Debug return useful values * */ def getGlobalVals: Map[String, BitVecLiteral] = { - stackFrames(globalFrame).collect { case (k, Scalar(b: BitVecLiteral)) => - k -> b + stackFrames(globalFrame).collect { + case (k, Scalar(b: BitVecLiteral)) => k -> b } } def getMem(name: String): Map[BitVecLiteral, BitVecLiteral] = { stackFrames(globalFrame)(name) match { - case BasilMapValue(innerMap, MapType(BitVecType(ks), BitVecType(vs))) => { + case BasilMapValue(innerMap, MapType(BitVecType(ks), BitVecType(vs))) => def unwrap(v: BasilValue): BitVecLiteral = v match { case Scalar(b: BitVecLiteral) => b case v => throw Exception(s"Failed to convert map value to bitvector: $v (interpreter type error somewhere)") } innerMap.map((k, v) => unwrap(k) -> unwrap(v)) - } case v => throw Exception(s"$name not a bitvec map variable: ${v.irType}") } } @@ -200,11 +193,11 @@ case class MemoryState( case h :: Nil if h == globalFrame => Left((Errored("tried to pop global scope"))) case h :: tl => Right((h, tl)) } - hv.map((hv) => { + hv.map { hv => val (frame, remactivs) = hv val frames = stackFrames.removed(frame) MemoryState(frames, remactivs, activationCount) - }) + } } /* Variable retrieval and setting */ @@ -217,7 +210,7 @@ case class MemoryState( /* Find variable definition scope and set it in the correct frame */ def setVar(v: String, value: BasilValue): MemoryState = { - val frame = findVarOpt(v).map(_._1).getOrElse(activations.head) + val frame = findVarOpt(v).map(_(0)).getOrElse(activations.head) setVar(frame, v, value) } @@ -235,24 +228,24 @@ case class MemoryState( /* Lookup the value of a variable */ def findVarOpt(name: String): Option[(StackFrameID, BasilValue)] = { val searchScopes = globalFrame :: activations.headOption.toList - searchScopes.foldRight(None: Option[(StackFrameID, BasilValue)])((r, acc) => + searchScopes.foldRight(None: Option[(StackFrameID, BasilValue)]) { (r, acc) => acc match { case None => stackFrames(r).get(name).map(v => (r, v)) - case s => s + case s => s } - ) + } } def findVar(name: String): Either[InterpreterError, (StackFrameID, BasilValue)] = { findVarOpt(name: String) .map(Right(_)) - .getOrElse(Left((Errored(s"Access to undefined variable $name")))) + .getOrElse(Left(Errored(s"Access to undefined variable $name"))) } - def getVarOpt(name: String): Option[BasilValue] = findVarOpt(name).map(_._2) + def getVarOpt(name: String): Option[BasilValue] = findVarOpt(name).map(_(1)) def getVar(name: String): Either[InterpreterError, BasilValue] = { - getVarOpt(name).map(Right(_)).getOrElse(Left((Errored(s"Access undefined variable $name")))) + getVarOpt(name).map(Right(_)).getOrElse(Left(Errored(s"Access undefined variable $name"))) } def getVar(v: Variable): Either[InterpreterError, BasilValue] = { @@ -268,47 +261,45 @@ case class MemoryState( /* Map variable accessing ; load and store operations */ def doLoad(vname: String, addr: List[BasilValue]): Either[InterpreterError, List[BasilValue]] = for { v <- findVar(vname) - mapv: MapValue <- v._2 match { + mapv: MapValue <- v(1) match { case m: MapValue => Right(m) - case m => Left((TypeError(s"Load from nonmap ${m.irType}"))) + case m => Left(TypeError(s"Load from nonmap ${m.irType}")) } rs: List[Option[BasilValue]] = addr.map(k => mapv.value.get(k)) - xs <- - (if (rs.forall(_.isDefined)) { - Right(rs.map(_.get)) - } else { - Left((MemoryError(s"Read from uninitialised $vname[${addr.head} .. ${addr.last}]"))) - }) - } yield (xs) + xs <- + if (rs.forall(_.isDefined)) { + Right(rs.map(_.get)) + } else { + Left(MemoryError(s"Read from uninitialised $vname[${addr.head} .. ${addr.last}]")) + } + } yield xs /** typecheck and some fields of a map variable */ def doStore(vname: String, values: Map[BasilValue, BasilValue]): Either[InterpreterError, MemoryState] = for { - - _ <- if (values.size == 0) then Left(MemoryError("Tried to store size 0")) else Right(()) + _ <- if values.isEmpty then Left(MemoryError("Tried to store size 0")) else Right(()) v <- findVar(vname) (frame, mem) = v mapval <- mem match { case m @ BasilMapValue(_, MapType(kt, vt)) => for { - m <- (values.find((k, v) => k.irType != Some(kt) || v.irType != Some(vt))) match { + m <- values.find((k, v) => k.irType != Some(kt) || v.irType != Some(vt)) match { case Some(v) => Left( TypeError( - s"Invalid addr or value type (${v._1.irType}, ${v._2.irType}) does not match map type $vname : ($kt, $vt)" + s"Invalid addr or value type (${v(0).irType}, ${v(1).irType}) does not match map type $vname : ($kt, $vt)" ) ) case None => Right(m) } nm = BasilMapValue(m.value ++ values, m.mapType) - } yield (nm) - case m @ GenMapValue(_) => { + } yield nm + case m @ GenMapValue(_) => Right(GenMapValue(m.value ++ values)) - } - case v => Left((TypeError(s"Invalid map store operation to $vname : ${v.irType}"))) + case v => Left(TypeError(s"Invalid map store operation to $vname : ${v.irType}")) } ms <- Right(setVar(frame, vname, mapval)) - } yield (ms) + } yield ms } object LibcIntrinsic { @@ -323,7 +314,7 @@ object LibcIntrinsic { res <- s.callIntrinsic(name, List(c)) _ <- if res.isDefined then s.storeVar("R0", Scope.Global, res.get) else State.pure(()) _ <- s.doReturn() - } yield (()) + } yield () def calloc[S, T <: Effects[S, InterpreterError]](s: T): State[S, Unit, InterpreterError] = for { size <- s.loadVar("R0") @@ -335,9 +326,9 @@ object LibcIntrinsic { } cl <- Eval.storeBV(s)("mem", ptr, BitVecLiteral(0, isize.toInt), Endian.LittleEndian) _ <- s.doReturn() - } yield (()) + } yield () - def intrinsics[S, T <: Effects[S, InterpreterError]] = + def intrinsics[S, T <: Effects[S, InterpreterError]]: Map[String, T => State[S, Unit, InterpreterError]] = Map[String, T => State[S, Unit, InterpreterError]]( "putc" -> singleArg("putc"), "puts" -> singleArg("puts"), @@ -359,7 +350,7 @@ object IntrinsicImpl { _ <- f.storeMem("ghost-file-bookkeeping", Map(Symbol("$$filecount") -> Scalar(BitVecLiteral(0, 64)))) _ <- f.callIntrinsic("fopen", List(Symbol("stderr"))) _ <- f.callIntrinsic("fopen", List(Symbol("stdout"))) - } yield (()) + } yield () /** Intrinsics defined over arbitrary effects * @@ -373,7 +364,7 @@ object IntrinsicImpl { _ <- f.storeMem("stdout", Map(addr.head -> Scalar(c))) naddr <- State.pureE(BasilValue.unsafeAdd(addr.head, 1)) _ <- f.storeMem("ghost-file-bookkeeping", Map(Symbol("stdout-ptr") -> naddr)) - } yield (None) + } yield None } def fopen[S, T <: Effects[S, InterpreterError]](f: T)(file: BasilValue): State[S, Option[BasilValue], InterpreterError] = { @@ -389,7 +380,7 @@ object IntrinsicImpl { _ <- f.storeVar("R0", Scope.Global, filecount.head) nfilecount <- State.pureE(BasilValue.unsafeAdd(filecount.head, 1)) _ <- f.storeMem("ghost-file-bookkeeping", Map(Symbol("$$filecount") -> nfilecount)) - } yield (Some(filecount.head)) + } yield Some(filecount.head) } def print[S, T <: Effects[S, InterpreterError]](f: T)(strptr: BasilValue): State[S, Option[BasilValue], InterpreterError] = { @@ -403,36 +394,34 @@ object IntrinsicImpl { _ <- f.storeMem("stdout", offs.zip(str.map(Scalar(_))).toMap) naddr <- State.pureE(BasilValue.unsafeAdd(baseptr.head, str.size)) _ <- f.storeMem("ghost-file-bookkeeping", Map(Symbol("stdout-ptr") -> naddr)) - } yield (None) + } yield None } def malloc[S, T <: Effects[S, InterpreterError]](f: T)(size: BasilValue): State[S, Option[BasilValue], InterpreterError] = { for { size <- (size match { - case (x @ Scalar(_: BitVecLiteral)) => State.pure(x) - case (Scalar(x: IntLiteral)) => State.pure(Scalar(BitVecLiteral(x.value, 64))) - case _ => State.setError(Errored("illegal prim arg")) + case x @ Scalar(_: BitVecLiteral) => State.pure(x) + case Scalar(x: IntLiteral) => State.pure(Scalar(BitVecLiteral(x.value, 64))) + case _ => State.setError(Errored("illegal prim arg")) }) x <- f.loadVar("ghost_malloc_top") x_gap <- State.pureE(BasilValue.unsafeAdd(x, 128)) // put a gap around allocations to catch buffer overflows x_end <- State.pureE(BasilValue.add(x_gap, size)) _ <- f.storeVar("ghost_malloc_top", Scope.Global, x_end) _ <- f.storeVar("R0", Scope.Global, x_gap) - } yield (Some(x_gap)) + } yield Some(x_gap) } } case class InterpreterState( - val nextCmd: ExecutionContinuation = Stopped(), - val callStack: List[ExecutionContinuation] = List.empty, - val memoryState: MemoryState = MemoryState() + nextCmd: ExecutionContinuation = Stopped(), + callStack: List[ExecutionContinuation] = List.empty, + memoryState: MemoryState = MemoryState() ) /** Implementation of Effects for InterpreterState concrete state representation. */ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { - - def callIntrinsic( name: String, args: List[BasilValue] @@ -447,70 +436,68 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { str <- Eval.getNullTerminatedString(this)("mem", args.head) r = Scalar(BitVecLiteral(str.length, 64)) _ <- storeVar("R0", Scope.Global, r) - } yield (Some(r)) + } yield Some(r) case "print" => IntrinsicImpl.print(this)(args.head) case "puts" => IntrinsicImpl.print(this)(args.head) >> IntrinsicImpl.putc(this)(Scalar(BitVecLiteral('\n'.toInt, 64))) case _ => State.setError(Errored(s"Call undefined intrinsic $name")) } } - def loadVar(v: String) = { + def loadVar(v: String): State[InterpreterState, BasilValue, InterpreterError] = { State.getE((s: InterpreterState) => { s.memoryState.getVar(v) }) } - def evalAddrToProc(addr: Int) = + def evalAddrToProc(addr: Int): State[InterpreterState, Option[FunPointer], InterpreterError] = Logger.debug(s" eff : FIND PROC $addr") for { - res: List[BasilValue] <- getE((s: InterpreterState) => + res: List[BasilValue] <- State.getE((s: InterpreterState) => s.memoryState.doLoad("ghost-funtable", List(Scalar(BitVecLiteral(addr, 64)))) ) } yield { res match { - case ((f: FunPointer) :: Nil) => Some(f) + case (f: FunPointer) :: Nil => Some(f) case _ => None } } def formatStore(varname: String, update: Map[BasilValue, BasilValue]): String = { - val ks = update.toList.sortWith((x, y) => { + val ks = update.toList.sortWith { (x, y) => def conv(v: BasilValue): BigInt = v match { - case (Scalar(b: BitVecLiteral)) => b.value - case (Scalar(b: IntLiteral)) => b.value - case _ => BigInt(0) + case Scalar(b: BitVecLiteral) => b.value + case Scalar(b: IntLiteral) => b.value + case _ => BigInt(0) } - conv(x._1) <= conv(y._1) - }) + conv(x(0)) <= conv(y(0)) + } - val rs = ks.foldLeft(Some((None, List[BitVecLiteral]())): Option[(Option[BigInt], List[BitVecLiteral])])((acc, v) => - v match { - case (Scalar(bv: BitVecLiteral), Scalar(bv2: BitVecLiteral)) => { - acc match { - case None => None - case Some(None, l) => Some(Some(bv.value), bv2 :: l) - case Some(Some(v), l) if bv.value == v + 1 => Some(Some(bv.value), bv2 :: l) - case Some(Some(v), l) => { - None + val rs = ks.foldLeft(Some((None, List[BitVecLiteral]())): Option[(Option[BigInt], List[BitVecLiteral])]) { + (acc, v) => + v match { + case (Scalar(bv: BitVecLiteral), Scalar(bv2: BitVecLiteral)) => + acc match { + case None => None + case Some(None, l) => Some(Some(bv.value), bv2 :: l) + case Some(Some(v), l) if bv.value == v + 1 => Some(Some(bv.value), bv2 :: l) + case Some(Some(v), l) => + None } - } + case (bv, bv2) => None } - case (bv, bv2) => None - } - ) + } rs match { - case Some(_, l) => { + case Some(_, l) => val vs = Scalar(l.foldLeft(BitVecLiteral(0, 0))((acc, r) => eval.evalBVBinExpr(BVCONCAT, acc, r))).toString - s"$varname[${ks.headOption.map(_._1).getOrElse("null")}] := $vs" - } - case None if ks.length < 8 => s"$varname[${ks.map(_._1).mkString(",")}] := ${ks.map(_._2).mkString(",")}" - case None => s"$varname[${ks.map(_._1).take(8).mkString(",")}...] := ${ks.map(_._2).take(8).mkString(", ")}... " + s"$varname[${ks.headOption.map(_(0)).getOrElse("null")}] := $vs" + case None if ks.length < 8 => s"$varname[${ks.map(_(0)).mkString(",")}] := ${ks.map(_(1)).mkString(",")}" + case None => s"$varname[${ks.map(_(0)).take(8).mkString(",")}...] := ${ks.map(_(1)).take(8).mkString(", ")}... " } } - def loadMem(v: String, addrs: List[BasilValue]) = { + def loadMem(v: String, addrs: List[BasilValue]): State[InterpreterState, List[BasilValue], InterpreterError] = { State.getE((s: InterpreterState) => { val r = s.memoryState.doLoad(v, addrs) Logger.debug(s" eff : LOAD ${addrs.head} x ${addrs.size}") @@ -518,15 +505,15 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { }) } - def getNext = State.get((s: InterpreterState) => s.nextCmd) + def getNext: State[InterpreterState, ExecutionContinuation, InterpreterError] = State.get((s: InterpreterState) => s.nextCmd) /** effects * */ - def setNext(c: ExecutionContinuation) = State.modify((s: InterpreterState) => { + def setNext(c: ExecutionContinuation): State[InterpreterState, Unit, InterpreterError] = State.modify((s: InterpreterState) => { s.copy(nextCmd = c) }) - def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = - modify((s: InterpreterState) => { + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): State[InterpreterState, Unit, InterpreterError] = + State.modify((s: InterpreterState) => { Logger.debug(s" eff : CALL $target") s.copy( nextCmd = beginFrom, @@ -535,9 +522,9 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { ) }) - def doReturn() = { + def doReturn(): State[InterpreterState, Unit, InterpreterError] = { Logger.debug(s" eff : RETURN") - modifyE((s: InterpreterState) => { + State.modifyE((s: InterpreterState) => { s.callStack match { case Nil => Right(s.copy(nextCmd = Stopped())) case h :: tl => @@ -553,12 +540,12 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { State.modify((s: InterpreterState) => s.copy(memoryState = s.memoryState.defVar(v, scope, value))) } - def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = + def storeMem(vname: String, update: Map[BasilValue, BasilValue]): State[InterpreterState, Unit, InterpreterError] = State.modifyE((s: InterpreterState) => { Logger.debug(s" eff : STORE ${formatStore(vname, update)}") for { ms <- s.memoryState.doStore(vname, update) - } yield (s.copy(memoryState = ms)) + } yield s.copy(memoryState = ms) }) } @@ -573,7 +560,7 @@ trait Interpreter[S, E](val f: Effects[S, E]) { final def run(begin: S): S = { val (fs, cont) = interpretOne.f(begin) - if (cont.contains(true)) then { + if (cont.contains(true)) { run(fs) } else { fs diff --git a/src/main/scala/ir/eval/InterpreterProduct.scala b/src/main/scala/ir/eval/InterpreterProduct.scala index ccf4fb682..18841a9c3 100644 --- a/src/main/scala/ir/eval/InterpreterProduct.scala +++ b/src/main/scala/ir/eval/InterpreterProduct.scala @@ -1,139 +1,130 @@ package ir.eval -import ir._ -import ir.eval.BitVectorEval.* import ir.* import util.Logger import util.functional.* -import util.functional.State.* import boogie.Scope -import scala.collection.WithFilter - -import scala.annotation.tailrec -import scala.collection.mutable -import scala.collection.immutable -import scala.util.control.Breaks.{break, breakable} def doLeft[L, T, V, E](f: State[L, V, E]): State[(L, T), V, E] = for { n <- State[(L, T), V, E]((s: (L, T)) => { - val r = f.f(s._1) - ((r._1, s._2), r._2) + val r = f.f(s(0)) + ((r(0), s(1)), r(1)) }) -} yield (n) +} yield n def doRight[L, T, V, E](f: State[T, V, E]): State[(L, T), V, E] = for { n <- State[(L, T), V, E]((s: (L, T)) => { - val r = f.f(s._2) - ((s._1, r._1), r._2) + val r = f.f(s(1)) + ((s(0), r(0)), r(1)) }) -} yield (n) +} yield n /** Runs two interpreters "inner" and "before" simultaneously, returning the value from inner, and ignoring before */ case class ProductInterpreter[L, T, E](inner: Effects[L, E], before: Effects[T, E]) extends Effects[(L, T), E] { - def loadVar(v: String) = for { + def loadVar(v: String): State[(L, T), BasilValue, E] = for { n <- doRight(before.loadVar(v)) f <- doLeft(inner.loadVar(v)) - } yield (f) + } yield f - def loadMem(v: String, addrs: List[BasilValue]) = for { + def loadMem(v: String, addrs: List[BasilValue]): State[(L, T), List[BasilValue], E] = for { n <- doRight(before.loadMem(v, addrs)) f <- doLeft(inner.loadMem(v, addrs)) - } yield (f) + } yield f - def evalAddrToProc(addr: Int) = for { + def evalAddrToProc(addr: Int): State[(L, T), Option[FunPointer], E] = for { n <- doRight(before.evalAddrToProc(addr: Int)) f <- doLeft(inner.evalAddrToProc(addr)) - } yield (f) + } yield f - def getNext = for { + def getNext: State[(L, T), ExecutionContinuation, E] = for { n <- doRight(before.getNext) f <- doLeft(inner.getNext) - } yield (f) + } yield f /** state effects */ - def setNext(c: ExecutionContinuation) = for { + def setNext(c: ExecutionContinuation): State[(L, T), Unit, E] = for { n <- doRight(before.setNext(c)) f <- doLeft(inner.setNext(c)) - } yield (f) + } yield f - def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = for { + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): State[(L, T), Unit, E] = for { n <- doRight(before.call(target, beginFrom, returnTo)) f <- doLeft(inner.call(target, beginFrom, returnTo)) - } yield (f) + } yield f - def callIntrinsic(name: String, args: List[BasilValue]) = for { + def callIntrinsic(name: String, args: List[BasilValue]): State[(L, T), Option[BasilValue], E] = for { n <- doRight(before.callIntrinsic(name, args)) f <- doLeft(inner.callIntrinsic(name, args)) - } yield (f) + } yield f - def doReturn() = for { + def doReturn(): State[(L, T), Unit, E] = for { n <- doRight(before.doReturn()) f <- doLeft(inner.doReturn()) - } yield (f) + } yield f - def storeVar(v: String, scope: Scope, value: BasilValue) = for { + def storeVar(v: String, scope: Scope, value: BasilValue): State[(L, T), Unit, E] = for { n <- doRight(before.storeVar(v, scope, value)) f <- doLeft(inner.storeVar(v, scope, value)) - } yield (f) + } yield f - def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = for { + def storeMem(vname: String, update: Map[BasilValue, BasilValue]): State[(L, T), Unit, E] = for { n <- doRight(before.storeMem(vname, update)) f <- doLeft(inner.storeMem(vname, update)) - } yield (f) + } yield f } case class LayerInterpreter[L, T, E](inner: Effects[L, E], before: Effects[(L, T), E]) extends Effects[(L, T), E] { - def loadVar(v: String) = for { - n <- (before.loadVar(v)) + def loadVar(v: String): State[(L, T), BasilValue, E] = for { + n <- before.loadVar(v) f <- doLeft(inner.loadVar(v)) - } yield (f) + } yield f - def loadMem(v: String, addrs: List[BasilValue]) = for { - n <- (before.loadMem(v, addrs)) + def loadMem(v: String, addrs: List[BasilValue]): State[(L, T), List[BasilValue], E] = for { + n <- before.loadMem(v, addrs) f <- doLeft(inner.loadMem(v, addrs)) - } yield (f) + } yield f - def evalAddrToProc(addr: Int) = for { - n <- (before.evalAddrToProc(addr: Int)) + def evalAddrToProc(addr: Int): State[(L, T), Option[FunPointer], E] = for { + n <- before.evalAddrToProc(addr) f <- doLeft(inner.evalAddrToProc(addr)) - } yield (f) + } yield f - def getNext = for { - n <- (before.getNext) + def getNext: State[(L, T), ExecutionContinuation, E] = for { + n <- before.getNext f <- doLeft(inner.getNext) - } yield (f) + } yield f /** state effects */ - def setNext(c: ExecutionContinuation) = for { - n <- (before.setNext(c)) + def setNext(c: ExecutionContinuation): State[(L, T), Unit, E] = for { + n <- before.setNext(c) f <- doLeft(inner.setNext(c)) - } yield (f) + } yield f - def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation) = for { - n <- (before.call(target, beginFrom, returnTo)) + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): State[(L, T), Unit, E] = for { + n <- before.call(target, beginFrom, returnTo) f <- doLeft(inner.call(target, beginFrom, returnTo)) - } yield (f) + } yield f - def callIntrinsic(name: String, args: List[BasilValue]) = for { + def callIntrinsic(name: String, args: List[BasilValue]): State[(L, T), Option[BasilValue], E] = for { n <- before.callIntrinsic(name, args) f <- doLeft(inner.callIntrinsic(name, args)) - } yield (f) + } yield f - def doReturn() = for { - n <- (before.doReturn()) + def doReturn(): State[(L, T), Unit, E] = for { + n <- before.doReturn() f <- doLeft(inner.doReturn()) - } yield (f) + } yield f - def storeVar(v: String, scope: Scope, value: BasilValue) = for { - n <- (before.storeVar(v, scope, value)) + def storeVar(v: String, scope: Scope, value: BasilValue): State[(L, T), Unit, E] = for { + n <- before.storeVar(v, scope, value) f <- doLeft(inner.storeVar(v, scope, value)) - } yield (f) + } yield f - def storeMem(vname: String, update: Map[BasilValue, BasilValue]) = for { - n <- (before.storeMem(vname, update)) + def storeMem(vname: String, update: Map[BasilValue, BasilValue]): State[(L, T), Unit, E] = for { + n <- before.storeMem(vname, update) f <- doLeft(inner.storeMem(vname, update)) - } yield (f) + } yield f } diff --git a/src/main/scala/util/functional.scala b/src/main/scala/util/functional/State.scala similarity index 59% rename from src/main/scala/util/functional.scala rename to src/main/scala/util/functional/State.scala index 975220686..fed5f6329 100644 --- a/src/main/scala/util/functional.scala +++ b/src/main/scala/util/functional/State.scala @@ -10,19 +10,17 @@ case class State[S, A, E](f: S => (S, Either[E, A])) { def >>(o: State[S,A,E]) = for { _ <- this x <- o - } yield (x) - + } yield x def flatMap[B](f: A => State[S, B, E]): State[S, B, E] = State(s => { val (s2, a) = this.f(s) - val r = a match { + val r = a match { case Left(l) => (s2, Left(l)) case Right(a) => f(a).f(s2) } r }) - def map[B](f: A => B): State[S, B, E] = { State(s => { val (s2, a) = this.f(s) @@ -48,37 +46,39 @@ case class State[S, A, E](f: S => (S, Either[E, A])) { object State { def get[S, A, E](f: S => A) : State[S, A, E] = State(s => (s, Right(f(s)))) def getE[S, A, E](f: S => Either[E,A]) : State[S, A, E] = State(s => (s, f(s))) - def getS[S,E] : State[S,S,E] = State((s:S) => (s,Right(s))) - def putS[S,E](s: S) : State[S,Unit,E] = State((_) => (s,Right(()))) - def modify[S, E](f: S => S) : State[S, Unit, E] = State(s => (f(s), Right(()))) - def modifyE[S, E](f: S => Either[E, S]) : State[S, Unit, E] = State(s => f(s) match { + def getS[S,E]: State[S,S,E] = State((s:S) => (s,Right(s))) + def putS[S,E](s: S): State[S,Unit,E] = State(_ => (s,Right(()))) + def modify[S, E](f: S => S): State[S, Unit, E] = State(s => (f(s), Right(()))) + def modifyE[S, E](f: S => Either[E, S]): State[S, Unit, E] = State(s => f(s) match { case Right(ns) => (ns, Right(())) case Left(e) => (s, Left(e)) }) - def execute[S, A, E](s: S, c: State[S,A, E]) : S = c.f(s)._1 - def evaluate[S, A, E](s: S, c: State[S,A, E]) : Either[E,A] = c.f(s)._2 + def execute[S, A, E](s: S, c: State[S,A, E]): S = c.f(s)._1 + def evaluate[S, A, E](s: S, c: State[S,A, E]): Either[E,A] = c.f(s)._2 - def setError[S,A,E](e: E) : State[S,A,E] = State(s => (s, Left(e))) + def setError[S,A,E](e: E): State[S,A,E] = State(s => (s, Left(e))) - def pure[S, A, E](a: A) : State[S, A, E] = State((s:S) => (s, Right(a))) - def pureE[S, A, E](a: Either[E, A]) : State[S, A, E] = State((s:S) => (s, a)) + def pure[S, A, E](a: A): State[S, A, E] = State((s:S) => (s, Right(a))) + def pureE[S, A, E](a: Either[E, A]): State[S, A, E] = State((s:S) => (s, a)) - def sequence[S, V, E](ident: State[S,V, E], xs: Iterable[State[S,V, E]]) : State[S, V, E] = { - xs.foldRight(ident)((l,r) => for { - x <- l - y <- r - } yield(y)) + def sequence[S, V, E](ident: State[S,V, E], xs: Iterable[State[S,V, E]]): State[S, V, E] = { + xs.foldRight(ident) { + (l, r) => for { + x <- l + y <- r + } yield y + } } - def filterM[A, S, E](m : (A => State[S, Boolean, E]), xs: Iterable[A]): State[S, List[A], E] = { + def filterM[A, S, E](m: (A => State[S, Boolean, E]), xs: Iterable[A]): State[S, List[A], E] = { xs.foldRight(pure(List[A]()))((b,acc) => acc.flatMap(c => m(b).map(v => if v then b::c else c))) } - def mapM[A, B, S, E](m : (A => State[S, B, E]), xs: Iterable[A]): State[S, List[B], E] = { + def mapM[A, B, S, E](m: (A => State[S, B, E]), xs: Iterable[A]): State[S, List[B], E] = { xs.foldRight(pure(List[B]()))((b,acc) => acc.flatMap(c => m(b).map(v => v::c))) } - def protect[S, V, E](f : () => State[S, V, E], fnly: PartialFunction[Exception, E]) : State[S, V, E] = { + def protect[S, V, E](f: () => State[S, V, E], fnly: PartialFunction[Exception, E]): State[S, V, E] = { State((s: S) => try { f().f(s) } catch { @@ -86,7 +86,7 @@ object State { }) } - def protectPure[S,V,E](f : () => V, fnly : PartialFunction[Exception, E]) : State[S, V, E] = { + def protectPure[S,V,E](f: () => V, fnly: PartialFunction[Exception, E]): State[S, V, E] = { State((s: S) => try { (s, Right(f())) } catch {