diff --git a/src/main/scala/analysis/SSAForm.scala b/src/main/scala/analysis/SSAForm.scala index a5b182188..c1d444460 100644 --- a/src/main/scala/analysis/SSAForm.scala +++ b/src/main/scala/analysis/SSAForm.scala @@ -120,10 +120,10 @@ object SSAForm { v.ssa_id = blockBasedMappings.getOrElseUpdate((block, v.name), Set()) }) val maxVal = varMaxTracker.getOrElse(localAssign.lhs.name, 0) - blockBasedMappings((block, localAssign.lhs.name)) = Set(maxVal + 1) + blockBasedMappings((block, localAssign.lhs.name)) = Set(maxVal) localAssign.lhs.ssa_id = blockBasedMappings((block, localAssign.lhs.name)) - varMaxTracker(localAssign.lhs.name) = blockBasedMappings((block, localAssign.lhs.name)).max + varMaxTracker(localAssign.lhs.name) = blockBasedMappings((block, localAssign.lhs.name)).max + 1 } case memoryAssign: MemoryAssign => { memoryAssign.lhs.variables.foreach(v => { diff --git a/src/main/scala/analysis/SteensgaardAnalysis.scala b/src/main/scala/analysis/SteensgaardAnalysis.scala index bc687fdb6..e412106ac 100644 --- a/src/main/scala/analysis/SteensgaardAnalysis.scala +++ b/src/main/scala/analysis/SteensgaardAnalysis.scala @@ -6,6 +6,27 @@ import util.Logger import scala.collection.mutable import scala.collection.mutable.ListBuffer +/** Wrapper for variables so we can have Steensgaard-specific equals method indirectly */ +sealed trait VariableWrapper { + val variable: Variable +} + +/** Wrapper for variables so we can have Steensgaard-specific equals method indirectly */ +case class RegisterVariableWrapper(variable: Variable) extends VariableWrapper { + override def equals(obj: Any): Boolean = { + obj match { + case RegisterVariableWrapper(other) => + variable == other && (variable.ssa_id == other.ssa_id || variable.ssa_id.intersect(other.ssa_id).nonEmpty) + case _ => + false + } + } + + override def hashCode(): Int = { + variable.hashCode() + variable.ssa_id.hashCode() + } +} + /** Steensgaard-style pointer analysis. The analysis associates an [[StTerm]] with each variable declaration and * expression node in the AST. It is implemented using [[tip.solvers.UnionFindSolver]]. */ @@ -148,7 +169,7 @@ class SteensgaardAnalysis( */ def visit(n: CfgNode, arg: Unit): Unit = { - def varToStTerm(vari: Variable): Term[StTerm] = IdentifierVariable(vari) + def varToStTerm(vari: VariableWrapper): Term[StTerm] = IdentifierVariable(vari) def exprToStTerm(expr: MemoryRegion | Expr): Term[StTerm] = ExpressionVariable(expr) def allocToTerm(alloc: MemoryRegion): Term[StTerm] = AllocVariable(alloc) //def identifierToTerm(id: AIdentifier): Term[StTerm] = IdentifierVariable(id) @@ -162,7 +183,7 @@ class SteensgaardAnalysis( evaluateExpression(mallocVariable, constantProp(n)) match { case Some(b: BitVecLiteral) => val alloc = HeapRegion(nextMallocCount(), b) - unify(varToStTerm(mallocVariable), PointerRef(allocToTerm(alloc))) + unify(varToStTerm(RegisterVariableWrapper(mallocVariable)), PointerRef(allocToTerm(alloc))) } } case localAssign: LocalAssign => @@ -174,7 +195,7 @@ class SteensgaardAnalysis( case Some(b: BitVecLiteral) => val $X2 = poolMaster(b, cmd.parent) val X1 = localAssign.lhs - unify(varToStTerm(X1), PointerRef(allocToTerm($X2))) + unify(varToStTerm(RegisterVariableWrapper(X1)), PointerRef(allocToTerm($X2))) } } // TODO: should lookout for global base + offset case as well @@ -186,26 +207,27 @@ class SteensgaardAnalysis( val X2_star = eval(memoryLoad.index, cmd) val alpha = FreshVariable() unify(exprToStTerm(X2_star), PointerRef(alpha)) // TODO: X2_star should be the value of the memload not the memload itself - unify(alpha, varToStTerm(X1)) + unify(alpha, varToStTerm(RegisterVariableWrapper(X1))) // TODO: This might not be correct for globals // X1 = &X: [[X1]] = ^[[X2]] (but for globals) val $X2 = eval(memoryLoad.index, cmd) $X2 match case region: MemoryRegion => - unify(varToStTerm(X1), PointerRef(allocToTerm(region))) + unify(varToStTerm(RegisterVariableWrapper(X1)), PointerRef(allocToTerm(region))) case _ => case variable: Variable => // X1 = X2: [[X1]] = [[X2]] val X1 = localAssign.lhs val X2 = variable - unify(varToStTerm(X1), varToStTerm(X2)) + unify(varToStTerm(RegisterVariableWrapper(X1)), varToStTerm(RegisterVariableWrapper(X2))) } } case memoryAssign: MemoryAssign => // *X1 = X2: [[X1]] = ↑a ^ [[X2]] = a where a is a fresh term variable val X1_star = eval(memoryAssign.rhs.index, cmd) val X2 = evaluateExpression(memoryAssign.rhs.value, constantProp(n)).getOrElse(memoryAssign.rhs.value) + println(X2) val alpha = FreshVariable() unify(exprToStTerm(X1_star), PointerRef(alpha)) unify(alpha, exprToStTerm(X2)) @@ -226,7 +248,7 @@ class SteensgaardAnalysis( /** @inheritdoc */ - def pointsTo(): Map[Object, Set[Variable | MemoryRegion]] = { + def pointsTo(): Map[Object, Set[VariableWrapper | MemoryRegion]] = { val solution = solver.solution() val unifications = solver.unifications() Logger.info(s"Solution: \n${solution.mkString(",\n")}\n") @@ -237,7 +259,7 @@ class SteensgaardAnalysis( .mkString(", ")}") val vars = solution.keys.collect { case id: IdentifierVariable => id } - val pointsto = vars.foldLeft(Map[Object, Set[Variable | MemoryRegion]]()) { case (a, v: IdentifierVariable) => + val pointsto = vars.foldLeft(Map[Object, Set[VariableWrapper | MemoryRegion]]()) { case (a, v: IdentifierVariable) => val pt = unifications(solution(v)) .collect({ case PointerRef(IdentifierVariable(id)) => id @@ -252,9 +274,9 @@ class SteensgaardAnalysis( /** @inheritdoc */ - def mayAlias(): (Variable, Variable) => Boolean = { + def mayAlias(): (VariableWrapper, VariableWrapper) => Boolean = { val solution = solver.solution() - (id1: Variable, id2: Variable) => + (id1: VariableWrapper, id2: VariableWrapper) => val sol1 = solution(IdentifierVariable(id1)) val sol2 = solution(IdentifierVariable(id2)) sol1 == sol2 && sol1.isInstanceOf[PointerRef] // same equivalence class, and it contains a reference @@ -274,7 +296,7 @@ case class AllocVariable(alloc: MemoryRegion) extends StTerm with Var[StTerm] { /** A term variable that represents an identifier in the program. */ -case class IdentifierVariable(id: Variable) extends StTerm with Var[StTerm] { +case class IdentifierVariable(id: VariableWrapper) extends StTerm with Var[StTerm] { override def toString: String = s"$id" } diff --git a/src/main/scala/ir/Expr.scala b/src/main/scala/ir/Expr.scala index 8c68d815f..c4dbc2c0b 100644 --- a/src/main/scala/ir/Expr.scala +++ b/src/main/scala/ir/Expr.scala @@ -361,14 +361,6 @@ sealed trait Variable extends Expr { override def acceptVisit(visitor: Visitor): Variable = throw new Exception("visitor " + visitor + " unimplemented for: " + this) - - override def equals(obj: Any): Boolean = - obj match { - case v: Variable => v.name == name && v.irType == irType && (v.ssa_id == ssa_id || v.ssa_id.intersect(ssa_id).nonEmpty) - case _ => false - } - - override def hashCode(): Int = name.hashCode + irType.hashCode + ssa_id.hashCode() } case class Register(var name: String, override val irType: IRType) extends Variable with Global {