From b9a6eb38c661e1bc3ba9fe0b09d9898309304152 Mon Sep 17 00:00:00 2001 From: yousifpatti Date: Wed, 6 Dec 2023 11:49:08 +1000 Subject: [PATCH] Set Based SSA --- src/main/scala/analysis/SSAForm.scala | 154 ++++++++++++------ .../scala/analysis/SteensgaardAnalysis.scala | 17 +- src/main/scala/ir/Expr.scala | 10 +- src/main/scala/util/RunUtils.scala | 2 +- 4 files changed, 120 insertions(+), 63 deletions(-) diff --git a/src/main/scala/analysis/SSAForm.scala b/src/main/scala/analysis/SSAForm.scala index d6eeaa7cb..a5b182188 100644 --- a/src/main/scala/analysis/SSAForm.scala +++ b/src/main/scala/analysis/SSAForm.scala @@ -22,82 +22,142 @@ object SSAForm { // } // } - def applySSA(program: Program, invasive: Boolean = false): Unit = { - val variableMapping = new mutable.HashMap[String, Int]().withDefault(_ => 0) +// def applySSA(program: Program, invasive: Boolean = false): Unit = { +// val variableMapping = new mutable.HashMap[String, Int]().withDefault(_ => 0) +// for (proc <- program.procedures) { +// val blockBasedMappings = new mutable.HashMap[Block, mutable.Map[String, Int]]() +// variableMapping.keys.foreach(key => variableMapping.update(key, variableMapping(key) + 1)) +// for (block <- proc.blocks) { +// blockBasedMappings.update(block, new mutable.HashMap[String, Int]().withDefault(_ => 0)) +// for (stmt <- block.statements) { +// println(stmt) +// stmt match { +// case localAssign: LocalAssign => { +// if (invasive) { +// localAssign.rhs.variables.foreach(v => { +// v.name = v.name + "_" + variableMapping(v.name) +// }) +// variableMapping(localAssign.lhs.name) += 1 +// localAssign.lhs.name = localAssign.lhs.name + "_" + variableMapping(localAssign.lhs.name) +// } else { +// localAssign.rhs.variables.foreach(v => { +// v.ssa_id = variableMapping(v.name) +// }) +// variableMapping(localAssign.lhs.name) += 1 +// localAssign.lhs.ssa_id = variableMapping(localAssign.lhs.name) +// } +// } +// case memoryAssign: MemoryAssign => { +// if (invasive) { +// memoryAssign.lhs.variables.foreach(v => { +// v.name = v.name + "_" + variableMapping(v.name) +// }) +// memoryAssign.rhs.variables.foreach(v => { +// v.name = v.name + "_" + variableMapping(v.name) +// }) +// } else { +// memoryAssign.lhs.variables.foreach(v => { +// v.ssa_id = variableMapping(v.name) +// }) +// memoryAssign.rhs.variables.foreach(v => { +// v.ssa_id = variableMapping(v.name) +// }) +// } +// } +// case assume: Assume => { +// if (invasive) { +// assume.body.variables.foreach(v => { +// v.name = v.name + "_" + variableMapping(v.name) +// }) +// } else { +// assume.body.variables.foreach(v => { +// v.ssa_id = variableMapping(v.name) +// }) +// } +// } // no required for analyses +// case assume: Assert => { +// if (invasive) { +// assume.body.variables.foreach(v => { +// v.name = v.name + "_" + variableMapping(v.name) +// }) +// } else { +// assume.body.variables.foreach(v => { +// v.ssa_id = variableMapping(v.name) +// }) +// } +// } // no required for analyses +// case _ => throw new RuntimeException("No SSA form for " + stmt.getClass + " yet") +// } +// } +// block.jump match { +// case indirectCall: IndirectCall => { +// if (invasive) { +// indirectCall.target.variables.foreach(v => { +// v.name = v.name + "_" + variableMapping(v.name) +// }) +// } else { +// indirectCall.target.variables.foreach(v => { +// v.ssa_id = variableMapping(v.name) +// }) +// } +// } +// case _ => {} +// } +// } +// } +// } + + def applySSA(program: Program): Unit = { + val varMaxTracker = new mutable.HashMap[String, Int]() + val blockBasedMappings = new mutable.HashMap[(Block, String), Set[Int]]().withDefault(_ => Set()) for (proc <- program.procedures) { - variableMapping.keys.foreach(key => variableMapping.update(key, variableMapping(key) + 1)) for (block <- proc.blocks) { for (stmt <- block.statements) { println(stmt) stmt match { case localAssign: LocalAssign => { - if (invasive) { - localAssign.rhs.variables.foreach(v => { - v.name = v.name + "_" + variableMapping(v.name) - }) - variableMapping(localAssign.lhs.name) += 1 - localAssign.lhs.name = localAssign.lhs.name + "_" + variableMapping(localAssign.lhs.name) - } else { localAssign.rhs.variables.foreach(v => { - v.ssa_id = variableMapping(v.name) + v.ssa_id = blockBasedMappings.getOrElseUpdate((block, v.name), Set()) }) - variableMapping(localAssign.lhs.name) += 1 - localAssign.lhs.ssa_id = variableMapping(localAssign.lhs.name) - } + val maxVal = varMaxTracker.getOrElse(localAssign.lhs.name, 0) + blockBasedMappings((block, localAssign.lhs.name)) = Set(maxVal + 1) + + localAssign.lhs.ssa_id = blockBasedMappings((block, localAssign.lhs.name)) + varMaxTracker(localAssign.lhs.name) = blockBasedMappings((block, localAssign.lhs.name)).max } case memoryAssign: MemoryAssign => { - if (invasive) { memoryAssign.lhs.variables.foreach(v => { - v.name = v.name + "_" + variableMapping(v.name) + v.ssa_id = blockBasedMappings.getOrElseUpdate((block, v.name), Set()) }) memoryAssign.rhs.variables.foreach(v => { - v.name = v.name + "_" + variableMapping(v.name) - }) - } else { - memoryAssign.lhs.variables.foreach(v => { - v.ssa_id = variableMapping(v.name) + v.ssa_id = blockBasedMappings.getOrElseUpdate((block, v.name), Set()) }) - memoryAssign.rhs.variables.foreach(v => { - v.ssa_id = variableMapping(v.name) - }) - } } case assume: Assume => { - if (invasive) { - assume.body.variables.foreach(v => { - v.name = v.name + "_" + variableMapping(v.name) - }) - } else { assume.body.variables.foreach(v => { - v.ssa_id = variableMapping(v.name) + v.ssa_id = blockBasedMappings.getOrElseUpdate((block, v.name), Set()) }) - } } // no required for analyses case assume: Assert => { - if (invasive) { - assume.body.variables.foreach(v => { - v.name = v.name + "_" + variableMapping(v.name) - }) - } else { assume.body.variables.foreach(v => { - v.ssa_id = variableMapping(v.name) + v.ssa_id = blockBasedMappings.getOrElseUpdate((block, v.name), Set()) }) - } } // no required for analyses case _ => throw new RuntimeException("No SSA form for " + stmt.getClass + " yet") } } block.jump match { case indirectCall: IndirectCall => { - if (invasive) { - indirectCall.target.variables.foreach(v => { - v.name = v.name + "_" + variableMapping(v.name) - }) - } else { - indirectCall.target.variables.foreach(v => { - v.ssa_id = variableMapping(v.name) + indirectCall.target.variables.foreach(v => { + v.ssa_id = blockBasedMappings.getOrElseUpdate((block, v.name), Set()) + }) + } + case goTo: GoTo => { + goTo.targets.foreach(b => { + varMaxTracker.keys.foreach(varr => { + blockBasedMappings((b, varr)) = blockBasedMappings(b, varr) ++ blockBasedMappings(block, varr) }) - } + }) } case _ => {} } diff --git a/src/main/scala/analysis/SteensgaardAnalysis.scala b/src/main/scala/analysis/SteensgaardAnalysis.scala index c137d6104..bc687fdb6 100644 --- a/src/main/scala/analysis/SteensgaardAnalysis.scala +++ b/src/main/scala/analysis/SteensgaardAnalysis.scala @@ -3,10 +3,8 @@ package analysis import analysis.solvers.{Cons, Term, UnionFindSolver, Var} import ir.* import util.Logger - -import java.io.{File, PrintWriter} import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.collection.mutable.ListBuffer /** Steensgaard-style pointer analysis. The analysis associates an [[StTerm]] with each variable declaration and * expression node in the AST. It is implemented using [[tip.solvers.UnionFindSolver]]. @@ -20,8 +18,6 @@ class SteensgaardAnalysis( val solver: UnionFindSolver[StTerm] = UnionFindSolver() - val stringArr: ArrayBuffer[String] = ArrayBuffer() - private val stackPointer = Register("R31", BitVecType(64)) private val linkRegister = Register("R30", BitVecType(64)) private val framePointer = Register("R29", BitVecType(64)) @@ -141,16 +137,9 @@ class SteensgaardAnalysis( /** @inheritdoc */ def analyze(): Unit = - // generate the constraints by traversing the AST and solve them on-the-fly + // generate the constraints by traversing the AST and solve them on-the-fly cfg.nodes.foreach(visit(_, ())) -// def dump_file(content: ArrayBuffer[String], name: String): Unit = { -// val outFile = File(s"$name") -// val pw = PrintWriter(outFile, "UTF-8") -// for (s <- content) { pw.append(s + "\n") } -// pw.close() -// } - /** Generates the constraints for the given sub-AST. * @param node * the node for which it generates the constraints @@ -179,7 +168,7 @@ class SteensgaardAnalysis( case localAssign: LocalAssign => localAssign.rhs match { case binOp: BinaryExpr => - // X1 = &X: [[X1]] = ↑[[X2]] + // X1 = &X2: [[X1]] = ↑[[X2]] if (binOp.arg1 == stackPointer) { evaluateExpression(binOp.arg2, constantProp(n)) match { case Some(b: BitVecLiteral) => diff --git a/src/main/scala/ir/Expr.scala b/src/main/scala/ir/Expr.scala index 3bee8beea..8c68d815f 100644 --- a/src/main/scala/ir/Expr.scala +++ b/src/main/scala/ir/Expr.scala @@ -3,7 +3,7 @@ package ir import boogie._ trait Expr { - var ssa_id: Int = 0 + var ssa_id: Set[Int] = Set() def toBoogie: BExpr def toGamma: BExpr = { val gammaVars: Set[BExpr] = gammas.map(_.toGamma) @@ -361,6 +361,14 @@ sealed trait Variable extends Expr { override def acceptVisit(visitor: Visitor): Variable = throw new Exception("visitor " + visitor + " unimplemented for: " + this) + + override def equals(obj: Any): Boolean = + obj match { + case v: Variable => v.name == name && v.irType == irType && (v.ssa_id == ssa_id || v.ssa_id.intersect(ssa_id).nonEmpty) + case _ => false + } + + override def hashCode(): Int = name.hashCode + irType.hashCode + ssa_id.hashCode() } case class Register(var name: String, override val irType: IRType) extends Variable with Global { diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index 0a001fedb..68123f7ea 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -143,7 +143,7 @@ object RunUtils { Logger.info(subroutines) val mergedSubroutines = subroutines ++ externalAddresses - applySSA(IRProgram, false) + applySSA(IRProgram) val cfg = ProgramCfgFactory().fromIR(IRProgram) Logger.info("[!] Running Constant Propagation")