diff --git a/src/main/scala/analysis/IntraLiveVarsAnalysis.scala b/src/main/scala/analysis/IntraLiveVarsAnalysis.scala index 29f29e00c..f17d91590 100644 --- a/src/main/scala/analysis/IntraLiveVarsAnalysis.scala +++ b/src/main/scala/analysis/IntraLiveVarsAnalysis.scala @@ -11,16 +11,15 @@ abstract class LivenessAnalysis(program: Program) extends Analysis[Any]: n match { case p: Procedure => s case b: Block => s - case Assign(variable, expr, _) => (s - variable) ++ expr.variables - case MemoryAssign(_, index, value, _, _, _) => s ++ index.variables ++ value.variables - case Assume(expr, _, _, _) => s ++ expr.variables - case Assert(expr, _, _) => s ++ expr.variables - case IndirectCall(variable, _) => s + variable - case c: DirectCall => s -- c.outParams.map(_._2) ++ c.actualParams.flatMap(_._2.variables) + case a : Assign => (s - a.lhs) ++ a.rhs.variables + case m: MemoryAssign => s ++ m.index.variables ++ m.value.variables + case a : Assume => s ++ a.body.variables + case a : Assert => s ++ a.body.variables + case i : IndirectCall => s + i.target + case c: DirectCall => (s -- c.outParams.map(_._2)) ++ c.actualParams.flatMap(_._2.variables) case g: GoTo => s case r: Return => s ++ r.outParams.flatMap(_._2.variables) case r: Unreachable => s - case _ => ??? } } diff --git a/src/main/scala/boogie/BExpr.scala b/src/main/scala/boogie/BExpr.scala index c00985361..3d3d00e52 100644 --- a/src/main/scala/boogie/BExpr.scala +++ b/src/main/scala/boogie/BExpr.scala @@ -359,7 +359,7 @@ case class BinaryBExpr(op: BinOp, arg1: BExpr, arg2: BExpr) extends BExpr { } else { throw new Exception(s"bitvector size mismatch: $arg1, $arg2") } - case BVULT | BVULE | BVUGT | BVUGE | BVSLT | BVSLE | BVSGT | BVSGE => + case BVULT | BVULE | BVUGT | BVUGE | BVSLT | BVSLE | BVSGT | BVSGE | BVSADDO => if (bv1.size == bv2.size) { BoolBType } else { diff --git a/src/main/scala/cfg_visualiser/DotTools.scala b/src/main/scala/cfg_visualiser/DotTools.scala index ab754d152..75f7b55ee 100644 --- a/src/main/scala/cfg_visualiser/DotTools.scala +++ b/src/main/scala/cfg_visualiser/DotTools.scala @@ -57,7 +57,7 @@ class DotNode(val id: String, val label: String) extends DotElement { override def toString: String = toDotString def toDotString: String = - s"\"$id\"" + "[label=\"" + wrap(label, 100) + "\", shape=\"box\", fontname=\"Mono\"]" + s"\"$id\"" + "[label=\"" + wrap(label, 100) + "\", shape=\"box\", fontname=\"Mono\", fontsize=\"5\"]" } @@ -140,5 +140,6 @@ class DotGraph(val title: String, val nodes: Iterable[DotNode], val edges: Itera override def toString: String = toDotString - def toDotString: String = "digraph " + title + " {\n" + (nodes ++ edges).foldLeft("")((str, elm) => str + elm.toDotString + "\n") + "}" + val graph = "graph [ fontsize=18 ];" + def toDotString: String = "digraph " + title + " {\n" + graph + "\n" + (nodes ++ edges).foldLeft("")((str, elm) => str + elm.toDotString + "\n") + "}" } diff --git a/src/main/scala/ir/Expr.scala b/src/main/scala/ir/Expr.scala index 3e09a7802..b572a52eb 100644 --- a/src/main/scala/ir/Expr.scala +++ b/src/main/scala/ir/Expr.scala @@ -21,6 +21,8 @@ sealed trait Expr { def gammas: Set[Expr] = Set() def variables: Set[Variable] = Set() def acceptVisit(visitor: Visitor): Expr = throw new Exception("visitor " + visitor + " unimplemented for: " + this) + + lazy val variablesCached = variables } @@ -195,7 +197,7 @@ case class BinaryExpr(op: BinOp, arg1: Expr, arg2: Expr) extends Expr { } else { throw new Exception("bitvector size mismatch") } - case BVULT | BVULE | BVUGT | BVUGE | BVSLT | BVSLE | BVSGT | BVSGE => + case BVULT | BVULE | BVUGT | BVUGE | BVSLT | BVSLE | BVSGT | BVSGE | BVSADDO => if (bv1.size == bv2.size) { BoolType } else { @@ -210,7 +212,7 @@ case class BinaryExpr(op: BinOp, arg1: Expr, arg2: Expr) extends Expr { case IntEQ | IntNEQ | IntLT | IntLE | IntGT | IntGE => BoolType } case _ => - throw new Exception("type mismatch, operator " + op + " type doesn't match args: (" + arg1 + ", " + arg2 + ")") + throw new Exception("type mismatch, operator " + op.getClass.getSimpleName + s" type doesn't match args: (" + arg1 + ", " + arg2 + ")") } private def inSize = arg1.getType match { @@ -255,6 +257,7 @@ sealed trait BVBinOp(op: String) extends BinOp { def opName = op } +case object BVSADDO extends BVBinOp("saddo") case object BVAND extends BVBinOp("and") case object BVOR extends BVBinOp("or") case object BVADD extends BVBinOp("add") diff --git a/src/main/scala/ir/IRCursor.scala b/src/main/scala/ir/IRCursor.scala index 21944d4f1..165e946f1 100644 --- a/src/main/scala/ir/IRCursor.scala +++ b/src/main/scala/ir/IRCursor.scala @@ -289,7 +289,7 @@ def toDot[T <: CFGPosition]( def nodeText(node: CFGPosition): String = { var text = node match { - case s: Block => f"[Block] ${s.label}" + case s: Block => f"[Block] (prec ${s.rpoOrder}) ${s.label}" case s => s.toString } if (labels.contains(node)) { diff --git a/src/main/scala/ir/Statement.scala b/src/main/scala/ir/Statement.scala index cb236a4fc..4fd94fac0 100644 --- a/src/main/scala/ir/Statement.scala +++ b/src/main/scala/ir/Statement.scala @@ -156,7 +156,7 @@ class DirectCall(val target: Procedure, case None => Set() } */ def calls: Set[Procedure] = Set(target) - override def toString: String = s"${labelStr}${outParams.mkString(",")} := DirectCall(${target.name})(${actualParams.mkString(",")})" + override def toString: String = s"${labelStr}${outParams.map(_._2.name).mkString(",")} := DirectCall(${target.name})(${actualParams.map(_._2).mkString(",")})" override def acceptVisit(visitor: Visitor): Statement = visitor.visitDirectCall(this) override def linkParent(p: Block): Unit = { diff --git a/src/main/scala/ir/cilvisitor/CILVisitor.scala b/src/main/scala/ir/cilvisitor/CILVisitor.scala index 89facf80e..a44672324 100644 --- a/src/main/scala/ir/cilvisitor/CILVisitor.scala +++ b/src/main/scala/ir/cilvisitor/CILVisitor.scala @@ -80,18 +80,46 @@ class CILVisitorImpl(val v: CILVisitor) { } - def visit_expr(n: Expr): Expr = { + + def visit_expr(n: Expr): Expr = { def continue(n: Expr): Expr = n match { - case n: Literal => n - case MemoryLoad(mem, index, endian, size) => MemoryLoad(visit_mem(mem), visit_expr(index), endian, size) - case Extract(end, start, arg) => Extract(end, start, visit_expr(arg)) - case Repeat(repeats, arg) => Repeat(repeats, visit_expr(arg)) - case ZeroExtend(bits, arg) => ZeroExtend(bits, visit_expr(arg)) - case SignExtend(bits, arg) => SignExtend(bits, visit_expr(arg)) - case BinaryExpr(op, arg, arg2) => BinaryExpr(op, visit_expr(arg), visit_expr(arg2)) - case UnaryExpr(op, arg) => UnaryExpr(op, visit_expr(arg)) - case v: Variable => visit_rvar(v) - case UninterpretedFunction(n, params, rt) => UninterpretedFunction(n, params.map(visit_expr), rt) + case n: Literal => n + case MemoryLoad(mem, index, endian, size) => { + val nmem = visit_mem(mem) + val nind = visit_expr(index) + if ((nmem ne mem) || (nind ne index)) MemoryLoad(visit_mem(mem), visit_expr(index), endian, size) else n + } + case Extract(end, start, arg) => { + val narg = visit_expr(arg) + if (narg ne arg) Extract(end, start, narg) else n + } + case Repeat(repeats, arg) => { + val narg = visit_expr(arg) + if (narg ne arg) Repeat(repeats, arg) else n + } + case ZeroExtend(bits, arg) => { + val narg = visit_expr(arg) + if (narg ne arg) ZeroExtend(bits, narg) else n + } + case SignExtend(bits, arg) => { + val narg = visit_expr(arg) + if (narg ne arg) SignExtend(bits, narg) else n + } + case BinaryExpr(op, arg, arg2) => { + val narg1 = visit_expr(arg) + val narg2 = visit_expr(arg2) + if ((narg1 ne arg) || (narg2 ne arg2)) BinaryExpr(op, narg1, narg2) else n + } + case UnaryExpr(op, arg) => { + val narg = visit_expr(arg) + if (narg ne arg) UnaryExpr(op, narg) else n + } + case v: Variable => visit_rvar(v) + case UninterpretedFunction(name, params, rt) => { + val nparams = params.map(visit_expr) + val updated = (params.zip(nparams).map((a, b) => a ne b)).contains(true) + if (updated) UninterpretedFunction(name, nparams, rt) else n + } } doVisit(v, v.vexpr(n), n, continue) } diff --git a/src/main/scala/ir/eval/ExprEval.scala b/src/main/scala/ir/eval/ExprEval.scala index 0ad0a55ca..22ee56370 100644 --- a/src/main/scala/ir/eval/ExprEval.scala +++ b/src/main/scala/ir/eval/ExprEval.scala @@ -147,6 +147,84 @@ trait Loader[S, E] { } } +def fastPartialEvalExpr(exp: Expr): Expr = { + /* + * Ignore substitutions and parital eval + */ + exp match { + case f: UninterpretedFunction => f + case unOp: UnaryExpr => { + unOp.arg match { + case l: Literal => evalUnOp(unOp.op, l) + case o => UnaryExpr(unOp.op, o) + } + } + case binOp: BinaryExpr => + val lhs = binOp.arg1 + val rhs = binOp.arg2 + binOp.getType match { + case m: MapType => binOp + case b: BitVecType => { + (binOp.op, lhs, rhs) match { + case (o: BVBinOp, l: BitVecLiteral, r: BitVecLiteral) => evalBVBinExpr(o, l, r) + case _ => BinaryExpr(binOp.op, lhs, rhs) + } + } + case BoolType => { + def bool2lit(b: Boolean) = if b then TrueLiteral else FalseLiteral + (binOp.op, lhs, rhs) match { + case (o: BVBinOp, l: BitVecLiteral, r: BitVecLiteral) => bool2lit(evalBVLogBinExpr(o, l, r)) + case (o: IntBinOp, l: IntLiteral, r: IntLiteral) => bool2lit(evalIntLogBinExpr(o, l.value, r.value)) + case (o: BoolBinOp, l: BoolLit, r: BoolLit) => bool2lit(evalBoolLogBinExpr(o, l.value, r.value)) + case _ => BinaryExpr(binOp.op, lhs, rhs) + } + } + case IntType => { + (binOp.op, lhs, rhs) match { + case (o: IntBinOp, l: IntLiteral, r: IntLiteral) => IntLiteral(evalIntBinExpr(o, l.value, r.value)) + case _ => BinaryExpr(binOp.op, lhs, rhs) + } + } + } + case extend: ZeroExtend => + val body = extend.body + (body match { + case b: BitVecLiteral => BitVectorEval.smt_zero_extend(extend.extension, b) + case o => extend.copy(body = o) + }) + case extend: SignExtend => + val body = extend.body + body match { + case b: BitVecLiteral => BitVectorEval.smt_sign_extend(extend.extension, b) + case o => extend.copy(body = o) + } + case e: Extract => + val body = e.body + body match { + case b: BitVecLiteral => BitVectorEval.boogie_extract(e.end, e.start, b) + case o => e.copy(body = o) + } + case r: Repeat => + val body = r.body + body match { + case b: BitVecLiteral => { + assert(r.repeats > 0) + if (r.repeats == 1) b + else { + (2 to r.repeats).foldLeft(b)((acc, r) => BitVectorEval.smt_concat(acc, b)) + } + } + case o => r.copy(body = o) + } + case variable: Variable => variable + case ml: MemoryLoad => + val addr = ml.index + ml.copy(index= addr) + case b: Literal => b + } +} + + def statePartialEvalExpr[S](l: Loader[S, InterpreterError])(exp: Expr): State[S, Expr, InterpreterError] = { val eval = statePartialEvalExpr(l) val ns = exp match { diff --git a/src/main/scala/ir/eval/SimplifyExpr.scala b/src/main/scala/ir/eval/SimplifyExpr.scala index 99eb22d3d..bd837b4e0 100644 --- a/src/main/scala/ir/eval/SimplifyExpr.scala +++ b/src/main/scala/ir/eval/SimplifyExpr.scala @@ -9,19 +9,17 @@ import ir.cilvisitor.* val assocOps: Set[BinOp] = Set(BVADD, BVMUL, BVOR, BVAND, BVEQ, BoolAND, BoolEQ, BoolOR, BoolEQUIV, BoolEQ, IntADD, IntMUL, IntEQ) - -object AlgebraicSimplifications extends CILVisitor { - override def vexpr(e: Expr) = ChangeDoChildrenPost(eval.simplifyExprFixpoint(e), eval.simplifyExprFixpoint) +object AlgebraicSimplifications extends CILVisitor { + override def vexpr(e: Expr) = ChangeDoChildrenPost(e, eval.simplifyExprFixpoint) def apply(e: Expr) = { visit_expr(this, e) } } - object SimplifyValidation { var traceLog = mutable.LinkedHashSet[(Expr, Expr)]() - var validate : Boolean = false + var validate: Boolean = false def makeValidation(writer: BufferedWriter) = { @@ -31,7 +29,7 @@ object SimplifyValidation { case BitVecType(sz) => BinaryExpr(BVEQ, a, b) case IntType => BinaryExpr(IntEQ, a, b) case BoolType => BinaryExpr(BoolEQ, a, b) - case m: MapType => ??? + case m: MapType => ??? } } @@ -51,28 +49,30 @@ object SimplifyValidation { } } - } - def simplifyExprFixpoint(e: Expr): Expr = { val begin = e var pe = e var count = 0 - var ne = simplifyExpr(pe) - while (ne != pe && count < 5) { + var ne = e + var changedAny = false + var changed = true + while (changed) { + val (x, didAnything) = simplifyExpr(ne) + changed = didAnything + changedAny = changedAny || changed count += 1 - pe = ne - ne = simplifyExpr(pe) + ne = ir.eval.fastPartialEvalExpr(x) } - if (ne != pe) { + if (changed) { Logger.error(s"stopping simp before fixed point: there is likely a simplificatinon loop: $pe !=== $ne") } if (SimplifyValidation.validate) { // normalise to deduplicate log entries val normer = VarNameNormalise() - val a = visit_expr(normer, begin) - val b = visit_expr(normer, ne) + val a = visit_expr(normer, begin) + val b = visit_expr(normer, ne) SimplifyValidation.traceLog.add((a, b)) } @@ -83,10 +83,9 @@ class VarNameNormalise() extends CILVisitor { var count = 1 val assigned = mutable.Map[Variable, Variable]() - def rename(v: Variable, newName: String) = { v match { - case LocalVar(n, t) => LocalVar(newName, t) + case LocalVar(n, t) => LocalVar(newName, t) case Register(n, sz) => Register(newName, sz) } } @@ -103,7 +102,6 @@ class VarNameNormalise() extends CILVisitor { } } - def apply(e: Expr) = { count = 1 assigned.clear() @@ -114,9 +112,7 @@ class VarNameNormalise() extends CILVisitor { } } - - -def simplifyExpr(e: Expr): Expr = { +def simplifyExpr(e: Expr): (Expr, Boolean) = { // println((0 until indent).map(" ").mkString("") + e) def bool2bv1(e: Expr) = { @@ -133,19 +129,7 @@ def simplifyExpr(e: Expr): Expr = { BVUGE -> BVUGT, BVSLE -> BVSLT, BVULE -> BVULT - ) - - - def isNegBV(e: Expr) = { - simplifyExpr(e).getType match { - case BitVecType(sz) => { - BinaryExpr(BVSLT, e, BitVecLiteral(0, sz)) - } - case _ => ??? - } - } - - val e2 = ir.eval.partialEvalExpr(e, (v) => None) + ) def pushExtend(e: Expr, extend: Expr => Expr): Expr = { e match { @@ -170,7 +154,8 @@ def simplifyExpr(e: Expr): Expr = { /** Apply the rewrite rules once. Note some rules expect a canonical form produced by other rules, and hence this is * more effective when applied iteratively until a fixed point. */ - val simped = e2 match { + var didAnything = true + val simped = e match { // constant folding // const + (expr + const) -> expr + (const + const) @@ -204,7 +189,6 @@ def simplifyExpr(e: Expr): Expr = { UnaryExpr(BoolNOT, BinaryExpr(BVEQ, (e1), (e2))) case BinaryExpr(BVNEQ, e1, e2) => UnaryExpr(BoolNOT, BinaryExpr(BVEQ, (e1), (e2))) - case BinaryExpr(op, BinaryExpr(op1, a, b: Literal), BinaryExpr(op2, c, d: Literal)) if !a.isInstanceOf[Literal] && !c.isInstanceOf[Literal] && assocOps.contains(op) && op == op1 && op == op2 => @@ -269,6 +253,10 @@ def simplifyExpr(e: Expr): Expr = { if ir.eval.BitVectorEval.isNegative(y) => BinaryExpr(BVEQ, x, UnaryExpr(BVNEG, y)) + case BinaryExpr(BVCONCAT, BitVecLiteral(0, sz), expr) => ZeroExtend(sz, expr) + case BinaryExpr(BVCONCAT, expr, BitVecLiteral(0, sz)) if (BigInt(2).pow(sz + size(expr).get) > sz) => + BinaryExpr(BVSHL, ZeroExtend(sz, expr), BitVecLiteral(sz, sz + size(expr).get)) + /* COMPARISON FLAG HANDLING * * We quite precisely pattern match ASLp's output for C and V, @@ -362,7 +350,7 @@ def simplifyExpr(e: Expr): Expr = { && AlgebraicSimplifications(x2) == AlgebraicSimplifications(ZeroExtend(exts, x1)) && AlgebraicSimplifications(y2) == AlgebraicSimplifications(ZeroExtend(exts, y1)) => { // C not Set - AlgebraicSimplifications(UnaryExpr(BoolNOT, BinaryExpr(BVUGE, x1, UnaryExpr(BVNEG, y1)))) + UnaryExpr(BoolNOT, BinaryExpr(BVUGE, x1, UnaryExpr(BVNEG, y1))) } case BinaryExpr( @@ -392,7 +380,7 @@ def simplifyExpr(e: Expr): Expr = { /* generic comparison simplification */ - // weak to strict inequality + // weak to strict inequality // x >= 0 && x != 0 ===> x > 0 case BinaryExpr(BoolAND, BinaryExpr(op, lhs, BitVecLiteral(0, sz)), UnaryExpr(BoolNOT, rhs)) if size(lhs).isDefined && (AlgebraicSimplifications(BinaryExpr(BVEQ, lhs, BitVecLiteral(0, size(lhs).get))) == rhs) && ineqToStrict.contains(op) => { @@ -407,7 +395,7 @@ def simplifyExpr(e: Expr): Expr = { BinaryExpr(ineqToStrict(op), lhs, rhs) } case BinaryExpr(BoolAND, BinaryExpr(op, lhs, rhs), UnaryExpr(BoolNOT, BinaryExpr(BVEQ, BinaryExpr(BVADD, lhs2, rhs2), BitVecLiteral(0, _)))) - if rhs == rhs2 && AlgebraicSimplifications(lhs) == AlgebraicSimplifications(UnaryExpr(BVNEG, lhs2)) && ineqToStrict.contains(op) => { + if rhs == rhs2 && (AlgebraicSimplifications(lhs) == AlgebraicSimplifications(UnaryExpr(BVNEG, lhs2))) && ineqToStrict.contains(op) => { BinaryExpr(ineqToStrict(op), lhs, rhs) } @@ -459,6 +447,18 @@ def simplifyExpr(e: Expr): Expr = { case _ => throw Exception("Type error (should be unreachable)") } case BinaryExpr(BoolEQ, x, FalseLiteral) => UnaryExpr(BoolNOT, x) + + // redundant extends + // extract extended zero part + case Extract(ed, bg, ZeroExtend(x, expr)) if (bg > size(expr).get) => BitVecLiteral(0, ed - bg) + // extract below extend + case Extract(ed, bg, ZeroExtend(x, expr)) if (bg < size(expr).get) && (ed < size(expr).get) => Extract(ed, bg, expr) + case Extract(ed, bg, SignExtend(x, expr)) if (bg < size(expr).get) && (ed < size(expr).get) => Extract(ed, bg, expr) + + case BinaryExpr(BVEQ, ZeroExtend(sz, expr), BitVecLiteral(0, sz2)) => BinaryExpr(BVEQ, expr, BitVecLiteral(0, size(expr).get)) + + + // double negation case UnaryExpr(BVNOT, UnaryExpr(BVNOT, body)) => (body) case UnaryExpr(BVNEG, UnaryExpr(BVNEG, body)) => (body) @@ -483,15 +483,12 @@ def simplifyExpr(e: Expr): Expr = { ) => BinaryExpr(BVEQ, (body), BitVecLiteral(0, 1)) - case BinaryExpr(BVSUB, x: Expr, y: BitVecLiteral) => BinaryExpr(BVADD, x, UnaryExpr(BVNEG, y)) - case BinaryExpr(BVAND, l, r) if l == r && l.loads.isEmpty => (l) - case BinaryExpr(op, x, y) => BinaryExpr(op, (x), (y)) - case r => r + case BinaryExpr(BVSUB, x: Expr, y: BitVecLiteral) => BinaryExpr(BVADD, x, UnaryExpr(BVNEG, y)) + case r => { + didAnything = false + r + } } - if (simped != e) { - // println(s"old $e -> ") - // println(s" $simped") - } - simped + (simped, didAnything) } diff --git a/src/main/scala/ir/invariant/BackwardsCFGMatchesForwards.scala b/src/main/scala/ir/invariant/CFGCorrect.scala similarity index 100% rename from src/main/scala/ir/invariant/BackwardsCFGMatchesForwards.scala rename to src/main/scala/ir/invariant/CFGCorrect.scala diff --git a/src/main/scala/ir/transforms/DynamicSingleAssignment.scala b/src/main/scala/ir/transforms/DynamicSingleAssignment.scala index 3a84ef732..5a009ad39 100644 --- a/src/main/scala/ir/transforms/DynamicSingleAssignment.scala +++ b/src/main/scala/ir/transforms/DynamicSingleAssignment.scala @@ -2,273 +2,541 @@ package ir.transforms import util.Logger import ir.cilvisitor.* +import translating.* import ir.* import scala.collection.mutable import analysis._ -object DynamicSingleAssignment { - /* TODO: improve using liveness */ +val phiAssignLabel = Some("phi") + +/** This transforms the program by no-op copies and renaming local variable indices to establish the property that + * + * \forall variables v, forall uses of v : u, No subset of definitions of v defines the use u. + */ + +class OnePassDSA( + liveVarsCheck: Option[Map[CFGPosition, Set[Variable]]] = None + /** Check our (faster) live var result against the TIP sovler solution */ +) { + + val liveVarsDom = transforms.IntraLiveVarsDomain() + val liveVarsSolver = transforms.worklistSolver(liveVarsDom) + + case class BlockState( + renamesBefore: mutable.Map[Variable, Int] = mutable.Map[Variable, Int](), + renamesAfter: mutable.Map[Variable, Int] = mutable.Map[Variable, Int](), + var filled: Boolean = false, /* have given local value numbering */ + var completed: Boolean = false, /* have filled and processed all incoming */ + var isPhi: Boolean = false /* begins filled */ + ) + + def renameLHS(c: Command, variable: Variable, index: Int) = { + c match { + case s: Statement => visit_stmt(StmtRenamer(Map((variable -> index)), Map()), s) + case j: Jump => visit_jump(StmtRenamer(Map((variable -> index)), Map()), j) + } + } + + def renameRHS(c: Command, variable: Variable, index: Int) = { + c match { + case s: Statement => visit_stmt(StmtRenamer(Map(), Map((variable -> index))), s) + case j: Jump => visit_jump(StmtRenamer(Map(), Map((variable -> index))), j) + } + } + + def appendAssign(b: Block, s: Assign) = { + // maintain call end of lock invariant + if (b.statements.size > 0 && b.statements.last.isInstanceOf[Call]) { + b.statements.insertBefore(b.statements.last, s) + } else { + b.statements.append(s) + } + } + + def withDefault(_st: mutable.Map[Block, BlockState])(b: Block) = { + if _st.contains(b) then _st(b) + else { + _st(b) = BlockState() + _st(b) + } + } - def applyTransform(program: Program, liveVars: Map[CFGPosition, Set[Variable]]) = { + def localProcessBlock( + state: mutable.Map[Block, BlockState], + count: mutable.Map[Variable, Int], + block: Block + ): Unit = { + def st(b: Block) = withDefault(state)(b) - case class DSARes(renames: Map[Variable, Int] = Map().withDefaultValue(-1)) // -1 means no rename + var renames = st(block).renamesBefore - def addIndex(v: Variable, idx: Int) = { - if (idx != -1) { - v match { - case Register(n, sz) => { - throw Exception("Should not SSA registers") - Register(n + "_" + idx, sz) + for (s <- block.statements) { + visit_stmt(StmtRenamer(Map(), renames.toMap), s) + s match { + case a @ Assign(lhs: LocalVar, _, _) => { + count(lhs) = count(lhs) + 1 + renameLHS(a, lhs, count(lhs)) + renames = renames + (lhs -> count(lhs)) + + } + case d: DirectCall => { + val vars = d.outParams.map(_._2).toList + for (lhs <- vars) { + count(lhs) = count(lhs) + 1 + renameLHS(d, lhs, count(lhs)) + renames = renames + (lhs -> count(lhs)) } - case v @ LocalVar(n, t) => LocalVar(v.varName, t, idx) } - } else { - v + case _ => () } } - class StmtRenamer(renamesL: Map[Variable, Int] = Map(), renames: Map[Variable, Int] = Map()) extends CILVisitor { - override def vrvar(v: Variable) = v match { - case v if renames.contains(v) && renames(v) != -1 => ChangeTo(addIndex(v, renames(v))) - case _ => DoChildren() - } + visit_jump(StmtRenamer(Map(), renames.toMap), block.jump) + st(block).renamesAfter.addAll(renames) + st(block).filled = true - override def vlvar(v: Variable) = v match { - case v if renamesL.contains(v) && renamesL(v) != -1 => ChangeTo(addIndex(v, renamesL(v))) - case _ => DoChildren() - } - } + } - def appendAssign(b: Block, s: Assign) = { - // maintain call end of lock invariant - if (b.statements.size > 0 && b.statements.last.isInstanceOf[Call]) { - b.statements.insertBefore(b.statements.last, s) - } else { - b.statements.append(s) - } + def applyTransform(p: Program): Unit = { + for (proc <- p.procedures) { + applyTransform(proc) } + } - def njoin(st: Map[Block, DSARes], blocks: Iterable[Block] /* incoming blocks */) = { - require(blocks.size >= 2) - val rs = blocks.map(st(_)) - val allRenamedVars = rs.flatMap(_.renames.keySet) - val renames = allRenamedVars.map { - case v => { - // for branches which have different renamings of variables, we know the larger index renaming is so far unused - // on the branches with smaller indexes, so we add a copy to these branches renaming to the largest index - val maxrename = rs.map(_.renames(v)).foldLeft(-1)(Integer.max) - v -> maxrename - } - }.toMap - - DSARes(renames.withDefaultValue(-1)) + def createBlockBetween(b1: Block, b2: Block, label: String = "_phi_"): Block = { + require(b1.nextBlocks.toSet.contains(b2)) + val nb = Block(b1.label + label + b2.label) + b1.parent.addBlocks(nb) + b1.jump match { + case g: GoTo => { + g.addTarget(nb) + g.removeTarget(b2) + } + case _ => ??? } + nb.replaceJump(GoTo(b2)) + nb + } - def fixJoins( - st: Map[Block, DSARes], - p: Procedure, - lhss: Map[Command, Map[Variable, Int]], - rhss: Map[Command, Map[Variable, Int]] - ): (Map[Command, Map[Variable, Int]], Map[Command, Map[Variable, Int]]) = { - /** - * Add copies to blocks who have two incoming blocks with different renamings present to unify the renamings into the - * subsequent block. We will expect that one of the incoming branches has a reassignment equal to this blocks entry point - * reassignment. All other branches should have a smaller SSA index, and allow an additional copy to be added ot unify them - * with the larger assignment. - * - * some consideration is required when adding copies to keep the call-end-of-block invariant - * - * TODO: we can avoid adding copies for variables not live at the join - */ - var lhs = lhss - var rhs = rhss - for (b <- p.blocks) { - val incoming = b.prevBlocks - val all = incoming.flatMap(st(_).renames.keySet) - val outgoing = rhss(IRWalk.firstInBlock(b)) - - // fix renames for all variables when when all incoming block renames rhs don't match - // the expected rename at this block rhs AND the variable is still live at this block - // by adding the appropriate copy - for (v <- all) { - val outgoingRename = outgoing(v) - for (b <- incoming) { - if (st(b).renames(v) != outgoingRename && outgoingRename != -1 && (liveVars(b).contains(v))) { - b.statements.lastOption match { - case Some(d: DirectCall) if d.outParams.toSet.map(_._2).contains(v) => { - // if there is a call on this block assigning the variable, update its outparam's ssa index - lhs = lhs + (d -> (lhs.get(d).getOrElse(Map()) + (v -> st(b).renames(v)))) + def fixPredecessors( + _st: mutable.Map[Block, BlockState], + count: mutable.Map[Variable, Int], + liveBefore: mutable.Map[Block, Set[Variable]], + block: Block + ) = { + def state(b: Block) = withDefault(_st)(b) + + val preds = block.prevBlocks.toList + val toJoin = preds.filter(state(_).filled) + assert(!(toJoin.isEmpty && preds.nonEmpty), s"should always have at least one processed predecessor ${preds}") + + { + val definedVars = toJoin.flatMap(state(_).renamesAfter.keySet).toSet.intersect(liveBefore(block)) + val toUnify = definedVars + .map(v => v -> toJoin.map(state(_).renamesAfter.get(v).getOrElse(0))) + .filter((v, rns) => { + rns.toList match { + case Nil => false + case h :: Nil => false + // if there is no renaming such that all the incoming renames agree + // then we create a new copy + case h :: tl => + tl.foldLeft(Some(h): Option[Int])((acc, rn) => + acc match { + case Some(v) if v == rn => Some(v) + case _ => None } - case c => { - // otherwise add an assignment to the block at the end or immediately before the call, and - // update the ssa index of the in-parameters to the call - val assign = Assign(v, v, Some("appended")) - appendAssign(b, assign) - - lhs = lhs + (assign -> (lhs.get(assign).getOrElse(Map()) + (v -> outgoingRename))) - rhs = rhs + (assign -> (rhs.get(assign).getOrElse(Map()) + (v -> st(b).renames(v)))) - c match { - case Some(call: Call) => { - rhs = rhs + (call -> (rhs.get(call).getOrElse(Map()) + (v -> outgoingRename))) - } - case _ => () - } - } - } - } + ).isEmpty + } + }) + + if (toUnify.nonEmpty) { + val blocks = toJoin.map(b => b -> createBlockBetween(b, block, "_phi_back_")).toMap + for (v <- toUnify.map(_._1)) { + count(v) = count(v) + 1 + // new index for new copy of v (definition added to all incoming edges) + + for (b <- toJoin) { + val nb = blocks(b) + assert(state(b).filled) + state(nb).renamesBefore.addAll(state(b).renamesAfter) + + val assign = Assign(v, v, Some("phiback")) + state(nb).renamesAfter(v) = count(v) + appendAssign(nb, assign) + renameLHS(assign, v, state(nb).renamesAfter(v)) + renameRHS(assign, v, state(nb).renamesBefore.get(v).getOrElse(0)) + state(block).renamesBefore.addAll(state(nb).renamesAfter) } + + } + } else { + // all our incoming are equal, or we have only one predecessor etc + for (b <- toJoin) { + state(block).renamesBefore.addAll(state(b).renamesAfter) } } - (lhs, rhs) } + // set completed + if (toJoin.size == preds.size) { + state(block).completed = true + } else { + state(block).completed = false + } + + } - def processCollect( - i: DSARes, - b: Block, - count: mutable.Map[Variable, Int], - lhs: Map[Command, Map[Variable, Int]], - rhs: Map[Command, Map[Variable, Int]] - ): (DSARes, Map[Command, Map[Variable, Int]], Map[Command, Map[Variable, Int]]) = { - var r = i - var lh = lhs - var rh = rhs - // push the renames through a block, incrementing the ssa index when we see a new assignment - - for (s <- b.statements) { - rh = rh.updated(s, rh(s) ++ r.renames) - val renames: Map[Variable, Int] = s match { - case a: Assign => { - if ((lh.get(a).flatMap(_.get(a.lhs))).map(_ == -1).getOrElse(true)) { - count(a.lhs) = (count(a.lhs) + 1) - Map(a.lhs -> count(a.lhs)) - } else { - Map(a.lhs -> lh(a)(a.lhs)) - } + def fixSuccessors( + _st: mutable.Map[Block, BlockState], + count: mutable.Map[Variable, Int], + liveBefore: mutable.Map[Block, Set[Variable]], + liveAfter: mutable.Map[Block, Set[Variable]], + block: Block + ) = { + def state(b: Block) = withDefault(_st)(b) + + val next = block.nextBlocks.toList + val anyNextFilled = next.exists(state(_).filled) + val anyNextJoin = next.exists(_.prevBlocks.size > 1) + for (b <- next) { + val definedVars = state(block).renamesAfter.keySet.intersect(liveAfter(block)) + + if (definedVars.size > 0 && (anyNextFilled || anyNextJoin)) { + val nb = createBlockBetween(block, b, "_phi_") + + state(nb).renamesBefore.addAll(state(block).renamesAfter) + if (state(b).filled) { + // if filled we have chosen an incoming rename + state(nb).renamesAfter.addAll(state(b).renamesBefore) + } + + for (v <- definedVars) { + if (!state(nb).renamesAfter.contains(v)) { + count(v) = count(v) + 1 + state(nb).renamesAfter(v) = count(v) } - case a: DirectCall => { - (a.outParams - .map(_._2)) - .map(l => { - if (lh.get(s).flatMap(_.get(l)).map(_ == -1).getOrElse(true)) { - count(l) = (count(l) + 1) - (l -> count(l)) - } else { - (l -> lh(s)(l)) - } - }) - .toMap + if (state(nb).renamesBefore(v) != state(nb).renamesAfter(v)) { + val assign = Assign(v, v, phiAssignLabel) + appendAssign(nb, assign) + renameLHS(assign, v, state(nb).renamesAfter(v)) + renameRHS(assign, v, state(nb).renamesBefore(v)) } - case _ => Map() } - lh = lh.updated(s, lh(s) ++ renames) - r = r.copy(renames = r.renames ++ renames) + state(nb).filled = true + state(nb).isPhi = true + liveBefore(nb) = liveBefore(b) + liveAfter(nb) = liveBefore(b) } - rh = rh.updated(b.jump, rh(b.jump) ++ r.renames) - r = r.copy(renames = r.renames ++ rh(b.jump)) + } + } - (r, lh, rh) + def visitBlock( + _st: mutable.Map[Block, BlockState], + count: mutable.Map[Variable, Int], + liveBefore: mutable.Map[Block, Set[Variable]], + liveAfter: mutable.Map[Block, Set[Variable]], + block: Block + ) = { + + /** VisitBlock: + * + * 1. for all complete incoming + * - if there is an incoming rename that is not equal across the predecessors + * - add back phi block to each incoming edge + * - create a fresh copy of each non-uniform renamed variable + * - add copies to each phi block unify the incoming rename with the nominated new rename 2. add local value + * numbering to this block 3. if all predecessors are filled, mark this complete, otherwise mark this + * filled. 4. for all successors + * - if marked filled, add phi block to unify our outgoing with its incoming + * - if > 1 total predecessors for each unmarked , add phi block to nominate a new copy 5. for all successors, if + * all predecessors are now filled, mark complete + */ + + def state(b: Block) = withDefault(_st)(b) + + // add copies on incoming edges to make the outgoing rename of all filled precesessors + // the same, and define `block`'s incoming rename + fixPredecessors(_st, count, liveBefore, block) + + // apply local value numbering to `block` + var seenBefore = true + if (!(state(block).filled)) { + localProcessBlock(_st, count, block) + state(block).filled = true + assert(state(block).filled) + seenBefore = false } - def renameAll(b: Block, lhs: Map[Command, Map[Variable, Int]], rhs: Map[Command, Map[Variable, Int]]) = { - for (s <- b.statements) { - visit_stmt(StmtRenamer(lhs.get(s).getOrElse(Map()), rhs.get(s).getOrElse(Map())), s) + // mark successors complete which are completed as a result of processing `block` + for (b <- block.nextBlocks) { + if (b.prevBlocks.forall(state(_).filled)) { + state(b).completed = true } - val s = b.jump - visit_jump(StmtRenamer(lhs.get(s).getOrElse(Map()), rhs.get(s).getOrElse(Map())), s) } - def visitProc(p: Procedure) = { - /* - * visit in weak topological order, collect renames and copies of ssa variables then apply transform - * we need to visit loops twice to handle joins - */ - - val worklist = mutable.PriorityQueue[Block]()(Ordering.by(b => b.rpoOrder)) - worklist.addAll(p.blocks) - var seen = Set[Block]() - val count = mutable.Map[Variable, Int]().withDefaultValue(0) + // add outgoing copies for e.g. loop headers + fixSuccessors(_st, count, liveBefore, liveAfter, block) + } - // ssa index to rename lvars and rvars respectively - var lhs = Map[Command, Map[Variable, Int]]().withDefaultValue(Map().withDefaultValue(-1)) - var rhs = Map[Command, Map[Variable, Int]]().withDefaultValue(Map().withDefaultValue(-1)) + def applyTransform(p: Procedure): Unit = { + val _st = mutable.Map[Block, BlockState]() + // ensure order is defined + p.entryBlock.map(reversePostOrder) + + val (liveBeforeIn, liveAfterIn) = liveVarsSolver.solveProc(p, backwards = true) + val liveBefore = mutable.Map.from(liveBeforeIn) + val liveAfter = mutable.Map.from(liveAfterIn) + + for (lvResult <- liveVarsCheck) { + // check live vars has equal result to TIP solver + for ((b, s) <- liveBefore) { + assert( + lvResult(b) == s, + s"LiveVars unequal ${b.label}: ${lvResult(b)} == $s (differing ${lvResult(b).diff(s)})" + ) + } + } - // the variable renaming at the end of each block - var st = Map[Block, DSARes]().withDefaultValue(DSARes()) + val worklist = mutable.PriorityQueue[Block]()(Ordering.by(b => b.rpoOrder)) + worklist.addAll(p.blocks) + var seen = Set[Block]() + val count = mutable.Map[Variable, Int]().withDefaultValue(0) + while (worklist.nonEmpty) { while (worklist.nonEmpty) { - val b = worklist.dequeue + val block = worklist.dequeue + assert(worklist.headOption.map(_.rpoOrder < block.rpoOrder).getOrElse(true)) - if (b.prevBlocks.forall(st.contains(_)) && !b.prevBlocks.forall(b => seen.contains(b))) { - worklist.addAll(b.prevBlocks) - worklist.enqueue(b) - } else { - val prev = if (b.prevBlocks.size > 1 && b.prevBlocks.forall(b => st.contains(b))) { - // if we have a (possibly incomplete) entry for the incoming blocks we join them - njoin(st, b.prevBlocks) - } else if (b.incomingJumps.size == 1) { - st(b.incomingJumps.head.parent) - } else { - DSARes() - } - val (processed, nlhs, nrhs) = processCollect(prev, b, count, lhs, rhs) - if (st(b) != processed || nlhs != lhs || nrhs != rhs) { - lhs = nlhs - rhs = nrhs - worklist.addAll(b.nextBlocks) - st = st.updated(b, processed) - } - } - seen += b + visitBlock(_st, count, liveBefore, liveAfter, block) } - val (lhss, rhss) = fixJoins(st, p, lhs, rhs) - lhs = lhss - rhs = rhss + } - for (b <- p.blocks) { - renameAll(b, lhs, rhs) - } + // fix up rpo index of added phi blocks + p.entryBlock.map(reversePostOrder) + } +} + +class StmtRenamer(renamesL: Map[Variable, Int] = Map(), renames: Map[Variable, Int] = Map()) extends CILVisitor { + + private def addIndex(v: Variable, idx: Int) = { + assert(idx != -1) + v match { + case Register(n, sz) => { + throw Exception("Should not SSA registers") + Register(n + "_" + idx, sz) + } + case v: LocalVar => LocalVar(v.varName, v.irType, idx) } + } - applyRPO(program) - program.procedures.foreach(visitProc) + override def vrvar(v: Variable) = v match { + case v if renames.contains(v) && renames(v) != -1 => ChangeTo(addIndex(v, renames(v))) + case _ => DoChildren() + } + + override def vlvar(v: Variable) = v match { + case v if renamesL.contains(v) && renamesL(v) != -1 => ChangeTo(addIndex(v, renamesL(v))) + case _ => DoChildren() } } +def rdDSAProperty(p: Procedure): Boolean = { + /* + * Check the DSA property using a reaching definitions analysis. + * DSA Property: Every use of a variable v has every definition of v as a reaching definition + * / no strict subset of defintions of v defines any use of v, forall v. + */ + val defs: Map[Variable, Set[Assign | DirectCall]] = p + .flatMap { + case a: Assign => Seq((a.lhs, (a: Assign | DirectCall))) + case a: DirectCall => a.outParams.map(_._2).map((l: Variable) => (l, (a: Assign | DirectCall))).toSeq + case _ => Seq() + } + .groupBy(_._1) + .map((v, vs) => (v, vs.map(_._2).toSet)) + + Logger.debug(s"Reaching defs ${p.name}") + val reachingDefs = basicReachingDefs(p) + Logger.debug(s"Reaching defs ${p.name} DONE") + + class CheckDSAProperty( + defs: Map[Variable, Set[Assign | DirectCall]], + reaching: Map[Command, Map[Variable, Set[Assign | DirectCall]]] + ) extends CILVisitor { + var passed = true + var stmt: Command = null + val violations = mutable.HashSet[(Command, Variable)]() + + override def vrvar(v: Variable) = { + val allDefs = defs.get(v).toSet.flatten + val reachDefs = reachingDefs(stmt).get(v).toSet.flatten + + val check = allDefs == reachDefs + if (!check) { + val vil = (stmt, v) + if (!violations.contains(vil)) { + violations.add(vil) + // Logger.error(s"DSA Property violated on $v at $stmt @ ${stmt.parent.parent.name}::${stmt.parent.label}\n\t ${allDefs.diff(reachDefs)} defs not reached") + Logger.error( + s"DSA Property violated on $v at $stmt @ ${stmt.parent.parent.name}::${stmt.parent.label}\n\t ${allDefs + .diff(reachDefs)} defs not reached\n\t${reachDefs}" + ) + } + } + passed = passed && check + SkipChildren() + } + override def vstmt(v: Statement) = { + stmt = v + DoChildren() + } + override def vjump(j: Jump) = { + stmt = j + DoChildren() + } + } -def undoDSA(p: Procedure) : Unit = { - /** - * This just naively removes the indices from variables, some analyses may require dsa form to maintain correctness - * (e.g. bitvector width removal) - * - * You can only do this if the ssa index order matches the flow order, and all have the same type. - */ - visit_proc(RevIndices, p) - visit_proc(RemoveCopy, p) + val vis = CheckDSAProperty(defs, reachingDefs) + visit_proc(vis, p) + if (vis.passed) { + Logger.debug(s"${p.name} DSA check OK") + } + vis.passed +} - object RevIndices extends CILVisitor { - override def vlvar(v: Variable) = v match { - case l: LocalVar => ChangeTo(LocalVar(l.varName, l.irType)) - case o => SkipChildren() - } - override def vrvar(v: Variable) = v match { - case l: LocalVar => ChangeTo(LocalVar(l.varName, l.irType)) - case o => SkipChildren() +object DSAPropCheck { + // check the property that no strict subset of definitions dominates a use + // + // This attempts to generate an SMT proof of this, by encoding the reaching definitions in SMT2 + // Likely this does not work. + + def getUses(s: CFGPosition): Set[Variable] = { + s match { + case a: Assign => a.rhs.variables + case a: DirectCall => a.actualParams.flatMap(_._2.variables).toSet + case a: Return => a.outParams.flatMap(_._2.variables).toSet + case a: IndirectCall => Set(a.target) + case a: Assert => a.body.variables + case a: Assume => a.body.variables + case _ => Set() } } - object RemoveCopy extends CILVisitor { - override def vstmt(s: Statement) = s match { - case Assign(LocalVar(n1, t1), LocalVar(n2, t2), _) if n1 == n2 => ChangeTo(List()) - case o => SkipChildren() + def getDefinitions(s: CFGPosition): Set[Variable] = { + s match { + case a: Assign => Set(a.lhs) + case a: DirectCall => a.outParams.map(_._2).toSet + case _ => Set() } } -} -def undoDSA(p: Program) : Unit = { - for (proc <- p.procedures) { - undoDSA(proc) + def emitProof(proc: Procedure) = { + + var fresh = 0 + val vartoIdx = mutable.Map[Variable, Int]() + val nodeToIdx = mutable.Map[CFGPosition, Int]() + + val blockcfg = proc.blocks.flatMap(b => { + val pred = IRWalk.lastInBlock(b) + b.statements.map(s => (s, s.successor)) ++ + b.nextBlocks + .map(IRWalk.firstInBlock) + .map(succ => { + (pred, succ) + }) + }) + + def getVarIdx(n: Variable): Int = { + if (vartoIdx.contains(n)) { + vartoIdx(n) + } else { + fresh += 1 + vartoIdx(n) = fresh + vartoIdx(n) + } + } + + def getGraphNode(p: CFGPosition): Int = { + // val idx = getVarIdx(v) // put in map + val n = p + if (nodeToIdx.contains(n)) { + nodeToIdx(n) + } else { + fresh += 1 + nodeToIdx(n) = fresh + nodeToIdx(n) + } + } + + def defTruePredicate(name: String, args: List[IRType]) = + list( + sym("declare-fun"), + sym(name), + Sexp.Slist(args.map(i => BasilIRToSMT2.basilTypeToSMTType(i)).toList), + BasilIRToSMT2.basilTypeToSMTType(BoolType) + ) + + def dominates[T](pred: Sexp[T], succ: Sexp[T]) = { + list(sym("assert"), list(sym("dominates"), pred, succ)) + } + def notDominates[T](pred: Sexp[T], succ: Sexp[T]) = { + list(sym("assert"), list(sym("not"), list(sym("dominates"), pred, succ))) + } + + val edges = blockcfg.map((p, s) => ((getGraphNode(p)), (getGraphNode(s)))).toSet + val nodes = edges.flatMap((a, b) => Seq(a, b)) + + val doms = nodes.flatMap(nn1 => { + nodes.flatMap(nn2 => { + val (n1, n2) = (BasilIRToSMT2.int2smt(nn1), BasilIRToSMT2.int2smt(nn2)) + if (edges.contains((nn1, nn2))) { + Seq(dominates(n1, n2)) + } else { + Seq() + // notDominates(n1, n2) + } + }) + }) + + val res: List[Sexp[_]] = List( + defTruePredicate("isUse", List(IntType, IntType)), // node, var + defTruePredicate("isDef", List(IntType, IntType)), // node, var + defTruePredicate("dominates", List(IntType, IntType)) // node, node + ) ++ doms + + val written = res.map(Sexp.print) + + val vars = nodeToIdx.flatMap((p, nodeID) => { + getUses(p).map(vn => { + val v = getVarIdx(vn) + list(sym("assert"), list(sym("isUse"), BasilIRToSMT2.int2smt(nodeID), BasilIRToSMT2.int2smt(v))) + }) + ++ + getDefinitions(p).map(vn => { + val v = getVarIdx(vn) + list(sym("assert"), list(sym("isUse"), BasilIRToSMT2.int2smt(nodeID), BasilIRToSMT2.int2smt(v))) + }) + }) + val axioms = List( + "(assert (forall ((x Int) (y Int) (z Int)) (implies (and (dominates x y) (dominates y z)) (dominates x z))))", + "(assert (not (forall ((n Int) (v Int) (v2 Int)) (and (isDef n v) (isUse n v2)))))", + // check of dsa property + """(declare-fun defines (Int Int) Bool) +(assert (forall ((d Int) (u Int) (i Int) (v Int)) (implies (and (isUse u v) (isDef d v) (dominates d u) (not (and (dominates d i) (dominates i u) (isDef i v)))) (defines d u)))) +(assert (not (forall ((usenode Int) (variable Int)) + (implies (isUse usenode variable) + (forall ((defnode Int)) + (implies (isDef defnode variable) + (defines defnode usenode))))))) + """ + ) + + util.writeToFile( + (written ++ vars.map(Sexp.print) ++ axioms).mkString("\n") + "\n(check-sat)", + s"proofs/${proc.name}-dsa-graph.smt2" + ) } -} +} diff --git a/src/main/scala/ir/transforms/ProcedureParameters.scala b/src/main/scala/ir/transforms/ProcedureParameters.scala index d861ea5f9..fa64f7c26 100644 --- a/src/main/scala/ir/transforms/ProcedureParameters.scala +++ b/src/main/scala/ir/transforms/ProcedureParameters.scala @@ -8,6 +8,24 @@ import specification.Specification import analysis.{TwoElement, TwoElementTop, TwoElementBottom} import ir.CallGraph +case class FunSig(inArgs: List[Register], outArgs: List[Register]) + +def R(n: Int) = { + Register(s"R$n", 64) +} + +val builtinSigs : Map[String, FunSig] = Map( + "#free" -> FunSig(List(R(0)), List(R(0))), + "malloc" -> FunSig(List(R(0)), List(R(0))), + "strlen" -> FunSig(List(R(0)), List(R(0))), + "strchr" -> FunSig(List(R(0), R(1)), List(R(0))), + "strlcpy" -> FunSig(List(R(0), R(1), R(2)), List(R(0))), + "strlcat" -> FunSig(List(R(0), R(1), R(2)), List(R(0))) + ) + +def fnsigToBinding(f: FunSig) = (f.inArgs.map(a => LocalVar(a.name + "_in", a.getType) -> LocalVar(a.name, a.getType)), + f.outArgs.map(a => LocalVar(a.name + "_out", a.getType) -> LocalVar(a.name, a.getType))) + def liftProcedureCallAbstraction(ctx: util.IRContext): util.IRContext = { val liveVars = @@ -103,6 +121,9 @@ class SetFormalParams( override def vproc(p: Procedure) = { if (externalFunctions.contains(p.name)) { + + + p.formalInParam = mutable.SortedSet.from(externalIn.map(_._1)) p.formalOutParam = mutable.SortedSet.from(externalOut.map(_._1)) p.inParamDefaultBinding = immutable.SortedMap.from(externalIn) @@ -293,22 +314,22 @@ class SetActualParams( override def vstmt(s: Statement) = { currStmt = Some(s) s match { - case d: DirectCall if !externalFunctions.contains(d.target.name) => { - // we have changed the parameter-passed variable to locals so we have LocalVar(n) -> LocalVar(n) - for (binding <- inBinding.get(d.target)) { - if (externalFunctions.contains(d.target.name)) { - d.actualParams = SortedMap.from(binding) - } else { - d.actualParams = d.actualParams ++ SortedMap.from(binding) - } - } - for (binding <- outBinding.get(d.target)) { - d.outParams = SortedMap.from(binding) - } - } - case d: DirectCall /* if external */ => { - d.actualParams = SortedMap.from(externalIn) - d.outParams = SortedMap.from(externalOut) + // case d: DirectCall if !externalFunctions.contains(d.target.name) => { + // // we have changed the parameter-passed variable to locals so we have LocalVar(n) -> LocalVar(n) + // for (binding <- inBinding.get(d.target)) { + // if (externalFunctions.contains(d.target.name)) { + // d.actualParams = SortedMap.from(binding) + // } else { + // d.actualParams = d.actualParams ++ SortedMap.from(binding) + // } + // } + // for (binding <- outBinding.get(d.target)) { + // d.outParams = SortedMap.from(binding) + // } + // } + case d: DirectCall => { + d.actualParams = SortedMap.from(d.target.inParamDefaultBinding) + d.outParams = SortedMap.from(d.target.outParamDefaultBinding) } case _ => () } @@ -318,9 +339,7 @@ class SetActualParams( override def vjump(j: Jump) = { j match { case r: Return => { - for (binding <- outBinding.get(r.parent.parent)) { - r.outParams = SortedMap.from(binding) - } + r.outParams = SortedMap.from(r.parent.parent.outParamDefaultBinding) DoChildren() } case _ => DoChildren() diff --git a/src/main/scala/ir/transforms/Simp.scala b/src/main/scala/ir/transforms/Simp.scala index e24b4f257..1b698c2ec 100644 --- a/src/main/scala/ir/transforms/Simp.scala +++ b/src/main/scala/ir/transforms/Simp.scala @@ -13,15 +13,136 @@ import scala.util.{Failure, Success} import ExecutionContext.Implicits.global trait AbstractDomain[L] { - def join(a: L, b: L): L - def widen(a: L, b: L): L = join(a, b) + def join(a: L, b: L, pos: Block): L + def widen(a: L, b: L, pos: Block): L = join(a, b, pos) def narrow(a: L, b: L): L = a - def transfer(a: L, b: Statement): L + def transfer(a: L, b: Command): L + def transferBlockFwd(a: L, b: Block): L = { + transfer(b.statements.foldLeft(a)(transfer), b.jump) + } + def transferBlockBwd(a: L, b: Block): L = { + b.statements.toList.reverse.foldLeft(transfer(a, b.jump))(transfer) + } def top: L def bot: L } + +def getLiveVars(p: Procedure) = { + val liveVarsDom = IntraLiveVarsDomain() + val liveVarsSolver = worklistSolver(liveVarsDom) + liveVarsSolver.solveProc(p, backwards=true) +} + + +def basicReachingDefs(p: Procedure): Map[Command, Map[Variable, Set[Assign | DirectCall]]] = { + val (beforeLive, afterLive) = getLiveVars(p) + val dom = DefUseDomain(beforeLive) + val solver = worklistSolver(dom) + // type rtype = Map[Block, Map[Variable, Set[Assign | DirectCall]]] + val (beforeRes, afterRes) = solver.solveProc(p) + + + val merged: Map[Command, Map[Variable, Set[Assign | DirectCall]]] = + beforeRes.flatMap((block, sts) => { + val b = Seq(IRWalk.firstInBlock(block) -> sts) + val stmts = + if (block.statements.nonEmpty) then + (block.statements.toList: List[Command]).zip(block.statements.toList.tail ++ List(block.jump)) + else List() + val transferred = stmts + .foldLeft((sts, List[(Command, Map[Variable, Set[Assign | DirectCall]])]()))((st, s) => { + // map successor to transferred predecessor + val x = dom.transfer(st._1, s._1) + (x, (s._2 -> x) :: st._2) + }) + ._2.toMap + b ++ transferred + }) + .toMap + merged +} + +case class DefUse(defined: Map[Variable, Assign | DirectCall]) + +// map v -> definitions reached here +class DefUseDomain(liveBefore: Map[Block, Set[Variable]]) extends AbstractDomain[Map[Variable, Set[Assign | DirectCall]]] { + // TODO: cull values using liveness + + override def transfer(s: Map[Variable, Set[Assign | DirectCall]], b: Command) = { + b match { + case a: Assign => s.updated(a.lhs, Set(a)) + case d: DirectCall => d.outParams.map(_._2).foldLeft(s)((s, r) => s.updated(r, Set(d))) + case _ => s + } + } + override def top = ??? + def bot = Map[Variable, Set[Assign | DirectCall]]() + def join(l: Map[Variable, Set[Assign | DirectCall]], r: Map[Variable, Set[Assign | DirectCall]], pos: Block) = { + l.keySet + .union(r.keySet).filter(k => liveBefore(pos).contains(k)) + .map(k => { + k -> (l.get(k).getOrElse(Set()) ++ r.get(k).getOrElse(Set())) + }) + .toMap + } + +} + +class CollectingDomain[T, D <: AbstractDomain[T]](d: D) extends AbstractDomain[(T, Map[Command, T])] { + def bot = (d.bot, Map()) + def top = ??? + + def join(l: (T, Map[Command, T]), r: (T, Map[Command, T]), pos: Block): (T, Map[Command, T]) = { + val nr = d.join(l._1, r._1, pos) + ( + nr, + (l._2.keySet ++ r._2.keySet) + .map((s: Command) => { + ((l._2.get(s), r._2.get(s)) match { + case (Some(l), Some(r)) if l == r => Seq(s -> l) + case (Some(l), Some(r)) if l != r => Seq() + case (Some(l), None) => Seq(s -> l) + case (None, Some(l)) => Seq(s -> l) + case _ => ??? + }) + }) + .flatten + .toMap + ) + } + + override def transfer(s: (T, Map[Command, T]), c: Command) = { + val u = d.transfer(s._1, c) + (u, s._2.updated(c, u)) + } +} + +trait PowerSetDomain[T] extends AbstractDomain[Set[T]] { + def bot = Set() + def top = ??? + def join(a: Set[T], b: Set[T], pos: Block) = a.union(b) +} + +class IntraLiveVarsDomain extends PowerSetDomain[Variable] { + // expected backwards + + def transfer(s: Set[Variable], a: Command): Set[Variable] = { + a match { + case a: Assign => (s - a.lhs) ++ a.rhs.variables + case m: MemoryAssign => s ++ m.index.variables ++ m.value.variables + case a: Assume => s ++ a.body.variables + case a: Assert => s ++ a.body.variables + case i: IndirectCall => s + i.target + case c: DirectCall => (s -- c.outParams.map(_._2)) ++ c.actualParams.flatMap(_._2.variables) + case g: GoTo => s + case r: Return => s ++ r.outParams.flatMap(_._2.variables) + case r: Unreachable => s + } + } +} + object MakeLocalsBlockUnique extends CILVisitor { var blockLabel: String = "" @@ -129,7 +250,7 @@ def removeSlices(p: Procedure): Unit = { } enum HighZeroBits: - case Bits(n: Int) // (i) and (ii) hold; the n highest bits are redundant + case Bits(n: Int) // most significant bit that is accessed (and all below) case False // property is false case Bot // don't know anything @@ -142,10 +263,12 @@ def removeSlices(p: Procedure): Unit = { val unifiedAssignments = ufsolver .unifications() - .map { case (v: LVTerm, rvs) => - v.v -> (rvs.map { case LVTerm(rv) => - rv - }).toSet + .map { + case (v: LVTerm, rvs) => + v.v -> (rvs.map { case LVTerm(rv) => + rv + }).toSet + case _ => ??? } .map((repr: LocalVar, elems: Set[LocalVar]) => repr -> elems.flatMap(assignments(_).filter(_ match { @@ -156,79 +279,115 @@ def removeSlices(p: Procedure): Unit = { ) // try and find a single extension size for all rhs of assignments to all variables in the assigned equality class - val varHighZeroBits: Map[LocalVar, HighZeroBits] = assignments.map((v, assigns) => - // note: this overapproximates on x := y when x and y may both be smaller than their declared size - val allRHSExtended = assigns.foldLeft(HighZeroBits.Bot: HighZeroBits)((e, assign) => - (e, assign.rhs) match { - case (HighZeroBits.Bot, ZeroExtend(i, lhs)) => HighZeroBits.Bits(i) - case (b @ HighZeroBits.Bits(ei), ZeroExtend(i, _)) if i == ei => b - case (b @ HighZeroBits.Bits(ei), ZeroExtend(i, _)) if i != ei => HighZeroBits.False - case (HighZeroBits.False, _) => HighZeroBits.False - case (_, other) => HighZeroBits.False - } - ) - (v, allRHSExtended) - ) - - val varsWithExtend: Map[LocalVar, HighZeroBits] = assignments - .map((lhs, _) => { - // map all lhs to the result for their representative - val rep = ufsolver.find(LVTerm(lhs)) match { - case LVTerm(r) => r - } - lhs -> varHighZeroBits.get(rep) - }) - .collect { case (l, Some(x)) /* remove anything we have no information on */ => - (l, x) - } + // val varHighZeroBits: Map[LocalVar, HighZeroBits] = assignments.map((v, assigns) => + // // note: this overapproximates on x := y when x and y may both be smaller than their declared size + // val allRHSExtended = assigns.foldLeft(HighZeroBits.Bot: HighZeroBits)((e, assign) => + // (e, assign.rhs) match { + // case (HighZeroBits.Bot, ZeroExtend(i, lhs)) => HighZeroBits.Bits(i) + // case (b @ HighZeroBits.Bits(ei), ZeroExtend(i, _)) if i == ei => b + // case (b @ HighZeroBits.Bits(ei), ZeroExtend(i, _)) if i != ei => HighZeroBits.False + // case (HighZeroBits.False, _) => HighZeroBits.False + // case (_, other) => HighZeroBits.False + // } + // ) + // (v, allRHSExtended) + // ) + + //val varsWithExtend: Map[LocalVar, HighZeroBits] = assignments + // .map((lhs, _) => { + // // map all lhs to the result for their representative + // val rep = ufsolver.find(LVTerm(lhs)) match { + // case LVTerm(r) => r + // case _ => ??? + // } + // lhs -> varHighZeroBits.get(rep) + // }) + // .collect { case (l, Some(x)) /* remove anything we have no information on */ => + // (l, x) + // } class CheckUsesHaveExtend() extends CILVisitor { val result: mutable.HashMap[LocalVar, HighZeroBits] = mutable.HashMap[LocalVar, HighZeroBits]() + override def vrvar(v: Variable) = { + v match { + case v: LocalVar => { + result(v) = HighZeroBits.False + } + case _ => () + } + SkipChildren() + } + override def vexpr(v: Expr) = { v match { case Extract(i, 0, v: LocalVar) - if size(v).isDefined && result.get(v).contains(HighZeroBits.Bits(size(v).get - i)) => - SkipChildren() - case v: LocalVar => { - result.remove(v) + if size(v).isDefined && ((!result.contains(v)) || result.get(v).contains(HighZeroBits.Bot)) => { + result(v) = HighZeroBits.Bits(i) SkipChildren() } + case Extract(i, 0, v: LocalVar) if size(v).isDefined && result.get(v).contains(HighZeroBits.Bits(i)) => + SkipChildren() case _ => DoChildren() } } - def apply(assignHighZeroBits: Map[LocalVar, HighZeroBits])(p: Procedure): Map[LocalVar, HighZeroBits] = { + def apply(p: Procedure): Map[LocalVar, HighZeroBits] = { result.clear() - result.addAll(assignHighZeroBits) visit_proc(this, p) result.toMap } } - val toSmallen = CheckUsesHaveExtend()(varsWithExtend)(p).collect { case (v, HighZeroBits.Bits(x)) => + val toSmallen = CheckUsesHaveExtend()(p).collect { case (v, HighZeroBits.Bits(x)) => v -> x }.toMap class ReplaceAlwaysSlicedVars(varHighZeroBits: Map[LocalVar, Int]) extends CILVisitor { + var formals = Set[LocalVar]() override def vexpr(v: Expr) = { v match { - case Extract(i, 0, v: LocalVar) if size(v).isDefined && varHighZeroBits.contains(v) => { - ChangeTo(LocalVar(v.name, BitVecType(size(v).get - varHighZeroBits(v)))) + case Extract(i, 0, v: LocalVar) + if size(v).isDefined && !(formals.contains(v)) && varHighZeroBits.contains(v) => { + assert(varHighZeroBits(v) == i) + ChangeTo(LocalVar(v.name, BitVecType(varHighZeroBits(v)))) } case _ => DoChildren() } } + override def vproc(p: Procedure) = { + formals = p.formalInParam.toSet ++ p.formalOutParam.toSet + DoChildren() + } + override def vstmt(s: Statement) = { s match { + case a @ Assign(lhs: LocalVar, SignExtend(sz, rhs), _) + if size(lhs).isDefined && varHighZeroBits.get(lhs).contains(size(rhs).get) && !(formals.contains(lhs)) => { + // assert(varHighZeroBits(lhs) == sz) + val varsize = varHighZeroBits(lhs) + a.lhs = LocalVar(lhs.name, BitVecType(varsize)) + a.rhs = rhs + assert(size(a.lhs).get == size(a.rhs).get) + DoChildren() + } case a @ Assign(lhs: LocalVar, ZeroExtend(sz, rhs), _) - if size(lhs).isDefined && varHighZeroBits.contains(lhs) => { - assert(varHighZeroBits(lhs) == sz) - a.lhs = LocalVar(lhs.name, BitVecType(size(lhs).get - varHighZeroBits(lhs))) + if size(lhs).isDefined && varHighZeroBits.get(lhs).contains(size(rhs).get) && !(formals.contains(lhs)) => { + val varsize = varHighZeroBits(lhs) + a.lhs = LocalVar(lhs.name, BitVecType(varsize)) a.rhs = rhs + assert(size(a.lhs).get == size(a.rhs).get) + DoChildren() + } + case a @ Assign(lhs: LocalVar, rhs, _) + if size(lhs).isDefined && varHighZeroBits.contains(lhs) && !(formals.contains(lhs)) => { + // promote extract to the definition + a.lhs = LocalVar(lhs.name, BitVecType(varHighZeroBits(lhs))) + a.rhs = Extract(varHighZeroBits(lhs), 0, rhs) + assert(size(a.lhs).get == size(a.rhs).get, s"${size(a.lhs).get} != ${size(a.rhs).get}") DoChildren() } case _ => DoChildren() @@ -249,14 +408,15 @@ def getRedundantAssignments(procedure: Procedure): Set[Assign] = { enum VS: case Bot case Assigned(definition: Set[Assign]) - case Read + case Read(definition: Set[Assign], uses: Set[CFGPosition]) def joinVS(a: VS, b: VS) = { (a, b) match { case (VS.Bot, o) => o case (o, VS.Bot) => o - case (VS.Read, _) => VS.Read - case (_, VS.Read) => VS.Read + case (VS.Read(d, u), VS.Read(d1, u1)) => VS.Read(d ++ d1, u ++ u1) + case (VS.Assigned(d), VS.Read(d1, u1)) => VS.Read(d ++ d1, u1) + case (VS.Read(d1, u1), VS.Assigned(d)) => VS.Read(d ++ d1, u1) case (VS.Assigned(d1), VS.Assigned(d2)) => VS.Assigned(d1 ++ d2) } } @@ -268,43 +428,42 @@ def getRedundantAssignments(procedure: Procedure): Set[Assign] = { case a: Assign => { assignedNotRead(a.lhs) = joinVS(assignedNotRead(a.lhs), VS.Assigned(Set(a))) a.rhs.variables.foreach(v => { - assignedNotRead(v) = VS.Read + assignedNotRead(v) = joinVS(assignedNotRead(v), VS.Read(Set(), Set(a))) }) } case m: MemoryAssign => { m.index.variables.foreach(v => { - assignedNotRead(v) = VS.Read + assignedNotRead(v) = joinVS(assignedNotRead(v), VS.Read(Set(), Set(m))) }) m.value.variables.foreach(v => { - assignedNotRead(v) = VS.Read + assignedNotRead(v) = joinVS(assignedNotRead(v), VS.Read(Set(), Set(m))) }) } case m: IndirectCall => { - assignedNotRead(m.target) = VS.Read + assignedNotRead(m.target) = joinVS(assignedNotRead(m.target), VS.Read(Set(), Set(m))) } case m: Assert => { m.body.variables.foreach(v => { - assignedNotRead(v) = VS.Read + assignedNotRead(v) = joinVS(assignedNotRead(v), VS.Read(Set(), Set(m))) }) } case m: Assume => { for (v <- m.body.variables) { - assignedNotRead(v) = VS.Read + assignedNotRead(v) = joinVS(assignedNotRead(v), VS.Read(Set(), Set(m))) } } case c: DirectCall => { c.actualParams .flatMap(_._2.variables) .foreach(v => { - assignedNotRead(v) = VS.Read + assignedNotRead(v) = joinVS(assignedNotRead(v), VS.Read(Set(), Set(c))) }) - } case p: Return => { p.outParams .flatMap(_._2.variables) .foreach(v => { - assignedNotRead(v) = VS.Read + assignedNotRead(v) = joinVS(assignedNotRead(v), VS.Read(Set(), Set(p))) }) } case p: GoTo => () @@ -315,7 +474,41 @@ def getRedundantAssignments(procedure: Procedure): Set[Assign] = { } } - val r = assignedNotRead + var toRemove = assignedNotRead + var removeOld = toRemove + + // def remove(a: Assign): Boolean = { + // var removed : Boolean = false + // toRemove = toRemove.map((v, s) => + // v -> { + // s match { + // case VS.Read(defs, uses) if uses.size == 1 && uses.contains(a) => { + // removed = true + // VS.Assigned(defs) + // } + // case VS.Read(defs, uses) if uses.contains(a) => { + // removed = true + // VS.Read(defs, uses - a) + // } + // case o => o + // } + // } + // ) + // removed + // } + + // while ({ + // removeOld = toRemove + // val removed = removeOld.map((v, s) => { + // s match { + // case VS.Assigned(definition) => definition.map(remove).foldLeft(false)((x,y) => x || y) + // case _ => false + // } + // }) + // removed.exists(x => x) + // }) {} + + val r = toRemove .collect { case (v, VS.Assigned(d)) => d } @@ -345,14 +538,15 @@ class CleanupAssignments() extends CILVisitor { def copypropTransform(p: Procedure) = { val t = util.PerformanceTimer(s"simplify ${p.name} (${p.blocks.size} blocks)") - val dom = ConstCopyProp() - val solver = worklistSolver(dom) + // val dom = ConstCopyProp() + // val solver = worklistSolver(dom) // Logger.info(s"${p.name} ExprComplexity ${ExprComplexity()(p)}") - val result = solver.solveProc(p) + // val result = solver.solveProc(p, true).withDefaultValue(dom.bot) + val result = DSACopyProp(p) val solve = t.checkPoint("Solve CopyProp") - val vis = Simplify(result.withDefaultValue(dom.bot)) + val vis = Simplify(result) visit_proc(vis, p) val xf = t.checkPoint("transform") // Logger.info(s" ${p.name} after transform expr complexity ${ExprComplexity()(p)}") @@ -364,15 +558,21 @@ def copypropTransform(p: Procedure) = { visit_proc(AlgebraicSimplifications, p) visit_proc(AlgebraicSimplifications, p) visit_proc(AlgebraicSimplifications, p) - visit_proc(AlgebraicSimplifications, p) - visit_proc(AlgebraicSimplifications, p) - visit_proc(AlgebraicSimplifications, p) - visit_proc(AlgebraicSimplifications, p) - visit_proc(AlgebraicSimplifications, p) - visit_proc(AlgebraicSimplifications, p) - visit_proc(AlgebraicSimplifications, p) + // visit_proc(AlgebraicSimplifications, p) + // visit_proc(AlgebraicSimplifications, p) + // visit_proc(AlgebraicSimplifications, p) + // visit_proc(AlgebraicSimplifications, p) + // visit_proc(AlgebraicSimplifications, p) + // visit_proc(AlgebraicSimplifications, p) + // visit_proc(AlgebraicSimplifications, p) + // visit_proc(AlgebraicSimplifications, p) // Logger.info(s" ${p.name} after simp expr complexity ${ExprComplexity()(p)}") val sipm = t.checkPoint("algebraic simp") + + // Logger.info("[!] Simplify :: RemoveSlices") + removeSlices(p) + visit_proc(AlgebraicSimplifications, p) + } def doCopyPropTransform(p: Program) = { @@ -406,18 +606,29 @@ def doCopyPropTransform(p: Program) = { // cleanup visit_prog(CleanupAssignments(), p) - val toremove = p.collect { - case b: Block if b.statements.size == 0 && b.prevBlocks.size == 1 && b.nextBlocks.size == 1 => b + for (proc <- p.procedures) { + val blocks = proc.blocks.toList + for (b <- blocks) { + b match { + case b: Block if b.statements.size == 0 && b.prevBlocks.size == 1 && b.nextBlocks.size == 1 => { + val p = b.prevBlocks.head + val n = b.nextBlocks.head + p.jump match { + case g: GoTo => { + g.addTarget(n) + g.removeTarget(b) + } + case _ => ??? + } + b.replaceJump(Unreachable()) + b.parent.removeBlocks(b) + } + case _ => () + } + } } Logger.info("[!] Simplify :: Merge empty blocks") - for (b <- toremove) { - val p = b.prevBlocks.head - val n = b.nextBlocks.head - p.replaceJump((GoTo(n))) - b.parent.removeBlocks(b) - } - } @@ -448,8 +659,8 @@ def applyRPO(p: Program) = { class worklistSolver[L, A <: AbstractDomain[L]](domain: A) { - def solveProc(p: Procedure) = { - solve(p.blocks, Set(), Set()) + def solveProc(p: Procedure, backwards: Boolean = false) = { + solve(p.blocks, Set(), Set(), backwards) } def solveProg( @@ -474,7 +685,7 @@ class worklistSolver[L, A <: AbstractDomain[L]](domain: A) { work .map((prog, x) => try { - (prog, Await.result(x, 10000.millis)) + (prog, Await.result(x, 10000.millis)._2) } catch { case t: Exception => { Logger.error(s"${prog.name} : $t") @@ -489,18 +700,30 @@ class worklistSolver[L, A <: AbstractDomain[L]](domain: A) { def solve( initial: IterableOnce[Block], widenpoints: Set[Block], // set of loop heads - narrowpoints: Set[Block] // set of conditions - ): Map[Block, L] = { - val saved: mutable.HashMap[Block, L] = mutable.HashMap() + narrowpoints: Set[Block], // set of conditions + backwards: Boolean = false + ): (Map[Block, L], Map[Block, L]) = { + val savedAfter: mutable.HashMap[Block, L] = mutable.HashMap() + val savedBefore: mutable.HashMap[Block, L] = mutable.HashMap() val saveCount: mutable.HashMap[Block, Int] = mutable.HashMap() - val worklist = mutable.PriorityQueue[Block]()(Ordering.by(b => b.rpoOrder)) + val worklist = { + if (backwards) { + mutable.PriorityQueue[Block]()(Ordering.by(b => -b.rpoOrder)) + } else { + mutable.PriorityQueue[Block]()(Ordering.by(b => b.rpoOrder)) + } + } worklist.addAll(initial) - var x = domain.bot + def successors(b: Block) = if backwards then b.prevBlocks else b.nextBlocks + def predecessors(b: Block) = if backwards then b.nextBlocks else b.prevBlocks + while (worklist.nonEmpty) { val b = worklist.dequeue - while (worklist.nonEmpty && (worklist.head.rpoOrder >= b.rpoOrder)) do { + while (worklist.nonEmpty && (if backwards then (worklist.head.rpoOrder <= b.rpoOrder) + else (worklist.head.rpoOrder >= b.rpoOrder))) + do { // drop rest of blocks with same priority val m = worklist.dequeue() assert( @@ -509,50 +732,42 @@ class worklistSolver[L, A <: AbstractDomain[L]](domain: A) { ) } - def bs(b: Block): List[Block] = { - var blocks = mutable.LinkedHashSet[Block]() - var thisBlock = b - while ({ - blocks.add(thisBlock) - - if (thisBlock.nextBlocks.size == 1) { - thisBlock = thisBlock.nextBlocks.head - blocks.contains(thisBlock) - } else { - false - } - }) {} - blocks.toList - } - val prev = saved.get(b) - x = b.prevBlocks.flatMap(ib => saved.get(ib).toList).foldLeft(x)(domain.join) - saved(b) = x - // val todo = bs(b) + val prev = savedAfter.get(b) + val x = { + predecessors(b).toList.flatMap(b => savedAfter.get(b).toList) match { + case Nil => domain.bot + case h :: Nil => h + case h :: tl => tl.foldLeft(h)((acc, nb) => domain.join(acc, nb, b)) + } + } + savedBefore(b) = x val todo = List(b) - val lastBlock = todo.last - def xf_block(x: L, b: Block) = { - saved(b) = b.statements.foldLeft(x)(domain.transfer) - saved(b) - } + val lastBlock = b // todo.last var nx = todo.foldLeft(x)((x, b) => { - saved(b) = xf_block(x, b) - saved(b) + savedBefore(b) = x + if (backwards) { + val ojmp = domain.transfer(x, b.jump) + savedAfter(b) = b.statements.toList.reverse.foldLeft(ojmp)(domain.transfer) + } else { + val stmts = b.statements.foldLeft(x)(domain.transfer) + savedAfter(b) = domain.transfer(stmts, b.jump) + } + savedAfter(b) }) - saved(lastBlock) = nx + savedAfter(lastBlock) = nx saveCount(lastBlock) = saveCount.get(lastBlock).getOrElse(0) + 1 if (!prev.contains(nx)) then { - if (saveCount(lastBlock) == 50) { + if (saveCount(lastBlock) >= 50) { Logger.warn(s"Large join count on block ${lastBlock.label}, no fix point? (-v for mor info)") Logger.debug(lastBlock.label + " ==> " + x) Logger.debug(lastBlock.label + " <== " + nx) } - worklist.addAll(lastBlock.nextBlocks) + worklist.addAll(successors(lastBlock)) } - x = nx } - saved.toMap + if backwards then (savedAfter.toMap, savedBefore.toMap) else (savedBefore.toMap, savedAfter.toMap) } } @@ -569,34 +784,12 @@ case class CCP( ) object CCP { + def toSubstitutions(c: CCP): Map[Variable, Expr] = { c.state.collect { case (v, CopyProp.Prop(e, _)) => v -> e } } -} - -class ConstCopyProp() extends AbstractDomain[CCP] { - private final val callClobbers = (0 to 30).map("R" + _).map(c => Register(c, 64)) - - def top: CCP = CCP(Map().withDefaultValue(CopyProp.Bot)) - def bot: CCP = CCP(Map().withDefaultValue(CopyProp.Clobbered)) - - override def join(l: CCP, r: CCP): CCP = { - val ks = l.state.keySet.intersect(r.state.keySet) - val merged = ks.map(v => - (v -> - ((l.state(v), r.state(v)) match { - case (l, CopyProp.Bot) => l - case (CopyProp.Bot, r) => r - case (c @ CopyProp.Clobbered, _) => c - case (_, c @ CopyProp.Clobbered) => c - case (p1 @ CopyProp.Prop(e1, deps1), p2 @ CopyProp.Prop(e2, deps2)) if (p1 == p2) => p1 - case (_, _) => CopyProp.Clobbered - })) - ) - CCP(merged.toMap) - } def clobberFull(c: CCP, l: Variable) = { val p = clobber(c, l) @@ -615,16 +808,126 @@ class ConstCopyProp() extends AbstractDomain[CCP] { .withDefaultValue(CopyProp.Bot) ) } +} + +def DSACopyProp(p: Procedure): Map[Variable, Expr] = { + + case class PropState(val e: Expr, val deps: Set[Variable], var clobbered: Boolean, var useCount: Int) + val state = mutable.HashMap[Variable, PropState]() + var poisoned = false + + def clobberFull(c: mutable.HashMap[Variable, PropState], l: Variable): Unit = { + if (c.contains(l)) { + c(l).clobbered = true + } else { + c(l) = PropState(FalseLiteral, Set(), true, 0) + } + } + + def clobberDeps(c: mutable.HashMap[Variable, PropState], l: Variable): Unit = { + val toclobber = c.filter(_.isInstanceOf[CopyProp.Prop]).filter(_.asInstanceOf[CopyProp.Prop].deps.contains(l)) + for ((v, e) <- toclobber) { + c(v).clobbered = true + } + } - override def transfer(c: CCP, s: Statement): CCP = { + def transfer(c: mutable.HashMap[Variable, PropState], s: Statement): Unit = { + // val callClobbers = ((0 to 7) ++ (19 to 30)).map("R" + _).map(c => Register(c, 64)) s match { - case m: MemoryAssign => { - // c.copy(exprs = c.exprs.filterNot((k, v) => v.expr.loads.nonEmpty)) - c - } case Assign(l, r, lb) => { if (r.loads.size > 0) { + // c.copy(state = c.state + (l -> CopyProp.Clobbered)) clobberFull(c, l) + } else { + val evaled = r + val rhsDeps = evaled.variables.toSet + val existing = c.get(l) + + existing match { + case None => { + c(l) = PropState(evaled, rhsDeps, false, 0) + } + case Some(ps) if ps.clobbered => { + () + } + case Some(ps) if ps.e != evaled => { + clobberFull(c, l) + } + case _ => { + // ps.e == evaled and have prop + } + } + + for (v <- rhsDeps) { + if (state.contains(v)) { + state(v).useCount += 1 + } + } + + } + } + case x: DirectCall => { + val lhs = x.outParams.map(_._2) + for (l <- lhs) { + clobberFull(c, l) + } + } + case x: IndirectCall => { + for ((i, v) <- c) { + v.clobbered = true + } + } + case _ => () + } + } + + // sort by precedence + val worklist = mutable.PriorityQueue[Block]()(Ordering.by(_.rpoOrder)) + worklist.addAll(p.blocks) + + while (worklist.nonEmpty) { + val b: Block = worklist.dequeue + + for (l <- b.statements) { + transfer(state, l) + } + } + + val res = state.collect { + case (v, c) if !c.clobbered => v -> c.e + }.toMap + res +} + +class ConstCopyProp() extends AbstractDomain[CCP] { + private final val callClobbers = ((0 to 7) ++ (19 to 30)).map("R" + _).map(c => Register(c, 64)) + + def top: CCP = CCP(Map().withDefaultValue(CopyProp.Clobbered)) + def bot: CCP = CCP(Map().withDefaultValue(CopyProp.Bot)) + + override def join(l: CCP, r: CCP, pos: Block): CCP = { + // val ks = l.state.keySet.intersect(r.state.keySet) + val ks = l.state.keySet ++ (r.state.keySet) + + val merged = ks.map(v => + (v -> + ((l.state.get(v).getOrElse(CopyProp.Clobbered), r.state.get(v).getOrElse(CopyProp.Clobbered)) match { + case (l, CopyProp.Bot) => l + case (CopyProp.Bot, r) => r + case (c @ CopyProp.Clobbered, _) => c + case (_, c @ CopyProp.Clobbered) => c + case (p1 @ CopyProp.Prop(e1, deps1), p2 @ CopyProp.Prop(e2, deps2)) if (p1 == p2) => p1 + case (_, _) => CopyProp.Clobbered + })) + ) + CCP(merged.toMap) + } + + override def transfer(c: CCP, s: Command): CCP = { + s match { + case Assign(l, r, lb) => { + if (r.loads.size > 0) { + CCP.clobberFull(c, l) } else { val consts = c.state.collect { case (k, CopyProp.Prop(c, deps)) if deps.isEmpty => k -> c @@ -633,22 +936,27 @@ class ConstCopyProp() extends AbstractDomain[CCP] { val rhsDeps = evaled.variables.toSet val existing = c.state.get(l).getOrElse(CopyProp.Bot) - val ns = existing match { - case CopyProp.Bot => CopyProp.Prop(evaled, rhsDeps) // not seen yet - case CopyProp.Prop(e, _) => CopyProp.Prop(evaled, rhsDeps) - case _ => CopyProp.Clobbered // our expr value has changed + existing match { + case CopyProp.Bot => { + c.copy(state = c.state + (l -> CopyProp.Prop(evaled, rhsDeps))) // not seen yet + } + case CopyProp.Prop(e, _) => { + val p = CCP.clobber(c, l) + p.copy(state = p.state + (l -> CopyProp.Prop(evaled, rhsDeps))) // not seen yet + } + case _ => { + CCP.clobberFull(c, l) + } } - val p = c.copy(state = c.state + (l -> ns)) - clobber(p, l) } } case x: DirectCall => { val lhs = x.outParams.map(_._2) - lhs.foldLeft(c)(clobberFull) + lhs.foldLeft(c)(CCP.clobberFull) } case x: IndirectCall => { val toClob = callClobbers - toClob.foldLeft(c)(clobberFull) + toClob.foldLeft(c)(CCP.clobberFull) } case _ => c } @@ -720,8 +1028,9 @@ class Substitute( } class Simplify( - val res: Map[Block, CCP], - val initialBlock: Block = null + val res: Map[Variable, Expr], + val initialBlock: Block = null, + val absdom: Option[ConstCopyProp] = None /* flow sensitive */ ) extends CILVisitor { var madeAnyChange = false @@ -731,7 +1040,7 @@ class Simplify( override def vexpr(e: Expr) = { val threshold = 500 val variables = e.variables.toSet - val subst = Substitute(CCP.toSubstitutions(res(block)), true, threshold) + val subst = Substitute(res, true, threshold) val result = subst(e).getOrElse(e) if (subst.complexity > threshold) { val bl = s"${block.parent.name}::${block.label}" @@ -744,7 +1053,8 @@ class Simplify( } override def vblock(b: Block) = { - block = b + block = b DoChildren() } + } diff --git a/src/main/scala/translating/GTIRBToIR.scala b/src/main/scala/translating/GTIRBToIR.scala index f675631a2..74ba8b8eb 100644 --- a/src/main/scala/translating/GTIRBToIR.scala +++ b/src/main/scala/translating/GTIRBToIR.scala @@ -261,6 +261,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ val (in, out) = createArguments(name) val procedure = Procedure(name, address, formalInParam = in.map(_._1), formalOutParam = out, inParamDefaultBinding=in.toMap) + procedure.inParamDefaultBinding = immutable.SortedMap.from(in.map((l,r) => l -> LocalVar(l.name, BitVecType(64)))) uuidToProcedure += (functionUUID -> procedure) entranceUUIDtoProcedure += (entranceUUID -> procedure) diff --git a/src/main/scala/translating/IRExpToSMT2.scala b/src/main/scala/translating/IRExpToSMT2.scala index 3add1ddbe..fd71a3440 100644 --- a/src/main/scala/translating/IRExpToSMT2.scala +++ b/src/main/scala/translating/IRExpToSMT2.scala @@ -159,24 +159,25 @@ object BasilIRToSMT2 extends BasilIRExpWithVis[Sexp] { } } - def interpretFun(x: UninterpretedFunction): Sexp[Expr] = { + def interpretFun(x: UninterpretedFunction): Option[Sexp[Expr]] = { x.name match { case "bool2bv1" => { - list( + Some(list( sym("define-fun"), sym(x.name), list(list(sym("arg"), basilTypeToSMTType(BoolType))), basilTypeToSMTType(x.returnType), list(sym("ite"), sym("arg"), bv2smt(BitVecLiteral(1, 1)), bv2smt(BitVecLiteral(0, 1))) - ) + )) } + case "bvsaddo" => None case _ => { - list( + Some(list( sym("declare-fun"), sym(x.name), Sexp.Slist(x.params.toList.map(a => basilTypeToSMTType(a.getType))), basilTypeToSMTType(x.returnType) - ) + )) } } } @@ -189,7 +190,7 @@ object BasilIRToSMT2 extends BasilIRExpWithVis[Sexp] { override def vexpr(e: Expr) = e match { case f: UninterpretedFunction => { val decl = interpretFun(f) - decled = decled + decl + decled = decled ++ decl.toSet DoChildren() // get variables out of args } case v: Variable => { diff --git a/src/main/scala/translating/ReadELFLoader.scala b/src/main/scala/translating/ReadELFLoader.scala index 0fca50e72..01af088d8 100644 --- a/src/main/scala/translating/ReadELFLoader.scala +++ b/src/main/scala/translating/ReadELFLoader.scala @@ -1,5 +1,6 @@ package translating +import util.Logger import Parsers.ReadELFParser.* import boogie.* import specification.* @@ -111,11 +112,14 @@ object ReadELFLoader { private def getFunctionAddress(ctx: SymbolTableContext, functionName: String): Option[BigInt] = { if (ctx.symbolTableHeader.tableName.STRING.getText == ".symtab") { val rows = ctx.symbolTableRow.asScala - val mainAddress = rows.collectFirst { - case r if r.entrytype.getText == "FUNC" && r.bind.getText == "GLOBAL" && r.name.getText == functionName => + val mainAddress = rows.collect { + case r if r.entrytype.getText == "FUNC" && r.name.getText == functionName => hexToBigInt(r.value.getText) } - mainAddress + if (mainAddress.size > 1) { + Logger.warn(s"Multiple procedures with name $functionName") + } + mainAddress.headOption } else { None } diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index 0ddf35b38..8a20a5070 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -310,8 +310,6 @@ object StaticAnalysis { Logger.debug("Subroutine Addresses:") Logger.debug(subroutines) - assert(invariant.blocksUniqueToEachProcedure(ctx.program)) - Logger.info("reducible loops") // reducible loops val detector = LoopDetector(IRProgram) @@ -530,31 +528,39 @@ object RunUtils { def doSimplify(ctx: IRContext, config: Option[StaticAnalysisConfig]) : Unit = { Logger.info("[!] Running Simplify") + val timer = PerformanceTimer("Simplify") + transforms.applyRPO(ctx.program) + Logger.info(s"RPO ${timer.checkPoint("RPO")} ms ") Logger.info("[!] Simplify :: DynamicSingleAssignment") - val liveVars : Map[CFGPosition, Set[Variable]] = analysis.IntraLiveVarsAnalysis(ctx.program).analyze() - - writeToFile(serialiseIL(ctx.program), s"il-before-dsa.il") - transforms.DynamicSingleAssignment.applyTransform(ctx.program, liveVars) + // val liveVars : Map[CFGPosition, Set[Variable]] = analysis.IntraLiveVarsAnalysis(ctx.program).analyze() + + // writeToFile(serialiseIL(ctx.program), s"il-before-dsa.il") + + // transforms.DynamicSingleAssignment.applyTransform(ctx.program, liveVars) + transforms.OnePassDSA().applyTransform(ctx.program) + Logger.info(s"DSA ${timer.checkPoint("DSA ")} ms ") + // writeToFile(dotBlockGraph(ctx.program, ctx.program.filter(_.isInstanceOf[Block]).map(b => b -> b.toString).toMap), s"blockgraph-after-dsa.dot") + Logger.info("DSA Check") + val x = ctx.program.procedures.forall(transforms.rdDSAProperty) + assert(x) + Logger.info("DSA Check passed") + assert(invariant.singleCallBlockEnd(ctx.program)) + assert(invariant.cfgCorrect(ctx.program)) + assert(invariant.blocksUniqueToEachProcedure(ctx.program)) - config.foreach(_.analysisDotPath.foreach { s => - writeToFile(dotBlockGraph(ctx.program, ctx.program.filter(_.isInstanceOf[Block]).map(b => b -> b.toString).toMap), s"${s}_blockgraph-after-dsa.dot") - }) - writeToFile(serialiseIL(ctx.program), s"il-before-copyprop.il") + // writeToFile(serialiseIL(ctx.program), s"il-before-copyprop.il") transforms.doCopyPropTransform(ctx.program) - writeToFile(serialiseIL(ctx.program), s"il-after-copyprop.il") + Logger.info(s"CopyProp ${timer.checkPoint("CopyProp")} ms ") + // writeToFile(serialiseIL(ctx.program), s"il-after-copyprop.il") // run this after cond recovery because sign bit calculations often need high bits // which go away in high level conss - Logger.info("[!] Simplify :: RemoveSlices") - transforms.removeSlices(ctx.program) - writeToFile(serialiseIL(ctx.program), s"il-after-slices.il") + // writeToFile(serialiseIL(ctx.program), s"il-after-slices.il") - config.foreach(_.analysisDotPath.foreach { s => - writeToFile(dotBlockGraph(ctx.program, ctx.program.filter(_.isInstanceOf[Block]).map(b => b -> b.toString).toMap), s"${s}_blockgraph-after-simp.dot") - }) + // writeToFile(dotBlockGraph(ctx.program, ctx.program.filter(_.isInstanceOf[Block]).map(b => b -> b.toString).toMap), s"blockgraph-after-simp.dot") if (ir.eval.SimplifyValidation.validate) { Logger.info("[!] Simplify :: Writing simplification validation") @@ -572,6 +578,10 @@ object RunUtils { var ctx = IRLoading.load(q.loading) + assert(invariant.singleCallBlockEnd(ctx.program)) + assert(invariant.cfgCorrect(ctx.program)) + assert(invariant.blocksUniqueToEachProcedure(ctx.program)) + ctx = IRTransform.doCleanup(ctx) if (q.loading.trimEarly) { @@ -580,9 +590,6 @@ object RunUtils { Logger.info( s"[!] Removed ${before - ctx.program.procedures.size} functions (${ctx.program.procedures.size} remaining)" ) - - val dumpdomain = computeDomain[CFGPosition, CFGPosition](InterProcIRCursor, ctx.program.procedures) - writeToFile(toDot(dumpdomain, InterProcIRCursor, Map.empty), s"ldakjldajnew_ir_intercfg.dot") } if (q.loading.parameterForm) { @@ -680,6 +687,7 @@ object RunUtils { } def writeToFile(content: String, fileName: String): Unit = { + Logger.debug(s"Writing $fileName (${content.size} bytes)") val outFile = File(fileName) val pw = PrintWriter(outFile, "UTF-8") pw.write(content) diff --git a/src/test/correct/basic_arrays_write/basic_arrays_write.spec b/src/test/correct/basic_arrays_write/basic_arrays_write.spec index cfe441557..6e181ea4b 100644 --- a/src/test/correct/basic_arrays_write/basic_arrays_write.spec +++ b/src/test/correct/basic_arrays_write/basic_arrays_write.spec @@ -6,4 +6,4 @@ Rely: true Guarantee: old(arr[0]) == arr[0] Subroutine: main -Requires: Gamma_main_argc == false \ No newline at end of file +Requires: Gamma_R0 == false diff --git a/src/test/correct/basic_function_call_caller/basic_function_call_caller.spec b/src/test/correct/basic_function_call_caller/basic_function_call_caller.spec index 3a35abe63..b7b91e75e 100644 --- a/src/test/correct/basic_function_call_caller/basic_function_call_caller.spec +++ b/src/test/correct/basic_function_call_caller/basic_function_call_caller.spec @@ -7,7 +7,7 @@ Rely: x == old(x), y == old(y) Guarantee: old(x) == 0bv32 ==> x == 0bv32, old(Gamma_y) ==> x == 0bv32 || Gamma_y Subroutine: main -Requires: Gamma_main_argc == false +Requires: Gamma_R0 == false Subroutine: zero Ensures: zero_result == 0bv32 && Gamma_R0 diff --git a/src/test/correct/basic_lock_security_write/basic_lock_security_write.spec b/src/test/correct/basic_lock_security_write/basic_lock_security_write.spec index 5804dfbcc..f7515ed10 100644 --- a/src/test/correct/basic_lock_security_write/basic_lock_security_write.spec +++ b/src/test/correct/basic_lock_security_write/basic_lock_security_write.spec @@ -8,4 +8,4 @@ Guarantee: old(z) == 0bv32 ==> x == old(x) && z == old(z) Subroutine: main Requires: z != 0bv32 -Requires: Gamma_main_argc == false \ No newline at end of file +Requires: Gamma_R0 == false diff --git a/src/test/correct/basic_sec_policy_write/basic_sec_policy_write.spec b/src/test/correct/basic_sec_policy_write/basic_sec_policy_write.spec index a66060980..3780fc89a 100644 --- a/src/test/correct/basic_sec_policy_write/basic_sec_policy_write.spec +++ b/src/test/correct/basic_sec_policy_write/basic_sec_policy_write.spec @@ -7,4 +7,4 @@ Rely: old(z) == z Guarantee: old(z) != 0bv32 ==> z != 0bv32 Subroutine: main -Requires: Gamma_main_argc == false \ No newline at end of file +Requires: Gamma_R0 == false diff --git a/src/test/correct/functionpointer/functionpointer.spec b/src/test/correct/functionpointer/functionpointer.spec index 84cc2d8c6..f49fabf7f 100644 --- a/src/test/correct/functionpointer/functionpointer.spec +++ b/src/test/correct/functionpointer/functionpointer.spec @@ -1,2 +1,2 @@ Subroutine: main -Requires: Gamma_main_argc == true \ No newline at end of file +Requires: Gamma_R0 == true diff --git a/src/test/correct/ifbranches/ifbranches.spec b/src/test/correct/ifbranches/ifbranches.spec index 84cc2d8c6..f49fabf7f 100644 --- a/src/test/correct/ifbranches/ifbranches.spec +++ b/src/test/correct/ifbranches/ifbranches.spec @@ -1,2 +1,2 @@ Subroutine: main -Requires: Gamma_main_argc == true \ No newline at end of file +Requires: Gamma_R0 == true diff --git a/src/test/indirect_calls/functionpointer/functionpointer.spec b/src/test/indirect_calls/functionpointer/functionpointer.spec index 84cc2d8c6..f49fabf7f 100644 --- a/src/test/indirect_calls/functionpointer/functionpointer.spec +++ b/src/test/indirect_calls/functionpointer/functionpointer.spec @@ -1,2 +1,2 @@ Subroutine: main -Requires: Gamma_main_argc == true \ No newline at end of file +Requires: Gamma_R0 == true diff --git a/src/test/indirect_calls/jumptable3/jumptable3.spec b/src/test/indirect_calls/jumptable3/jumptable3.spec index 84cc2d8c6..f49fabf7f 100644 --- a/src/test/indirect_calls/jumptable3/jumptable3.spec +++ b/src/test/indirect_calls/jumptable3/jumptable3.spec @@ -1,2 +1,2 @@ Subroutine: main -Requires: Gamma_main_argc == true \ No newline at end of file +Requires: Gamma_R0 == true diff --git a/src/test/indirect_calls/switch2/switch2.spec b/src/test/indirect_calls/switch2/switch2.spec index 84cc2d8c6..f49fabf7f 100644 --- a/src/test/indirect_calls/switch2/switch2.spec +++ b/src/test/indirect_calls/switch2/switch2.spec @@ -1,2 +1,2 @@ Subroutine: main -Requires: Gamma_main_argc == true \ No newline at end of file +Requires: Gamma_R0 == true