From 4586ce26329779734ba8312be164e5178f1af168 Mon Sep 17 00:00:00 2001 From: l-kent <56100168+l-kent@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:53:50 +1000 Subject: [PATCH] MemoryLoad as Statement (#269) * make MemoryLoad a type of Statement, some general cleanup * consolidate constant propagation variations * fix renamed parser * fix deprecated override (this is still inelegant but it will have to do) * add labels to MemoryLoad * make DSA behaviour consistent with previous * fix issues --------- Co-authored-by: l-kent --- src/main/antlr4/{Semantics.g4 => ASLp.g4} | 2 +- src/main/scala/analysis/ANR.scala | 38 ++- src/main/scala/analysis/Analysis.scala | 183 +----------- .../scala/analysis/BasicIRConstProp.scala | 75 ----- .../scala/analysis/ConstantPropagation.scala | 192 ++++++++++++ .../scala/analysis/GlobalRegionAnalysis.scala | 30 +- .../analysis/InterLiveVarsAnalysis.scala | 54 ++-- .../InterprocSteensgaardAnalysis.scala | 36 ++- .../analysis/IntraLiveVarsAnalysis.scala | 23 +- .../scala/analysis/MemoryRegionAnalysis.scala | 71 +++-- src/main/scala/analysis/RNA.scala | 43 ++- .../ReachingDefinitionsAnalysis.scala | 30 +- src/main/scala/analysis/ReachingDefs.scala | 8 +- .../scala/analysis/RegToMemAnalysis.scala | 72 ----- src/main/scala/analysis/RegionInjector.scala | 47 +-- .../scala/analysis/SummaryGenerator.scala | 46 ++- src/main/scala/analysis/TaintAnalysis.scala | 66 ++--- src/main/scala/analysis/UtilMethods.scala | 27 +- src/main/scala/analysis/VSA.scala | 65 ++-- .../analysis/VariableDependencyAnalysis.scala | 52 ++-- src/main/scala/analysis/WriteToAnalysis.scala | 27 +- .../data_structure_analysis/Graph.scala | 80 ++--- .../data_structure_analysis/LocalPhase.scala | 69 ++--- .../SymbolicAddressAnalysis.scala | 14 +- .../data_structure_analysis/Utility.scala | 1 - .../scala/analysis/solvers/IDESolver.scala | 67 +++-- src/main/scala/bap/BAPExpr.scala | 107 +------ src/main/scala/bap/BAPProgram.scala | 11 +- src/main/scala/bap/BAPStatement.scala | 14 +- src/main/scala/ir/Expr.scala | 23 -- src/main/scala/ir/Interpreter.scala | 36 ++- src/main/scala/ir/Statement.scala | 43 ++- src/main/scala/ir/Visitor.scala | 53 ++-- src/main/scala/ir/cilvisitor/CILVisitor.scala | 39 ++- src/main/scala/translating/BAPToIR.scala | 168 ++++++++++- ...emanticsLoader.scala => GTIRBLoader.scala} | 280 ++++++++++-------- src/main/scala/translating/GTIRBToIR.scala | 14 +- src/main/scala/translating/ILtoIL.scala | 34 +-- src/main/scala/translating/IRToBoogie.scala | 33 ++- src/main/scala/util/RunUtils.scala | 96 +++--- .../scala/DataStructureAnalysisTest.scala | 38 +-- src/test/scala/LiveVarsAnalysisTests.scala | 46 +-- src/test/scala/PointsToTest.scala | 28 +- src/test/scala/TaintAnalysisTests.scala | 22 +- src/test/scala/ir/CILVisitorTest.scala | 10 +- src/test/scala/ir/IRTest.scala | 42 +-- src/test/scala/ir/SingleCallInvariant.scala | 18 +- 47 files changed, 1268 insertions(+), 1305 deletions(-) rename src/main/antlr4/{Semantics.g4 => ASLp.g4} (99%) delete mode 100644 src/main/scala/analysis/BasicIRConstProp.scala create mode 100644 src/main/scala/analysis/ConstantPropagation.scala delete mode 100644 src/main/scala/analysis/RegToMemAnalysis.scala rename src/main/scala/translating/{SemanticsLoader.scala => GTIRBLoader.scala} (68%) diff --git a/src/main/antlr4/Semantics.g4 b/src/main/antlr4/ASLp.g4 similarity index 99% rename from src/main/antlr4/Semantics.g4 rename to src/main/antlr4/ASLp.g4 index 821111783..372f4ced5 100644 --- a/src/main/antlr4/Semantics.g4 +++ b/src/main/antlr4/ASLp.g4 @@ -1,4 +1,4 @@ -grammar Semantics; +grammar ASLp; // See aslp/libASL/asl.ott for reference grammar Bap-ali-plugin/asli_lifer.ml may also be useful for // visitors diff --git a/src/main/scala/analysis/ANR.scala b/src/main/scala/analysis/ANR.scala index 196ef8634..c318bbaee 100644 --- a/src/main/scala/analysis/ANR.scala +++ b/src/main/scala/analysis/ANR.scala @@ -7,7 +7,7 @@ import scala.collection.immutable /** * Calculates the set of variables that are not read after being written up to that point in the program. - * Useful for detecting dead stores, constants and if what variables are passed as parameters in a function call. + * Useful for detecting dead stores, constants and which variables are passed as parameters in a function call. */ trait ANRAnalysis(program: Program) { @@ -26,35 +26,41 @@ trait ANRAnalysis(program: Program) { /** Default implementation of eval. */ def eval(cmd: Command, s: Set[Variable]): Set[Variable] = { - var m = s cmd match { case assume: Assume => - m.diff(assume.body.variables) + s.diff(assume.body.variables) case assert: Assert => - m.diff(assert.body.variables) - case memoryAssign: MemoryAssign => - m.diff(memoryAssign.index.variables) + s.diff(assert.body.variables) + case memoryStore: MemoryStore => + s.diff(memoryStore.index.variables) case indirectCall: IndirectCall => - m - indirectCall.target - case assign: Assign => - m = m.diff(assign.rhs.variables) - if ignoreRegions.contains(assign.lhs) then m else m + assign.lhs + s - indirectCall.target + case assign: LocalAssign => + val m = s.diff(assign.rhs.variables) + if (ignoreRegions.contains(assign.lhs)) { + m + } else { + m + assign.lhs + } + case memoryLoad: MemoryLoad => + val m = s.diff(memoryLoad.index.variables) + if (ignoreRegions.contains(memoryLoad.lhs)) { + m + } else { + m + memoryLoad.lhs + } case _ => - m + s } } /** Transfer function for state lattice elements. */ - def localTransfer(n: CFGPosition, s: Set[Variable]): Set[Variable] = n match { + def transfer(n: CFGPosition, s: Set[Variable]): Set[Variable] = n match { case cmd: Command => eval(cmd, s) case _ => s // ignore other kinds of nodes } - - /** Transfer function for state lattice elements. - */ - def transfer(n: CFGPosition, s: Set[Variable]): Set[Variable] = localTransfer(n, s) } class ANRAnalysisSolver(program: Program) extends ANRAnalysis(program) diff --git a/src/main/scala/analysis/Analysis.scala b/src/main/scala/analysis/Analysis.scala index 969bbc2e3..9a77d4032 100644 --- a/src/main/scala/analysis/Analysis.scala +++ b/src/main/scala/analysis/Analysis.scala @@ -1,14 +1,5 @@ package analysis -import ir.* -import analysis.solvers.* - -import scala.collection.mutable.{ArrayBuffer, HashMap, ListBuffer} -import java.io.{File, PrintWriter} -import scala.collection.mutable -import scala.collection.immutable -import util.Logger - /** Trait for program analyses. * * @tparam R @@ -18,176 +9,4 @@ trait Analysis[+R]: /** Performs the analysis and returns the result. */ - def analyze(): R - -/** Base class for value analysis with simple (non-lifted) lattice. - */ -trait ConstantPropagation(val program: Program) { - /** The lattice of abstract states. - */ - - val valuelattice: ConstantPropagationLattice = ConstantPropagationLattice() - - val statelattice: MapLattice[Variable, FlatElement[BitVecLiteral], ConstantPropagationLattice] = MapLattice(valuelattice) - - /** Default implementation of eval. - */ - def eval(exp: Expr, env: Map[Variable, FlatElement[BitVecLiteral]]): FlatElement[BitVecLiteral] = { - import valuelattice._ - exp match { - case id: Variable => env(id) - case n: BitVecLiteral => bv(n) - case ze: ZeroExtend => zero_extend(ze.extension, eval(ze.body, env)) - case se: SignExtend => sign_extend(se.extension, eval(se.body, env)) - case e: Extract => extract(e.end, e.start, eval(e.body, env)) - case bin: BinaryExpr => - val left = eval(bin.arg1, env) - val right = eval(bin.arg2, env) - bin.op match { - case BVADD => bvadd(left, right) - case BVSUB => bvsub(left, right) - case BVMUL => bvmul(left, right) - case BVUDIV => bvudiv(left, right) - case BVSDIV => bvsdiv(left, right) - case BVSREM => bvsrem(left, right) - case BVUREM => bvurem(left, right) - case BVSMOD => bvsmod(left, right) - case BVAND => bvand(left, right) - case BVOR => bvor(left, right) - case BVXOR => bvxor(left, right) - case BVNAND => bvnand(left, right) - case BVNOR => bvnor(left, right) - case BVXNOR => bvxnor(left, right) - case BVSHL => bvshl(left, right) - case BVLSHR => bvlshr(left, right) - case BVASHR => bvashr(left, right) - case BVCOMP => bvcomp(left, right) - case BVCONCAT => concat(left, right) - } - case un: UnaryExpr => - val arg = eval(un.arg, env) - un.op match { - case BVNOT => bvnot(arg) - case BVNEG => bvneg(arg) - } - case _ => valuelattice.top - } - } - - - /** Transfer function for state lattice elements. - */ - def localTransfer(n: CFGPosition, s: Map[Variable, FlatElement[BitVecLiteral]]): Map[Variable, FlatElement[BitVecLiteral]] = { - n match { - // assignments - case la: Assign => - s + (la.lhs -> eval(la.rhs, s)) - // all others: like no-ops - case _ => s - } - } - - /** The analysis lattice. - */ - val lattice: MapLattice[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]], MapLattice[Variable, FlatElement[BitVecLiteral], ConstantPropagationLattice]] = MapLattice(statelattice) - - val domain: Set[CFGPosition] = Set.empty ++ program - - /** Transfer function for state lattice elements. (Same as `localTransfer` for simple value analysis.) - */ - def transfer(n: CFGPosition, s: Map[Variable, FlatElement[BitVecLiteral]]): Map[Variable, FlatElement[BitVecLiteral]] = localTransfer(n, s) -} - -class ConstantPropagationSolver(program: Program) extends ConstantPropagation(program) - with SimplePushDownWorklistFixpointSolver[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]], MapLattice[Variable, FlatElement[BitVecLiteral], ConstantPropagationLattice]] - with IRInterproceduralForwardDependencies - with Analysis[Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]]] - - -/** Base class for value analysis with simple (non-lifted) lattice. - */ -trait ConstantPropagationWithSSA(val program: Program, val reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])]) { - /** The lattice of abstract states. - */ - - val valuelattice: ConstantPropagationLatticeWithSSA = ConstantPropagationLatticeWithSSA() - - val statelattice: MapLattice[RegisterWrapperEqualSets, Set[BitVecLiteral], ConstantPropagationLatticeWithSSA] = MapLattice(valuelattice) - - /** Default implementation of eval. - */ - def eval(exp: Expr, env: Map[RegisterWrapperEqualSets, Set[BitVecLiteral]], n: CFGPosition): Set[BitVecLiteral] = { - import valuelattice._ - exp match { - case id: Variable => env(RegisterWrapperEqualSets(id, getUse(id, n, reachingDefs))) - case n: BitVecLiteral => bv(n) - case ze: ZeroExtend => zero_extend(ze.extension, eval(ze.body, env, n)) - case se: SignExtend => sign_extend(se.extension, eval(se.body, env, n)) - case e: Extract => extract(e.end, e.start, eval(e.body, env, n)) - case bin: BinaryExpr => - val left = eval(bin.arg1, env, n) - val right = eval(bin.arg2, env, n) - bin.op match { - case BVADD => bvadd(left, right) - case BVSUB => bvsub(left, right) - case BVMUL => bvmul(left, right) - case BVUDIV => bvudiv(left, right) - case BVSDIV => bvsdiv(left, right) - case BVSREM => bvsrem(left, right) - case BVUREM => bvurem(left, right) - case BVSMOD => bvsmod(left, right) - case BVAND => bvand(left, right) - case BVOR => bvor(left, right) - case BVXOR => bvxor(left, right) - case BVNAND => bvnand(left, right) - case BVNOR => bvnor(left, right) - case BVXNOR => bvxnor(left, right) - case BVSHL => bvshl(left, right) - case BVLSHR => bvlshr(left, right) - case BVASHR => bvashr(left, right) - case BVCOMP => bvcomp(left, right) - case BVCONCAT => concat(left, right) - } - - case un: UnaryExpr => - val arg = eval(un.arg, env, n) - un.op match { - case BVNOT => bvnot(arg) - case BVNEG => bvneg(arg) - } - - case _ => Set.empty - } - } - - /** Transfer function for state lattice elements. - */ - def localTransfer(n: CFGPosition, s: Map[RegisterWrapperEqualSets, Set[BitVecLiteral]]): Map[RegisterWrapperEqualSets, Set[BitVecLiteral]] = - n match { - case a: Assign => - val lhsWrappers = s.collect { - case (k, v) if RegisterVariableWrapper(k.variable, k.assigns) == RegisterVariableWrapper(a.lhs, getDefinition(a.lhs, a, reachingDefs)) => (k, v) - } - if (lhsWrappers.nonEmpty) { - s ++ lhsWrappers.map((k, v) => (k, v.union(eval(a.rhs, s, a)))) - } else { - s + (RegisterWrapperEqualSets(a.lhs, getDefinition(a.lhs, a, reachingDefs)) -> eval(a.rhs, s, n)) - } - case _ => s - } - - /** The analysis lattice. - */ - val lattice: MapLattice[CFGPosition, Map[RegisterWrapperEqualSets, Set[BitVecLiteral]], MapLattice[RegisterWrapperEqualSets, Set[BitVecLiteral], ConstantPropagationLatticeWithSSA]] = MapLattice(statelattice) - - val domain: Set[CFGPosition] = Set.empty ++ program - - /** Transfer function for state lattice elements. (Same as `localTransfer` for simple value analysis.) - */ - def transfer(n: CFGPosition, s: Map[RegisterWrapperEqualSets, Set[BitVecLiteral]]): Map[RegisterWrapperEqualSets, Set[BitVecLiteral]] = localTransfer(n, s) -} - -class ConstantPropagationSolverWithSSA(program: Program, reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])]) extends ConstantPropagationWithSSA(program, reachingDefs) - with SimplePushDownWorklistFixpointSolver[CFGPosition, Map[RegisterWrapperEqualSets, Set[BitVecLiteral]], MapLattice[RegisterWrapperEqualSets, Set[BitVecLiteral], ConstantPropagationLatticeWithSSA]] - with IRInterproceduralForwardDependencies - with Analysis[Map[CFGPosition, Map[RegisterWrapperEqualSets, Set[BitVecLiteral]]]] + def analyze(): R \ No newline at end of file diff --git a/src/main/scala/analysis/BasicIRConstProp.scala b/src/main/scala/analysis/BasicIRConstProp.scala deleted file mode 100644 index 36a6d72d8..000000000 --- a/src/main/scala/analysis/BasicIRConstProp.scala +++ /dev/null @@ -1,75 +0,0 @@ -package analysis -import ir.* -import analysis.solvers.* - -trait ILValueAnalysisMisc: - val valuelattice: ConstantPropagationLattice = ConstantPropagationLattice() - val statelattice: MapLattice[Variable, FlatElement[BitVecLiteral], ConstantPropagationLattice] = MapLattice(valuelattice) - - def eval(exp: Expr, env: Map[Variable, FlatElement[BitVecLiteral]]): FlatElement[BitVecLiteral] = - import valuelattice._ - exp match - case id: Variable => env(id) - case n: BitVecLiteral => bv(n) - case ze: ZeroExtend => zero_extend(ze.extension, eval(ze.body, env)) - case se: SignExtend => sign_extend(se.extension, eval(se.body, env)) - case e: Extract => extract(e.end, e.start, eval(e.body, env)) - case bin: BinaryExpr => - val left = eval(bin.arg1, env) - val right = eval(bin.arg2, env) - bin.op match - case BVADD => bvadd(left, right) - case BVSUB => bvsub(left, right) - case BVMUL => bvmul(left, right) - case BVUDIV => bvudiv(left, right) - case BVSDIV => bvsdiv(left, right) - case BVSREM => bvsrem(left, right) - case BVUREM => bvurem(left, right) - case BVSMOD => bvsmod(left, right) - case BVAND => bvand(left, right) - case BVOR => bvor(left, right) - case BVXOR => bvxor(left, right) - case BVNAND => bvnand(left, right) - case BVNOR => bvnor(left, right) - case BVXNOR => bvxnor(left, right) - case BVSHL => bvshl(left, right) - case BVLSHR => bvlshr(left, right) - case BVASHR => bvashr(left, right) - case BVCOMP => bvcomp(left, right) - case BVCONCAT => concat(left, right) - - case un: UnaryExpr => - val arg = eval(un.arg, env) - - un.op match - case BVNOT => bvnot(arg) - case BVNEG => bvneg(arg) - - case _ => valuelattice.top - - private final val callerPreservedRegisters = Set("R0", "R1", "R2", "R3", "R4", "R5", "R6", "R7", "R8", "R9", "R10", - "R11", "R12", "R13", "R14", "R15", "R16", "R17", "R18", "R30") - - /** Transfer function for state lattice elements. - */ - def localTransfer(n: CFGPosition, s: Map[Variable, FlatElement[BitVecLiteral]]): Map[Variable, FlatElement[BitVecLiteral]] = - n match - case la: Assign => - s + (la.lhs -> eval(la.rhs, s)) - case c: Call => s ++ callerPreservedRegisters.filter(reg => s.keys.exists(_.name == reg)).map(n => Register(n, 64) -> statelattice.sublattice.top).toMap - case _ => s - - - -object IRSimpleValueAnalysis: - - class Solver(prog: Program) extends ILValueAnalysisMisc - with IRIntraproceduralForwardDependencies - with Analysis[Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]]] - with SimplePushDownWorklistFixpointSolver[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]], MapLattice[Variable, FlatElement[BitVecLiteral], ConstantPropagationLattice]]: - /* Worklist initial set */ - //override val lattice: MapLattice[CFGPosition, statelattice.type] = MapLattice(statelattice) - override val lattice: MapLattice[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]], MapLattice[Variable, FlatElement[BitVecLiteral], ConstantPropagationLattice]] = MapLattice(statelattice) - - override val domain: Set[CFGPosition] = computeDomain(IntraProcIRCursor, prog.procedures).toSet - def transfer(n: CFGPosition, s: Map[Variable, FlatElement[BitVecLiteral]]): Map[Variable, FlatElement[BitVecLiteral]] = localTransfer(n, s) diff --git a/src/main/scala/analysis/ConstantPropagation.scala b/src/main/scala/analysis/ConstantPropagation.scala new file mode 100644 index 000000000..a5f6d2d00 --- /dev/null +++ b/src/main/scala/analysis/ConstantPropagation.scala @@ -0,0 +1,192 @@ +package analysis +import ir.* +import analysis.solvers.* + +trait ConstantPropagation { + val valuelattice: ConstantPropagationLattice = ConstantPropagationLattice() + val statelattice: MapLattice[Variable, FlatElement[BitVecLiteral], ConstantPropagationLattice] = MapLattice(valuelattice) + val lattice: MapLattice[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]], MapLattice[Variable, FlatElement[BitVecLiteral], ConstantPropagationLattice]] = MapLattice(statelattice) + + def eval(exp: Expr, env: Map[Variable, FlatElement[BitVecLiteral]]): FlatElement[BitVecLiteral] = { + import valuelattice.* + exp match { + case id: Variable => env(id) + case n: BitVecLiteral => bv(n) + case ze: ZeroExtend => zero_extend(ze.extension, eval(ze.body, env)) + case se: SignExtend => sign_extend(se.extension, eval(se.body, env)) + case e: Extract => extract(e.end, e.start, eval(e.body, env)) + case bin: BinaryExpr => + val left = eval(bin.arg1, env) + val right = eval(bin.arg2, env) + bin.op match { + case BVADD => bvadd(left, right) + case BVSUB => bvsub(left, right) + case BVMUL => bvmul(left, right) + case BVUDIV => bvudiv(left, right) + case BVSDIV => bvsdiv(left, right) + case BVSREM => bvsrem(left, right) + case BVUREM => bvurem(left, right) + case BVSMOD => bvsmod(left, right) + case BVAND => bvand(left, right) + case BVOR => bvor(left, right) + case BVXOR => bvxor(left, right) + case BVNAND => bvnand(left, right) + case BVNOR => bvnor(left, right) + case BVXNOR => bvxnor(left, right) + case BVSHL => bvshl(left, right) + case BVLSHR => bvlshr(left, right) + case BVASHR => bvashr(left, right) + case BVCOMP => bvcomp(left, right) + case BVCONCAT => concat(left, right) + } + case un: UnaryExpr => + val arg = eval(un.arg, env) + un.op match { + case BVNOT => bvnot(arg) + case BVNEG => bvneg(arg) + } + case _ => valuelattice.top + } + } +} + +class IntraProcConstantPropagation(prog: Program) extends ConstantPropagation +with IRIntraproceduralForwardDependencies +with Analysis[Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]]] +with SimplePushDownWorklistFixpointSolver[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]], MapLattice[Variable, FlatElement[BitVecLiteral], ConstantPropagationLattice]] { + override val domain: Set[CFGPosition] = computeDomain(IntraProcIRCursor, prog.procedures).toSet + + private final val callerPreservedRegisters: Set[Variable] = Set("R0", "R1", "R2", "R3", "R4", "R5", "R6", "R7", "R8", "R9", "R10", + "R11", "R12", "R13", "R14", "R15", "R16", "R17", "R18", "R30").map(n => Register(n, 64)) + + def transfer(n: CFGPosition, s: Map[Variable, FlatElement[BitVecLiteral]]): Map[Variable, FlatElement[BitVecLiteral]] = { + n match { + case la: LocalAssign => + s + (la.lhs -> eval(la.rhs, s)) + case l: MemoryLoad => + s + (l.lhs -> valuelattice.top) + case _: Call => s.map { (k, v) => + if (callerPreservedRegisters.contains(k)) { + (k, valuelattice.top) + } else { + (k, v) + } + } + case _ => s + } + } +} + +class InterProcConstantPropagation(val program: Program) extends ConstantPropagation +with SimplePushDownWorklistFixpointSolver[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]], MapLattice[Variable, FlatElement[BitVecLiteral], ConstantPropagationLattice]] +with IRInterproceduralForwardDependencies +with Analysis[Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]]] { + + def transfer(n: CFGPosition, s: Map[Variable, FlatElement[BitVecLiteral]]): Map[Variable, FlatElement[BitVecLiteral]] = { + n match { + // assignments + case la: LocalAssign => + s + (la.lhs -> eval(la.rhs, s)) + case load: MemoryLoad => + s + (load.lhs -> valuelattice.top) + // all others: like no-ops + case _ => s + } + } + + override val domain: Set[CFGPosition] = Set.empty ++ program +} + +/** Base class for value analysis with simple (non-lifted) lattice. + */ +trait ConstantPropagationWithSSA(val program: Program, val reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])]) { + /** The lattice of abstract states. + */ + + val valuelattice: ConstantPropagationLatticeWithSSA = ConstantPropagationLatticeWithSSA() + + val statelattice: MapLattice[RegisterWrapperEqualSets, Set[BitVecLiteral], ConstantPropagationLatticeWithSSA] = MapLattice(valuelattice) + + /** Default implementation of eval. + */ + def eval(exp: Expr, env: Map[RegisterWrapperEqualSets, Set[BitVecLiteral]], n: CFGPosition): Set[BitVecLiteral] = { + import valuelattice.* + exp match { + case id: Variable => env(RegisterWrapperEqualSets(id, getUse(id, n, reachingDefs))) + case n: BitVecLiteral => bv(n) + case ze: ZeroExtend => zero_extend(ze.extension, eval(ze.body, env, n)) + case se: SignExtend => sign_extend(se.extension, eval(se.body, env, n)) + case e: Extract => extract(e.end, e.start, eval(e.body, env, n)) + case bin: BinaryExpr => + val left = eval(bin.arg1, env, n) + val right = eval(bin.arg2, env, n) + bin.op match { + case BVADD => bvadd(left, right) + case BVSUB => bvsub(left, right) + case BVMUL => bvmul(left, right) + case BVUDIV => bvudiv(left, right) + case BVSDIV => bvsdiv(left, right) + case BVSREM => bvsrem(left, right) + case BVUREM => bvurem(left, right) + case BVSMOD => bvsmod(left, right) + case BVAND => bvand(left, right) + case BVOR => bvor(left, right) + case BVXOR => bvxor(left, right) + case BVNAND => bvnand(left, right) + case BVNOR => bvnor(left, right) + case BVXNOR => bvxnor(left, right) + case BVSHL => bvshl(left, right) + case BVLSHR => bvlshr(left, right) + case BVASHR => bvashr(left, right) + case BVCOMP => bvcomp(left, right) + case BVCONCAT => concat(left, right) + } + + case un: UnaryExpr => + val arg = eval(un.arg, env, n) + un.op match { + case BVNOT => bvnot(arg) + case BVNEG => bvneg(arg) + } + + case _ => Set.empty + } + } + + /** Transfer function for state lattice elements. + */ + def transfer(n: CFGPosition, s: Map[RegisterWrapperEqualSets, Set[BitVecLiteral]]): Map[RegisterWrapperEqualSets, Set[BitVecLiteral]] = + n match { + case a: LocalAssign => + val lhsWrappers = s.collect { + case (k, v) if RegisterVariableWrapper(k.variable, k.assigns) == RegisterVariableWrapper(a.lhs, getDefinition(a.lhs, a, reachingDefs)) => (k, v) + } + if (lhsWrappers.nonEmpty) { + s ++ lhsWrappers.map((k, v) => (k, v.union(eval(a.rhs, s, a)))) + } else { + s + (RegisterWrapperEqualSets(a.lhs, getDefinition(a.lhs, a, reachingDefs)) -> eval(a.rhs, s, n)) + } + case l: MemoryLoad => + val lhsWrappers = s.collect { + case (k, v) if RegisterVariableWrapper(k.variable, k.assigns) == RegisterVariableWrapper(l.lhs, getDefinition(l.lhs, l, reachingDefs)) => (k, v) + } + if (lhsWrappers.nonEmpty) { + s ++ lhsWrappers + } else { + s + (RegisterWrapperEqualSets(l.lhs, getDefinition(l.lhs, l, reachingDefs)) -> Set()) + } + + case _ => s + } + + /** The analysis lattice. + */ + val lattice: MapLattice[CFGPosition, Map[RegisterWrapperEqualSets, Set[BitVecLiteral]], MapLattice[RegisterWrapperEqualSets, Set[BitVecLiteral], ConstantPropagationLatticeWithSSA]] = MapLattice(statelattice) + + val domain: Set[CFGPosition] = Set.empty ++ program +} + +class ConstantPropagationSolverWithSSA(program: Program, reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])]) extends ConstantPropagationWithSSA(program, reachingDefs) + with SimplePushDownWorklistFixpointSolver[CFGPosition, Map[RegisterWrapperEqualSets, Set[BitVecLiteral]], MapLattice[RegisterWrapperEqualSets, Set[BitVecLiteral], ConstantPropagationLatticeWithSSA]] + with IRInterproceduralForwardDependencies + with Analysis[Map[CFGPosition, Map[RegisterWrapperEqualSets, Set[BitVecLiteral]]]] diff --git a/src/main/scala/analysis/GlobalRegionAnalysis.scala b/src/main/scala/analysis/GlobalRegionAnalysis.scala index 47b6b8ddb..2157b4d80 100644 --- a/src/main/scala/analysis/GlobalRegionAnalysis.scala +++ b/src/main/scala/analysis/GlobalRegionAnalysis.scala @@ -88,7 +88,6 @@ trait GlobalRegionAnalysis(val program: Program, } else { Set() } - case _: MemoryLoad => ??? case _: UninterpretedFunction => Set.empty case variable: Variable => val ctx = getUse(variable, n, reachingDefs) @@ -169,16 +168,13 @@ trait GlobalRegionAnalysis(val program: Program, */ def localTransfer(n: CFGPosition, s: Set[DataRegion]): Set[DataRegion] = { n match { - case memAssign: MemoryAssign => - checkIfDefined(evalMemLoadToGlobal(memAssign.index, memAssign.size, memAssign), n) - case assign: Assign => - val unwrapped = unwrapExpr(assign.rhs) - if (unwrapped.isDefined) { - checkIfDefined(evalMemLoadToGlobal(unwrapped.get.index, unwrapped.get.size, assign, loadOp = true), n) - } else { - // this is a constant but we need to check if it is a data region - checkIfDefined(evalMemLoadToGlobal(assign.rhs, 1, assign), n) - } + case store: MemoryStore => + checkIfDefined(evalMemLoadToGlobal(store.index, store.size, store), n) + case load: MemoryLoad => + checkIfDefined(evalMemLoadToGlobal(load.index, load.size, load, loadOp = true), n) + case assign: LocalAssign => + // this is a constant but we need to check if it is a data region + checkIfDefined(evalMemLoadToGlobal(assign.rhs, 1, assign), n) case _ => Set() } @@ -188,12 +184,12 @@ trait GlobalRegionAnalysis(val program: Program, } class GlobalRegionAnalysisSolver( - program: Program, - domain: Set[CFGPosition], - constantProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], - reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], - mmm: MemoryModelMap, - vsaResult: Map[CFGPosition, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]] + program: Program, + domain: Set[CFGPosition], + constantProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], + reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], + mmm: MemoryModelMap, + vsaResult: Map[CFGPosition, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]] ) extends GlobalRegionAnalysis(program, domain, constantProp, reachingDefs, mmm, vsaResult) with IRIntraproceduralForwardDependencies with Analysis[Map[CFGPosition, Set[DataRegion]]] diff --git a/src/main/scala/analysis/InterLiveVarsAnalysis.scala b/src/main/scala/analysis/InterLiveVarsAnalysis.scala index cbe3076c6..ef9173a65 100644 --- a/src/main/scala/analysis/InterLiveVarsAnalysis.scala +++ b/src/main/scala/analysis/InterLiveVarsAnalysis.scala @@ -1,7 +1,7 @@ package analysis import analysis.solvers.BackwardIDESolver -import ir.{Assert, Assume, Block, GoTo, CFGPosition, Command, DirectCall, IndirectCall, Assign, MemoryAssign, Unreachable, Return, Procedure, Program, Variable, toShortString} +import ir.{Assert, LocalAssign, Assume, CFGPosition, Command, DirectCall, IndirectCall, MemoryLoad, MemoryStore, Procedure, Program, Return, Variable} /** * Micro-transfer-functions for LiveVar analysis @@ -28,54 +28,68 @@ trait LiveVarsAnalysisFunctions extends BackwardIDEAnalysis[Variable, TwoElement } def edgesCallToAfterCall(call: Command, aftercall: DirectCall)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { - d match - case Left(value) => Map() // maps all variables before the call to bottom + d match { + case Left(_) => Map() // maps all variables before the call to bottom case Right(_) => Map(d -> IdEdge()) + } } def edgesOther(n: CFGPosition)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { - n match - case Assign(variable, expr, _) => // (s - variable) ++ expr.variables - d match + n match { + case LocalAssign(variable, expr, _) => // (s - variable) ++ expr.variables + d match { case Left(value) => if value == variable then Map() else Map(d -> IdEdge()) - - case Right(_) => expr.variables.foldLeft(Map[DL, EdgeFunction[TwoElement]](d -> IdEdge())) { + case Right(_) => + expr.variables.foldLeft(Map[DL, EdgeFunction[TwoElement]](d -> IdEdge())) { + (mp, expVar) => mp + (Left(expVar) -> ConstEdge(TwoElementTop)) + } + } + case MemoryLoad(lhs, _, index, _, _, _) => + d match { + case Left(value) => + if value == lhs then + Map() + else + Map(d -> IdEdge()) + case Right(_) => index.variables.foldLeft(Map[DL, EdgeFunction[TwoElement]](d -> IdEdge())) { (mp, expVar) => mp + (Left(expVar) -> ConstEdge(TwoElementTop)) } - - case MemoryAssign(_, index, value, _, _, _) => // s ++ store.index.variables ++ store.value.variables - d match - case Left(value) => Map(d -> IdEdge()) + } + case MemoryStore(_, index, value, _, _, _) => // s ++ store.index.variables ++ store.value.variables + d match { + case Left(_) => Map(d -> IdEdge()) case Right(_) => (index.variables ++ value.variables).foldLeft(Map[DL, EdgeFunction[TwoElement]](d -> IdEdge())) { (mp, storVar) => mp + (Left(storVar) -> ConstEdge(TwoElementTop)) } - + } case Assume(expr, _, _, _) => // s ++ expr.variables - d match - case Left(value) => Map(d -> IdEdge()) + d match { + case Left(_) => Map(d -> IdEdge()) case Right(_) => expr.variables.foldLeft(Map(d -> IdEdge()): Map[DL, EdgeFunction[TwoElement]]) { (mp, expVar) => mp + (Left(expVar) -> ConstEdge(TwoElementTop)) } - + } case Assert(expr, _, _) => // s ++ expr.variables - d match - case Left(value) => Map(d -> IdEdge()) + d match { + case Left(_) => Map(d -> IdEdge()) case Right(_) => expr.variables.foldLeft(Map[DL, EdgeFunction[TwoElement]](d -> IdEdge())) { (mp, expVar) => mp + (Left(expVar) -> ConstEdge(TwoElementTop)) } + } case IndirectCall(variable, _) => - d match + d match { case Left(value) => if value != variable then Map(d -> IdEdge()) else Map() case Right(_) => Map(d -> IdEdge(), Left(variable) -> ConstEdge(TwoElementTop)) + } case _ => Map(d -> IdEdge()) - + } } } diff --git a/src/main/scala/analysis/InterprocSteensgaardAnalysis.scala b/src/main/scala/analysis/InterprocSteensgaardAnalysis.scala index c05f83d60..ea34505f1 100644 --- a/src/main/scala/analysis/InterprocSteensgaardAnalysis.scala +++ b/src/main/scala/analysis/InterprocSteensgaardAnalysis.scala @@ -28,10 +28,10 @@ case class RegisterWrapperEqualSets(variable: Variable, assigns: Set[Assign]) * expression node in the AST. It is implemented using [[analysis.solvers.UnionFindSolver]]. */ class InterprocSteensgaardAnalysis( - domain: Set[CFGPosition], - mmm: MemoryModelMap, - reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], - vsaResult: Map[CFGPosition, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]]) extends Analysis[Any] { + domain: Set[CFGPosition], + mmm: MemoryModelMap, + reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], + vsaResult: Map[CFGPosition, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]]) extends Analysis[Any] { val solver: UnionFindSolver[StTerm] = UnionFindSolver() @@ -80,7 +80,8 @@ class InterprocSteensgaardAnalysis( val alloc = mmm.nodeToRegion(directCall).head val defs = getDefinition(mallocVariable, directCall, reachingDefs) unify(IdentifierVariable(RegisterWrapperEqualSets(mallocVariable, defs)), PointerRef(AllocVariable(alloc))) - case assign: Assign => + case assign: LocalAssign => + // TODO: unsound val unwrapped = unwrapExprToVar(assign.rhs) if (unwrapped.isDefined) { // X1 = X2: [[X1]] = [[X2]] @@ -97,11 +98,11 @@ class InterprocSteensgaardAnalysis( } unify(IdentifierVariable(RegisterWrapperEqualSets(X1, getDefinition(X1, assign, reachingDefs))), alpha) } - case memoryAssign: MemoryAssign => + case memoryStore: MemoryStore => // *X1 = X2: [[X1]] = ↑a ^ [[X2]] = a where a is a fresh term variable val X1_star = mmm.nodeToRegion(node) - // TODO: This is risky as it tries to coerce every value to a region (needed for functionpointer example) - val unwrapped = unwrapExprToVar(memoryAssign.value) + // TODO: This is not sound + val unwrapped = unwrapExprToVar(memoryStore.value) if (unwrapped.isDefined) { val X2 = unwrapped.get val X2_regions: Set[MemoryRegion] = vsaApproximation(X2, node) @@ -115,6 +116,25 @@ class InterprocSteensgaardAnalysis( unify(ExpressionVariable(x), alpha) } } + case memoryLoad: MemoryLoad => + // TODO: unsound + val unwrapped = unwrapExprToVar(memoryLoad.index) + if (unwrapped.isDefined) { + // X1 = X2: [[X1]] = [[X2]] + val X1 = memoryLoad.lhs + val X2 = unwrapped.get + unify(IdentifierVariable(RegisterWrapperEqualSets(X1, getDefinition(X1, memoryLoad, reachingDefs))), IdentifierVariable(RegisterWrapperEqualSets(X2, getUse(X2, memoryLoad, reachingDefs)))) + } else { + // X1 = *X2: [[X2]] = ↑a ^ [[X1]] = a where a is a fresh term variable + val X1 = memoryLoad.lhs + val X2_star = mmm.nodeToRegion(node) + val alpha = FreshVariable() + X2_star.foreach { x => + unify(PointerRef(alpha), ExpressionVariable(x)) + } + unify(IdentifierVariable(RegisterWrapperEqualSets(X1, getDefinition(X1, memoryLoad, reachingDefs))), alpha) + } + case _ => // do nothing TODO: Maybe LocalVar too? } } diff --git a/src/main/scala/analysis/IntraLiveVarsAnalysis.scala b/src/main/scala/analysis/IntraLiveVarsAnalysis.scala index a576b27fb..1271aa720 100644 --- a/src/main/scala/analysis/IntraLiveVarsAnalysis.scala +++ b/src/main/scala/analysis/IntraLiveVarsAnalysis.scala @@ -1,28 +1,31 @@ package analysis import analysis.solvers.SimpleWorklistFixpointSolver -import ir.{Assert, Assume, Block, CFGPosition, Call, DirectCall, GoTo, IndirectCall, Jump, Assign, MemoryAssign, NOP, Procedure, Program, Statement, Variable, Return, Unreachable} +import ir.{Assert, Assume, Block, CFGPosition, Call, DirectCall, GoTo, IndirectCall, Jump, LocalAssign, MemoryLoad, MemoryStore, Procedure, Program, Statement, Variable, Return, Unreachable} -abstract class LivenessAnalysis(program: Program) extends Analysis[Any]: +abstract class LivenessAnalysis(program: Program) extends Analysis[Any] { val lattice: MapLattice[CFGPosition, Set[Variable], PowersetLattice[Variable]] = MapLattice(PowersetLattice()) val domain: Set[CFGPosition] = Set.empty ++ program def transfer(n: CFGPosition, s: Set[Variable]): Set[Variable] = { n match { - case p: Procedure => s - case b: Block => s - case Assign(variable, expr, _) => (s - variable) ++ expr.variables - case MemoryAssign(_, index, value, _, _, _) => s ++ index.variables ++ value.variables + case _: Procedure => s + case _: Block => s + case LocalAssign(variable, expr, _) => (s - variable) ++ expr.variables + case MemoryStore(_, index, value, _, _, _) => s ++ index.variables ++ value.variables + case MemoryLoad(lhs, _, index, _, _, _) => (s - lhs) ++ index.variables case Assume(expr, _, _, _) => s ++ expr.variables case Assert(expr, _, _) => s ++ expr.variables case IndirectCall(variable, _) => s + variable - case c: DirectCall => s - case g: GoTo => s - case r: Return => s - case r: Unreachable => s + case _: DirectCall => s + case _: GoTo => s + case _: Return => s + case _: Unreachable => s case _ => ??? } } +} + class IntraLiveVarsAnalysis(program: Program) extends LivenessAnalysis(program) diff --git a/src/main/scala/analysis/MemoryRegionAnalysis.scala b/src/main/scala/analysis/MemoryRegionAnalysis.scala index 5f8e07560..d40957842 100644 --- a/src/main/scala/analysis/MemoryRegionAnalysis.scala +++ b/src/main/scala/analysis/MemoryRegionAnalysis.scala @@ -1,6 +1,6 @@ package analysis -import analysis.BitVectorEval.isNegative +import analysis.BitVectorEval.bv2SignedInt import analysis.solvers.SimpleWorklistFixpointSolver import ir.* import util.Logger @@ -62,7 +62,7 @@ trait MemoryRegionAnalysis(val program: Program, Logger.debug("Stack detection") Logger.debug(spList) stmt match { - case assign: Assign => + case assign: LocalAssign => if (spList.contains(assign.rhs)) { // add lhs to spList spList.addOne(assign.lhs) @@ -104,10 +104,11 @@ trait MemoryRegionAnalysis(val program: Program, evaluateExpression(binExpr.arg2, constantProp(n)) match { case Some(b: BitVecLiteral) => val ctx = getUse(variable, n, reachingDefs) - for { - i <- ctx - stackRegion <- eval(i.rhs, Set.empty, i, subAccess) - } yield { + val stackRegions = ctx.flatMap { + case l: LocalAssign => eval(l.rhs, l, subAccess) + case m: MemoryLoad => eval(m.index, m, m.size) + } + for (stackRegion <- stackRegions) yield { val nextOffset = bitVectorOpToBigIntOp(binExpr.op, stackRegion.start, b.value) poolMaster(nextOffset, IRWalk.procedure(n), subAccess) } @@ -115,7 +116,7 @@ trait MemoryRegionAnalysis(val program: Program, Set() } case _ => - eval(binExpr, Set.empty, n, subAccess) + eval(binExpr, n, subAccess) } reducedRegions } @@ -126,7 +127,10 @@ trait MemoryRegionAnalysis(val program: Program, // TODO: nicer way to deal with loops (a variable is being incremented in a loop) val regions = ctx.flatMap { i => if (i != n) { - eval(i.rhs, Set.empty, i, subAccess) + i match { + case l: LocalAssign => eval(l.rhs, l, subAccess) + case m: MemoryLoad => eval(m.index, m, m.size) + } } else { Set() } @@ -134,7 +138,7 @@ trait MemoryRegionAnalysis(val program: Program, regions } - def eval(exp: Expr, env: Set[StackRegion], n: Command, subAccess: BigInt): Set[StackRegion] = { + def eval(exp: Expr, n: Command, subAccess: BigInt): Set[StackRegion] = { if (graResult(n).nonEmpty) { Set.empty // skip global memory regions } else { @@ -143,7 +147,7 @@ trait MemoryRegionAnalysis(val program: Program, if (spList.contains(binOp.arg1)) { evaluateExpression(binOp.arg2, constantProp(n)) match { case Some(b: BitVecLiteral) => - val negB = if isNegative(b) then b.value - BigInt(2).pow(b.size) else b.value + val negB = bv2SignedInt(b) Set(poolMaster(negB, IRWalk.procedure(n), subAccess)) case None => Set.empty } @@ -161,12 +165,10 @@ trait MemoryRegionAnalysis(val program: Program, case variable: Variable => evaluateExpression(variable, constantProp(n)) match { case Some(b: BitVecLiteral) => - eval(b, env, n, subAccess) + eval(b, n, subAccess) case _ => reducibleVariable(variable, n, subAccess) } - case memoryLoad: MemoryLoad => - eval(memoryLoad.index, env, n, memoryLoad.size) // ignore case where it could be a global region (loaded later in MMM from relf) case _: BitVecLiteral => Set.empty @@ -202,7 +204,7 @@ trait MemoryRegionAnalysis(val program: Program, if (directCall.target.name == "malloc") { evaluateExpression(mallocVariable, constantProp(n)) match { case Some(b: BitVecLiteral) => - val negB = if isNegative(b) then b.value - BigInt(2).pow(b.size) else b.value + val negB = bv2SignedInt(b) val (name, start) = nextMallocCount(negB) val newHeapRegion = HeapRegion(name, start, negB, IRWalk.procedure(n)) addReturnHeap(directCall, newHeapRegion) @@ -218,25 +220,18 @@ trait MemoryRegionAnalysis(val program: Program, } else { s } - case memAssign: MemoryAssign => - val result = eval(memAssign.index, s, memAssign, memAssign.size) + case memAssign: MemoryStore => + val result = eval(memAssign.index, memAssign, memAssign.size) // if (result.size > 1) { // //throw new Exception(s"Memory load resulted in multiple regions ${result} for mem load $memoryLoad") // addMergableRegions(result) // } result - case assign: Assign => + case assign: LocalAssign => stackDetection(assign) - val unwrapped = unwrapExpr(assign.rhs) - if (unwrapped.isDefined) { - eval(unwrapped.get.index, s, assign, unwrapped.get.size) - // if (result.size > 1) { - // //throw new Exception(s"Memory load resulted in multiple regions ${result} for mem load $memoryLoad") - // addMergableRegions(result) - // } - } else { - Set() - } + Set() + case load: MemoryLoad => + eval(load.index, load, load.size) case _ => s } @@ -244,17 +239,17 @@ trait MemoryRegionAnalysis(val program: Program, } class MemoryRegionAnalysisSolver( - program: Program, - domain: Set[CFGPosition], - globals: Map[BigInt, String], - globalOffsets: Map[BigInt, BigInt], - subroutines: Map[BigInt, String], - constantProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], - ANRResult: Map[CFGPosition, Set[Variable]], - RNAResult: Map[CFGPosition, Set[Variable]], - reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], - graResult: Map[CFGPosition, Set[DataRegion]], - mmm: MemoryModelMap + program: Program, + domain: Set[CFGPosition], + globals: Map[BigInt, String], + globalOffsets: Map[BigInt, BigInt], + subroutines: Map[BigInt, String], + constantProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], + ANRResult: Map[CFGPosition, Set[Variable]], + RNAResult: Map[CFGPosition, Set[Variable]], + reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], + graResult: Map[CFGPosition, Set[DataRegion]], + mmm: MemoryModelMap ) extends MemoryRegionAnalysis(program, domain, globals, globalOffsets, subroutines, constantProp, ANRResult, RNAResult, reachingDefs, graResult, mmm) with IRIntraproceduralForwardDependencies with Analysis[Map[CFGPosition, Set[StackRegion]]] diff --git a/src/main/scala/analysis/RNA.scala b/src/main/scala/analysis/RNA.scala index b58e74dd9..e15dbe914 100644 --- a/src/main/scala/analysis/RNA.scala +++ b/src/main/scala/analysis/RNA.scala @@ -22,47 +22,44 @@ trait RNAAnalysis(program: Program) { private val linkRegister = Register("R30", 64) private val framePointer = Register("R29", 64) - private val ignoreRegions: Set[Expr] = Set(linkRegister, framePointer, stackPointer) + private val ignoreRegions: Set[Variable] = Set(linkRegister, framePointer, stackPointer) - /** Default implementation of eval. - */ def eval(cmd: Command, s: Set[Variable]): Set[Variable] = { - var m = s cmd match { case assume: Assume => - m.union(assume.body.variables.filter(!ignoreRegions.contains(_))) + s ++ (assume.body.variables -- ignoreRegions) case assert: Assert => - m.union(assert.body.variables.filter(!ignoreRegions.contains(_))) - case memoryAssign: MemoryAssign => - m.union(memoryAssign.index.variables.filter(!ignoreRegions.contains(_))) + s ++ (assert.body.variables -- ignoreRegions) + case memoryStore: MemoryStore => + s ++ (memoryStore.index.variables -- ignoreRegions) case indirectCall: IndirectCall => - if (ignoreRegions.contains(indirectCall.target)) return m - m + indirectCall.target - case assign: Assign => - m = m - assign.lhs - m.union(assign.rhs.variables.filter(!ignoreRegions.contains(_))) + if (ignoreRegions.contains(indirectCall.target)) { + s + } else { + s + indirectCall.target + } + case assign: LocalAssign => + val m = s - assign.lhs + m ++ (assign.rhs.variables -- ignoreRegions) + case memoryLoad: MemoryLoad => + val m = s - memoryLoad.lhs + m ++ (memoryLoad.index.variables -- ignoreRegions) case _ => - m + s } } /** Transfer function for state lattice elements. */ - def localTransfer(n: CFGPosition, s: Set[Variable]): Set[Variable] = n match { + def transfer(n: CFGPosition, s: Set[Variable]): Set[Variable] = n match { case cmd: Command => eval(cmd, s) case _ => s // ignore other kinds of nodes } - /** Transfer function for state lattice elements. - */ - def transfer(n: CFGPosition, s: Set[Variable]): Set[Variable] = localTransfer(n, s) } -class RNAAnalysisSolver( - program: Program, -) extends RNAAnalysis(program) +class RNAAnalysisSolver(program: Program) extends RNAAnalysis(program) with IRIntraproceduralBackwardDependencies with Analysis[Map[CFGPosition, Set[Variable]]] - with SimpleWorklistFixpointSolver[CFGPosition, Set[Variable], PowersetLattice[Variable]] { -} \ No newline at end of file + with SimpleWorklistFixpointSolver[CFGPosition, Set[Variable], PowersetLattice[Variable]] \ No newline at end of file diff --git a/src/main/scala/analysis/ReachingDefinitionsAnalysis.scala b/src/main/scala/analysis/ReachingDefinitionsAnalysis.scala index 93aa4ad6e..07fe443d2 100644 --- a/src/main/scala/analysis/ReachingDefinitionsAnalysis.scala +++ b/src/main/scala/analysis/ReachingDefinitionsAnalysis.scala @@ -21,16 +21,12 @@ trait ReachingDefinitionsAnalysis(program: Program) { val domain: Set[CFGPosition] = Set.empty ++ program - def transfer(n: CFGPosition, s: (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])): (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]]) = - localTransfer(n, s) - - def localTransfer( - n: CFGPosition, - s: (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]]) - ): (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]]) = n match { - case cmd: Command => - eval(cmd, s) - case _ => s + def transfer(n: CFGPosition, s: (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])): (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]]) = { + n match { + case cmd: Command => + eval(cmd, s) + case _ => s + } } private def transformUses(vars: Set[Variable], s: (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])): (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]]) = { @@ -42,7 +38,7 @@ trait ReachingDefinitionsAnalysis(program: Program) { def eval(cmd: Command, s: (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])): (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]]) = cmd match { - case assign: Assign => + case assign: LocalAssign => // do the rhs first (should reset the values for this node to the empty set) // for each variable in the rhs, find the definitions from the lattice lhs and add them to the lattice rhs // for lhs, addOrReplace the definition @@ -55,8 +51,16 @@ trait ReachingDefinitionsAnalysis(program: Program) { (s(0) + (lhs -> Set(assign)), rhsUseDefs) case assert: Assert => transformUses(assert.body.variables, s) - case memoryAssign: MemoryAssign => - transformUses(memoryAssign.index.variables ++ memoryAssign.value.variables, s) + case memoryStore: MemoryStore => + transformUses(memoryStore.index.variables ++ memoryStore.value.variables, s) + case memoryLoad: MemoryLoad => + val lhs = memoryLoad.lhs + val rhs = memoryLoad.index.variables + val rhsUseDefs: Map[Variable, Set[Assign]] = rhs.foldLeft(Map.empty[Variable, Set[Assign]]) { + case (acc, v) => + acc + (v -> s(0)(v)) + } + (s(0) + (lhs -> Set(memoryLoad)), rhsUseDefs) case assume: Assume => transformUses(assume.body.variables, s) case indirectCall: IndirectCall => diff --git a/src/main/scala/analysis/ReachingDefs.scala b/src/main/scala/analysis/ReachingDefs.scala index 5a2bb2f2a..f60e9ffc6 100644 --- a/src/main/scala/analysis/ReachingDefs.scala +++ b/src/main/scala/analysis/ReachingDefs.scala @@ -1,7 +1,7 @@ package analysis import analysis.solvers.SimplePushDownWorklistFixpointSolver -import ir.{Assert, Assume, BitVecType, CFGPosition, Call, DirectCall, Expr, GoTo, IndirectCall, InterProcIRCursor, IntraProcIRCursor, Assign, MemoryAssign, NOP, Procedure, Program, Register, Variable, computeDomain} +import ir.{LocalAssign, CFGPosition, DirectCall, IntraProcIRCursor, MemoryLoad, Procedure, Program, Register, Variable, computeDomain} abstract class ReachingDefs(program: Program, writesTo: Map[Procedure, Set[Register]]) extends Analysis[Map[CFGPosition, Map[Variable, Set[CFGPosition]]]] { @@ -11,8 +11,10 @@ abstract class ReachingDefs(program: Program, writesTo: Map[Procedure, Set[Regis def transfer(n: CFGPosition, s: Map[Variable, Set[CFGPosition]]): Map[Variable, Set[CFGPosition]] = { n match { - case loc: Assign => + case loc: LocalAssign => s + (loc.lhs -> Set(n)) + case load: MemoryLoad => + s + (load.lhs -> Set(n)) case DirectCall(target, _) if target.name == "malloc" => s + (mallocRegister -> Set(n)) case DirectCall(target, _) if writesTo.contains(target) => @@ -27,6 +29,6 @@ abstract class ReachingDefs(program: Program, writesTo: Map[Procedure, Set[Regis } -class ReachingDefsAnalysis(program: Program, writesTo: Map[Procedure, Set[Register]]) extends ReachingDefs(program, writesTo), IRIntraproceduralForwardDependencies, +class ReachingDefsAnalysis(program: Program, writesTo: Map[Procedure, Set[Register]]) extends ReachingDefs(program, writesTo), IRIntraproceduralForwardDependencies, SimplePushDownWorklistFixpointSolver[CFGPosition, Map[Variable, Set[CFGPosition]], MapLattice[Variable, Set[CFGPosition], PowersetLattice[CFGPosition]]] diff --git a/src/main/scala/analysis/RegToMemAnalysis.scala b/src/main/scala/analysis/RegToMemAnalysis.scala deleted file mode 100644 index 1afde6f4d..000000000 --- a/src/main/scala/analysis/RegToMemAnalysis.scala +++ /dev/null @@ -1,72 +0,0 @@ -package analysis - -import ir.{MemoryLoad, *} -import analysis.solvers.* -import util.Logger - -import scala.collection.immutable - -/** - * Collects all the memory loads and the expressions that are assigned to a register but cannot be evaluated. - * - * Tracks: - * R_x = MemoryLoad[Base + Offset] - * R_x = Base + Offset - * - * Both in which constant propagation mark as TOP which is not useful. - */ -trait RegionAccessesAnalysis(program: Program, constantProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])]) { - - val mapLattice: MapLattice[RegisterVariableWrapper, FlatElement[Expr], FlatLattice[Expr]] = MapLattice(FlatLattice[_root_.ir.Expr]()) - - val lattice: MapLattice[CFGPosition, Map[RegisterVariableWrapper, FlatElement[Expr]], MapLattice[RegisterVariableWrapper, FlatElement[Expr], FlatLattice[Expr]]] = MapLattice(mapLattice) - - val domain: Set[CFGPosition] = program.toSet - - val first: Set[CFGPosition] = program.procedures.toSet - - /** Default implementation of eval. - */ - def eval(cmd: Statement, constants: Map[Variable, FlatElement[BitVecLiteral]], s: Map[RegisterVariableWrapper, FlatElement[Expr]]): Map[RegisterVariableWrapper, FlatElement[Expr]] = { - cmd match { - case assign: Assign => - assign.rhs match { - case memoryLoad: MemoryLoad => - s + (RegisterVariableWrapper(assign.lhs, getDefinition(assign.lhs, cmd, reachingDefs)) -> FlatEl(memoryLoad)) - case binaryExpr: BinaryExpr => - if (evaluateExpression(binaryExpr.arg1, constants).isEmpty) { // approximates Base + Offset - Logger.debug(s"Approximating $assign in $binaryExpr") - Logger.debug(s"Reaching defs: ${reachingDefs(cmd)}") - s + (RegisterVariableWrapper(assign.lhs, getDefinition(assign.lhs, cmd, reachingDefs)) -> FlatEl(binaryExpr)) - } else { - s - } - case _ => s - } - case _ => - s - } - } - - /** Transfer function for state lattice elements. - */ - def localTransfer(n: CFGPosition, s: Map[RegisterVariableWrapper, FlatElement[Expr]]): Map[RegisterVariableWrapper, FlatElement[Expr]] = n match { - case cmd: Statement => - eval(cmd, constantProp(cmd), s) - case _ => s // ignore other kinds of nodes - } - - /** Transfer function for state lattice elements. - */ - def transfer(n: CFGPosition, s: Map[RegisterVariableWrapper, FlatElement[Expr]]): Map[RegisterVariableWrapper, FlatElement[Expr]] = localTransfer(n, s) -} - -class RegionAccessesAnalysisSolver( - program: Program, - constantProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], - reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], - ) extends RegionAccessesAnalysis(program, constantProp, reachingDefs) - with IRInterproceduralForwardDependencies - with Analysis[Map[CFGPosition, Map[RegisterVariableWrapper, FlatElement[Expr]]]] - with SimpleWorklistFixpointSolver[CFGPosition, Map[RegisterVariableWrapper, FlatElement[Expr]], MapLattice[RegisterVariableWrapper, FlatElement[Expr], FlatLattice[Expr]]] { -} diff --git a/src/main/scala/analysis/RegionInjector.scala b/src/main/scala/analysis/RegionInjector.scala index 3cc414f06..77f1bd89c 100644 --- a/src/main/scala/analysis/RegionInjector.scala +++ b/src/main/scala/analysis/RegionInjector.scala @@ -15,7 +15,6 @@ class MergedRegion(var name: String, val subregions: mutable.Set[MemoryRegion]) class RegionInjector(program: Program, mmm: MemoryModelMap) { private val accessToRegion = mutable.Map[Statement, Set[MemoryRegion]]() - private val loadToMemory = mutable.Map[Statement, Memory]() val mergedRegions: mutable.Map[MemoryRegion, MergedRegion] = mutable.Map() def nodeVisitor(): Unit = { @@ -87,13 +86,13 @@ class RegionInjector(program: Program, mmm: MemoryModelMap) { val mergedRegion = mergedRegions(regionsHead) access match { - case store: MemoryAssign => + case store: MemoryStore => val newMemory = replaceMemory(store.mem, regionsHead, mergedRegion) store.mem = newMemory + case load: MemoryLoad => + val newMemory = replaceMemory(load.mem, regionsHead, mergedRegion) + load.mem = newMemory case _ => - val newMemory = replaceMemory(loadToMemory(access), regionsHead, mergedRegion) - val renamer = RegionRenamer(newMemory) - renamer.visitStatement(access) } } @@ -115,43 +114,13 @@ class RegionInjector(program: Program, mmm: MemoryModelMap) { mmm.getStack(n) ++ mmm.getData(n) } - def visitExpr(expr: Expr, cmd: Statement): Unit = { - expr match { - case Extract(_, _, body) => - visitExpr(body, cmd) - case UninterpretedFunction(_, params, _) => - params.foreach { - p => visitExpr(p, cmd) - } - case Repeat(_, body) => - visitExpr(body, cmd) - case ZeroExtend(_, body) => - visitExpr(body, cmd) - case SignExtend(_, body) => - visitExpr(body, cmd) - case UnaryExpr(_, arg) => - visitExpr(arg, cmd) - case BinaryExpr(_, arg1, arg2) => - visitExpr(arg1, cmd) - visitExpr(arg2, cmd) - case m: MemoryLoad => - val regions = statementToRegions(cmd) - accessToRegion(cmd) = regions - loadToMemory(cmd) = m.mem - case _ => - } - } - def visitStatement(n: Statement): Unit = n match { - case assign: Assign => - visitExpr(assign.rhs, assign) - case m: MemoryAssign => + case m: MemoryStore => val regions = statementToRegions(m) accessToRegion(m) = regions - case assert: Assert => - visitExpr(assert.body, assert) - case assume: Assume => - visitExpr(assume.body, assume) + case m: MemoryLoad => + val regions = statementToRegions(n) + accessToRegion(n) = regions case _ => // ignore other kinds of nodes } diff --git a/src/main/scala/analysis/SummaryGenerator.scala b/src/main/scala/analysis/SummaryGenerator.scala index 18a5b4205..1a878e21d 100644 --- a/src/main/scala/analysis/SummaryGenerator.scala +++ b/src/main/scala/analysis/SummaryGenerator.scala @@ -24,35 +24,31 @@ private trait RNATaintableAnalysis( private val linkRegister = Register("R30", 64) private val framePointer = Register("R29", 64) - private val ignoreRegions: Set[Expr] = Set(linkRegister, framePointer, stackPointer) + private val ignoreRegions: Set[Variable] = Set(linkRegister, framePointer, stackPointer) def eval(cmd: Command, s: Set[Taintable]): Set[Taintable] = { - var m = s - val exprs = cmd match { + cmd match { case assume: Assume => - Set(assume.body) + s ++ assume.body.variables -- ignoreRegions case assert: Assert => - Set(assert.body) - case memoryAssign: MemoryAssign => - m = m -- getMemoryVariable(cmd, memoryAssign.mem, memoryAssign.index, memoryAssign.size, constProp, globals) - Set(memoryAssign.index, memoryAssign.value) + s ++ assert.body.variables -- ignoreRegions + case memoryStore: MemoryStore => + val m = s -- getMemoryVariable(cmd, memoryStore.mem, memoryStore.index, memoryStore.size, constProp, globals) + m ++ memoryStore.index.variables ++ memoryStore.value.variables -- ignoreRegions case indirectCall: IndirectCall => - if (ignoreRegions.contains(indirectCall.target)) return m - Set(indirectCall.target) - case assign: Assign => - m = m - assign.lhs - Set(assign.rhs) - case _ => return m - } - - exprs.foldLeft(m) { - (m, expr) => { - val vars = expr.variables.filter(!ignoreRegions.contains(_)).map { v => v: Taintable } - val memvars: Set[Taintable] = expr.loads.flatMap { - l => getMemoryVariable(cmd, l.mem, l.index, l.size, constProp, globals) + if (ignoreRegions.contains(indirectCall.target)) { + s + } else { + s + indirectCall.target -- ignoreRegions } - m.union(vars).union(memvars) - } + case assign: LocalAssign => + val m = s - assign.lhs + m ++ assign.rhs.variables -- ignoreRegions + case memoryLoad: MemoryLoad => + val m = s - memoryLoad.lhs + val memvar = getMemoryVariable(cmd, memoryLoad.mem, memoryLoad.index, memoryLoad.size, constProp, globals) + m ++ memvar ++ memoryLoad.index.variables -- ignoreRegions + case _ => s } } @@ -96,10 +92,10 @@ class SummaryGenerator( private def toGamma(variable: Taintable): Option[BExpr] = { variable match { case variable: Register => Some(variable.toGamma) - case variable: LocalVar => None + case _: LocalVar => None case variable: GlobalVariable => Some(variable.toGamma) //case variable: LocalStackVariable => None - case variable: UnknownMemory => Some(FalseBLiteral) + case _: UnknownMemory => Some(FalseBLiteral) } } diff --git a/src/main/scala/analysis/TaintAnalysis.scala b/src/main/scala/analysis/TaintAnalysis.scala index d51eda374..d18e27d48 100644 --- a/src/main/scala/analysis/TaintAnalysis.scala +++ b/src/main/scala/analysis/TaintAnalysis.scala @@ -15,7 +15,7 @@ type Taintable = Variable | GlobalVariable /*| LocalStackVariable*/ | UnknownMem /** * A global variable in memory. */ -case class GlobalVariable(val mem: Memory, val address: BitVecLiteral, val size: Int, val identifier: String) { +case class GlobalVariable(mem: Memory, address: BitVecLiteral, size: Int, identifier: String) { override def toString(): String = { s"GlobalVariable($mem, $identifier, $size, $address)" } @@ -59,7 +59,10 @@ case class UnknownMemory() { } def getMemoryVariable( - n: CFGPosition, mem: Memory, expression: Expr, size: Int, + n: CFGPosition, + mem: Memory, + expression: Expr, + size: Int, constProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], globals: Map[BigInt, String], ): Option[GlobalVariable/*| LocalStackVariable*/] = { @@ -69,25 +72,23 @@ def getMemoryVariable( expression match { // TODO assumes stack var accesses are all of the form R31 + n, or just R31, when in reality they could be more complex. - case BinaryExpr(BVADD, arg1, arg2) if arg1 == stackPointer => { + case BinaryExpr(BVADD, arg1, arg2) if arg1 == stackPointer => evaluateExpression(arg2, constProp(n)) match // TODO This assumes that all stack variables are initialized local variables, which is not necessarily the case. // If a stack address is read, without being assigned a value in this procedure, it will be // assumed untainted, when in reality it may be UnknownMemory. //case Some(addr) => Some(LocalStackVariable(addr, size)) - case Some(addr) => None + case Some(_) => None case None => None - } //case v: Variable if v == stackPointer => Some(LocalStackVariable(BitVecLiteral(0, 64), size)) case v: Variable if v == stackPointer => None - case _ => { + case _ => // TOOD check that the global access has the right size evaluateExpression(expression, constProp(n)) match case Some(addr) => globals.get(addr.value) match case Some(global) => Some(GlobalVariable(mem, addr, size, global)) case None => None case None => None - } } } @@ -100,8 +101,6 @@ trait TaintAnalysisFunctions( val edgelattice = EdgeFunctionLattice(valuelattice) import edgelattice.{IdEdge, ConstEdge} - private val stackPointer = Register("R31", 64) - def edgesCallToEntry(call: DirectCall, entry: Procedure)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { Map(d -> IdEdge()) } @@ -115,39 +114,34 @@ trait TaintAnalysisFunctions( } def edgesOther(n: CFGPosition)(d: DL): Map[DL, EdgeFunction[TwoElement]] = { - def containsValue(expression: Expr, value: Taintable): Boolean = { - value match { - case (v: Variable) => expression.variables.contains(v) - case v => { - expression.loads.map { - load => getMemoryVariable(n, load.mem, load.index, load.size, constProp, globals).getOrElse(UnknownMemory()) - }.contains(v) - } - } - } - (n match { - case Assign(variable, expression, _) => { + case LocalAssign(variable, expression, _) => d match { - case Left(v) => { - if containsValue(expression, v) then Map(d -> IdEdge(), Left(variable) -> IdEdge()) - else if v == variable then Map() - else Map(d -> IdEdge()) - } + case Left(v: Variable) => + if (expression.variables.contains(v)) { + Map(d -> IdEdge(), Left(variable) -> IdEdge()) + } else if (v == variable) { + Map() + } else { + Map(d -> IdEdge()) + } case _ => Map(d -> IdEdge()) } - } - case MemoryAssign(mem, index, expression, _, size, _) => { - val variable = getMemoryVariable(n, mem, index, size, constProp, globals).getOrElse(UnknownMemory()) + case MemoryStore(mem, index, expression, _, size, _) => + val variable: Taintable = getMemoryVariable(n, mem, index, size, constProp, globals).getOrElse(UnknownMemory()) d match { - case Left(v) => { - if containsValue(expression, v) then Map(d -> IdEdge(), Left(variable) -> IdEdge()) - else if variable == v && v != UnknownMemory() then Map() - else Map(d -> IdEdge()) - } - case Right(_) => Map(d -> IdEdge()) + case Left(v: Variable) if expression.variables.contains(v) => Map(d -> IdEdge(), Left(variable) -> IdEdge()) + case Left(v: GlobalVariable) if variable == v => Map() + case _ => Map(d -> IdEdge()) + } + case MemoryLoad(lhs, mem, index, _, size, _) => + val memoryVariable: Taintable = getMemoryVariable(n, mem, index, size, constProp, globals).getOrElse(UnknownMemory()) + d match { + case Left(v: Variable) if index.variables.contains(v) => Map(d -> IdEdge(), Left(lhs) -> IdEdge()) + case Left(v: Variable) if v == lhs => Map() + case Left(v: Taintable) if memoryVariable == v => Map(d -> IdEdge(), Left(lhs) -> IdEdge()) + case _ => Map(d -> IdEdge()) } - } case _ => Map(d -> IdEdge()) }) ++ ( d match diff --git a/src/main/scala/analysis/UtilMethods.scala b/src/main/scala/analysis/UtilMethods.scala index 2f74b65e8..be32662f8 100644 --- a/src/main/scala/analysis/UtilMethods.scala +++ b/src/main/scala/analysis/UtilMethods.scala @@ -81,9 +81,7 @@ def evaluateExpressionWithSSA(exp: Expr, constantPropResult: Map[RegisterWrapper } def applySingle(op: BitVecLiteral => BitVecLiteral, a: Set[BitVecLiteral]): Set[BitVecLiteral] = { - val res = for { - x <- a - } yield op(x) + val res = for (x <- a) yield op(x) res } @@ -136,9 +134,8 @@ def evaluateExpressionWithSSA(exp: Expr, constantPropResult: Map[RegisterWrapper Logger.debug("getUse: " + getUse(variable, n, reachingDefs)) constantPropResult(RegisterWrapperEqualSets(variable, getUse(variable, n, reachingDefs))) case b: BitVecLiteral => Set(b) - case Repeat(repeats, body) => evaluateExpressionWithSSA(body, constantPropResult, n, reachingDefs) - case MemoryLoad(mem, index, endian, size) => Set.empty - case UninterpretedFunction(name, params, returnType) => Set.empty + case Repeat(_, body) => evaluateExpressionWithSSA(body, constantPropResult, n, reachingDefs) + case _: UninterpretedFunction => Set.empty case _ => throw RuntimeException("ERROR: CASE NOT HANDLED: " + exp + "\n") } } @@ -153,23 +150,6 @@ def getUse(variable: Variable, node: CFGPosition, reachingDefs: Map[CFGPosition, out.getOrElse(variable, Set()) } -def unwrapExpr(expr: Expr): Option[MemoryLoad] = { - expr match { - case e: Extract => unwrapExpr(e.body) - case e: SignExtend => unwrapExpr(e.body) - case e: ZeroExtend => unwrapExpr(e.body) - case repeat: Repeat => unwrapExpr(repeat.body) - case unaryExpr: UnaryExpr => unwrapExpr(unaryExpr.arg) - case binaryExpr: BinaryExpr => // TODO: incorrect - unwrapExpr(binaryExpr.arg1) - unwrapExpr(binaryExpr.arg2) - case memoryLoad: MemoryLoad => - Some(memoryLoad) - case _ => - None - } -} - def unwrapExprToVar(expr: Expr): Option[Variable] = { expr match { case variable: Variable => @@ -182,7 +162,6 @@ def unwrapExprToVar(expr: Expr): Option[Variable] = { case binaryExpr: BinaryExpr => // TODO: incorrect unwrapExprToVar(binaryExpr.arg1) unwrapExprToVar(binaryExpr.arg2) - case memoryLoad: MemoryLoad => unwrapExprToVar(memoryLoad.index) case _ => None } diff --git a/src/main/scala/analysis/VSA.scala b/src/main/scala/analysis/VSA.scala index 37d503238..4a9f11a14 100644 --- a/src/main/scala/analysis/VSA.scala +++ b/src/main/scala/analysis/VSA.scala @@ -10,8 +10,7 @@ import scala.collection.immutable import util.Logger /** ValueSets are PowerSet of possible values */ -trait Value { -} +trait Value case class AddressValue(region: MemoryRegion) extends Value { override def toString: String = "Address(" + region + ")" @@ -43,26 +42,27 @@ trait ValueSetAnalysis(program: Program, /** Default implementation of eval. */ - def eval(cmd: Command, s: Map[Variable | MemoryRegion, Set[Value]], n: CFGPosition): Map[Variable | MemoryRegion, Set[Value]] = { - cmd match + def eval(cmd: Command, s: Map[Variable | MemoryRegion, Set[Value]]): Map[Variable | MemoryRegion, Set[Value]] = { + cmd match { case directCall: DirectCall if directCall.target.name == "malloc" => - val regions = mmm.nodeToRegion(n) + val regions = mmm.nodeToRegion(cmd) // malloc variable s + (mallocVariable -> regions.map(r => AddressValue(r))) - case localAssign: Assign => - val regions = mmm.nodeToRegion(n) + case localAssign: LocalAssign => + val regions = mmm.nodeToRegion(cmd) if (regions.nonEmpty) { s + (localAssign.lhs -> regions.map(r => AddressValue(r))) } else { - evaluateExpression(localAssign.rhs, constantProp(n)) match { + evaluateExpression(localAssign.rhs, constantProp(cmd)) match { case Some(bitVecLiteral: BitVecLiteral) => val possibleData = canCoerceIntoDataRegion(bitVecLiteral, 1) - if (possibleData.isDefined) { - s + (localAssign.lhs -> Set(AddressValue(possibleData.get))) - } else { - s + (localAssign.lhs -> Set(LiteralValue(bitVecLiteral))) - } + if (possibleData.isDefined) { + s + (localAssign.lhs -> Set(AddressValue(possibleData.get))) + } else { + s + (localAssign.lhs -> Set(LiteralValue(bitVecLiteral))) + } case None => + // TODO this is not at all sound val unwrapValue = unwrapExprToVar(localAssign.rhs) unwrapValue match { case Some(v: Variable) => @@ -73,33 +73,50 @@ trait ValueSetAnalysis(program: Program, } } } - case memAssign: MemoryAssign => - val regions = mmm.nodeToRegion(n) - evaluateExpression(memAssign.value, constantProp(n)) match { + case load: MemoryLoad => + val regions = mmm.nodeToRegion(cmd) + if (regions.nonEmpty) { + s + (load.lhs -> regions.map(r => AddressValue(r))) + } else { + // TODO this is blatantly incorrect but maintaining current functionality to start + val unwrapValue = unwrapExprToVar(load.index) + unwrapValue match { + case Some(v: Variable) => + s + (load.lhs -> s(v)) + case None => + Logger.debug(s"Too Complex: ${load.index}") // do nothing + s + } + } + case store: MemoryStore => + val regions = mmm.nodeToRegion(cmd) + evaluateExpression(store.value, constantProp(cmd)) match { case Some(bitVecLiteral: BitVecLiteral) => - val possibleData = canCoerceIntoDataRegion(bitVecLiteral, memAssign.size) + val possibleData = canCoerceIntoDataRegion(bitVecLiteral, store.size) if (possibleData.isDefined) { s ++ regions.map(r => r -> Set(AddressValue(possibleData.get))) } else { s ++ regions.map(r => r -> Set(LiteralValue(bitVecLiteral))) } case None => - val unwrapValue = unwrapExprToVar(memAssign.value) + // TODO: unsound + val unwrapValue = unwrapExprToVar(store.value) unwrapValue match { case Some(v: Variable) => s ++ regions.map(r => r -> s(v)) case None => - Logger.debug(s"Too Complex: $memAssign.value") // do nothing + Logger.debug(s"Too Complex: $store.value") // do nothing s } } case _ => s + } } - /** Transfer function for state lattice elements. + /** Transfer function for state lattice elements. (Same as `localTransfer` for simple value analysis.) */ - def localTransfer(n: CFGPosition, s: Map[Variable | MemoryRegion, Set[Value]]): Map[Variable | MemoryRegion, Set[Value]] = { + def transferUnlifted(n: CFGPosition, s: Map[Variable | MemoryRegion, Set[Value]]): Map[Variable | MemoryRegion, Set[Value]] = { n match { case p: Procedure => mmm.pushContext(p.name) @@ -108,15 +125,11 @@ trait ValueSetAnalysis(program: Program, mmm.popContext() s case command: Command => - eval(command, s, n) + eval(command, s) case _ => s } } - - /** Transfer function for state lattice elements. (Same as `localTransfer` for simple value analysis.) - */ - def transferUnlifted(n: CFGPosition, s: Map[Variable | MemoryRegion, Set[Value]]): Map[Variable | MemoryRegion, Set[Value]] = localTransfer(n, s) } class ValueSetAnalysisSolver( diff --git a/src/main/scala/analysis/VariableDependencyAnalysis.scala b/src/main/scala/analysis/VariableDependencyAnalysis.scala index ed1ed1eaf..643bf5b6f 100644 --- a/src/main/scala/analysis/VariableDependencyAnalysis.scala +++ b/src/main/scala/analysis/VariableDependencyAnalysis.scala @@ -37,46 +37,50 @@ trait ProcVariableDependencyAnalysisFunctions( def edgesCallToAfterCall(call: DirectCall, aftercall: Command)(d: DL): Map[DL, EdgeFunction[Set[Taintable]]] = { d match { - case Left(v) => varDepsSummaries.get(call.target).flatMap(_.get(v).map( _.foldLeft(Map[DL, EdgeFunction[Set[Taintable]]]()) { - (m, d) => m + (Left(d) -> IdEdge()) - })).getOrElse(Map()) + case Left(v) => + varDepsSummaries.get(call.target).flatMap { + _.get(v).map { + _.foldLeft(Map[DL, EdgeFunction[Set[Taintable]]]()) { + (m, d) => m + (Left(d) -> IdEdge()) + } + } + }.getOrElse(Map()) case Right(_) => Map(d -> IdEdge()) } } def edgesOther(n: CFGPosition)(d: DL): Map[DL, EdgeFunction[Set[Taintable]]] = { - def getVars(expression: Expr): Set[Taintable] = { - expression.variables.map { v => v: Taintable } ++ - expression.loads.map { l => getMemoryVariable(n, l.mem, l.index, l.size, constProp, globals).getOrElse(UnknownMemory()) } - } - if n == procedure then d match { // At the start of the procedure, no variables should depend on anything but themselves. case Left(_) => Map() - case Right(_) => { + case Right(_) => variables.foldLeft(Map(d -> IdEdge())) { (m: Map[DL, EdgeFunction[Set[Taintable]]], v) => m + (Left(v) -> ConstEdge(Set(v))) } - } } else n match { - case Assign(assigned, expression, _) => { - val vars = getVars(expression) -- ignoredRegisters + case LocalAssign(assigned, expression, _) => + val vars = expression.variables -- ignoredRegisters d match { - case Left(v) if vars.contains(v) => Map(d -> IdEdge(), Left(assigned) -> IdEdge()) - case Left(v) if v == assigned => Map() + case Left(v: Variable) if vars.contains(v) => Map(d -> IdEdge(), Left(assigned) -> IdEdge()) + case Left(v: Variable) if v == assigned => Map() case _ => Map(d -> IdEdge()) } - } - case MemoryAssign(mem, index, expression, _, size, _) => { - val assigned = getMemoryVariable(n, mem, index, size, constProp, globals).getOrElse(UnknownMemory()) - - val vars = getVars(expression) -- ignoredRegisters + case MemoryStore(mem, index, expression, _, size, _) => + val assigned: Taintable = getMemoryVariable(n, mem, index, size, constProp, globals).getOrElse(UnknownMemory()) + val vars = expression.variables -- ignoredRegisters d match { - case Left(v) if vars.contains(v) => Map(d -> IdEdge(), Left(assigned) -> IdEdge()) - case Left(v) if v == assigned && v != UnknownMemory() => Map() + case Left(v: Variable) if vars.contains(v) => Map(d -> IdEdge(), Left(assigned) -> IdEdge()) + case Left(v: GlobalVariable) if v == assigned => Map() + case _ => Map(d -> IdEdge()) + } + case MemoryLoad(lhs, mem, index, _, size, _) => + val memoryVariable: Taintable = getMemoryVariable(n, mem, index, size, constProp, globals).getOrElse(UnknownMemory()) + val vars: Set[Taintable] = Set(memoryVariable) ++ index.variables -- ignoredRegisters + d match { + case Left(v) if vars.contains(v) => Map(d -> IdEdge(), Left(lhs) -> IdEdge()) + case Left(v) if v == lhs => Map() case _ => Map(d -> IdEdge()) } - } case _ => Map(d -> IdEdge()) } } @@ -97,7 +101,7 @@ class ProcVariableDependencyAnalysis( { override def phase2Init: Set[Taintable] = Set(Register("R0", 64)) - override val startNode: CFGPosition = procedure + override def start: CFGPosition = procedure } class VariableDependencyAnalysis( @@ -107,7 +111,7 @@ class VariableDependencyAnalysis( constProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], scc: mutable.ListBuffer[mutable.Set[Procedure]], ) { - val varDepVariables: Set[analysis.Taintable] = 0.to(28).map { n => + val varDepVariables: Set[Taintable] = 0.to(28).map { n => Register(s"R$n", 64) }.toSet ++ specGlobals.map { g => analysis.GlobalVariable(dsl.mem, BitVecLiteral(g.address, 64), g.size, g.name) diff --git a/src/main/scala/analysis/WriteToAnalysis.scala b/src/main/scala/analysis/WriteToAnalysis.scala index 9e3139297..cc4e3eba2 100644 --- a/src/main/scala/analysis/WriteToAnalysis.scala +++ b/src/main/scala/analysis/WriteToAnalysis.scala @@ -1,6 +1,6 @@ package analysis -import ir.{Assert, Assume, BitVecType, Call, DirectCall, GoTo, Assign, MemoryAssign, NOP, Procedure, Program, Register} +import ir.{DirectCall, LocalAssign, MemoryLoad, MemoryStore, Procedure, Program, Register} import scala.collection.mutable @@ -24,18 +24,19 @@ class WriteToAnalysis(program: Program) extends Analysis[Map[Procedure, Set[Regi writesTo(proc) else val writtenTo: mutable.Set[Register] = mutable.Set() - proc.blocks.foreach( - block => - block.statements.foreach { - case Assign(variable: Register, _, _) if paramRegisters.contains(variable) => - writtenTo.add(variable) - case DirectCall(target, _) if target.name == "malloc" => - writtenTo.add(mallocRegister) - case DirectCall(target, _) if program.procedures.contains(target) => - writtenTo.addAll(getWritesTos(target)) - case _ => - } - ) + proc.blocks.foreach { block => + block.statements.foreach { + case LocalAssign(variable: Register, _, _) if paramRegisters.contains(variable) => + writtenTo.add(variable) + case MemoryLoad(lhs: Register, _, _, _, _, _) if paramRegisters.contains(lhs) => + writtenTo.add(lhs) + case DirectCall(target, _) if target.name == "malloc" => + writtenTo.add(mallocRegister) + case DirectCall(target, _) if program.procedures.contains(target) => + writtenTo.addAll(getWritesTos(target)) + case _ => + } + } writesTo.update(proc, writtenTo.toSet) writesTo(proc) diff --git a/src/main/scala/analysis/data_structure_analysis/Graph.scala b/src/main/scala/analysis/data_structure_analysis/Graph.scala index 1829bbeb3..a7cae0af4 100644 --- a/src/main/scala/analysis/data_structure_analysis/Graph.scala +++ b/src/main/scala/analysis/data_structure_analysis/Graph.scala @@ -48,36 +48,37 @@ class Graph(val proc: Procedure, // collect all stack access and their maximum accessed size // BigInt is the offset of the stack position and Int is it's size - private val stackAccesses: Map[BigInt, Int] = computeDomain(IntraProcIRCursor, Set(proc)).toSeq.sortBy(_.toShortString).foldLeft(Map[BigInt, Int]()) { - (results, pos) => - pos match - case Assign(_, expr, _) => - expr match - case MemoryLoad(_, index, _, size) => - visitStackAccess(pos, index, size).foldLeft(results) { - (res, access) => - if !res.contains(access.offset) || (res.getOrElse(access.offset, -1) < access.size) then - res + (access.offset -> access.size) - else - res - } - case _ => - visitStackAccess(pos, expr, 0).foldLeft(results) { - (res, access) => - if !res.contains(access.offset) || (res.getOrElse(access.offset, -1) < access.size) then - res + (access.offset -> access.size) - else - res - } - case MemoryAssign(_, index, _, _, size, _) => - visitStackAccess(pos, index, size).foldLeft(results) { - (res, access) => - if !res.contains(access.offset) || (res.getOrElse(access.offset, -1) < access.size) then - res + (access.offset -> access.size) - else - res - } - case _ => results + private val stackAccesses: Map[BigInt, Int] = { + computeDomain(IntraProcIRCursor, Set(proc)).toSeq.sortBy(_.toShortString).foldLeft(Map[BigInt, Int]()) { + (results, pos) => + pos match { + case LocalAssign(_, expr, _) => + visitStackAccess(pos, expr, 0).foldLeft(results) { + (res, access) => + if !res.contains(access.offset) || (res.getOrElse(access.offset, -1) < access.size) then + res + (access.offset -> access.size) + else + res + } + case MemoryStore(_, index, _, _, size, _) => + visitStackAccess(pos, index, size).foldLeft(results) { + (res, access) => + if !res.contains(access.offset) || (res.getOrElse(access.offset, -1) < access.size) then + res + (access.offset -> access.size) + else + res + } + case MemoryLoad(_, _, index, _, size, _) => + visitStackAccess(pos, index, size).foldLeft(results) { + (res, access) => + if !res.contains(access.offset) || (res.getOrElse(access.offset, -1) < access.size) then + res + (access.offset -> access.size) + else + res + } + case _ => results + } + } } private case class StackAccess(offset: BigInt, size: Int) @@ -271,9 +272,9 @@ class Graph(val proc: Procedure, if (varName.startsWith("#")) { varName = s"LocalVar_${varName.drop(1)}" } - structs.append(DotStruct(s"SSA_${id}_${varName}", s"SSA_${pos}_${varName}", None, false)) + structs.append(DotStruct(s"SSA_${id}_$varName", s"SSA_${pos}_$varName", None, false)) val value = find(slice) - arrows.append(StructArrow(DotStructElement(s"SSA_${id}_${varName}", None), DotStructElement(value.node.id.toString, Some(value.cell.offset.toString)), value.internalOffset.toString)) + arrows.append(StructArrow(DotStructElement(s"SSA_${id}_$varName", None), DotStructElement(value.node.id.toString, Some(value.cell.offset.toString)), value.internalOffset.toString)) } } @@ -601,7 +602,7 @@ class Graph(val proc: Procedure, val varToCell = mutable.Map[CFGPosition, mutable.Map[Variable, Slice]]() val domain = computeDomain(IntraProcIRCursor, Set(proc)) domain.foreach { - case pos @ Assign(variable, value, _) => + case pos @ LocalAssign(variable, value, _) => value.variables.foreach { v => if (isFormal(pos, v)) { val node = Node(Some(this)) @@ -612,6 +613,17 @@ class Graph(val proc: Procedure, } val node = Node(Some(this)) varToCell(pos) = mutable.Map(variable -> Slice(node.cells(0), 0)) + case pos @ MemoryLoad(lhs, _, index, _, _, _) => + index.variables.foreach { v => + if (isFormal(pos, v)) { + val node = Node(Some(this)) + node.flags.incomplete = true + nodes.add(node) + formals.update(v, Slice(node.cells(0), 0)) + } + } + val node = Node(Some(this)) + varToCell(pos) = mutable.Map(lhs -> Slice(node.cells(0), 0)) case pos @ DirectCall(target, _) if target.name == "malloc" => val node = Node(Some(this)) varToCell(pos) = mutable.Map(mallocRegister -> Slice(node.cells(0), 0)) @@ -622,7 +634,7 @@ class Graph(val proc: Procedure, result(variable) = Slice(node.cells(0), 0) } varToCell(pos) = result - case pos @ MemoryAssign(_, _, expr, _, _, _) => + case pos @ MemoryStore(_, _, expr, _, _, _) => unwrapPaddingAndSlicing(expr) match { case value: Variable => if (isFormal(pos, value)) { diff --git a/src/main/scala/analysis/data_structure_analysis/LocalPhase.scala b/src/main/scala/analysis/data_structure_analysis/LocalPhase.scala index 9d9dac1f7..cd975bbb4 100644 --- a/src/main/scala/analysis/data_structure_analysis/LocalPhase.scala +++ b/src/main/scala/analysis/data_structure_analysis/LocalPhase.scala @@ -144,7 +144,7 @@ class LocalPhase(proc: Procedure, * Handles unification for instructions of the form R_x = R_y [+ offset] where R_y is a pointer and [+ offset] is optional * @param position the cfg position being visited (note this might be a local assign of the form R_x = R_y [+ offset] * or it might be memory load/store where the index is of the form R_y [+ offset] - * @param lhs Ev(R_x) if position is local assign or a cell from an empty node if R_y [+ offset] is the index of a memoryAssign + * @param lhs Ev(R_x) if position is local assign or a cell from an empty node if R_y [+ offset] is the index of a memoryStore * @param rhs R_y, reachingDefs(position)(R_y) can be used to find the set of SSA variables that may define R_x * @param pointee if false, the position is local pointer arithmetic therefore Ev(R_y [+ offset]) is merged with lhs * else, the position is a memory read/write therefore E(Ev(R_y [+ offset])) is merged with lhs @@ -245,11 +245,11 @@ class LocalPhase(proc: Procedure, val returnArgument = graph.varToCell(n)(variable) graph.mergeCells(graph.adjust(returnArgument), graph.adjust(slice)) } - case Assign(variable, rhs, _) => + case LocalAssign(variable, rhs, _) => val expr: Expr = unwrapPaddingAndSlicing(rhs) val lhsCell = graph.adjust(graph.varToCell(n)(variable)) - var global = isGlobal(rhs, n) - var stack = isStack(rhs, n) + val global = isGlobal(rhs, n) + val stack = isStack(rhs, n) if global.isDefined then // Rx = global address graph.mergeCells(lhsCell, global.get) else if stack.isDefined then // Rx = stack address @@ -271,38 +271,39 @@ class LocalPhase(proc: Procedure, // Rx = Ry merge corresponding cells to Rx and Ry case arg: Variable /*if varToSym.contains(n) && varToSym(n).contains(arg)*/ => visitPointerArithmeticOperation(n, lhsCell, arg, 0) - - case MemoryLoad(_, index, _, size) => // Rx = Mem[Ry], merge Rx and pointee of Ry (E(Ry)) - assert(size % 8 == 0) - val byteSize = size/8 - lhsCell.node.get.flags.read = true - global = isGlobal(index, n, byteSize) - stack = isStack(index, n) - if global.isDefined then - graph.mergeCells(lhsCell,graph.adjust(graph.find(global.get).getPointee)) - else if stack.isDefined then - graph.mergeCells(lhsCell, graph.adjust(graph.find(stack.get).getPointee)) - else - index match - case BinaryExpr(op, arg1: Variable, arg2) if op.equals(BVADD) => - evaluateExpression(arg2, constProp(n)) match - case Some(v) => -// assert(varToSym(n).contains(arg1)) - val offset = v.value - visitPointerArithmeticOperation(n, lhsCell, arg1, byteSize, true, offset) - case None => -// assert(varToSym(n).contains(arg1)) - // collapse the result -// visitPointerArithmeticOperation(n, lhsCell, arg1, byteSize, true, 0, true) - unsupportedPointerArithmeticOperation(n, index,Node(Some(graph)).cells(0)) - case arg: Variable => -// assert(varToSym(n).contains(arg)) - visitPointerArithmeticOperation(n, lhsCell, arg, byteSize, true) - case _ => ??? case _ => unsupportedPointerArithmeticOperation(n, expr, lhsCell) - - case MemoryAssign(_, ind, expr, _, size, _) => + + case MemoryLoad(lhs, _, index, _, size, _) => // Rx = Mem[Ry], merge Rx and pointee of Ry (E(Ry)) + val indexUnwrapped = unwrapPaddingAndSlicing(index) + val lhsCell = graph.adjust(graph.varToCell(n)(lhs)) + assert(size % 8 == 0) + val byteSize = size / 8 + lhsCell.node.get.flags.read = true + val global = isGlobal(indexUnwrapped, n, byteSize) + val stack = isStack(indexUnwrapped, n) + if global.isDefined then + graph.mergeCells(lhsCell, graph.adjust(graph.find(global.get).getPointee)) + else if stack.isDefined then + graph.mergeCells(lhsCell, graph.adjust(graph.find(stack.get).getPointee)) + else + indexUnwrapped match + case BinaryExpr(op, arg1: Variable, arg2) if op.equals(BVADD) => + evaluateExpression(arg2, constProp(n)) match + case Some(v) => + // assert(varToSym(n).contains(arg1)) + val offset = v.value + visitPointerArithmeticOperation(n, lhsCell, arg1, byteSize, true, offset) + case None => + // assert(varToSym(n).contains(arg1)) + // collapse the result + // visitPointerArithmeticOperation(n, lhsCell, arg1, byteSize, true, 0, true) + unsupportedPointerArithmeticOperation(n, indexUnwrapped, Node(Some(graph)).cells(0)) + case arg: Variable => + // assert(varToSym(n).contains(arg)) + visitPointerArithmeticOperation(n, lhsCell, arg, byteSize, true) + case _ => ??? + case MemoryStore(_, ind, expr, _, size, _) => val unwrapped = unwrapPaddingAndSlicing(expr) unwrapped match { // Mem[Ry] = Rx diff --git a/src/main/scala/analysis/data_structure_analysis/SymbolicAddressAnalysis.scala b/src/main/scala/analysis/data_structure_analysis/SymbolicAddressAnalysis.scala index 52cb23639..a66b7f616 100644 --- a/src/main/scala/analysis/data_structure_analysis/SymbolicAddressAnalysis.scala +++ b/src/main/scala/analysis/data_structure_analysis/SymbolicAddressAnalysis.scala @@ -89,7 +89,7 @@ trait SymbolicAddressFunctions(constProp: Map[CFGPosition, Map[Variable, FlatEle def edgesOther(n: CFGPosition)(d: DL): Map[DL, EdgeFunction[TwoElement]] = n match - case Assign(variable, rhs, _) => + case LocalAssign(variable, rhs, _) => val expr = unwrapPaddingAndSlicing(rhs) expr match case BinaryExpr(op, arg1: Variable, arg2) if op.equals(BVADD) => @@ -115,7 +115,7 @@ trait SymbolicAddressFunctions(constProp: Map[CFGPosition, Map[Variable, FlatEle case Left(value) if value.accessor == variable => Map() case _ => Map(d -> IdEdge()) case None => Map(d -> IdEdge()) - case arg:Variable => + case arg: Variable => d match case Left(value) if value.accessor == arg => val result: Map[DL, EdgeFunction[TwoElement]] = Map(Left(SymbolicAddress(variable, value.symbolicBase, value.offset)) -> ConstEdge(TwoElementTop)) @@ -125,15 +125,15 @@ trait SymbolicAddressFunctions(constProp: Map[CFGPosition, Map[Variable, FlatEle result case Left(value) if value.accessor == variable => Map() case _ => Map(d -> IdEdge()) - case _: MemoryLoad => - d match - case Left(value) if value.accessor == variable => Map() - case Left(_) => Map(d -> IdEdge()) - case Right(_) => Map(d -> IdEdge(), Left(SymbolicAddress(variable, UnknownLocation(nextunknownCount, IRWalk.procedure(n)), 0)) -> ConstEdge(TwoElementTop)) case _ => d match case Left(value) if value.accessor == variable => Map() case _ => Map(d -> IdEdge()) + case MemoryLoad(lhs, _, _, _, _, _) => + d match + case Left(value) if value.accessor == lhs => Map() + case Left(_) => Map(d -> IdEdge()) + case Right(_) => Map(d -> IdEdge(), Left(SymbolicAddress(lhs, UnknownLocation(nextunknownCount, IRWalk.procedure(n)), 0)) -> ConstEdge(TwoElementTop)) case DirectCall(target, _) if target.name == "malloc" => d match case Left(value) if value.accessor == mallocVariable => Map() diff --git a/src/main/scala/analysis/data_structure_analysis/Utility.scala b/src/main/scala/analysis/data_structure_analysis/Utility.scala index 002c1ec12..6d89db0d7 100644 --- a/src/main/scala/analysis/data_structure_analysis/Utility.scala +++ b/src/main/scala/analysis/data_structure_analysis/Utility.scala @@ -246,7 +246,6 @@ def unwrapPaddingAndSlicing(expr: Expr): Expr = case SignExtend(extension, body) => SignExtend(extension, unwrapPaddingAndSlicing(body)) case UnaryExpr(op, arg) => UnaryExpr(op, arg) case BinaryExpr(op, arg1, arg2) => BinaryExpr(op, unwrapPaddingAndSlicing(arg1), unwrapPaddingAndSlicing(arg2)) - case MemoryLoad(mem, index, endian, size) => MemoryLoad(mem, unwrapPaddingAndSlicing(index), endian, size) case variable: Variable => variable case Extract(_, _, body) /*if start == 0 && end == 32*/ => unwrapPaddingAndSlicing(body) // this may make it unsound case ZeroExtend(_, body) => unwrapPaddingAndSlicing(body) diff --git a/src/main/scala/analysis/solvers/IDESolver.scala b/src/main/scala/analysis/solvers/IDESolver.scala index c057b98f0..4b60e2e16 100644 --- a/src/main/scala/analysis/solvers/IDESolver.scala +++ b/src/main/scala/analysis/solvers/IDESolver.scala @@ -24,7 +24,8 @@ abstract class IDESolver[E <: Procedure | Command, EE <: Procedure | Command, C protected def isExit(exit: CFGPosition): Boolean protected def getAfterCalls(exit: EE): Set[R] - def phase2Init = valuelattice.top + def phase2Init: T = valuelattice.top + def start: CFGPosition = startNode /** * Phase 1 of the IDE algorithm. @@ -32,10 +33,10 @@ abstract class IDESolver[E <: Procedure | Command, EE <: Procedure | Command, C * The original version of the algorithm uses summary edges from call nodes to after-call nodes * instead of `callJumpCache` and `exitJumpCache`. */ - private class Phase1(val program: Program) extends InitializingPushDownWorklistFixpointSolver[(CFGPosition, DL, DL), EdgeFunction[T], EdgeFunctionLattice[T, L]] { + private class Phase1 extends InitializingPushDownWorklistFixpointSolver[(CFGPosition, DL, DL), EdgeFunction[T], EdgeFunctionLattice[T, L]] { val lattice: MapLattice[(CFGPosition, DL, DL), EdgeFunction[T], EdgeFunctionLattice[T, L]] = MapLattice(edgelattice) - val first: Set[(CFGPosition, DL, DL)] = Set((startNode, Right(Lambda()), Right(Lambda()))) + val first: Set[(CFGPosition, DL, DL)] = Set((start, Right(Lambda()), Right(Lambda()))) /** * callJumpCache(funentry, d1, call)(d3) returns the composition of the edges (call.funentry, d3) -> (call, *) -> (funentry, d1). @@ -77,7 +78,6 @@ abstract class IDESolver[E <: Procedure | Command, EE <: Procedure | Command, C } } - def process(n: (CFGPosition, DL, DL)): Unit = { val (position, d1, d2) = n val e1 = x(n) @@ -142,14 +142,14 @@ abstract class IDESolver[E <: Procedure | Command, EE <: Procedure | Command, C * Performs a forward dataflow analysis using the decomposed lattice and the micro-transformers. * The original RHS version of IDE uses jump functions for all nodes, not only at exits, but the analysis result and complexity is the same. */ - private class Phase2(val program: Program, val phase1: Phase1) extends InitializingPushDownWorklistFixpointSolver[(CFGPosition, DL), T, L]: + private class Phase2(val phase1: Phase1) extends InitializingPushDownWorklistFixpointSolver[(CFGPosition, DL), T, L] { val lattice: MapLattice[(CFGPosition, DL), T, L] = MapLattice(valuelattice) - val first: Set[(CFGPosition, DL)] = Set((startNode, Right(Lambda()))) + val first: Set[(CFGPosition, DL)] = Set((start, Right(Lambda()))) /** - * Function summaries from phase 1. - * Built when first invoked. - */ + * Function summaries from phase 1. + * Built when first invoked. + */ lazy val summaries: mutable.Map[Procedure, mutable.Map[DL, mutable.Map[DL, EdgeFunction[T]]]] = phase1.summaries() def init: T = phase2Init @@ -188,21 +188,23 @@ abstract class IDESolver[E <: Procedure | Command, EE <: Procedure | Command, C val restructuredlattice: MapLattice[CFGPosition, Map[D, T], MapLattice[D, T, L]] = MapLattice(MapLattice(valuelattice)) /** - * Restructures the analysis output to match `restructuredlattice`. - */ - def restructure(y: lattice.Element): restructuredlattice.Element = + * Restructures the analysis output to match `restructuredlattice`. + */ + def restructure(y: lattice.Element): restructuredlattice.Element = { y.foldLeft(Map[CFGPosition, Map[D, valuelattice.Element]]()) { case (acc, ((n, dl), e)) => dl match { case Left(d) => acc + (n -> (acc.getOrElse(n, Map[D, valuelattice.Element]()) + (d -> e))) case _ => acc } } + } + } def analyze(): Map[CFGPosition, Map[D, T]] = { if (program.mainProcedure.blocks.nonEmpty && program.mainProcedure.returnBlock.isDefined && program.mainProcedure.entryBlock.isDefined) { - val phase1 = Phase1(program) + val phase1 = Phase1() phase1.analyze() - val phase2 = Phase2(program, phase1) + val phase2 = Phase2(phase1) phase2.restructure(phase2.analyze()) } else { Logger.warn(s"Disabling IDE solver tests due to external main procedure: ${program.mainProcedure.name}") @@ -224,7 +226,7 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) protected def returnToCall(ret: Command): DirectCall = ret match { case ret: Statement => ret.parent.statements.getPrev(ret).asInstanceOf[DirectCall] - case r: Jump => ret.parent.statements.last.asInstanceOf[DirectCall] + case _: Jump => ret.parent.statements.last.asInstanceOf[DirectCall] } protected def getCallee(call: DirectCall): Procedure = { @@ -232,16 +234,20 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) call.target } - protected def isCall(call: CFGPosition): Boolean = - call match - case directCall: DirectCall if (!directCall.successor.isInstanceOf[Unreachable] && directCall.target.returnBlock.isDefined && directCall.target.entryBlock.isDefined) => true + protected def isCall(call: CFGPosition): Boolean = { + call match { + case directCall: DirectCall if !directCall.successor.isInstanceOf[Unreachable] && directCall.target.returnBlock.isDefined && directCall.target.entryBlock.isDefined => true case _ => false + } + } - protected def isExit(exit: CFGPosition): Boolean = - exit match + protected def isExit(exit: CFGPosition): Boolean = { + exit match { // only looking at functions with statements - case command: Return => true + case _: Return => true case _ => false + } + } protected def getAfterCalls(exit: Return): Set[Command] = InterProcIRCursor.succ(exit).filter(_.isInstanceOf[Command]).map(_.asInstanceOf[Command]) @@ -258,7 +264,7 @@ abstract class BackwardIDESolver[D, T, L <: Lattice[T]](program: Program) protected def callToReturn(call: Command): DirectCall = { IRWalk.prevCommandInBlock(call) match { - case Some(x : DirectCall) => x + case Some(x: DirectCall) => x case p => throw Exception(s"Not a return/aftercall node $call .... prev = $p") } } @@ -271,22 +277,25 @@ abstract class BackwardIDESolver[D, T, L <: Lattice[T]](program: Program) procCalled.returnBlock.getOrElse(throw Exception(s"No return node for procedure ${procCalled}")).jump.asInstanceOf[Return] } - protected def isCall(call: CFGPosition): Boolean = - call match - case c: Unreachable => false /* don't process non-returning calls */ - case c : Command => { + protected def isCall(call: CFGPosition): Boolean = { + call match { + case _: Unreachable => false /* don't process non-returning calls */ + case c: Command => val call = IRWalk.prevCommandInBlock(c) call match { case Some(d: DirectCall) if d.target.returnBlock.isDefined => true case _ => false } - } case _ => false + } + } - protected def isExit(exit: CFGPosition): Boolean = - exit match + protected def isExit(exit: CFGPosition): Boolean = { + exit match { case procedure: Procedure => procedure.blocks.nonEmpty case _ => false + } + } protected def getAfterCalls(exit: Procedure): Set[DirectCall] = exit.incomingCalls().toSet } diff --git a/src/main/scala/bap/BAPExpr.scala b/src/main/scala/bap/BAPExpr.scala index 7408328dd..0f2591e3d 100644 --- a/src/main/scala/bap/BAPExpr.scala +++ b/src/main/scala/bap/BAPExpr.scala @@ -5,8 +5,6 @@ import ir._ /** Expression */ trait BAPExpr { - def toIR: Expr - /* * The size of output of the given expression. * @@ -18,8 +16,6 @@ trait BAPExpr { /** Concatenation of two bitvectors */ case class BAPConcat(left: BAPExpr, right: BAPExpr) extends BAPExpr { - def toIR: BinaryExpr = BinaryExpr(BVCONCAT, left.toIR, right.toIR) - override val size: Int = left.size + right.size } @@ -28,14 +24,6 @@ case class BAPConcat(left: BAPExpr, right: BAPExpr) extends BAPExpr { case class BAPSignedExtend(width: Int, body: BAPExpr) extends BAPExpr { override val size: Int = width - - override def toIR: Expr = { - if (width > body.size) { - SignExtend(width - body.size, body.toIR) - } else { - BAPExtract(width - 1, 0, body).toIR - } - } } /** Unsigned extend - pad in BIL @@ -43,15 +31,6 @@ case class BAPSignedExtend(width: Int, body: BAPExpr) extends BAPExpr { case class BAPUnsignedExtend(width: Int, body: BAPExpr) extends BAPExpr { override val size: Int = width - - override def toIR: Expr = { - if (width > body.size) { - ZeroExtend(width - body.size, body.toIR) - } else { - BAPExtract(width - 1, 0, body).toIR - } - - } } /** Extracts the bits from firstInt to secondInt (inclusive) from variable. @@ -61,19 +40,6 @@ case class BAPExtract(high: Int, low: Int, body: BAPExpr) extends BAPExpr { // + 1 as extracts are inclusive (e.g. [31:0] has 32 bits) override val size: Int = high - low + 1 - - override def toIR: Expr = { - val bodySize = body.size - if (size > bodySize) { - if (low == 0) { - ZeroExtend(size - bodySize, body.toIR) - } else { - Extract(high + 1, low, ZeroExtend(size - bodySize, body.toIR)) - } - } else { - Extract(high + 1, low, body.toIR) - } - } } case object BAPHighCast { @@ -90,19 +56,12 @@ case class BAPLiteral(value: BigInt, size: Int) extends BAPExpr { /** Value of literal */ override def toString: String = s"${value}bv$size" - - override def toIR: BitVecLiteral = BitVecLiteral(value, size) } /** Unary operator */ case class BAPUnOp(operator: BAPUnOperator, exp: BAPExpr) extends BAPExpr { override val size: Int = exp.size - - override def toIR: UnaryExpr = operator match { - case NOT => UnaryExpr(BVNOT, exp.toIR) - case NEG => UnaryExpr(BVNEG, exp.toIR) - } } sealed trait BAPUnOperator(op: String) { @@ -126,46 +85,6 @@ case class BAPBinOp(operator: BAPBinOperator, lhs: BAPExpr, rhs: BAPExpr) extend case EQ | NEQ | LT | LE | SLT | SLE => 1 case _ => lhs.size } - - override def toIR: Expr = operator match { - case PLUS => BinaryExpr(BVADD, lhs.toIR, rhs.toIR) - case MINUS => BinaryExpr(BVSUB, lhs.toIR, rhs.toIR) - case TIMES => BinaryExpr(BVMUL, lhs.toIR, rhs.toIR) - case DIVIDE => BinaryExpr(BVUDIV, lhs.toIR, rhs.toIR) - case SDIVIDE => BinaryExpr(BVSDIV, lhs.toIR, rhs.toIR) - // counterintuitive but correct according to BAP source - case MOD => BinaryExpr(BVSREM, lhs.toIR, rhs.toIR) - // counterintuitive but correct according to BAP source - case SMOD => BinaryExpr(BVUREM, lhs.toIR, rhs.toIR) - case LSHIFT => // BAP says caring about this case is necessary? - if (lhs.size == rhs.size) { - BinaryExpr(BVSHL, lhs.toIR, rhs.toIR) - } else { - BinaryExpr(BVSHL, lhs.toIR, ZeroExtend(lhs.size - rhs.size, rhs.toIR)) - } - case RSHIFT => - if (lhs.size == rhs.size) { - BinaryExpr(BVLSHR, lhs.toIR, rhs.toIR) - } else { - BinaryExpr(BVLSHR, lhs.toIR, ZeroExtend(lhs.size - rhs.size, rhs.toIR)) - } - case ARSHIFT => - if (lhs.size == rhs.size) { - BinaryExpr(BVASHR, lhs.toIR, rhs.toIR) - } else { - BinaryExpr(BVASHR, lhs.toIR, ZeroExtend(lhs.size - rhs.size, rhs.toIR)) - } - case AND => BinaryExpr(BVAND, lhs.toIR, rhs.toIR) - case OR => BinaryExpr(BVOR, lhs.toIR, rhs.toIR) - case XOR => BinaryExpr(BVXOR, lhs.toIR, rhs.toIR) - case EQ => BinaryExpr(BVCOMP, lhs.toIR, rhs.toIR) - case NEQ => UnaryExpr(BVNOT, BinaryExpr(BVCOMP, lhs.toIR, rhs.toIR)) - case LT => BinaryExpr(BVULT, lhs.toIR, rhs.toIR) - case LE => BinaryExpr(BVULE, lhs.toIR, rhs.toIR) - case SLT => BinaryExpr(BVSLT, lhs.toIR, rhs.toIR) - case SLE => BinaryExpr(BVSLE, lhs.toIR, rhs.toIR) - } - } sealed trait BAPBinOperator(op: String) { @@ -216,39 +135,23 @@ case object LE extends BAPBinOperator("LE") case object SLT extends BAPBinOperator("SLT") case object SLE extends BAPBinOperator("SLE") -trait BAPVariable extends BAPExpr - -trait BAPVar extends BAPVariable { +trait BAPVar extends BAPExpr { val name: String override val size: Int override def toString: String = name - override def toIR: Variable } +case class BAPRegister(override val name: String, override val size: Int) extends BAPVar -case class BAPRegister(override val name: String, override val size: Int) extends BAPVar { - override def toIR: Register = Register(s"$name", size) -} - -case class BAPLocalVar(override val name: String, override val size: Int) extends BAPVar { - override def toIR: LocalVar = LocalVar(s"$name", BitVecType(size)) -} +case class BAPLocalVar(override val name: String, override val size: Int) extends BAPVar /** A load from memory at location exp */ -case class BAPMemAccess(memory: BAPMemory, index: BAPExpr, endian: Endian, override val size: Int) extends BAPVariable { +case class BAPMemAccess(memory: BAPMemory, index: BAPExpr, endian: Endian, override val size: Int) extends BAPExpr { override def toString: String = s"${memory.name}[$index]" - override def toIR: MemoryLoad = { - MemoryLoad(memory.toIRMemory, index.toIR, endian, size) - } } -case class BAPMemory(name: String, addressSize: Int, valueSize: Int) extends BAPVariable { - override val size: Int = valueSize // should reconsider - override def toIR: Expr = ??? // should not encounter - def toIRMemory: Memory = SharedMemory(name, addressSize, valueSize) -} +case class BAPMemory(name: String, addressSize: Int, valueSize: Int) case class BAPStore(memory: BAPMemory, index: BAPExpr, value: BAPExpr, endian: Endian, size: Int) extends BAPExpr { - override def toIR: Expr = ??? // should not encounter override def toString: String = s"${memory.name}[$index] := $value" } diff --git a/src/main/scala/bap/BAPProgram.scala b/src/main/scala/bap/BAPProgram.scala index c586f5980..ed1f65400 100644 --- a/src/main/scala/bap/BAPProgram.scala +++ b/src/main/scala/bap/BAPProgram.scala @@ -45,15 +45,6 @@ case class BAPBlock(label: String, address: Option[BigInt], statements: List[BAP } -case class BAPParameter(name: String, size: Int, value: BAPVar) { - def toIR: Parameter = { - val register = value.toIR - register match { - case r: Register => Parameter(name, size, r) - case _ => throw Exception(s"subroutine parameter $this refers to non-register variable $value") - } - - } -} +case class BAPParameter(name: String, size: Int, value: BAPVar) case class BAPMemorySection(name: String, address: BigInt, size: Int, bytes: Seq[BAPLiteral]) diff --git a/src/main/scala/bap/BAPStatement.scala b/src/main/scala/bap/BAPStatement.scala index 09b77f60d..2f9afe34d 100644 --- a/src/main/scala/bap/BAPStatement.scala +++ b/src/main/scala/bap/BAPStatement.scala @@ -23,14 +23,12 @@ case class BAPGoTo(target: String, condition: BAPExpr, override val line: String sealed trait BAPStatement -sealed trait BAPAssign(lhs: BAPVariable, rhs: BAPExpr, line: String, instruction: String) extends BAPStatement { - override def toString: String = String.format("%s := %s;", lhs, rhs) -} - /** Memory store */ -case class BAPMemAssign(lhs: BAPMemory, rhs: BAPStore, line: String, instruction: String, address: Option[BigInt] = None) - extends BAPAssign(lhs, rhs, line, instruction) +case class BAPMemAssign(lhs: BAPMemory, rhs: BAPStore, line: String, instruction: String, address: Option[BigInt] = None) extends BAPStatement { + override def toString: String = String.format("%s := %s;", lhs, rhs) +} -case class BAPLocalAssign(lhs: BAPVar, rhs: BAPExpr, line: String, instruction: String, address: Option[BigInt] = None) - extends BAPAssign(lhs, rhs, line, instruction) +case class BAPLocalAssign(lhs: BAPVar, rhs: BAPExpr, line: String, instruction: String, address: Option[BigInt] = None) extends BAPStatement { + override def toString: String = String.format("%s := %s;", lhs, rhs) +} diff --git a/src/main/scala/ir/Expr.scala b/src/main/scala/ir/Expr.scala index 03579c2cf..d250ca9fc 100644 --- a/src/main/scala/ir/Expr.scala +++ b/src/main/scala/ir/Expr.scala @@ -5,7 +5,6 @@ import scala.collection.mutable sealed trait Expr { def toBoogie: BExpr - def loads: Set[MemoryLoad] = Set() def getType: IRType def gammas: Set[Variable] = Set() // variables not including those inside a load's index def variables: Set[Variable] = Set() @@ -54,7 +53,6 @@ case class Extract(end: Int, start: Int, body: Expr) extends Expr { override def getType: BitVecType = BitVecType(end - start) override def toString: String = s"$body[$end:$start]" override def acceptVisit(visitor: Visitor): Expr = visitor.visitExtract(this) - override def loads: Set[MemoryLoad] = body.loads } case class Repeat(repeats: Int, body: Expr) extends Expr { @@ -68,7 +66,6 @@ case class Repeat(repeats: Int, body: Expr) extends Expr { } override def toString: String = s"Repeat($repeats, $body)" override def acceptVisit(visitor: Visitor): Expr = visitor.visitRepeat(this) - override def loads: Set[MemoryLoad] = body.loads } case class ZeroExtend(extension: Int, body: Expr) extends Expr { @@ -82,7 +79,6 @@ case class ZeroExtend(extension: Int, body: Expr) extends Expr { } override def toString: String = s"ZeroExtend($extension, $body)" override def acceptVisit(visitor: Visitor): Expr = visitor.visitZeroExtend(this) - override def loads: Set[MemoryLoad] = body.loads } case class SignExtend(extension: Int, body: Expr) extends Expr { @@ -96,14 +92,12 @@ case class SignExtend(extension: Int, body: Expr) extends Expr { } override def toString: String = s"SignExtend($extension, $body)" override def acceptVisit(visitor: Visitor): Expr = visitor.visitSignExtend(this) - override def loads: Set[MemoryLoad] = body.loads } case class UnaryExpr(op: UnOp, arg: Expr) extends Expr { override def toBoogie: BExpr = UnaryBExpr(op, arg.toBoogie) override def gammas: Set[Variable] = arg.gammas override def variables: Set[Variable] = arg.variables - override def loads: Set[MemoryLoad] = arg.loads override def getType: IRType = (op, arg.getType) match { case (_: BoolUnOp, BoolType) => BoolType case (_: BVUnOp, bv: BitVecType) => bv @@ -152,7 +146,6 @@ case class BinaryExpr(op: BinOp, arg1: Expr, arg2: Expr) extends Expr { override def toBoogie: BExpr = BinaryBExpr(op, arg1.toBoogie, arg2.toBoogie) override def gammas: Set[Variable] = arg1.gammas ++ arg2.gammas override def variables: Set[Variable] = arg1.variables ++ arg2.variables - override def loads: Set[MemoryLoad] = arg1.loads ++ arg2.loads override def getType: IRType = (op, arg1.getType, arg2.getType) match { case (_: BoolBinOp, BoolType, BoolType) => BoolType case (binOp: BVBinOp, bv1: BitVecType, bv2: BitVecType) => @@ -292,22 +285,6 @@ enum Endian { case BigEndian } -case class MemoryLoad(mem: Memory, index: Expr, endian: Endian, size: Int) extends Expr { - override def toBoogie: BMemoryLoad = BMemoryLoad(mem.toBoogie, index.toBoogie, endian, size) - def toGamma(LArgs: List[BMapVar]): BExpr = mem match { - case m: StackMemory => - GammaLoad(m.toGamma, index.toBoogie, size, size / m.valueSize) - case m: SharedMemory => - BinaryBExpr(BoolOR, GammaLoad(m.toGamma, index.toBoogie, size, size / m.valueSize), L(LArgs, index.toBoogie)) - } - override def variables: Set[Variable] = index.variables - override def gammas: Set[Variable] = Set() - override def loads: Set[MemoryLoad] = Set(this) - override def getType: IRType = BitVecType(size) - override def toString: String = s"MemoryLoad($mem, $index, $endian, $size)" - override def acceptVisit(visitor: Visitor): Expr = visitor.visitMemoryLoad(this) -} - case class UninterpretedFunction(name: String, params: Seq[Expr], returnType: IRType) extends Expr { override def getType: IRType = returnType override def toBoogie: BFunctionCall = BFunctionCall(name, params.map(_.toBoogie).toList, returnType.toBoogie, true) diff --git a/src/main/scala/ir/Interpreter.scala b/src/main/scala/ir/Interpreter.scala index 53ef40c2d..204a3fda7 100644 --- a/src/main/scala/ir/Interpreter.scala +++ b/src/main/scala/ir/Interpreter.scala @@ -85,11 +85,6 @@ class Interpreter() { case BVNOT => smt_bvnot(arg) } - case ml: MemoryLoad => - Logger.debug(s"\t$ml") - val index: Int = eval(ml.index, env).value.toInt - getMemory(index, ml.size, ml.endian, mems) - case u: UninterpretedFunction => Logger.debug(s"\t$u") ??? @@ -259,26 +254,35 @@ class Interpreter() { private def interpretStatement(s: Statement): Unit = { Logger.debug(s"statement[$s]:") s match { - case assign: Assign => + case assign: LocalAssign => Logger.debug(s"LocalAssign ${assign.lhs} = ${assign.rhs}") val evalRight = eval(assign.rhs, regs) Logger.debug(s"LocalAssign ${assign.lhs} := 0x${evalRight.value.toString(16)}[u${evalRight.size}]\n") regs += (assign.lhs -> evalRight) - case assign: MemoryAssign => - Logger.debug(s"MemoryAssign ${assign.mem}[${assign.index}] = ${assign.value}") + case store: MemoryStore => + Logger.debug(s"MemoryStore ${store.mem}[${store.index}] = ${store.value}") - val index: Int = eval(assign.index, regs).value.toInt - val value: BitVecLiteral = eval(assign.value, regs) - Logger.debug(s"\tMemoryStore(mem:${assign.mem}, index:0x${index.toHexString}, value:0x${ - value.value - .toString(16) - }[u${value.size}], size:${assign.size})") + val index: Int = eval(store.index, regs).value.toInt + val value: BitVecLiteral = eval(store.value, regs) + Logger.debug(s"\tMemoryStore(mem:${store.mem}, index:0x${index.toHexString}, value:0x${ + value.value.toString(16) + }[u${value.size}], size:${store.size})") - val evalStore = setMemory(index, assign.size, assign.endian, value, mems) + val evalStore = setMemory(index, store.size, store.endian, value, mems) evalStore match { case BitVecLiteral(value, size) => - Logger.debug(s"MemoryAssign ${assign.mem} := 0x${value.toString(16)}[u$size]\n") + Logger.debug(s"MemoryStore ${store.mem} := 0x${value.toString(16)}[u$size]\n") + } + case load: MemoryLoad => + Logger.debug(s"MemoryLoad ${load.lhs} = ${load.mem}[${load.index}]") + val index: Int = eval(load.index, regs).value.toInt + Logger.debug(s"MemoryLoad ${load.lhs} := ${load.mem}[0x${index.toHexString}[u${load.size}]\n") + val evalLoad = getMemory(index, load.size, load.endian, mems) + regs += (load.lhs -> evalLoad) + evalLoad match { + case BitVecLiteral(value, size) => + Logger.debug(s"MemoryStore ${load.lhs} := 0x${value.toString(16)}[u$size]\n") } case _ : NOP => () case assert: Assert => diff --git a/src/main/scala/ir/Statement.scala b/src/main/scala/ir/Statement.scala index ce49bc82e..c1862a9ff 100644 --- a/src/main/scala/ir/Statement.scala +++ b/src/main/scala/ir/Statement.scala @@ -23,33 +23,47 @@ sealed trait Statement extends Command, IntrusiveListElement[Statement] { def acceptVisit(visitor: Visitor): Statement = throw new Exception( "visitor " + visitor + " unimplemented for: " + this ) - def successor: Command = parent.statements.nextOption(this).getOrElse(parent.jump) +} +sealed trait Assign extends Statement { + var lhs: Variable } -// invariant: rhs contains at most one MemoryLoad -class Assign(var lhs: Variable, var rhs: Expr, override val label: Option[String] = None) extends Statement { +class LocalAssign(var lhs: Variable, var rhs: Expr, override val label: Option[String] = None) extends Assign { override def modifies: Set[Global] = lhs match { case r: Register => Set(r) case _ => Set() } override def toString: String = s"$labelStr$lhs := $rhs" - override def acceptVisit(visitor: Visitor): Statement = visitor.visitAssign(this) + override def acceptVisit(visitor: Visitor): Statement = visitor.visitLocalAssign(this) } -object Assign: - def unapply(l: Assign): Option[(Variable, Expr, Option[String])] = Some(l.lhs, l.rhs, l.label) +object LocalAssign: + def unapply(l: LocalAssign): Option[(Variable, Expr, Option[String])] = Some(l.lhs, l.rhs, l.label) -// invariant: index and value do not contain MemoryLoads -class MemoryAssign(var mem: Memory, var index: Expr, var value: Expr, var endian: Endian, var size: Int, override val label: Option[String] = None) extends Statement { +class MemoryStore(var mem: Memory, var index: Expr, var value: Expr, var endian: Endian, var size: Int, override val label: Option[String] = None) extends Statement { override def modifies: Set[Global] = Set(mem) override def toString: String = s"$labelStr$mem[$index] := MemoryStore($value, $endian, $size)" - override def acceptVisit(visitor: Visitor): Statement = visitor.visitMemoryAssign(this) + override def acceptVisit(visitor: Visitor): Statement = visitor.visitMemoryStore(this) } -object MemoryAssign: - def unapply(m: MemoryAssign): Option[(Memory, Expr, Expr, Endian, Int, Option[String])] = Some(m.mem, m.index, m.value, m.endian, m.size, m.label) +object MemoryStore { + def unapply(m: MemoryStore): Option[(Memory, Expr, Expr, Endian, Int, Option[String])] = Some(m.mem, m.index, m.value, m.endian, m.size, m.label) +} + +class MemoryLoad(var lhs: Variable, var mem: Memory, var index: Expr, var endian: Endian, var size: Int, override val label: Option[String] = None) extends Assign { + override def modifies: Set[Global] = lhs match { + case r: Register => Set(r) + case _ => Set() + } + override def toString: String = s"$labelStr$lhs := MemoryLoad($mem, $index, $endian, $size)" + override def acceptVisit(visitor: Visitor): Statement = visitor.visitMemoryLoad(this) +} + +object MemoryLoad { + def unapply(m: MemoryLoad): Option[(Variable, Memory, Expr, Endian, Int, Option[String])] = Some(m.lhs, m.mem, m.index, m.endian, m.size, m.label) +} class NOP(override val label: Option[String] = None) extends Statement { override def toString: String = s"NOP $labelStr" @@ -87,10 +101,17 @@ class Unreachable(override val label: Option[String] = None) extends Jump { override def acceptVisit(visitor: Visitor): Jump = this } +object Unreachable { + def unapply(u: Unreachable): Option[Option[String]] = Some(u.label) +} + class Return(override val label: Option[String] = None) extends Jump { override def acceptVisit(visitor: Visitor): Jump = this } +object Return { + def unapply(r: Return): Option[Option[String]] = Some(r.label) +} class GoTo private (private val _targets: mutable.LinkedHashSet[Block], override val label: Option[String]) extends Jump { diff --git a/src/main/scala/ir/Visitor.scala b/src/main/scala/ir/Visitor.scala index 0b88f2a4c..c649dc4f3 100644 --- a/src/main/scala/ir/Visitor.scala +++ b/src/main/scala/ir/Visitor.scala @@ -10,19 +10,27 @@ abstract class Visitor { def visitStatement(node: Statement): Statement = node.acceptVisit(this) - def visitAssign(node: Assign): Statement = { + def visitLocalAssign(node: LocalAssign): Statement = { node.lhs = visitVariable(node.lhs) node.rhs = visitExpr(node.rhs) node } - def visitMemoryAssign(node: MemoryAssign): Statement = { + def visitMemoryStore(node: MemoryStore): Statement = { node.mem = visitMemory(node.mem) node.index = visitExpr(node.index) node.value = visitExpr(node.value) node } + def visitMemoryLoad(node: MemoryLoad): Statement = { + node.lhs = visitVariable(node.lhs) + node.mem = visitMemory(node.mem) + node.index = visitExpr(node.index) + node + } + + def visitAssume(node: Assume): Statement = { node.body = visitExpr(node.body) node @@ -110,10 +118,6 @@ abstract class Visitor { node.copy(arg1 = visitExpr(node.arg1), arg2 = visitExpr(node.arg2)) } - def visitMemoryLoad(node: MemoryLoad): Expr = { - node.copy(mem = visitMemory(node.mem), index = visitExpr(node.index)) - } - def visitMemory(node: Memory): Memory = node.acceptVisit(this) def visitStackMemory(node: StackMemory): Memory = node @@ -166,25 +170,26 @@ abstract class ReadOnlyVisitor extends Visitor { node } - override def visitMemoryLoad(node: MemoryLoad): Expr = { - visitMemory(node.mem) - visitExpr(node.index) - node - } - - override def visitAssign(node: Assign): Statement = { + override def visitLocalAssign(node: LocalAssign): Statement = { visitVariable(node.lhs) visitExpr(node.rhs) node } - override def visitMemoryAssign(node: MemoryAssign): Statement = { + override def visitMemoryStore(node: MemoryStore): Statement = { visitMemory(node.mem) visitExpr(node.index) visitExpr(node.value) node } + override def visitMemoryLoad(node: MemoryLoad): Statement = { + visitVariable(node.lhs) + visitMemory(node.mem) + visitExpr(node.index) + node + } + override def visitAssume(node: Assume): Statement = { visitExpr(node.body) node @@ -307,17 +312,18 @@ class StackSubstituter extends IntraproceduralControlFlowVisitor { override def visitMemoryLoad(node: MemoryLoad): MemoryLoad = { // replace mem with stack in load if index contains stack references val loadStackRefs = node.index.variables.intersect(stackRefs) + if (loadStackRefs.nonEmpty) { - node.copy(mem = stackMemory) - } else { - node + node.mem = stackMemory + } + if (stackRefs.contains(node.lhs) && node.lhs != stackPointer) { + stackRefs.remove(node.lhs) } - } - override def visitAssign(node: Assign): Statement = { - node.lhs = visitVariable(node.lhs) - node.rhs = visitExpr(node.rhs) + node + } + override def visitLocalAssign(node: LocalAssign): Statement = { // update stack references val variableVisitor = VariablesWithoutStoresLoads() variableVisitor.visitExpr(node.rhs) @@ -331,7 +337,7 @@ class StackSubstituter extends IntraproceduralControlFlowVisitor { node } - override def visitMemoryAssign(node: MemoryAssign): Statement = { + override def visitMemoryStore(node: MemoryStore): Statement = { val indexStackRefs = node.index.variables.intersect(stackRefs) if (indexStackRefs.nonEmpty) { node.mem = stackMemory @@ -421,7 +427,7 @@ class ExternalRemover(external: Set[String]) extends Visitor { } } -/** Gives variables that are not contained within a MemoryStore or MemoryLoad +/** Gives variables that are not contained within a MemoryStore or the rhs of a MemoryLoad * */ class VariablesWithoutStoresLoads extends ReadOnlyVisitor { val variables: mutable.Set[Variable] = mutable.Set() @@ -437,6 +443,7 @@ class VariablesWithoutStoresLoads extends ReadOnlyVisitor { } override def visitMemoryLoad(node: MemoryLoad): MemoryLoad = { + visitVariable(node.lhs) node } diff --git a/src/main/scala/ir/cilvisitor/CILVisitor.scala b/src/main/scala/ir/cilvisitor/CILVisitor.scala index 5583b12da..372b500de 100644 --- a/src/main/scala/ir/cilvisitor/CILVisitor.scala +++ b/src/main/scala/ir/cilvisitor/CILVisitor.scala @@ -35,7 +35,7 @@ trait CILVisitor: def leave_scope(outparam: ArrayBuffer[Parameter]): Unit = () -def doVisitList[T](v: CILVisitor, a: VisitAction[List[T]], n: T, continue: (T) => T): List[T] = { +def doVisitList[T](v: CILVisitor, a: VisitAction[List[T]], n: T, continue: T => T): List[T] = { a match { case SkipChildren() => List(n) case ChangeTo(z) => z @@ -44,7 +44,7 @@ def doVisitList[T](v: CILVisitor, a: VisitAction[List[T]], n: T, continue: (T) = } } -def doVisit[T](v: CILVisitor, a: VisitAction[T], n: T, continue: (T) => T): T = { +def doVisit[T](v: CILVisitor, a: VisitAction[T], n: T, continue: T => T): T = { a match { case SkipChildren() => n case DoChildren() => continue(n) @@ -56,31 +56,30 @@ def doVisit[T](v: CILVisitor, a: VisitAction[T], n: T, continue: (T) => T): T = class CILVisitorImpl(val v: CILVisitor) { def visit_parameters(p: ArrayBuffer[Parameter]): ArrayBuffer[Parameter] = { - doVisit(v, v.vparams(p), p, (n) => n) + doVisit(v, v.vparams(p), p, n => n) } def visit_var(n: Variable): Variable = { - doVisit(v, v.vvar(n), n, (n) => n) + doVisit(v, v.vvar(n), n, n => n) } def visit_mem(n: Memory): Memory = { - doVisit(v, v.vmem(n), n, (n) => n) + doVisit(v, v.vmem(n), n, n => n) } def visit_jump(j: Jump): Jump = { - doVisit(v, v.vjump(j), j, (j) => j) + doVisit(v, v.vjump(j), j, j => j) } def visit_fallthrough(j: Option[GoTo]): Option[GoTo] = { - doVisit(v, v.vfallthrough(j), j, (j) => j) + doVisit(v, v.vfallthrough(j), j, j => j) } def visit_expr(n: Expr): Expr = { def continue(n: Expr): Expr = n match { case n: Literal => n - case MemoryLoad(mem, index, endian, size) => MemoryLoad(visit_mem(mem), visit_expr(index), endian, size) case Extract(end, start, arg) => Extract(end, start, visit_expr(arg)) case Repeat(repeats, arg) => Repeat(repeats, visit_expr(arg)) case ZeroExtend(bits, arg) => ZeroExtend(bits, visit_expr(arg)) @@ -96,29 +95,29 @@ class CILVisitorImpl(val v: CILVisitor) { def visit_stmt(s: Statement): List[Statement] = { def continue(n: Statement) = n match { case d: DirectCall => d - case i: IndirectCall => { + case i: IndirectCall => i.target = visit_var(i.target) i - } - case m: MemoryAssign => { + case m: MemoryStore => m.mem = visit_mem(m.mem) m.index = visit_expr(m.index) m.value = visit_expr(m.value) m - } - case m: Assign => { + case m: MemoryLoad => + m.mem = visit_mem(m.mem) + m.index = visit_expr(m.index) + m.lhs = visit_var(m.lhs) + m + case m: LocalAssign => m.rhs = visit_expr(m.rhs) m.lhs = visit_var(m.lhs) m - } - case s: Assert => { + case s: Assert => s.body = visit_expr(s.body) s - } - case s: Assume => { + case s: Assume => s.body = visit_expr(s.body) s - } case n: NOP => n } doVisitList(v, v.vstmt(s), s, continue) @@ -126,7 +125,7 @@ class CILVisitorImpl(val v: CILVisitor) { def visit_block(b: Block): Block = { def continue(b: Block) = { - b.statements.foreach(s => { + b.statements.foreach { s => val r = visit_stmt(s) r match { case Nil => b.statements.remove(s) @@ -134,7 +133,7 @@ class CILVisitorImpl(val v: CILVisitor) { b.statements.replace(s, n) b.statements.insertAllAfter(Some(n), tl) } - }) + } b.replaceJump(visit_jump(b.jump)) b } diff --git a/src/main/scala/translating/BAPToIR.scala b/src/main/scala/translating/BAPToIR.scala index 908978046..9793022cb 100644 --- a/src/main/scala/translating/BAPToIR.scala +++ b/src/main/scala/translating/BAPToIR.scala @@ -16,6 +16,8 @@ class BAPToIR(var program: BAPProgram, mainAddress: BigInt) { private val nameToProcedure: mutable.Map[String, Procedure] = mutable.Map() private val labelToBlock: mutable.Map[String, Block] = mutable.Map() + private var loadCounter: Int = 0 + def translate: Program = { var mainProcedure: Option[Procedure] = None val procedures: ArrayBuffer[Procedure] = ArrayBuffer() @@ -30,10 +32,10 @@ class BAPToIR(var program: BAPProgram, mainAddress: BigInt) { labelToBlock.addOne(b.label, block) } for (p <- s.in) { - procedure.in.append(p.toIR) + procedure.in.append(translateParameter(p)) } for (p <- s.out) { - procedure.out.append(p.toIR) + procedure.out.append(translateParameter(p)) } if (s.address.get == mainAddress) { mainProcedure = Some(procedure) @@ -47,7 +49,10 @@ class BAPToIR(var program: BAPProgram, mainAddress: BigInt) { for (b <- s.blocks) { val block = labelToBlock(b.label) for (st <- b.statements) { - block.statements.append(translate(st)) + val statements = translateStatement(st) + for (s <- statements) { + block.statements.append(s) + } } val (call, jump, newBlocks) = translate(b.jumps, block) procedure.addBlocks(newBlocks) @@ -68,7 +73,7 @@ class BAPToIR(var program: BAPProgram, mainAddress: BigInt) { val bytes = if (m.name == ".bss" && m.bytes.isEmpty) { for (_ <- 0 until m.size) yield BitVecLiteral(0, 8) } else { - m.bytes.map(_.toIR) + m.bytes.map(translateLiteral) } val readOnly = m.name == ".rodata" || m.name == ".got" // crude heuristic memorySections.addOne(m.address, MemorySection(m.name, m.address, m.size, bytes, readOnly, None)) @@ -77,17 +82,146 @@ class BAPToIR(var program: BAPProgram, mainAddress: BigInt) { Program(procedures, mainProcedure.get, memorySections) } - private def translate(s: BAPStatement) = s match { + private def translateStatement(s: BAPStatement): Seq[Statement] = s match { case b: BAPMemAssign => - val mem = b.lhs.toIRMemory - if (mem != b.rhs.memory.toIRMemory) { + val mem = translateMemory(b.lhs) + if (mem != translateMemory(b.rhs.memory)) { throw Exception(s"$b has conflicting lhs ${b.lhs} and rhs ${b.rhs.memory}") } - MemoryAssign(mem, b.rhs.index.toIR, b.rhs.value.toIR, b.rhs.endian, b.rhs.size, Some(b.line)) + Seq(MemoryStore(mem, translateExprOnly(b.rhs.index), translateExprOnly(b.rhs.value), b.rhs.endian, b.rhs.size, Some(b.line))) case b: BAPLocalAssign => - Assign(b.lhs.toIR, b.rhs.toIR, Some(b.line)) + val lhs = translateVar(b.lhs) + val (rhs, load) = translateExpr(b.rhs) + if (load.isDefined) { + val loadWithLabel = MemoryLoad(load.get.lhs, load.get.mem, load.get.index, load.get.endian, load.get.size, Some(b.line + "$0")) + val assign = LocalAssign(lhs, rhs, Some(b.line + "$1")) + Seq(loadWithLabel, assign) + } else { + val assign = LocalAssign(lhs, rhs, Some(b.line)) + Seq(assign) + } + } + + private def translateExpr(e: BAPExpr): (Expr, Option[MemoryLoad]) = e match { + case b @ BAPConcat(left, right) => + val (arg0, load0) = translateExpr(left) + val (arg1, load1) = translateExpr(right) + (load0, load1) match { + case (Some(load), None) => (BinaryExpr(BVCONCAT, arg0, arg1), Some(load)) + case (None, Some(load)) => (BinaryExpr(BVCONCAT, arg0, arg1), Some(load)) + case (None, None) => (BinaryExpr(BVCONCAT, arg0, arg1), None) + case (Some(_), Some(_)) => throw Exception(s"$b contains multiple loads") + } + case BAPSignedExtend(width, body) => + if (width > body.size) { + val (irBody, load) = translateExpr(body) + val se = SignExtend(width - body.size, irBody) + (se, load) + } else { + translateExpr(BAPExtract(width - 1, 0, body)) + } + case BAPUnsignedExtend(width, body) => + if (width > body.size) { + val (irBody, load) = translateExpr(body) + val ze = ZeroExtend(width - body.size, irBody) + (ze, load) + } else { + translateExpr(BAPExtract(width - 1, 0, body)) + } + case b @ BAPExtract(high, low, body) => + val bodySize = body.size + val (irBody, load) = translateExpr(body) + val extract = if (b.size > bodySize) { + if (low == 0) { + ZeroExtend(b.size - bodySize, irBody) + } else { + Extract(high + 1, low, ZeroExtend(b.size - bodySize, irBody)) + } + } else { + Extract(high + 1, low, irBody) + } + (extract, load) + case literal: BAPLiteral => (translateLiteral(literal), None) + case BAPUnOp(operator, exp) => operator match { + case NOT => (UnaryExpr(BVNOT, translateExprOnly(exp)), None) + case NEG => (UnaryExpr(BVNEG, translateExprOnly(exp)), None) + } + case BAPBinOp(operator, lhs, rhs) => operator match { + case PLUS => (BinaryExpr(BVADD, translateExprOnly(lhs), translateExprOnly(rhs)), None) + case MINUS => (BinaryExpr(BVSUB, translateExprOnly(lhs), translateExprOnly(rhs)), None) + case TIMES => (BinaryExpr(BVMUL, translateExprOnly(lhs), translateExprOnly(rhs)), None) + case DIVIDE => (BinaryExpr(BVUDIV, translateExprOnly(lhs), translateExprOnly(rhs)), None) + case SDIVIDE => (BinaryExpr(BVSDIV, translateExprOnly(lhs), translateExprOnly(rhs)), None) + // counterintuitive but correct according to BAP source + case MOD => (BinaryExpr(BVSREM, translateExprOnly(lhs), translateExprOnly(rhs)), None) + // counterintuitive but correct according to BAP source + case SMOD => (BinaryExpr(BVUREM, translateExprOnly(lhs), translateExprOnly(rhs)), None) + case LSHIFT => // BAP says caring about this case is necessary? + if (lhs.size == rhs.size) { + (BinaryExpr(BVSHL, translateExprOnly(lhs), translateExprOnly(rhs)), None) + } else { + (BinaryExpr(BVSHL, translateExprOnly(lhs), ZeroExtend(lhs.size - rhs.size, translateExprOnly(rhs))), None) + } + case RSHIFT => + if (lhs.size == rhs.size) { + (BinaryExpr(BVLSHR, translateExprOnly(lhs), translateExprOnly(rhs)), None) + } else { + (BinaryExpr(BVLSHR, translateExprOnly(lhs), ZeroExtend(lhs.size - rhs.size, translateExprOnly(rhs))), None) + } + case ARSHIFT => + if (lhs.size == rhs.size) { + (BinaryExpr(BVASHR, translateExprOnly(lhs), translateExprOnly(rhs)), None) + } else { + (BinaryExpr(BVASHR, translateExprOnly(lhs), ZeroExtend(lhs.size - rhs.size, translateExprOnly(rhs))), None) + } + case AND => (BinaryExpr(BVAND, translateExprOnly(lhs), translateExprOnly(rhs)), None) + case OR => (BinaryExpr(BVOR, translateExprOnly(lhs), translateExprOnly(rhs)), None) + case XOR => (BinaryExpr(BVXOR, translateExprOnly(lhs), translateExprOnly(rhs)), None) + case EQ => (BinaryExpr(BVCOMP, translateExprOnly(lhs), translateExprOnly(rhs)), None) + case NEQ => (UnaryExpr(BVNOT, BinaryExpr(BVCOMP, translateExprOnly(lhs), translateExprOnly(rhs))), None) + case LT => (BinaryExpr(BVULT, translateExprOnly(lhs), translateExprOnly(rhs)), None) + case LE => (BinaryExpr(BVULE, translateExprOnly(lhs), translateExprOnly(rhs)), None) + case SLT => (BinaryExpr(BVSLT, translateExprOnly(lhs), translateExprOnly(rhs)), None) + case SLE => (BinaryExpr(BVSLE, translateExprOnly(lhs), translateExprOnly(rhs)), None) + } + case b: BAPVar => (translateVar(b), None) + case BAPMemAccess(memory, index, endian, size) => + val temp = LocalVar("$load$" + loadCounter, BitVecType(size)) + loadCounter += 1 + val load = MemoryLoad(temp, translateMemory(memory), translateExprOnly(index), endian, size, None) + (temp, Some(load)) + } + + private def translateExprOnly(e: BAPExpr) = { + val (expr, load) = translateExpr(e) + if (load.isDefined) { + throw Exception(s"unexpected load in $e") + } + expr + } + + private def translateVar(variable: BAPVar): Variable = variable match { + case BAPRegister(name, size) => Register(name, size) + case BAPLocalVar(name, size) => LocalVar(name, BitVecType(size)) + } + + private def translateMemory(memory: BAPMemory): Memory = { + SharedMemory(memory.name, memory.addressSize, memory.valueSize) } + private def translateParameter(parameter: BAPParameter): Parameter = { + val register = translateExprOnly(parameter.value) + register match { + case r: Register => Parameter(parameter.name, parameter.size, r) + case _ => throw Exception(s"subroutine parameter $this refers to non-register variable ${parameter.value}") + } + } + + private def translateLiteral(literal: BAPLiteral) = { + BitVecLiteral(literal.value, literal.size) + } + + /** * Translates a list of jumps from BAP into a single Jump at the IR level by moving any conditions on jumps to * Assume statements in new blocks @@ -112,7 +246,9 @@ class BAPToIR(var program: BAPProgram, mainAddress: BigInt) { // condition is true and previous conditions existing means this condition // is actually that all previous conditions are false val conditionsIR = conditions.map(c => convertConditionBool(c, true)) - val condition = conditionsIR.tail.foldLeft(conditionsIR.head)((ands: Expr, next: Expr) => BinaryExpr(BoolAND, next, ands)) + val condition = conditionsIR.tail.foldLeft(conditionsIR.head) { + (ands: Expr, next: Expr) => BinaryExpr(BoolAND, next, ands) + } val newBlock = newBlockCondition(block, target, condition) newBlocks.append(newBlock) targets.append(newBlock) @@ -127,7 +263,9 @@ class BAPToIR(var program: BAPProgram, mainAddress: BigInt) { // if this is not the first condition, then we need to need to add // that all previous conditions are false val conditionsIR = conditions.map(c => convertConditionBool(c, true)) - conditionsIR.tail.foldLeft(currentCondition)((ands: Expr, next: Expr) => BinaryExpr(BoolAND, next, ands)) + conditionsIR.tail.foldLeft(currentCondition) { + (ands: Expr, next: Expr) => BinaryExpr(BoolAND, next, ands) + } } val newBlock = newBlockCondition(block, target, condition) newBlocks.append(newBlock) @@ -142,11 +280,11 @@ class BAPToIR(var program: BAPProgram, mainAddress: BigInt) { jumps.head match { case b: BAPDirectCall => val call = Some(DirectCall(nameToProcedure(b.target),Some(b.line))) - val ft = (b.returnTarget.map(t => labelToBlock(t))).map(x => GoTo(Set(x))).getOrElse(Unreachable()) + val ft = b.returnTarget.map(t => labelToBlock(t)).map(x => GoTo(Set(x))).getOrElse(Unreachable()) (call, ft, ArrayBuffer()) case b: BAPIndirectCall => - val call = IndirectCall(b.target.toIR, Some(b.line)) - val ft = (b.returnTarget.map(t => labelToBlock(t))).map(x => GoTo(Set(x))).getOrElse(Unreachable()) + val call = IndirectCall(translateVar(b.target), Some(b.line)) + val ft = b.returnTarget.map(t => labelToBlock(t)).map(x => GoTo(Set(x))).getOrElse(Unreachable()) (Some(call), ft, ArrayBuffer()) case b: BAPGoTo => val target = labelToBlock(b.target) @@ -173,7 +311,7 @@ class BAPToIR(var program: BAPProgram, mainAddress: BigInt) { * if necessary. * */ private def convertConditionBool(expr: BAPExpr, negative: Boolean): Expr = { - val e = expr.toIR + val e = translateExprOnly(expr) e.getType match { case BitVecType(s) => if (negative) { diff --git a/src/main/scala/translating/SemanticsLoader.scala b/src/main/scala/translating/GTIRBLoader.scala similarity index 68% rename from src/main/scala/translating/SemanticsLoader.scala rename to src/main/scala/translating/GTIRBLoader.scala index e858a9e85..836ed5de0 100644 --- a/src/main/scala/translating/SemanticsLoader.scala +++ b/src/main/scala/translating/GTIRBLoader.scala @@ -1,5 +1,5 @@ package translating -import Parsers.SemanticsParser.* +import Parsers.ASLpParser.* import com.google.protobuf.ByteString import Parsers.* @@ -13,14 +13,15 @@ import scala.collection.mutable.ArrayBuffer import com.grammatech.gtirb.proto.Module.ByteOrder.LittleEndian import util.Logger -class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) { +class GTIRBLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) { private val constMap = mutable.Map[String, IRType]() private val varMap = mutable.Map[String, IRType]() private var instructionCount = 0 private var blockCount = 0 + private var loadCounter = 0 - val opcodeSize = 4 + private val opcodeSize = 4 def visitBlock(blockUUID: ByteString, blockCountIn: Int, blockAddress: Option[BigInt]): ArrayBuffer[Statement] = { blockCount = blockCountIn @@ -39,34 +40,31 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] val instructionAddress = a + (opcodeSize * instructionCount) instructionAddress.toString + "$" + i } - - val statement = visitStmt(s, label) - if (statement.isDefined) { - statements.append(statement.get) - } + + statements.appendAll(visitStmt(s, label)) } instructionCount += 1 } statements } - private def visitStmt(ctx: StmtContext, label: Option[String] = None): Option[Statement] = { + private def visitStmt(ctx: StmtContext, label: Option[String] = None): Seq[Statement] = { ctx match { case a: AssignContext => visitAssign(a, label) case c: ConstDeclContext => visitConstDecl(c, label) case v: VarDeclContext => visitVarDecl(v, label) case v: VarDeclsNoInitContext => visitVarDeclsNoInit(v) - None - case a: AssertContext => visitAssert(a, label) - case t: TCallContext => visitTCall(t, label) - case i: IfContext => visitIf(i, label) - case t: ThrowContext => Some(visitThrow(t, label)) + Seq() + case a: AssertContext => visitAssert(a, label).toSeq + case t: TCallContext => visitTCall(t, label).toSeq + case i: IfContext => visitIf(i, label).toSeq + case t: ThrowContext => Seq(visitThrow(t, label)) } } - private def visitAssert(ctx: AssertContext, label: Option[String] = None): Option[Assert] = { - val expr = visitExpr(ctx.expr) + private def visitAssert(ctx: AssertContext, label: Option[String] = None): Option[Statement] = { + val expr = visitExprOnly(ctx.expr) if (expr.isDefined) { Some(Assert(expr.get, None, label)) } else { @@ -90,8 +88,8 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] checkArgs(function, 1, 4, typeArgs.size, args.size, ctx.getText) val mem = SharedMemory("mem", 64, 8) // yanked from BAP val size = parseInt(typeArgs.head) * 8 - val index = visitExpr(args.head) - val value = visitExpr(args(3)) + val index = visitExprOnly(args.head) + val value = visitExprOnly(args(3)) val otherSize = parseInt(args(1)) * 8 val accessType = parseInt(args(2)) // AccType enum in ASLi, not very relevant to us if (size != otherSize) { @@ -100,12 +98,12 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] // LittleEndian is an assumption if (index.isDefined && value.isDefined) { - Some(MemoryAssign(mem, index.get, value.get, Endian.LittleEndian, size.toInt, label)) + Some(MemoryStore(mem, index.get, value.get, Endian.LittleEndian, size.toInt, label)) } else { None } case "unsupported_opcode.0" => { - val op = args.headOption.flatMap(visitExpr) match { + val op = args.headOption.flatMap(visitExprOnly) match { case Some(IntLiteral(s)) => Some("%08x".format(s)) case c => c.map(_.toString) } @@ -130,7 +128,7 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] } private def visitIf(ctx: IfContext, label: Option[String] = None): Option[TempIf] = { - val condition = visitExpr(ctx.cond) + val condition = visitExprOnly(ctx.cond) val thenStmts = ctx.thenStmts.stmt.asScala.flatMap(visitStmt(_, label)) val elseStmts = Option(ctx.elseStmts) match { @@ -151,35 +149,64 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] varMap ++= newVars } - private def visitVarDecl(ctx: VarDeclContext, label: Option[String] = None): Option[Assign] = { + private def visitVarDecl(ctx: VarDeclContext, label: Option[String] = None): Seq[Statement] = { val ty = visitType(ctx.`type`()) val name = visitIdent(ctx.lvar) varMap += (name -> ty) - val expr = visitExpr(ctx.expr()) - expr.map(Assign(LocalVar(name, ty), _, label)) + val (expr, load) = visitExpr(ctx.expr) + if (expr.isDefined) { + if (load.isDefined) { + val loadWithLabel = MemoryLoad(load.get.lhs, load.get.mem, load.get.index, load.get.endian, load.get.size, label.map(_ + "$0")) + val assign = LocalAssign(LocalVar(name, ty), expr.get, label.map(_ + "$1")) + Seq(loadWithLabel, assign) + } else { + val assign = LocalAssign(LocalVar(name, ty), expr.get, label) + Seq(assign) + } + } else { + Seq() + } } - private def visitAssign(ctx: AssignContext, label: Option[String] = None): Option[Assign] = { + private def visitAssign(ctx: AssignContext, label: Option[String] = None): Seq[Statement] = { val lhs = visitLexpr(ctx.lexpr) - val rhs = visitExpr(ctx.expr) - lhs.zip(rhs).map((lhs, rhs) => Assign(lhs, rhs, label)) + val (rhs, load) = visitExpr(ctx.expr) + if (lhs.isDefined && rhs.isDefined) { + if (load.isDefined) { + val loadWithLabel = MemoryLoad(load.get.lhs, load.get.mem, load.get.index, load.get.endian, load.get.size, label.map(_ + "$0")) + val assign = LocalAssign(lhs.get, rhs.get, label.map(_ + "$1")) + Seq(loadWithLabel, assign) + } else { + val assign = LocalAssign(lhs.get, rhs.get, label) + Seq(assign) + } + } else { + Seq() + } } - private def visitConstDecl(ctx: ConstDeclContext, label: Option[String] = None): Option[Assign] = { + private def visitConstDecl(ctx: ConstDeclContext, label: Option[String] = None): Seq[Statement] = { val ty = visitType(ctx.`type`()) val name = visitIdent(ctx.lvar) constMap += (name -> ty) - val expr = visitExpr(ctx.expr) + val (expr, load) = visitExpr(ctx.expr) if (expr.isDefined) { - Some(Assign(LocalVar(name + "$" + blockCount + "$" + instructionCount, ty), expr.get, label)) + if (load.isDefined) { + val loadWithLabel = MemoryLoad(load.get.lhs, load.get.mem, load.get.index, load.get.endian, load.get.size, label.map(_ + "$0")) + val assign = LocalAssign(LocalVar(name + "$" + blockCount + "$" + instructionCount, ty), expr.get, label.map(_ + "$1")) + Seq(loadWithLabel, assign) + } else { + val assign = LocalAssign(LocalVar(name + "$" + blockCount + "$" + instructionCount, ty), expr.get, label) + Seq(assign) + } } else { - None + Seq() } } private def visitType(ctx: TypeContext): IRType = { - ctx match + ctx match { case e: TypeBitsContext => BitVecType(parseInt(e.size).toInt) case r: TypeRegisterContext => // this is a special register - not the same as a register in the IR @@ -191,21 +218,31 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] case _ => throw Exception(s"unknown type ${ctx.getText}") } case _ => throw Exception(s"unknown type ${ctx.getText}") + } } - private def visitExpr(ctx: ExprContext): Option[Expr] = { + private def visitExpr(ctx: ExprContext): (Option[Expr], Option[MemoryLoad]) = { ctx match { - case e: ExprVarContext => visitExprVar(e) + case e: ExprVarContext => (visitExprVar(e), None) case e: ExprTApplyContext => visitExprTApply(e) case e: ExprSlicesContext => visitExprSlices(e) - case e: ExprFieldContext => Some(visitExprField(e)) - case e: ExprArrayContext => Some(visitExprArray(e)) - case e: ExprLitIntContext => Some(IntLiteral(parseInt(e))) - case e: ExprLitBitsContext => Some(visitExprLitBits(e)) + case e: ExprFieldContext => (Some(visitExprField(e)), None) + case e: ExprArrayContext => (Some(visitExprArray(e)), None) + case e: ExprLitIntContext => (Some(IntLiteral(parseInt(e))), None) + case e: ExprLitBitsContext => (Some(visitExprLitBits(e)), None) } } - private def visitExprVar(ctx: ExprVarContext): Option[Expr] = { + private def visitExprOnly(ctx: ExprContext): Option[Expr] = { + val (expr, load) = visitExpr(ctx) + if (load.isDefined) { + throw Exception("") + } else { + expr + } + } + + private def visitExprVar(ctx: ExprVarContext): Option[Expr] = { val name = visitIdent(ctx.ident) name match { case n if constMap.contains(n) => Some(LocalVar(n + "$" + blockCount + "$" + instructionCount, constMap(n))) @@ -225,7 +262,7 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] } } - private def visitExprTApply(ctx: ExprTApplyContext): Option[Expr] = { + private def visitExprTApply(ctx: ExprTApplyContext): (Option[Expr], Option[MemoryLoad]) = { val function = visitIdent(ctx.ident) val typeArgs: mutable.Buffer[ExprContext] = Option(ctx.tes) match { @@ -241,7 +278,7 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] case "Mem.read.0" => checkArgs(function, 1, 3, typeArgs.size, args.size, ctx.getText) val mem = SharedMemory("mem", 64, 8) - val index = visitExpr(args.head) + val index = visitExprOnly(args.head) // can't have load inside load val size = parseInt(typeArgs.head) * 8 val otherSize = parseInt(args(1)) * 8 @@ -250,112 +287,112 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] throw Exception(s"inconsistent size parameters in Mem.read.0: ${ctx.getText}") } + val temp = LocalVar("$load" + loadCounter, BitVecType(size.toInt)) + loadCounter += 1 + if (index.isDefined) { // LittleEndian is assumed - Some(MemoryLoad(mem, index.get, Endian.LittleEndian, size.toInt)) + (Some(temp), Some(MemoryLoad(temp, mem, index.get, Endian.LittleEndian, size.toInt, None))) } else { - None + (None, None) } case "cvt_bool_bv.0" => checkArgs(function, 0, 1, typeArgs.size, args.size, ctx.getText) - val expr = visitExpr(args.head) - if (expr.isDefined) { - val e = expr.get - e match { - case b: BinaryExpr if b.op == BVEQ => Some(BinaryExpr(BVCOMP, b.arg1, b.arg2)) - case FalseLiteral => Some(BitVecLiteral(0, 1)) - case TrueLiteral => Some(BitVecLiteral(1, 1)) - case _ => throw Exception(s"unhandled conversion from bool to bitvector: ${ctx.getText}") - } - } else { - None + val expr = visitExprOnly(args.head) + val result = expr.map { + case b: BinaryExpr if b.op == BVEQ => BinaryExpr(BVCOMP, b.arg1, b.arg2) + case FalseLiteral => BitVecLiteral(0, 1) + case TrueLiteral => BitVecLiteral(1, 1) + case _ => throw Exception(s"unhandled conversion from bool to bitvector: ${ctx.getText}") } - - case "not_bool.0" => resolveUnaryOp(BoolNOT, function, 0, typeArgs, args, ctx.getText) - case "eq_enum.0" => resolveBinaryOp(BoolEQ, function, 0, typeArgs, args, ctx.getText) - case "or_bool.0" => resolveBinaryOp(BoolOR, function, 0, typeArgs, args, ctx.getText) - case "and_bool.0" => resolveBinaryOp(BoolAND, function, 0, typeArgs, args, ctx.getText) - - case "not_bits.0" => resolveUnaryOp(BVNOT, function, 1, typeArgs, args, ctx.getText) - case "or_bits.0" => resolveBinaryOp(BVOR, function, 1, typeArgs, args, ctx.getText) - case "and_bits.0" => resolveBinaryOp(BVAND, function, 1, typeArgs, args, ctx.getText) - case "eor_bits.0" => resolveBinaryOp(BVXOR, function, 1, typeArgs, args, ctx.getText) - case "eq_bits.0" => resolveBinaryOp(BVEQ, function, 1, typeArgs, args, ctx.getText) - case "add_bits.0" => resolveBinaryOp(BVADD, function, 1, typeArgs, args, ctx.getText) - case "sub_bits.0" => resolveBinaryOp(BVSUB, function, 1, typeArgs, args, ctx.getText) - case "mul_bits.0" => resolveBinaryOp(BVMUL, function, 1, typeArgs, args, ctx.getText) - case "sdiv_bits.0" => resolveBinaryOp(BVSDIV, function, 1, typeArgs, args, ctx.getText) - - case "slt_bits.0" => resolveBinaryOp(BVSLT, function, 1, typeArgs, args, ctx.getText) - case "sle_bits.0" => resolveBinaryOp(BVSLE, function, 1, typeArgs, args, ctx.getText) - - case "lsl_bits.0" => resolveBitShiftOp(BVSHL, function, typeArgs, args, ctx.getText) - case "lsr_bits.0" => resolveBitShiftOp(BVLSHR, function, typeArgs, args, ctx.getText) - case "asr_bits.0" => resolveBitShiftOp(BVASHR, function, typeArgs, args, ctx.getText) + (result, None) + + case "not_bool.0" => (resolveUnaryOp(BoolNOT, function, 0, typeArgs, args, ctx.getText), None) + case "eq_enum.0" => (resolveBinaryOp(BoolEQ, function, 0, typeArgs, args, ctx.getText), None) + case "or_bool.0" => (resolveBinaryOp(BoolOR, function, 0, typeArgs, args, ctx.getText), None) + case "and_bool.0" => (resolveBinaryOp(BoolAND, function, 0, typeArgs, args, ctx.getText), None) + + case "not_bits.0" => (resolveUnaryOp(BVNOT, function, 1, typeArgs, args, ctx.getText), None) + case "or_bits.0" => (resolveBinaryOp(BVOR, function, 1, typeArgs, args, ctx.getText), None) + case "and_bits.0" => (resolveBinaryOp(BVAND, function, 1, typeArgs, args, ctx.getText), None) + case "eor_bits.0" => (resolveBinaryOp(BVXOR, function, 1, typeArgs, args, ctx.getText), None) + case "eq_bits.0" => (resolveBinaryOp(BVEQ, function, 1, typeArgs, args, ctx.getText), None) + case "add_bits.0" => (resolveBinaryOp(BVADD, function, 1, typeArgs, args, ctx.getText), None) + case "sub_bits.0" => (resolveBinaryOp(BVSUB, function, 1, typeArgs, args, ctx.getText), None) + case "mul_bits.0" => (resolveBinaryOp(BVMUL, function, 1, typeArgs, args, ctx.getText), None) + case "sdiv_bits.0" => (resolveBinaryOp(BVSDIV, function, 1, typeArgs, args, ctx.getText), None) + + case "slt_bits.0" => (resolveBinaryOp(BVSLT, function, 1, typeArgs, args, ctx.getText), None) + case "sle_bits.0" => (resolveBinaryOp(BVSLE, function, 1, typeArgs, args, ctx.getText), None) + + case "lsl_bits.0" => (resolveBitShiftOp(BVSHL, function, typeArgs, args, ctx.getText), None) + case "lsr_bits.0" => (resolveBitShiftOp(BVLSHR, function, typeArgs, args, ctx.getText), None) + case "asr_bits.0" => (resolveBitShiftOp(BVASHR, function, typeArgs, args, ctx.getText), None) case "append_bits.0" => - resolveBinaryOp(BVCONCAT, function, 2, typeArgs, args, ctx.getText) + (resolveBinaryOp(BVCONCAT, function, 2, typeArgs, args, ctx.getText), None) case "replicate_bits.0" => checkArgs(function, 2, 2, typeArgs.size, args.size, ctx.getText) val oldSize = parseInt(typeArgs(0)) val replications = parseInt(typeArgs(1)).toInt - val arg0 = visitExpr(args(0)) + // memory loads shouldn't appear here? + val arg0 = visitExprOnly(args(0)) val arg1 = parseInt(args(1)) val newSize = oldSize * replications if (arg1 != replications) { Exception(s"inconsistent size parameters in replicate_bits.0: ${ctx.getText}") } if (arg0.isDefined) { - Some(Repeat(replications, arg0.get)) + (Some(Repeat(replications, arg0.get)), None) } else { - None + (None, None) } case "ZeroExtend.0" => checkArgs(function, 2, 2, typeArgs.size, args.size, ctx.getText) val oldSize = parseInt(typeArgs(0)) val newSize = parseInt(typeArgs(1)) - val arg0 = visitExpr(args(0)) + val (arg0, load) = visitExpr(args(0)) val arg1 = parseInt(args(1)) if (arg1 != newSize) { Exception(s"inconsistent size parameters in ZeroExtend.0: ${ctx.getText}") } if (arg0.isDefined) { - Some(ZeroExtend((newSize - oldSize).toInt, arg0.get)) + (Some(ZeroExtend((newSize - oldSize).toInt, arg0.get)), load) } else { - None + (None, None) } case "SignExtend.0" => checkArgs(function, 2, 2, typeArgs.size, args.size, ctx.getText) val oldSize = parseInt(typeArgs(0)) val newSize = parseInt(typeArgs(1)) - val arg0 = visitExpr(args(0)) + val (arg0, load) = visitExpr(args(0)) val arg1 = parseInt(args(1)) if (arg1 != newSize) { Exception(s"inconsistent size parameters in SignExtend.0: ${ctx.getText}") } if (arg0.isDefined) { - Some(SignExtend((newSize - oldSize).toInt, arg0.get)) + (Some(SignExtend((newSize - oldSize).toInt, arg0.get)), load) } else { - None + (None, None) } case "FPCompareGT.0" | "FPCompareGE.0" | "FPCompareEQ.0" => checkArgs(function, 1, 3, typeArgs.size, args.size, ctx.getText) val name = function.stripSuffix(".0") val size = parseInt(typeArgs(0)) - val argsIR = args.flatMap(visitExpr).toSeq - Some(UninterpretedFunction(name + "$" + size, argsIR, BoolType)) + val argsIR = args.flatMap(visitExprOnly).toSeq + (Some(UninterpretedFunction(name + "$" + size, argsIR, BoolType)), None) case "FPAdd.0" | "FPMul.0" | "FPDiv.0" | "FPMulX.0" | "FPMax.0" | "FPMin.0" | "FPMaxNum.0" | "FPMinNum.0" | "FPSub.0" => checkArgs(function, 1, 3, typeArgs.size, args.size, ctx.getText) val name = function.stripSuffix(".0") val size = parseInt(typeArgs(0)).toInt - val argsIR = args.flatMap(visitExpr).toSeq - Some(UninterpretedFunction(name + "$" + size, argsIR, BitVecType(size))) + val argsIR = args.flatMap(visitExprOnly).toSeq + (Some(UninterpretedFunction(name + "$" + size, argsIR, BitVecType(size))), None) case "FPMulAddH.0" | "FPMulAdd.0" | "FPRoundInt.0" | @@ -363,31 +400,31 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] checkArgs(function, 1, 4, typeArgs.size, args.size, ctx.getText) val name = function.stripSuffix(".0") val size = parseInt(typeArgs(0)).toInt - val argsIR = args.flatMap(visitExpr).toSeq - Some(UninterpretedFunction(name + "$" + size, argsIR, BitVecType(size))) + val argsIR = args.flatMap(visitExprOnly).toSeq + (Some(UninterpretedFunction(name + "$" + size, argsIR, BitVecType(size))), None) case "FPRecpX.0" | "FPSqrt.0" | "FPRecipEstimate.0" | "FPRSqrtStepFused.0" | "FPRecipStepFused.0" => checkArgs(function, 1, 2, typeArgs.size, args.size, ctx.getText) val name = function.stripSuffix(".0") val size = parseInt(typeArgs(0)).toInt - val argsIR = args.flatMap(visitExpr).toSeq - Some(UninterpretedFunction(name + "$" + size, argsIR, BitVecType(size))) + val argsIR = args.flatMap(visitExprOnly).toSeq + (Some(UninterpretedFunction(name + "$" + size, argsIR, BitVecType(size))), None) case "FPCompare.0" => checkArgs(function, 1, 4, typeArgs.size, args.size, ctx.getText) val name = function.stripSuffix(".0") val size = parseInt(typeArgs(0)) - val argsIR = args.flatMap(visitExpr).toSeq - Some(UninterpretedFunction(name + "$" + size, argsIR, BitVecType(4))) + val argsIR = args.flatMap(visitExprOnly).toSeq + (Some(UninterpretedFunction(name + "$" + size, argsIR, BitVecType(4))), None) case "FPConvert.0" => checkArgs(function, 2, 3, typeArgs.size, args.size, ctx.getText) val name = function.stripSuffix(".0") val outSize = parseInt(typeArgs(0)).toInt val inSize = parseInt(typeArgs(1)) - val argsIR = args.flatMap(visitExpr).toSeq - Some(UninterpretedFunction(name + "$" + outSize + "$" + inSize, argsIR, BitVecType(outSize))) + val argsIR = args.flatMap(visitExprOnly).toSeq + (Some(UninterpretedFunction(name + "$" + outSize + "$" + inSize, argsIR, BitVecType(outSize))), None) case "FPToFixed.0" => checkArgs(function, 2, 5, typeArgs.size, args.size, ctx.getText) @@ -395,8 +432,8 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] val outSize = parseInt(typeArgs(0)).toInt val inSize = parseInt(typeArgs(1)) // need to specifically handle the integer parameter - val argsIR = args.flatMap(visitExpr).toSeq - Some(UninterpretedFunction(name + "$" + outSize + "$" + inSize, argsIR, BitVecType(outSize))) + val argsIR = args.flatMap(visitExprOnly).toSeq + (Some(UninterpretedFunction(name + "$" + outSize + "$" + inSize, argsIR, BitVecType(outSize))), None) case "FixedToFP.0" => checkArgs(function, 2, 5, typeArgs.size, args.size, ctx.getText) @@ -404,28 +441,28 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] val inSize = parseInt(typeArgs(0)) val outSize = parseInt(typeArgs(1)).toInt // need to specifically handle the integer parameter - val argsIR = args.flatMap(visitExpr).toSeq - Some(UninterpretedFunction(name + "$" + outSize + "$" + inSize, argsIR, BitVecType(outSize))) + val argsIR = args.flatMap(visitExprOnly).toSeq + (Some(UninterpretedFunction(name + "$" + outSize + "$" + inSize, argsIR, BitVecType(outSize))), None) case "FPConvertBF.0" => checkArgs(function, 0, 3, typeArgs.size, args.size, ctx.getText) val name = function.stripSuffix(".0") - val argsIR = args.flatMap(visitExpr).toSeq - Some(UninterpretedFunction(name, argsIR, BitVecType(32))) + val argsIR = args.flatMap(visitExprOnly).toSeq + (Some(UninterpretedFunction(name, argsIR, BitVecType(32))), None) case "FPToFixedJS_impl.0" => checkArgs(function, 2, 3, typeArgs.size, args.size, ctx.getText) val name = function.stripSuffix(".0") val inSize = parseInt(typeArgs(0)) val outSize = parseInt(typeArgs(1)).toInt - val argsIR = args.flatMap(visitExpr).toSeq - Some(UninterpretedFunction(name + "$" + outSize + "$" + inSize, argsIR, BitVecType(outSize))) + val argsIR = args.flatMap(visitExprOnly).toSeq + (Some(UninterpretedFunction(name + "$" + outSize + "$" + inSize, argsIR, BitVecType(outSize))), None) case "BFAdd.0" | "BFMul.0" => checkArgs(function, 0, 2, typeArgs.size, args.size, ctx.getText) val name = function.stripSuffix(".0") - val argsIR = args.flatMap(visitExpr).toSeq - Some(UninterpretedFunction(name, argsIR, BitVecType(32))) + val argsIR = args.flatMap(visitExprOnly).toSeq + (Some(UninterpretedFunction(name, argsIR, BitVecType(32))), None) case _ => // known ASLp methods not yet handled: @@ -434,7 +471,7 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] // and will require some research into their semantics // AtomicStart, AtomicEnd - can't model as uninterpreted functions, requires modelling atomic section Logger.debug(s"unidentified call to $function: ${ctx.getText}") - None + (None, None) } } @@ -448,8 +485,9 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] ): Option[BinaryExpr] = { checkArgs(function, typeArgsExpected, 2, typeArgs.size, args.size, token) // we don't currently check the size for BV ops which is the type arg - val arg0 = visitExpr(args(0)) - val arg1 = visitExpr(args(1)) + // memory loads shouldn't appear inside binary operations? + val arg0 = visitExprOnly(args(0)) + val arg1 = visitExprOnly(args(1)) if (arg0.isDefined && arg1.isDefined) { Some(BinaryExpr(operator, arg0.get, arg1.get)) } else { @@ -466,7 +504,8 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] ): Option[UnaryExpr] = { checkArgs(function, typeArgsExpected, 1, typeArgs.size, args.size, token) // we don't currently check the size for BV ops which is the type arg - val arg = visitExpr(args.head) + // memory loads shouldn't appear inside unary operations? + val arg = visitExprOnly(args.head) if (arg.isDefined) { Some(UnaryExpr(operator, arg.get)) } else { @@ -483,8 +522,9 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] checkArgs(function, 2, 2, typeArgs.size, args.size, token) val size0 = parseInt(typeArgs(0)) val size1 = parseInt(typeArgs(1)) - val arg0 = visitExpr(args(0)) - val arg1 = visitExpr(args(1)) + val arg0 = visitExprOnly(args(0)) + val arg1 = visitExprOnly(args(1)) + // memory loads shouldn't appear inside bitshifts? if (arg0.isDefined && arg1.isDefined) { if (size0 == size1) { Some(BinaryExpr(operator, arg0.get, arg1.get)) @@ -496,18 +536,18 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] } } - private def visitExprSlices(ctx: ExprSlicesContext): Option[Extract] = { + private def visitExprSlices(ctx: ExprSlicesContext): (Option[Extract], Option[MemoryLoad]) = { val slices = ctx.slices.slice().asScala if (slices.size != 1) { // need to determine the semantics for this case throw Exception(s"currently unable to handle Expr_Slices that contains more than one slice: ${ctx.getText}") } val (hi, lo) = visitSliceContext(slices.head) - val expr = visitExpr(ctx.expr) + val (expr, load) = visitExpr(ctx.expr) if (expr.isDefined) { - Some(Extract(hi, lo, expr.get)) + (Some(Extract(hi, lo, expr.get)), load) } else { - None + (None, None) } } @@ -524,7 +564,7 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] } } - private def visitExprField(ctx: ExprFieldContext): Register = { + private def visitExprField(ctx: ExprFieldContext): Register = { val name = ctx.expr match { case e: ExprVarContext => visitIdent(e.ident) case _ => throw Exception(s"expected ${ctx.getText} to have an Expr_Var as first parameter") @@ -534,7 +574,7 @@ class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]] resolveFieldExpr(name, field) } - private def visitExprArray(ctx: ExprArrayContext): Register = { + private def visitExprArray(ctx: ExprArrayContext): Register = { val name = ctx.array match { case e: ExprVarContext => visitIdent(e.ident) case _ => throw Exception(s"expected ${ctx.getText} to have an Expr_Var as first parameter") diff --git a/src/main/scala/translating/GTIRBToIR.scala b/src/main/scala/translating/GTIRBToIR.scala index bf8a26ad9..31049706c 100644 --- a/src/main/scala/translating/GTIRBToIR.scala +++ b/src/main/scala/translating/GTIRBToIR.scala @@ -7,7 +7,7 @@ import com.grammatech.gtirb.proto.CFG.Edge import com.grammatech.gtirb.proto.CFG.EdgeLabel import com.grammatech.gtirb.proto.Module.Module import com.grammatech.gtirb.proto.Symbol.Symbol -import Parsers.SemanticsParser.* +import Parsers.ASLpParser.* import gtirb.* import ir.* @@ -156,7 +156,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ // maybe good to sort blocks by address around here? - val semanticsLoader = SemanticsLoader(parserMap) + val semanticsLoader = GTIRBLoader(parserMap) for ((functionUUID, blockUUIDs) <- functionBlocks) { val procedure = uuidToProcedure(functionUUID) @@ -228,7 +228,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ private def removePCAssign(block: Block): Option[String] = { block.statements.last match { - case last @ Assign(lhs: Register, _, _) if lhs.name == "_PC" => + case last @ LocalAssign(lhs: Register, _, _) if lhs.name == "_PC" => val label = last.label block.statements.remove(last) label @@ -238,7 +238,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ private def getPCTarget(block: Block): Register = { block.statements.last match { - case Assign(lhs: Register, rhs: Register, _) if lhs.name == "_PC" => rhs + case LocalAssign(lhs: Register, rhs: Register, _) if lhs.name == "_PC" => rhs case _ => throw Exception(s"expected block ${block.label} to have a program counter assignment at its end") } } @@ -373,8 +373,8 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ // need to copy jump as it can't have multiple parents val jumpCopy = currentBlock.jump match { case GoTo(targets, label) => GoTo(targets, label) - case h: Unreachable => Unreachable() - case r: Return => Return() + case Unreachable(label) => Unreachable(label) + case Return(label) => Return(label) case _ => throw Exception("this shouldn't be reachable") } trueBlock.replaceJump(currentBlock.jump) @@ -397,7 +397,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ if (proxySymbols.isEmpty) { // indirect call with no further information val target = block.statements.last match { - case Assign(lhs: Register, rhs: Register, _) if lhs.name == "_PC" => rhs + case LocalAssign(lhs: Register, rhs: Register, _) if lhs.name == "_PC" => rhs case _ => throw Exception(s"no assignment to program counter found before indirect call in block ${block.label}") } val label = block.statements.last.label diff --git a/src/main/scala/translating/ILtoIL.scala b/src/main/scala/translating/ILtoIL.scala index 856b18934..9e17aee9e 100644 --- a/src/main/scala/translating/ILtoIL.scala +++ b/src/main/scala/translating/ILtoIL.scala @@ -1,5 +1,5 @@ package translating -import ir._ +import ir.* private class ILSerialiser extends ReadOnlyVisitor { var program: StringBuilder = StringBuilder() @@ -32,7 +32,7 @@ private class ILSerialiser extends ReadOnlyVisitor { override def visitStatement(node: Statement): Statement = node.acceptVisit(this) - override def visitAssign(node: Assign): Statement = { + override def visitLocalAssign(node: LocalAssign): Statement = { program ++= "LocalAssign(" visitVariable(node.lhs) program ++= " := " @@ -41,8 +41,8 @@ private class ILSerialiser extends ReadOnlyVisitor { node } - override def visitMemoryAssign(node: MemoryAssign): Statement = { - program ++= "MemoryAssign(" + override def visitMemoryStore(node: MemoryStore): Statement = { + program ++= "MemoryStore(" visitMemory(node.mem) program ++= "[" visitExpr(node.index) @@ -53,6 +53,17 @@ private class ILSerialiser extends ReadOnlyVisitor { node } + override def visitMemoryLoad(node: MemoryLoad): Statement = { + program ++= "MemoryLoad(" + visitVariable(node.lhs) + program ++= " := " + visitMemory(node.mem) + program ++= ", [" + visitExpr(node.index) + program ++= "])" + node + } + override def visitAssert(node: Assert): Statement = { program ++= "Assert(" visitExpr(node.body) @@ -63,14 +74,13 @@ private class ILSerialiser extends ReadOnlyVisitor { override def visitJump(node: Jump): Jump = { node match { case j: GoTo => program ++= s"goTo(${j.targets.map(_.label).mkString(", ")})" - case h: Unreachable => program ++= "halt" - case h: Return => program ++= "return" + case _: Unreachable => program ++= "halt" + case _: Return => program ++= "return" } node } - override def visitGoTo(node: GoTo): GoTo = { program ++= "GoTo(" program ++= node.targets.map(blockIdentifier).mkString(", ") @@ -78,7 +88,6 @@ private class ILSerialiser extends ReadOnlyVisitor { node } - override def visitDirectCall(node: DirectCall): Statement = { program ++= "DirectCall(" program ++= procedureIdentifier(node.target) @@ -213,15 +222,6 @@ private class ILSerialiser extends ReadOnlyVisitor { node } - override def visitMemoryLoad(node: MemoryLoad): Expr = { - program ++= "MemoryLoad(" - visitMemory(node.mem) - program ++= ", [" - visitExpr(node.index) - program ++= "])" - node - } - override def visitMemory(node: Memory): Memory = { program ++= "Memory(" program ++= s"\"${node.name}\", ${node.addressSize}, ${node.valueSize})" diff --git a/src/main/scala/translating/IRToBoogie.scala b/src/main/scala/translating/IRToBoogie.scala index 1c917d038..ec9049334 100644 --- a/src/main/scala/translating/IRToBoogie.scala +++ b/src/main/scala/translating/IRToBoogie.scala @@ -733,15 +733,15 @@ class IRToBoogie(var program: Program, var spec: Specification, var thread: Opti def translate(s: Statement): List[BCmd] = s match { case d: Call => translate(d) case _: NOP => List.empty - case m: MemoryAssign => + case m: MemoryStore => val lhs = m.mem.toBoogie val rhs = BMemoryStore(m.mem.toBoogie, m.index.toBoogie, m.value.toBoogie, m.endian, m.size) val lhsGamma = m.mem.toGamma val rhsGamma = GammaStore(m.mem.toGamma, m.index.toBoogie, exprToGamma(m.value), m.size, m.size / m.mem.valueSize) val store = AssignCmd(List(lhs, lhsGamma), List(rhs, rhsGamma)) val stateSplit = s match { - case MemoryAssign(_, _, _, _, _, Some(label)) => List(captureStateStatement(s"$label")) - case Assign(_, _, Some(label)) => List(captureStateStatement(s"$label")) + case MemoryStore(_, _, _, _, _, Some(label)) => List(captureStateStatement(s"$label")) + case LocalAssign(_, _, Some(label)) => List(captureStateStatement(s"$label")) case _ => List.empty } m.mem match { @@ -801,22 +801,29 @@ class IRToBoogie(var program: Program, var spec: Specification, var thread: Opti } (List(rely, gammaValueCheck) ++ oldAssigns ++ oldGammaAssigns :+ store) ++ secureUpdate ++ guaranteeChecks ++ stateSplit } - case l: Assign => + case l: LocalAssign => val lhs = l.lhs.toBoogie val rhs = l.rhs.toBoogie val lhsGamma = l.lhs.toGamma val rhsGamma = exprToGamma(l.rhs) - val assign = AssignCmd(List(lhs, lhsGamma), List(rhs, rhsGamma)) - val loads = l.rhs.loads - if (loads.size > 1) { - throw Exception(s"$l contains multiple loads") + List(AssignCmd(List(lhs, lhsGamma), List(rhs, rhsGamma))) + case m: MemoryLoad => + val lhs = m.lhs.toBoogie + val lhsGamma = m.lhs.toGamma + val rhs = BMemoryLoad(m.mem.toBoogie, m.index.toBoogie, m.endian, m.size) + val rhsGamma = m.mem match { + case s: StackMemory => + GammaLoad(s.toGamma, m.index.toBoogie, m.size, m.size / s.valueSize) + case s: SharedMemory => + val boogieIndex = m.index.toBoogie + BinaryBExpr(BoolOR, GammaLoad(s.toGamma, boogieIndex, m.size, m.size / s.valueSize), L(LArgs, boogieIndex)) } - // add rely call if assignment contains a non-stack load - loads.headOption match { - case Some(MemoryLoad(SharedMemory(_, _, _), _, _, _)) => + val assign = AssignCmd(List(lhs, lhsGamma), List(rhs, rhsGamma)) + // add rely call if is a non-stack load + m.mem match { + case _: SharedMemory => List(BProcedureCall("rely"), assign) case _ => - // load is a stack load or doesn't exist List(assign) } case a: Assert => @@ -828,7 +835,7 @@ class IRToBoogie(var program: Program, var spec: Specification, var thread: Opti } def exprToGamma(e: Expr): BExpr = { - val gammaVars: Set[BExpr] = e.gammas.map(_.toGamma) ++ e.loads.map(_.toGamma(LArgs)) + val gammaVars: Set[BExpr] = e.gammas.map(_.toGamma) if (gammaVars.isEmpty) { TrueBLiteral } else if (gammaVars.size == 1) { diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index 0dec78187..24b5c9741 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -17,7 +17,7 @@ import ir.* import boogie.* import specification.* import Parsers.* -import Parsers.SemanticsParser.* +import Parsers.ASLpParser.* import analysis.data_structure_analysis.{DataStructureAnalysis, Graph, SymbolicAddress, SymbolicAddressAnalysis} import org.antlr.v4.runtime.tree.ParseTreeWalker import org.antlr.v4.runtime.BailErrorStrategy @@ -52,21 +52,21 @@ case class IRContext( /** Stores the results of the static analyses. */ case class StaticAnalysisContext( - constPropResult: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], - IRconstPropResult: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], - memoryRegionResult: Map[CFGPosition, Set[StackRegion]], - vsaResult: Map[CFGPosition, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]], - interLiveVarsResults: Map[CFGPosition, Map[Variable, TwoElement]], - paramResults: Map[Procedure, Set[Variable]], - steensgaardResults: Map[RegisterWrapperEqualSets, Set[RegisterWrapperEqualSets | MemoryRegion]], - mmmResults: MemoryModelMap, - reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], - varDepsSummaries: Map[Procedure, Map[Taintable, Set[Taintable]]], - regionInjector: Option[RegionInjector], - symbolicAddresses: Map[CFGPosition, Map[SymbolicAddress, TwoElement]], - localDSA: Map[Procedure, Graph], - bottomUpDSA: Map[Procedure, Graph], - topDownDSA: Map[Procedure, Graph] + intraProcConstProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], + interProcConstProp: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]], + memoryRegionResult: Map[CFGPosition, Set[StackRegion]], + vsaResult: Map[CFGPosition, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]], + interLiveVarsResults: Map[CFGPosition, Map[Variable, TwoElement]], + paramResults: Map[Procedure, Set[Variable]], + steensgaardResults: Map[RegisterWrapperEqualSets, Set[RegisterWrapperEqualSets | MemoryRegion]], + mmmResults: MemoryModelMap, + reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], + varDepsSummaries: Map[Procedure, Map[Taintable, Set[Taintable]]], + regionInjector: Option[RegionInjector], + symbolicAddresses: Map[CFGPosition, Map[SymbolicAddress, TwoElement]], + localDSA: Map[Procedure, Graph], + bottomUpDSA: Map[Procedure, Graph], + topDownDSA: Map[Procedure, Graph] ) /** Results of the main program execution. @@ -123,9 +123,9 @@ object IRLoading { val semantics = mods.map(_.auxData("ast").data.toStringUtf8.parseJson.convertTo[Map[String, Array[Array[String]]]]) def parse_insn(line: String): StmtContext = { - val semanticsLexer = SemanticsLexer(CharStreams.fromString(line)) - val tokens = CommonTokenStream(semanticsLexer) - val parser = SemanticsParser(tokens) + val lexer = ASLpLexer(CharStreams.fromString(line)) + val tokens = CommonTokenStream(lexer) + val parser = ASLpParser(tokens) parser.setErrorHandler(BailErrorStrategy()) parser.setBuildParseTree(true) @@ -342,40 +342,40 @@ object StaticAnalysis { val RNASolver = RNAAnalysisSolver(IRProgram) val RNAResult = RNASolver.analyze() - Logger.debug("[!] Running Constant Propagation") - val constPropSolver = ConstantPropagationSolver(IRProgram) - val constPropResult: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]] = constPropSolver.analyze() + Logger.debug("[!] Running Inter-procedural Constant Propagation") + val interProcConstProp = InterProcConstantPropagation(IRProgram) + val interProcConstPropResult: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]] = interProcConstProp.analyze() - config.analysisResultsPath.foreach(s => - writeToFile(printAnalysisResults(IRProgram, constPropResult), s"${s}OGconstprop$iteration.txt") - ) + config.analysisResultsPath.foreach { s => + writeToFile(printAnalysisResults(IRProgram, interProcConstPropResult), s"${s}OGconstprop$iteration.txt") + } Logger.debug("[!] Variable dependency summaries") val scc = stronglyConnectedComponents(CallGraph, List(IRProgram.mainProcedure)) val specGlobalAddresses = ctx.specification.globals.map(s => s.address -> s.name).toMap - val varDepsSummaries = VariableDependencyAnalysis(IRProgram, ctx.specification.globals, specGlobalAddresses, constPropResult, scc).analyze() + val varDepsSummaries = VariableDependencyAnalysis(IRProgram, ctx.specification.globals, specGlobalAddresses, interProcConstPropResult, scc).analyze() - val ilcpsolver = IRSimpleValueAnalysis.Solver(IRProgram) - val newCPResult: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]] = ilcpsolver.analyze() + val intraProcConstProp = IntraProcConstantPropagation(IRProgram) + val intraProcConstPropResult: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]] = intraProcConstProp.analyze() - config.analysisResultsPath.foreach(s => - writeToFile(printAnalysisResults(IRProgram, newCPResult), s"${s}_new_ir_constprop$iteration.txt") - ) + config.analysisResultsPath.foreach { s => + writeToFile(printAnalysisResults(IRProgram, intraProcConstPropResult), s"${s}_new_ir_constprop$iteration.txt") + } - config.analysisDotPath.foreach(f => { + config.analysisDotPath.foreach { f => val dumpdomain = computeDomain[CFGPosition, CFGPosition](InterProcIRCursor, IRProgram.procedures) writeToFile(toDot(dumpdomain, InterProcIRCursor, Map.empty), s"${f}_new_ir_intercfg$iteration.dot") - }) + } val reachingDefinitionsAnalysisSolver = InterprocReachingDefinitionsAnalysisSolver(IRProgram) val reachingDefinitionsAnalysisResults = reachingDefinitionsAnalysisSolver.analyze() - config.analysisDotPath.foreach(s => { + config.analysisDotPath.foreach { s => writeToFile( toDot(IRProgram, IRProgram.filter(_.isInstanceOf[Command]).map(b => b -> reachingDefinitionsAnalysisResults(b).toString).toMap, true), s"${s}_reachingDefinitions$iteration.dot" ) - }) + } val mmm = MemoryModelMap(globalOffsets) mmm.preLoadGlobals(mergedSubroutines, globalAddresses, globalSizes) @@ -387,14 +387,14 @@ object StaticAnalysis { } Logger.debug("[!] Running GRA") - val graSolver = GlobalRegionAnalysisSolver(IRProgram, domain.toSet, constPropResult, reachingDefinitionsAnalysisResults, mmm, previousVSAResults) + val graSolver = GlobalRegionAnalysisSolver(IRProgram, domain.toSet, interProcConstPropResult, reachingDefinitionsAnalysisResults, mmm, previousVSAResults) val graResult = graSolver.analyze() Logger.debug("[!] Running MRA") - val mraSolver = MemoryRegionAnalysisSolver(IRProgram, domain.toSet, globalAddresses, globalOffsets, mergedSubroutines, constPropResult, ANRResult, RNAResult, reachingDefinitionsAnalysisResults, graResult, mmm) + val mraSolver = MemoryRegionAnalysisSolver(IRProgram, domain.toSet, globalAddresses, globalOffsets, mergedSubroutines, interProcConstPropResult, ANRResult, RNAResult, reachingDefinitionsAnalysisResults, graResult, mmm) val mraResult = mraSolver.analyze() - config.analysisDotPath.foreach(s => { + config.analysisDotPath.foreach { s => writeToFile(dotCallGraph(IRProgram), s"${s}_callgraph$iteration.dot") writeToFile( dotBlockGraph(IRProgram, IRProgram.filter(_.isInstanceOf[Block]).map(b => b -> b.toString).toMap), @@ -402,7 +402,7 @@ object StaticAnalysis { ) writeToFile( - toDot(IRProgram, IRProgram.filter(_.isInstanceOf[Command]).map(b => b -> newCPResult(b).toString).toMap), + toDot(IRProgram, IRProgram.filter(_.isInstanceOf[Command]).map(b => b -> intraProcConstPropResult(b).toString).toMap), s"${s}_new_ir_constprop$iteration.dot" ) @@ -415,7 +415,7 @@ object StaticAnalysis { toDot(IRProgram, IRProgram.filter(_.isInstanceOf[Command]).map(b => b -> graResult(b).toString).toMap), s"${s}_GRA$iteration.dot" ) - }) + } Logger.debug("[!] Running MMM") mmm.convertMemoryRegions(mraSolver.procedureToStackRegions, mraSolver.procedureToHeapRegions, mraResult, mraSolver.procedureToSharedRegions, graSolver.getDataMap, graResult) @@ -427,15 +427,15 @@ object StaticAnalysis { val steensgaardResults = steensgaardSolver.pointsTo() Logger.debug("[!] Running VSA") - val vsaSolver = ValueSetAnalysisSolver(IRProgram, mmm, constPropResult) + val vsaSolver = ValueSetAnalysisSolver(IRProgram, mmm, interProcConstPropResult) val vsaResult: Map[CFGPosition, LiftedElement[Map[Variable | MemoryRegion, Set[Value]]]] = vsaSolver.analyze() - config.analysisDotPath.foreach(s => { + config.analysisDotPath.foreach { s => writeToFile( toDot(IRProgram, IRProgram.filter(_.isInstanceOf[Command]).map(b => b -> vsaResult(b).toString).toMap), s"${s}_VSA$iteration.dot" ) - }) + } Logger.debug("[!] Injecting regions") val regionInjector = if (config.memoryRegions) { @@ -450,8 +450,8 @@ object StaticAnalysis { val interLiveVarsResults: Map[CFGPosition, Map[Variable, TwoElement]] = InterLiveVarsAnalysis(IRProgram).analyze() StaticAnalysisContext( - constPropResult = constPropResult, - IRconstPropResult = newCPResult, + intraProcConstProp = interProcConstPropResult, + interProcConstProp = intraProcConstPropResult, memoryRegionResult = mraResult, vsaResult = vsaResult, interLiveVarsResults = interLiveVarsResults, @@ -595,7 +595,7 @@ object RunUtils { Logger.debug("[!] Generating Procedure Summaries") if (config.summariseProcedures) { - IRTransform.generateProcedureSummaries(ctx, ctx.program, result.constPropResult, result.varDepsSummaries) + IRTransform.generateProcedureSummaries(ctx, ctx.program, result.intraProcConstProp, result.varDepsSummaries) } if (modified) { @@ -619,7 +619,7 @@ object RunUtils { Logger.debug("[!] Running Symbolic Access Analysis") val symResults: Map[CFGPosition, Map[SymbolicAddress, TwoElement]] = - SymbolicAddressAnalysis(ctx.program, analysisResult.last.IRconstPropResult).analyze() + SymbolicAddressAnalysis(ctx.program, analysisResult.last.interProcConstProp).analyze() config.analysisDotPath.foreach { s => val labels = symResults.map { (k, v) => k -> v.toString } writeToFile(toDot(ctx.program, labels), s"${s}_saa.dot") @@ -627,7 +627,7 @@ object RunUtils { Logger.debug("[!] Running DSA Analysis") val symbolTableEntries: Set[SymbolTableEntry] = ctx.globals ++ ctx.funcEntries - val dsa = DataStructureAnalysis(ctx.program, symResults, analysisResult.last.IRconstPropResult, symbolTableEntries, ctx.globalOffsets, ctx.externalFunctions, reachingDefs, writesTo, analysisResult.last.paramResults) + val dsa = DataStructureAnalysis(ctx.program, symResults, analysisResult.last.interProcConstProp, symbolTableEntries, ctx.globalOffsets, ctx.externalFunctions, reachingDefs, writesTo, analysisResult.last.paramResults) dsa.analyze() config.analysisDotPath.foreach { s => diff --git a/src/test/scala/DataStructureAnalysisTest.scala b/src/test/scala/DataStructureAnalysisTest.scala index bdedf8cfe..80d848821 100644 --- a/src/test/scala/DataStructureAnalysisTest.scala +++ b/src/test/scala/DataStructureAnalysisTest.scala @@ -246,15 +246,15 @@ class DataStructureAnalysisTest extends AnyFunSuite { test("internal merge") { // this is an internal merge (two cells of the same node overlap and are merged together) val mem = SharedMemory("mem", 64, 8) - val locAssign1 = Assign(R6, BinaryExpr(BVADD, R0, BitVecLiteral(4, 64)), Some("00001")) - val locAssign2 = Assign(R7, BinaryExpr(BVADD, R0, BitVecLiteral(5, 64)), Some("00002")) + val locAssign1 = LocalAssign(R6, BinaryExpr(BVADD, R0, BitVecLiteral(4, 64)), Some("00001")) + val locAssign2 = LocalAssign(R7, BinaryExpr(BVADD, R0, BitVecLiteral(5, 64)), Some("00002")) val program = prog( proc("main", block("operations", locAssign1, // R6 = R0 + 4 locAssign2, // R7 = R0 + 5 - MemoryAssign(mem, R7, R1, Endian.BigEndian, 64, Some("00003")), // *R7 = R1, (*R6 + 1) = R1 - MemoryAssign(mem, R6, R2, Endian.BigEndian, 64, Some("00004")), // *R6 = R2 + MemoryStore(mem, R7, R1, Endian.BigEndian, 64, Some("00003")), // *R7 = R1, (*R6 + 1) = R1 + MemoryStore(mem, R6, R2, Endian.BigEndian, 64, Some("00004")), // *R6 = R2 ret ) ) @@ -282,17 +282,17 @@ class DataStructureAnalysisTest extends AnyFunSuite { test("offsetting from middle of cell to a new cell") { val mem = SharedMemory("mem", 64, 8) - val locAssign1 = Assign(R6, BinaryExpr(BVADD, R0, BitVecLiteral(4, 64)), Some("00001")) - val locAssign2 = Assign(R7, BinaryExpr(BVADD, R0, BitVecLiteral(5, 64)), Some("00002")) - val locAssign3 = Assign(R5, BinaryExpr(BVADD, R7, BitVecLiteral(8, 64)), Some("00005")) + val locAssign1 = LocalAssign(R6, BinaryExpr(BVADD, R0, BitVecLiteral(4, 64)), Some("00001")) + val locAssign2 = LocalAssign(R7, BinaryExpr(BVADD, R0, BitVecLiteral(5, 64)), Some("00002")) + val locAssign3 = LocalAssign(R5, BinaryExpr(BVADD, R7, BitVecLiteral(8, 64)), Some("00005")) val program = prog( proc("main", block("operations", locAssign1, // R6 = R0 + 4 locAssign2, // R7 = R0 + 5 - MemoryAssign(mem, R7, R1, Endian.BigEndian, 64, Some("00003")), - MemoryAssign(mem, R6, R2, Endian.BigEndian, 64, Some("00004")), + MemoryStore(mem, R7, R1, Endian.BigEndian, 64, Some("00003")), + MemoryStore(mem, R6, R2, Endian.BigEndian, 64, Some("00004")), locAssign3, // R5 = R7 + 8 ret ) @@ -309,17 +309,17 @@ class DataStructureAnalysisTest extends AnyFunSuite { // similar to above except instead of creating new cell the last assign // points R5's cell at an internal offset of 8 val mem = SharedMemory("mem", 64, 8) - val locAssign1 = Assign(R6, BinaryExpr(BVADD, R0, BitVecLiteral(4, 64)), Some("00001")) - val locAssign2 = Assign(R7, BinaryExpr(BVADD, R0, BitVecLiteral(5, 64)), Some("00002")) - val locAssign3 = Assign(R5, BinaryExpr(BVADD, R7, BitVecLiteral(7, 64)), Some("00005")) + val locAssign1 = LocalAssign(R6, BinaryExpr(BVADD, R0, BitVecLiteral(4, 64)), Some("00001")) + val locAssign2 = LocalAssign(R7, BinaryExpr(BVADD, R0, BitVecLiteral(5, 64)), Some("00002")) + val locAssign3 = LocalAssign(R5, BinaryExpr(BVADD, R7, BitVecLiteral(7, 64)), Some("00005")) val program = prog( proc("main", block("operations", locAssign1, locAssign2, - MemoryAssign(mem, R7, R1, Endian.BigEndian, 64, Some("00003")), - MemoryAssign(mem, R6, R2, Endian.BigEndian, 64, Some("00004")), + MemoryStore(mem, R7, R1, Endian.BigEndian, 64, Some("00003")), + MemoryStore(mem, R6, R2, Endian.BigEndian, 64, Some("00004")), locAssign3, ret ) @@ -341,9 +341,9 @@ class DataStructureAnalysisTest extends AnyFunSuite { test("internal offset transfer") { // this is a test to check assignments transfer internal offset of slices. val mem = SharedMemory("mem", 64, 8) - val locAssign1 = Assign(R6, BinaryExpr(BVADD, R0, BitVecLiteral(4, 64)), Some("00001")) - val locAssign2 = Assign(R7, BinaryExpr(BVADD, R0, BitVecLiteral(5, 64)), Some("00002")) - val locAssign3 = Assign(R5, R7, Some("00005")) + val locAssign1 = LocalAssign(R6, BinaryExpr(BVADD, R0, BitVecLiteral(4, 64)), Some("00001")) + val locAssign2 = LocalAssign(R7, BinaryExpr(BVADD, R0, BitVecLiteral(5, 64)), Some("00002")) + val locAssign3 = LocalAssign(R5, R7, Some("00005")) val program = prog( proc("main", @@ -351,8 +351,8 @@ class DataStructureAnalysisTest extends AnyFunSuite { // Assign(R0, MemoryLoad(mem, R0, BigEndian, 0), Some("00000")), locAssign1, locAssign2, - MemoryAssign(mem, R7, R1, Endian.BigEndian, 64, Some("00003")), - MemoryAssign(mem, R6, R2, Endian.BigEndian, 64, Some("00004")), + MemoryStore(mem, R7, R1, Endian.BigEndian, 64, Some("00003")), + MemoryStore(mem, R6, R2, Endian.BigEndian, 64, Some("00004")), locAssign3, ret ) diff --git a/src/test/scala/LiveVarsAnalysisTests.scala b/src/test/scala/LiveVarsAnalysisTests.scala index 762fd395b..bd7a5ed8e 100644 --- a/src/test/scala/LiveVarsAnalysisTests.scala +++ b/src/test/scala/LiveVarsAnalysisTests.scala @@ -1,6 +1,6 @@ import analysis.{InterLiveVarsAnalysis, TwoElementTop} import ir.dsl.* -import ir.{BitVecLiteral, BitVecType, dsl, Assign, LocalVar, Program, Register, Statement, Variable, transforms, cilvisitor, Procedure} +import ir.{BitVecLiteral, BitVecType, dsl, LocalAssign, LocalVar, Program, Register, Statement, Variable, transforms, cilvisitor, Procedure} import util.{Logger, LogLevel} import org.scalatest.funsuite.AnyFunSuite import test_util.BASILTest @@ -30,10 +30,10 @@ class LiveVarsAnalysisTests extends AnyFunSuite, BASILTest { def differentCalleesBothLive(): Unit = { val constant1 = bv64(1) - val r0ConstantAssign = Assign(R0, constant1, Some("00001")) - val r1ConstantAssign = Assign(R1, constant1, Some("00002")) - val r2r0Assign = Assign(R2, R0, Some("00003")) - val r2r1Assign = Assign(R2, R1, Some("00004")) + val r0ConstantAssign = LocalAssign(R0, constant1, Some("00001")) + val r1ConstantAssign = LocalAssign(R1, constant1, Some("00002")) + val r2r0Assign = LocalAssign(R2, R0, Some("00003")) + val r2r1Assign = LocalAssign(R2, R1, Some("00004")) val program: Program = prog( proc("main", @@ -70,11 +70,11 @@ class LiveVarsAnalysisTests extends AnyFunSuite, BASILTest { def differentCalleesOneAlive(): Unit = { val constant1 = bv64(1) - val r0ConstantAssign = Assign(R0, constant1, Some("00001")) - val r1ConstantAssign = Assign(R1, constant1, Some("00002")) - val r2r0Assign = Assign(R2, R0, Some("00003")) - val r2r1Assign = Assign(R2, R1, Some("00004")) - val r1Reassign = Assign(R1, BitVecLiteral(2, 64), Some("00005")) + val r0ConstantAssign = LocalAssign(R0, constant1, Some("00001")) + val r1ConstantAssign = LocalAssign(R1, constant1, Some("00002")) + val r2r0Assign = LocalAssign(R2, R0, Some("00003")) + val r2r1Assign = LocalAssign(R2, R1, Some("00004")) + val r1Reassign = LocalAssign(R1, BitVecLiteral(2, 64), Some("00005")) val program: Program = prog( proc("main", @@ -108,9 +108,9 @@ class LiveVarsAnalysisTests extends AnyFunSuite, BASILTest { def twoCallers(): Unit = { val constant1 = bv64(1) - val r0ConstantAssign = Assign(R0, constant1, Some("00001")) - val r1Assign = Assign(R0, R1, Some("00002")) - val r2Assign = Assign(R0, R2, Some("00003")) + val r0ConstantAssign = LocalAssign(R0, constant1, Some("00001")) + val r1Assign = LocalAssign(R0, R1, Some("00002")) + val r2Assign = LocalAssign(R0, R2, Some("00003")) val program = prog( proc("main", @@ -129,7 +129,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, BASILTest { createSimpleProc("callee3", Seq(r2Assign)), proc("wrapper1", block("wrapper1_first_call", - Assign(R1, constant1), + LocalAssign(R1, constant1), directCall("callee"), goto("wrapper1_second_call") ), @@ -140,7 +140,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, BASILTest { ), proc("wrapper2", block("wrapper2_first_call", - Assign(R2, constant1), + LocalAssign(R2, constant1), directCall("callee"), goto("wrapper2_second_call") ), block("wrapper2_second_call", @@ -167,11 +167,11 @@ class LiveVarsAnalysisTests extends AnyFunSuite, BASILTest { directCall("killer"), goto("aftercall") ), block("aftercall", - Assign(R0, R1), + LocalAssign(R0, R1), ret ) ), - createSimpleProc("killer", Seq(Assign(R1, bv64(1)))) + createSimpleProc("killer", Seq(LocalAssign(R1, bv64(1)))) ) cilvisitor.visit_prog(transforms.ReplaceReturns(), program) @@ -186,8 +186,8 @@ class LiveVarsAnalysisTests extends AnyFunSuite, BASILTest { } def simpleBranch(): Unit = { - val r1Assign = Assign(R0, R1, Some("00001")) - val r2Assign = Assign(R0, R2, Some("00002")) + val r1Assign = LocalAssign(R0, R1, Some("00001")) + val r2Assign = LocalAssign(R0, R2, Some("00002")) val program : Program = prog( proc( @@ -228,11 +228,11 @@ class LiveVarsAnalysisTests extends AnyFunSuite, BASILTest { proc("main", block( "lmain", - Assign(R0, R1), + LocalAssign(R0, R1), directCall("main"), goto("return") ), block("return", - Assign(R0, R2), + LocalAssign(R0, R2), ret ) ) @@ -251,7 +251,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, BASILTest { val program: Program = prog( proc("main", block("lmain", - Assign(R0, R1), + LocalAssign(R0, R1), goto("recursion", "non-recursion") ), block( @@ -259,7 +259,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, BASILTest { directCall("main"), goto("assign") ), block("assign", - Assign(R0, R2), + LocalAssign(R0, R2), goto("return") ), block( diff --git a/src/test/scala/PointsToTest.scala b/src/test/scala/PointsToTest.scala index d4b148d41..4b267d7e8 100644 --- a/src/test/scala/PointsToTest.scala +++ b/src/test/scala/PointsToTest.scala @@ -30,11 +30,11 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest { var program: Program = prog( proc("main", block("0x0", - Assign(R6, R31), + LocalAssign(R6, R31), goto("0x1") ), block("0x1", - MemoryAssign(mem, BinaryExpr(BVADD, R6, bv64(4)), bv64(10), LittleEndian, 64), + MemoryStore(mem, BinaryExpr(BVADD, R6, bv64(4)), bv64(10), LittleEndian, 64), goto("returntarget") ), block("returntarget", @@ -60,8 +60,8 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest { var program: Program = prog( proc("main", block("0x0", - Assign(R1, MemoryLoad(mem, BinaryExpr(BVADD, R31, bv64(6)), LittleEndian, 64)), - Assign(R3, MemoryLoad(mem, BinaryExpr(BVADD, R31, bv64(4)), LittleEndian, 64)), + MemoryLoad(R1, mem, BinaryExpr(BVADD, R31, bv64(6)), LittleEndian, 64), + MemoryLoad(R3, mem, BinaryExpr(BVADD, R31, bv64(4)), LittleEndian, 64), goto("0x1") ), block("0x1", @@ -135,8 +135,8 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest { val program: Program = prog( proc("main", block("0x0", - Assign(R0, MemoryLoad(mem, BinaryExpr(BVADD, R31, bv64(6)), LittleEndian, 64)), - Assign(R1, BinaryExpr(BVADD, R31, bv64(10))), + MemoryLoad(R0, mem, BinaryExpr(BVADD, R31, bv64(6)), LittleEndian, 64), + LocalAssign(R1, BinaryExpr(BVADD, R31, bv64(10))), goto("0x1") ), block("0x1", @@ -148,8 +148,8 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest { ), proc("p2", block("l_p2", - Assign(R3, R0), - Assign(R2, MemoryLoad(mem, R1, LittleEndian, 64)), + LocalAssign(R3, R0), + MemoryLoad(R2, mem, R1, LittleEndian, 64), goto("l_p2_1"), ), block("l_p2_1", @@ -184,8 +184,8 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest { val program: Program = prog( proc("main", block("0x0", - Assign(R0, MemoryLoad(mem, BinaryExpr(BVADD, R31, bv64(6)), LittleEndian, 64)), - Assign(R1, BinaryExpr(BVADD, R31, bv64(10))), + MemoryLoad(R0, mem, BinaryExpr(BVADD, R31, bv64(6)), LittleEndian, 64), + LocalAssign(R1, BinaryExpr(BVADD, R31, bv64(10))), goto("0x1") ), block("0x1", @@ -197,8 +197,8 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest { ), proc("foo", block("l_foo", - Assign(R0, MemoryLoad(mem, BinaryExpr(BVADD, R31, bv64(6)), LittleEndian, 64)), - Assign(R1, BinaryExpr(BVADD, R31, bv64(10))), + MemoryLoad(R0, mem, BinaryExpr(BVADD, R31, bv64(6)), LittleEndian, 64), + LocalAssign(R1, BinaryExpr(BVADD, R31, bv64(10))), directCall("p2"), goto("l_foo_1") ), block("l_foo_1", @@ -207,8 +207,8 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest { ), proc("p2", block("l_p2", - Assign(R3, R0), - Assign(R2, MemoryLoad(mem, R1, LittleEndian, 64)), + LocalAssign(R3, R0), + MemoryLoad(R2, mem, R1, LittleEndian, 64), goto("l_p2_1"), ), block("l_p2_1", diff --git a/src/test/scala/TaintAnalysisTests.scala b/src/test/scala/TaintAnalysisTests.scala index 35e7da1ab..2ffeb3215 100644 --- a/src/test/scala/TaintAnalysisTests.scala +++ b/src/test/scala/TaintAnalysisTests.scala @@ -7,12 +7,12 @@ import test_util.BASILTest class TaintAnalysisTests extends AnyFunSuite, BASILTest { def getTaintAnalysisResults(program: Program, taint: Map[CFGPosition, Set[Taintable]]): Map[CFGPosition, Set[Taintable]] = { - val constPropResults = ConstantPropagationSolver(program).analyze() + val constPropResults = InterProcConstantPropagation(program).analyze() TaintAnalysis(program, Map(), constPropResults, taint).analyze().map { (c, m) => (c, m.map { (v, _) => v }.toSet)} } def getVarDepResults(program: Program, procedure: Procedure): Map[CFGPosition, Map[Taintable, Set[Taintable]]] = { - val constPropResults = ConstantPropagationSolver(program).analyze() + val constPropResults = InterProcConstantPropagation(program).analyze() val variables = registers ProcVariableDependencyAnalysis(program, variables, Map(), constPropResults, Map(), procedure).analyze() } @@ -31,7 +31,7 @@ class TaintAnalysisTests extends AnyFunSuite, BASILTest { ), proc("f", block("assign", - Assign(R0, bv64(2), None), + LocalAssign(R0, bv64(2), None), goto("returnBlock"), ), block("returnBlock", @@ -65,7 +65,7 @@ class TaintAnalysisTests extends AnyFunSuite, BASILTest { ), proc("f", block("assign", - Assign(R0, BinaryExpr(BVADD, R0, R1), None), + LocalAssign(R0, BinaryExpr(BVADD, R0, R1), None), goto("returnBlock"), ), block("returnBlock", @@ -102,11 +102,11 @@ class TaintAnalysisTests extends AnyFunSuite, BASILTest { goto("a", "b"), ), block("a", - Assign(R0, R1, None), + LocalAssign(R0, R1, None), goto("returnBlock"), ), block("b", - Assign(R0, R2, None), + LocalAssign(R0, R2, None), goto("returnBlock"), ), block("returnBlock", @@ -143,12 +143,12 @@ class TaintAnalysisTests extends AnyFunSuite, BASILTest { goto("a", "b"), ), block("a", - Assign(R1, R1, None), + LocalAssign(R1, R1, None), directCall("g"), goto("returnBlock"), ), block("b", - Assign(R1, R2, None), + LocalAssign(R1, R2, None), directCall("g"), goto("returnBlock"), ), @@ -158,7 +158,7 @@ class TaintAnalysisTests extends AnyFunSuite, BASILTest { ), proc("g", block("body", - Assign(R0, R1, None), + LocalAssign(R0, R1, None), goto("returnBlock"), ), block("returnBlock", @@ -195,11 +195,11 @@ class TaintAnalysisTests extends AnyFunSuite, BASILTest { goto("a", "b"), ), block("a", - Assign(R0, BinaryExpr(BVADD, R0, R1), None), + LocalAssign(R0, BinaryExpr(BVADD, R0, R1), None), goto("branch"), ), block("b", - Assign(R0, R2, None), + LocalAssign(R0, R2, None), goto("returnBlock"), ), block("returnBlock", diff --git a/src/test/scala/ir/CILVisitorTest.scala b/src/test/scala/ir/CILVisitorTest.scala index f06528cbf..49eae1cb5 100644 --- a/src/test/scala/ir/CILVisitorTest.scala +++ b/src/test/scala/ir/CILVisitorTest.scala @@ -40,7 +40,7 @@ class AddGammas extends CILVisitor { override def vstmt(s: Statement) = { s match { - case a: Assign => ChangeTo(List(a, Assign(gamma_v(a.lhs), gamma_e(a.rhs)))) + case a: LocalAssign => ChangeTo(List(a, LocalAssign(gamma_v(a.lhs), gamma_e(a.rhs)))) case _ => SkipChildren() } @@ -82,10 +82,10 @@ class CILVisitorTest extends AnyFunSuite { val program: Program = prog( proc( "main", - block("0x0", Assign(getRegister("R6"), getRegister("R31")), goto("0x1")), + block("0x0", LocalAssign(getRegister("R6"), getRegister("R31")), goto("0x1")), block( "0x1", - MemoryAssign(mem, BinaryExpr(BVADD, getRegister("R6"), bv64(4)), bv64(10), Endian.LittleEndian, 64), + MemoryStore(mem, BinaryExpr(BVADD, getRegister("R6"), bv64(4)), bv64(10), Endian.LittleEndian, 64), goto("returntarget") ), block("returntarget", ret) @@ -123,10 +123,10 @@ class CILVisitorTest extends AnyFunSuite { val program: Program = prog( proc( "main", - block("0x0", Assign(getRegister("R6"), getRegister("R31")), goto("0x1")), + block("0x0", LocalAssign(getRegister("R6"), getRegister("R31")), goto("0x1")), block( "0x1", - MemoryAssign(mem, BinaryExpr(BVADD, getRegister("R6"), bv64(4)), bv64(10), Endian.LittleEndian, 64), + MemoryStore(mem, BinaryExpr(BVADD, getRegister("R6"), bv64(4)), bv64(10), Endian.LittleEndian, 64), goto("returntarget") ), block("returntarget", ret) diff --git a/src/test/scala/ir/IRTest.scala b/src/test/scala/ir/IRTest.scala index a1592f9eb..f156a83c0 100644 --- a/src/test/scala/ir/IRTest.scala +++ b/src/test/scala/ir/IRTest.scala @@ -95,12 +95,12 @@ class IRTest extends AnyFunSuite { val p = prog( proc("main", block("l_main", - Assign(R0, bv64(10)), - Assign(R1, bv64(10)), + LocalAssign(R0, bv64(10)), + LocalAssign(R1, bv64(10)), goto("newblock") ), block("l_main_1", - Assign(R0, bv64(22)), + LocalAssign(R0, bv64(22)), directCall("p2"), goto("returntarget") ), @@ -109,7 +109,7 @@ class IRTest extends AnyFunSuite { ) ), proc("p2", - block("l_p2", Assign(R0, bv64(10)), goto("l_p2_1")), + block("l_p2", LocalAssign(R0, bv64(10)), goto("l_p2_1")), block("l_p2_1", ret) ) ) @@ -154,15 +154,15 @@ class IRTest extends AnyFunSuite { ) val b2 = block("newblock2", - Assign(R0, bv64(22)), - Assign(R0, bv64(22)), - Assign(R0, bv64(22)), + LocalAssign(R0, bv64(22)), + LocalAssign(R0, bv64(22)), + LocalAssign(R0, bv64(22)), goto("lmain2") ).resolve(p) val b1 = block("newblock1", - Assign(R0, bv64(22)), - Assign(R0, bv64(22)), - Assign(R0, bv64(22)), + LocalAssign(R0, bv64(22)), + LocalAssign(R0, bv64(22)), + LocalAssign(R0, bv64(22)), goto("lmain2") ).resolve(p) @@ -191,16 +191,16 @@ class IRTest extends AnyFunSuite { ) val b1 = block("newblock2", - Assign(R0, bv64(22)), - Assign(R0, bv64(22)), - Assign(R0, bv64(22)), + LocalAssign(R0, bv64(22)), + LocalAssign(R0, bv64(22)), + LocalAssign(R0, bv64(22)), directCall("main"), unreachable ).resolve(p) val b2 = block("newblock1", - Assign(R0, bv64(22)), - Assign(R0, bv64(22)), - Assign(R0, bv64(22)), + LocalAssign(R0, bv64(22)), + LocalAssign(R0, bv64(22)), + LocalAssign(R0, bv64(22)), ret ).resolve(p) @@ -218,7 +218,7 @@ class IRTest extends AnyFunSuite { assert(called.incomingCalls().isEmpty) val b3 = block("newblock3", - Assign(R0, bv64(22)), + LocalAssign(R0, bv64(22)), directCall("called"), unreachable ).resolve(p) @@ -254,8 +254,8 @@ class IRTest extends AnyFunSuite { val p = prog( proc("main", block("l_main", - Assign(R0, bv64(10)), - Assign(R1, bv64(10)), + LocalAssign(R0, bv64(10)), + LocalAssign(R1, bv64(10)), goto("returntarget") ), block("returntarget", @@ -285,13 +285,13 @@ class IRTest extends AnyFunSuite { val p = prog( proc("p1", block("b1", - Assign(R0, bv64(10)), + LocalAssign(R0, bv64(10)), ret ) ), proc("main", block("l_main", - Assign(R0, bv64(10)), + LocalAssign(R0, bv64(10)), directCall("p1"), goto("returntarget") ), block("returntarget", diff --git a/src/test/scala/ir/SingleCallInvariant.scala b/src/test/scala/ir/SingleCallInvariant.scala index d8efb6fc2..4e0061424 100644 --- a/src/test/scala/ir/SingleCallInvariant.scala +++ b/src/test/scala/ir/SingleCallInvariant.scala @@ -10,13 +10,13 @@ class InvariantTest extends AnyFunSuite { var program: Program = prog( proc("main", block("first_call", - Assign(R0, bv64(10)), - Assign(R1, bv64(10)), + LocalAssign(R0, bv64(10)), + LocalAssign(R1, bv64(10)), directCall("callee1"), ret ), block("second_call", - Assign(R0, bv64(10)), + LocalAssign(R0, bv64(10)), directCall("callee2"), ret ), @@ -35,14 +35,14 @@ class InvariantTest extends AnyFunSuite { var program: Program = prog( proc("main", block("first_call", - Assign(R0, bv64(10)), + LocalAssign(R0, bv64(10)), directCall("callee2"), - Assign(R1, bv64(10)), + LocalAssign(R1, bv64(10)), directCall("callee1"), ret ), block("second_call", - Assign(R0, bv64(10)), + LocalAssign(R0, bv64(10)), ret ), block("returnBlock", @@ -60,13 +60,13 @@ class InvariantTest extends AnyFunSuite { var program: Program = prog( proc("main", block("first_call", - Assign(R0, bv64(10)), - Assign(R1, bv64(10)), + LocalAssign(R0, bv64(10)), + LocalAssign(R1, bv64(10)), ret ), block("second_call", directCall("callee2"), - Assign(R0, bv64(10)), + LocalAssign(R0, bv64(10)), ret ), block("returnBlock",