Skip to content

Commit

Permalink
Set Based SSA
Browse files Browse the repository at this point in the history
  • Loading branch information
yousifpatti committed Dec 6, 2023
1 parent 4a5b6c4 commit b9a6eb3
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 63 deletions.
154 changes: 107 additions & 47 deletions src/main/scala/analysis/SSAForm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,82 +22,142 @@ object SSAForm {
// }
// }

def applySSA(program: Program, invasive: Boolean = false): Unit = {
val variableMapping = new mutable.HashMap[String, Int]().withDefault(_ => 0)
// def applySSA(program: Program, invasive: Boolean = false): Unit = {
// val variableMapping = new mutable.HashMap[String, Int]().withDefault(_ => 0)
// for (proc <- program.procedures) {
// val blockBasedMappings = new mutable.HashMap[Block, mutable.Map[String, Int]]()
// variableMapping.keys.foreach(key => variableMapping.update(key, variableMapping(key) + 1))
// for (block <- proc.blocks) {
// blockBasedMappings.update(block, new mutable.HashMap[String, Int]().withDefault(_ => 0))
// for (stmt <- block.statements) {
// println(stmt)
// stmt match {
// case localAssign: LocalAssign => {
// if (invasive) {
// localAssign.rhs.variables.foreach(v => {
// v.name = v.name + "_" + variableMapping(v.name)
// })
// variableMapping(localAssign.lhs.name) += 1
// localAssign.lhs.name = localAssign.lhs.name + "_" + variableMapping(localAssign.lhs.name)
// } else {
// localAssign.rhs.variables.foreach(v => {
// v.ssa_id = variableMapping(v.name)
// })
// variableMapping(localAssign.lhs.name) += 1
// localAssign.lhs.ssa_id = variableMapping(localAssign.lhs.name)
// }
// }
// case memoryAssign: MemoryAssign => {
// if (invasive) {
// memoryAssign.lhs.variables.foreach(v => {
// v.name = v.name + "_" + variableMapping(v.name)
// })
// memoryAssign.rhs.variables.foreach(v => {
// v.name = v.name + "_" + variableMapping(v.name)
// })
// } else {
// memoryAssign.lhs.variables.foreach(v => {
// v.ssa_id = variableMapping(v.name)
// })
// memoryAssign.rhs.variables.foreach(v => {
// v.ssa_id = variableMapping(v.name)
// })
// }
// }
// case assume: Assume => {
// if (invasive) {
// assume.body.variables.foreach(v => {
// v.name = v.name + "_" + variableMapping(v.name)
// })
// } else {
// assume.body.variables.foreach(v => {
// v.ssa_id = variableMapping(v.name)
// })
// }
// } // no required for analyses
// case assume: Assert => {
// if (invasive) {
// assume.body.variables.foreach(v => {
// v.name = v.name + "_" + variableMapping(v.name)
// })
// } else {
// assume.body.variables.foreach(v => {
// v.ssa_id = variableMapping(v.name)
// })
// }
// } // no required for analyses
// case _ => throw new RuntimeException("No SSA form for " + stmt.getClass + " yet")
// }
// }
// block.jump match {
// case indirectCall: IndirectCall => {
// if (invasive) {
// indirectCall.target.variables.foreach(v => {
// v.name = v.name + "_" + variableMapping(v.name)
// })
// } else {
// indirectCall.target.variables.foreach(v => {
// v.ssa_id = variableMapping(v.name)
// })
// }
// }
// case _ => {}
// }
// }
// }
// }

def applySSA(program: Program): Unit = {
val varMaxTracker = new mutable.HashMap[String, Int]()
val blockBasedMappings = new mutable.HashMap[(Block, String), Set[Int]]().withDefault(_ => Set())
for (proc <- program.procedures) {
variableMapping.keys.foreach(key => variableMapping.update(key, variableMapping(key) + 1))
for (block <- proc.blocks) {
for (stmt <- block.statements) {
println(stmt)
stmt match {
case localAssign: LocalAssign => {
if (invasive) {
localAssign.rhs.variables.foreach(v => {
v.name = v.name + "_" + variableMapping(v.name)
})
variableMapping(localAssign.lhs.name) += 1
localAssign.lhs.name = localAssign.lhs.name + "_" + variableMapping(localAssign.lhs.name)
} else {
localAssign.rhs.variables.foreach(v => {
v.ssa_id = variableMapping(v.name)
v.ssa_id = blockBasedMappings.getOrElseUpdate((block, v.name), Set())
})
variableMapping(localAssign.lhs.name) += 1
localAssign.lhs.ssa_id = variableMapping(localAssign.lhs.name)
}
val maxVal = varMaxTracker.getOrElse(localAssign.lhs.name, 0)
blockBasedMappings((block, localAssign.lhs.name)) = Set(maxVal + 1)

localAssign.lhs.ssa_id = blockBasedMappings((block, localAssign.lhs.name))
varMaxTracker(localAssign.lhs.name) = blockBasedMappings((block, localAssign.lhs.name)).max
}
case memoryAssign: MemoryAssign => {
if (invasive) {
memoryAssign.lhs.variables.foreach(v => {
v.name = v.name + "_" + variableMapping(v.name)
v.ssa_id = blockBasedMappings.getOrElseUpdate((block, v.name), Set())
})
memoryAssign.rhs.variables.foreach(v => {
v.name = v.name + "_" + variableMapping(v.name)
})
} else {
memoryAssign.lhs.variables.foreach(v => {
v.ssa_id = variableMapping(v.name)
v.ssa_id = blockBasedMappings.getOrElseUpdate((block, v.name), Set())
})
memoryAssign.rhs.variables.foreach(v => {
v.ssa_id = variableMapping(v.name)
})
}
}
case assume: Assume => {
if (invasive) {
assume.body.variables.foreach(v => {
v.name = v.name + "_" + variableMapping(v.name)
})
} else {
assume.body.variables.foreach(v => {
v.ssa_id = variableMapping(v.name)
v.ssa_id = blockBasedMappings.getOrElseUpdate((block, v.name), Set())
})
}
} // no required for analyses
case assume: Assert => {
if (invasive) {
assume.body.variables.foreach(v => {
v.name = v.name + "_" + variableMapping(v.name)
})
} else {
assume.body.variables.foreach(v => {
v.ssa_id = variableMapping(v.name)
v.ssa_id = blockBasedMappings.getOrElseUpdate((block, v.name), Set())
})
}
} // no required for analyses
case _ => throw new RuntimeException("No SSA form for " + stmt.getClass + " yet")
}
}
block.jump match {
case indirectCall: IndirectCall => {
if (invasive) {
indirectCall.target.variables.foreach(v => {
v.name = v.name + "_" + variableMapping(v.name)
})
} else {
indirectCall.target.variables.foreach(v => {
v.ssa_id = variableMapping(v.name)
indirectCall.target.variables.foreach(v => {
v.ssa_id = blockBasedMappings.getOrElseUpdate((block, v.name), Set())
})
}
case goTo: GoTo => {
goTo.targets.foreach(b => {
varMaxTracker.keys.foreach(varr => {
blockBasedMappings((b, varr)) = blockBasedMappings(b, varr) ++ blockBasedMappings(block, varr)
})
}
})
}
case _ => {}
}
Expand Down
17 changes: 3 additions & 14 deletions src/main/scala/analysis/SteensgaardAnalysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@ package analysis
import analysis.solvers.{Cons, Term, UnionFindSolver, Var}
import ir.*
import util.Logger

import java.io.{File, PrintWriter}
import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.collection.mutable.ListBuffer

/** Steensgaard-style pointer analysis. The analysis associates an [[StTerm]] with each variable declaration and
* expression node in the AST. It is implemented using [[tip.solvers.UnionFindSolver]].
Expand All @@ -20,8 +18,6 @@ class SteensgaardAnalysis(

val solver: UnionFindSolver[StTerm] = UnionFindSolver()

val stringArr: ArrayBuffer[String] = ArrayBuffer()

private val stackPointer = Register("R31", BitVecType(64))
private val linkRegister = Register("R30", BitVecType(64))
private val framePointer = Register("R29", BitVecType(64))
Expand Down Expand Up @@ -141,16 +137,9 @@ class SteensgaardAnalysis(
/** @inheritdoc
*/
def analyze(): Unit =
// generate the constraints by traversing the AST and solve them on-the-fly
// generate the constraints by traversing the AST and solve them on-the-fly
cfg.nodes.foreach(visit(_, ()))

// def dump_file(content: ArrayBuffer[String], name: String): Unit = {
// val outFile = File(s"$name")
// val pw = PrintWriter(outFile, "UTF-8")
// for (s <- content) { pw.append(s + "\n") }
// pw.close()
// }

/** Generates the constraints for the given sub-AST.
* @param node
* the node for which it generates the constraints
Expand Down Expand Up @@ -179,7 +168,7 @@ class SteensgaardAnalysis(
case localAssign: LocalAssign =>
localAssign.rhs match {
case binOp: BinaryExpr =>
// X1 = &X: [[X1]] = ↑[[X2]]
// X1 = &X2: [[X1]] = ↑[[X2]]
if (binOp.arg1 == stackPointer) {
evaluateExpression(binOp.arg2, constantProp(n)) match {
case Some(b: BitVecLiteral) =>
Expand Down
10 changes: 9 additions & 1 deletion src/main/scala/ir/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package ir
import boogie._

trait Expr {
var ssa_id: Int = 0
var ssa_id: Set[Int] = Set()
def toBoogie: BExpr
def toGamma: BExpr = {
val gammaVars: Set[BExpr] = gammas.map(_.toGamma)
Expand Down Expand Up @@ -361,6 +361,14 @@ sealed trait Variable extends Expr {

override def acceptVisit(visitor: Visitor): Variable =
throw new Exception("visitor " + visitor + " unimplemented for: " + this)

override def equals(obj: Any): Boolean =
obj match {
case v: Variable => v.name == name && v.irType == irType && (v.ssa_id == ssa_id || v.ssa_id.intersect(ssa_id).nonEmpty)
case _ => false
}

override def hashCode(): Int = name.hashCode + irType.hashCode + ssa_id.hashCode()
}

case class Register(var name: String, override val irType: IRType) extends Variable with Global {
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/util/RunUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ object RunUtils {
Logger.info(subroutines)

val mergedSubroutines = subroutines ++ externalAddresses
applySSA(IRProgram, false)
applySSA(IRProgram)
val cfg = ProgramCfgFactory().fromIR(IRProgram)

Logger.info("[!] Running Constant Propagation")
Expand Down

0 comments on commit b9a6eb3

Please sign in to comment.