diff --git a/src/main/scala/analysis/ContextTransfer.scala b/src/main/scala/analysis/ContextTransfer.scala new file mode 100644 index 000000000..f4cf346c0 --- /dev/null +++ b/src/main/scala/analysis/ContextTransfer.scala @@ -0,0 +1,74 @@ +package analysis +import ir.* +import util.Logger + +import scala.collection.mutable + +/** Steensgaard-style pointer analysis. The analysis associates an [[UnifiableTerm]] with each variable declaration and + * expression node in the AST. It is implemented using [[tip.solvers.UnionFindSolver]]. + */ +class ContextTransfer( + cfg: ProgramCfg, + constantProp: Map[CfgNode, Map[Variable, FlatElement[BitVecLiteral]]]) extends Analysis[mutable.Map[Procedure, mutable.Map[Variable, mutable.Set[FlatElement[BitVecLiteral]]]]] { + + val functionMergedCtx = mutable.Map[Procedure, mutable.Map[Variable, mutable.Set[FlatElement[BitVecLiteral]]]]() + + + def mergeContexts(ctx1: Map[Variable, FlatElement[BitVecLiteral]], + ctx2: mutable.Map[Variable, mutable.Set[FlatElement[BitVecLiteral]]] = mutable.Map.empty): mutable.Map[Variable, mutable.Set[FlatElement[BitVecLiteral]]] = { + val mergedCtx = mutable.Map[Variable, mutable.Set[FlatElement[BitVecLiteral]]]() + ctx1.foreach { case (v, e) => { + val set = mutable.Set[FlatElement[BitVecLiteral]]() + set += e + mergedCtx(v) = set + }} + if (ctx2.isEmpty) return mergedCtx + + ctx2.foreach { case (v, e) => { + val set = mergedCtx.getOrElse(v, mutable.Set[FlatElement[BitVecLiteral]]()) + set ++= e + mergedCtx(v) = set + }} + mergedCtx + } + + /** @inheritdoc + */ + def analyze(): mutable.Map[Procedure, mutable.Map[Variable, mutable.Set[FlatElement[BitVecLiteral]]]] = + // generate the constraints by traversing the AST and solve them on-the-fly + cfg.nodes.foreach(n => visit(n, ())) + functionMergedCtx + + /** Generates the constraints for the given sub-AST. + * + * @param node + * the node for which it generates the constraints + * @param arg + * unused for this visitor + */ + def visit(n: CfgNode, arg: Unit): Unit = { + + n match { + case cfgJumpNode: CfgJumpNode => { + cfgJumpNode.data match { + case directCall: DirectCall => { + val currentCtx = constantProp(n) + val procedure = directCall.target + val procedureCtx = functionMergedCtx.get(procedure) + procedureCtx match { + case Some(ctx) => { + val mergedCtx = mergeContexts(currentCtx, ctx) + functionMergedCtx(procedure) = mergedCtx + } + case None => { + val mergedCtx = mergeContexts(currentCtx) + functionMergedCtx(procedure) = mergedCtx + } + } + } + } + } + case _ => // do nothing + } + } +} \ No newline at end of file diff --git a/src/main/scala/analysis/SteensgaardAnalysis.scala b/src/main/scala/analysis/SteensgaardAnalysis.scala index a15b87b18..969190271 100644 --- a/src/main/scala/analysis/SteensgaardAnalysis.scala +++ b/src/main/scala/analysis/SteensgaardAnalysis.scala @@ -33,6 +33,7 @@ case class RegisterVariableWrapper(variable: Variable) extends VariableWrapper { class SteensgaardAnalysis( cfg: ProgramCfg, constantProp: Map[CfgNode, Map[Variable, FlatElement[BitVecLiteral]]], + contextTransfer: mutable.Map[Procedure, mutable.Map[Variable, mutable.Set[FlatElement[BitVecLiteral]]]], globals: Map[BigInt, String], globalOffsets: Map[BigInt, BigInt], subroutines: Map[BigInt, String]) extends Analysis[Any] { @@ -97,6 +98,85 @@ class SteensgaardAnalysis( buffers } + + def evalUsingContext(exp: Expr, n: CfgCommandNode, returnSet: mutable.Set[MemoryRegion]): mutable.Set[MemoryRegion] = { + Logger.debug(s"evaluating using context $exp") + Logger.debug(s"n: $n") + exp match { + case binOp: BinaryExpr => + if (binOp.arg1 == stackPointer) { + evaluateExpression(binOp.arg2, constantProp(n)) match { + case Some(b: BitVecLiteral) => + returnSet.add(poolMaster(b, n.parent)) + returnSet + case None => + returnSet + } + } else { + evaluateExpression(binOp, constantProp(n)) match { + case Some(b: BitVecLiteral) => + returnSet.addAll(evalUsingContext(b, n, returnSet)) + returnSet + case None => + returnSet + } + } + case bitVecLiteral: BitVecLiteral => + if (globals.contains(bitVecLiteral.value)) { + val globalName = globals(bitVecLiteral.value) + returnSet.add(DataRegion(globalName, bitVecLiteral)) + returnSet + } else if (subroutines.contains(bitVecLiteral.value)) { + val subroutineName = subroutines(bitVecLiteral.value) + returnSet.add(DataRegion(subroutineName, bitVecLiteral)) + returnSet + } else if (globalOffsets.contains(bitVecLiteral.value)) { + val val1 = globalOffsets(bitVecLiteral.value) + if (subroutines.contains(val1)) { + val globalName = subroutines(val1) + returnSet.add(DataRegion(globalName, bitVecLiteral)) + returnSet + } else { + returnSet.add(DataRegion(s"Unknown_$bitVecLiteral", bitVecLiteral)) + returnSet + } + } else { + //throw new Exception(s"Unknown type for $bitVecLiteral") + // unknown region here + returnSet.add(DataRegion(s"Unknown_$bitVecLiteral", bitVecLiteral)) + returnSet + } + case variable: Variable => + variable match { + case _: LocalVar => + returnSet + case reg: Register if reg == stackPointer => + returnSet + case _ => + evaluateExpression(variable, constantProp(n)) match { + case Some(b: BitVecLiteral) => + returnSet.addAll(evalUsingContext(b, n, returnSet)) + returnSet + case _ => + contextTransfer(n.parent.data).get(variable) match { + case Some(set) => + set.foreach { + case FlatEl(el) => returnSet.addAll(evalUsingContext(el, n, returnSet)) + case _ => + } + returnSet + case None => + returnSet // we cannot evaluate this to a concrete value, we need VSA for this + } + } + } + // we cannot evaluate this to a concrete value, we need VSA for this + case _ => + Logger.debug(s"type: ${exp.getClass} $exp\n") + throw new Exception("Unknown type") + } + } + def eval(exp: Expr, n: CfgCommandNode): MemoryRegion | Expr = { Logger.debug(s"evaluating $exp") Logger.debug(s"n: $n") @@ -210,7 +290,7 @@ class SteensgaardAnalysis( unify(alpha, varToStTerm(RegisterVariableWrapper(X1))) // TODO: This might not be correct for globals - // X1 = &X: [[X1]] = ^[[X2]] (but for globals) + // X1 = &X: [[X1]] = ↑[[X2]] (but for globals) val $X2 = eval(memoryLoad.index, cmd) $X2 match case region: MemoryRegion => diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index 68123f7ea..4937be77b 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -153,8 +153,9 @@ object RunUtils { config.analysisDotPath.foreach(s => writeToFile(cfg.toDot(Output.labeler(constPropResult, true), Output.dotIder), s"${s}_constprop$iteration.dot")) config.analysisResultsPath.foreach(s => writeToFile(printAnalysisResults(cfg, constPropResult, iteration), s"${s}_constprop$iteration.txt")) + val contextTransfer = ContextTransfer(cfg, constPropResult).analyze() Logger.info("[!] Running Steensgaard") - val steensgaardSolver = SteensgaardAnalysis(cfg, constPropResult, globalAddresses, globalOffsets, mergedSubroutines) + val steensgaardSolver = SteensgaardAnalysis(cfg, constPropResult, contextTransfer, globalAddresses, globalOffsets, mergedSubroutines) steensgaardSolver.analyze() steensgaardSolver.pointsTo() steensgaardSolver.mayAlias()