From 834f34e68889ddbc74950088f684cd682f368cbe Mon Sep 17 00:00:00 2001 From: yousifpatti Date: Tue, 13 Aug 2024 16:02:23 +1000 Subject: [PATCH] Merging only region injection files --- src/main/scala/analysis/MemoryModelMap.scala | 227 +++++++++++++- src/main/scala/analysis/RegionInjector.scala | 301 +++++++++++++++++++ src/main/scala/util/RunUtils.scala | 4 + 3 files changed, 529 insertions(+), 3 deletions(-) create mode 100644 src/main/scala/analysis/RegionInjector.scala diff --git a/src/main/scala/analysis/MemoryModelMap.scala b/src/main/scala/analysis/MemoryModelMap.scala index d91340d5b..eb14b983c 100644 --- a/src/main/scala/analysis/MemoryModelMap.scala +++ b/src/main/scala/analysis/MemoryModelMap.scala @@ -30,6 +30,8 @@ class MemoryModelMap { private val heapMap: mutable.Map[RangeKey, HeapRegion] = mutable.TreeMap() private val dataMap: mutable.Map[RangeKey, DataRegion] = mutable.TreeMap() + private val uf = new UnionFind() + /** Add a range and object to the mapping * * @param offset the offset of the range @@ -66,6 +68,15 @@ class MemoryModelMap { currentDataMap.addOne(updatedRange -> currentMaxRegion) currentDataMap(RangeKey(offset, MAX_BIGINT)) = d } + case h: HeapRegion => + val currentHeapMap = heapMap + if (currentHeapMap.isEmpty) { + currentHeapMap(RangeKey(offset, offset + h.size.value - 1)) = h + } else { + val currentMaxRange = currentHeapMap.keys.maxBy(_.end) + val currentMaxRegion = currentHeapMap(currentMaxRange) + currentHeapMap(RangeKey(currentMaxRange.start + 1, h.size.value - 1)) = h + } } } @@ -164,6 +175,22 @@ class MemoryModelMap { for (dataRgn <- allDataRgns) { add(dataRgn.start.value, dataRgn) } + + // add heap regions + val rangeStart = 0 + for ((position, regions) <- memoryRegions) { + regions match { + case Lift(node) => + for (region <- node) { + region match { + case heapRegion: HeapRegion => + add(BigInt(0), heapRegion) + case _ => + } + } + case LiftedBottom => + } + } } // TODO: push and pop could be optimised by caching the results def pushContext(funName: String): Unit = { @@ -201,15 +228,139 @@ class MemoryModelMap { } } + /* All regions that either: + * 1. starts at value but size less than region size + * 2. starts at value but size more than region size (add both regions ie. next region) + * 3. starts between regions (start, end) and (value + size) => end + * 4. starts between regions (start, end) and (value + size) < end (add both regions ie. next region) + */ + def findStackPartialAccessesOnly(value: BigInt, size: BigInt): Set[StackRegion] = { + val matchingRegions = scala.collection.mutable.Set[StackRegion]() + + stackMap.foreach { case (range, region) => + // Condition 1: Starts at value but size less than region size + if (range.start == value && range.size > size) { + matchingRegions += region + } + // Condition 2: Starts at value but size more than region size (add subsequent regions) + else if (range.start == value && range.size < size) { + matchingRegions += region + var remainingSize = size - range.size + var nextStart = range.end + stackMap.toSeq.sortBy(_._1.start).dropWhile(_._1.start <= range.start).foreach { case (nextRange, nextRegion) => + if (remainingSize > 0) { + matchingRegions += nextRegion + remainingSize -= nextRange.size + nextStart = nextRange.end + } + } + } + // Condition 3: Starts between regions (start, end) and (value + size) => end + else if (range.start < value && (value + size) <= range.end) { + matchingRegions += region + } + // Condition 4: Starts between regions (start, end) and (value + size) < end (add subsequent regions) + else if (range.start < value && (value + size) > range.end) { + matchingRegions += region + var remainingSize = (value + size) - range.end + var nextStart = range.end + stackMap.toSeq.sortBy(_._1.start).dropWhile(_._1.start <= range.start).foreach { case (nextRange, nextRegion) => + if (remainingSize > 0) { + matchingRegions += nextRegion + remainingSize -= nextRange.size + nextStart = nextRange.end + } + } + } + } + + matchingRegions.toSet.map(returnRegion) + } + + def getRegionsWithSize(size: BigInt, function: String, negateCondition: Boolean = false): Set[MemoryRegion] = { + val matchingRegions = scala.collection.mutable.Set[MemoryRegion]() + + pushContext(function) + stackMap.foreach { + case (range, region) => + if (negateCondition) { + if (range.size != size) { + matchingRegions += region + } + } else if (range.size == size) { + matchingRegions += region + } + } + popContext() + + heapMap.foreach { case (range, region) => + if (negateCondition) { + if (range.size != size) { + matchingRegions += region + } + } else if (range.size == size) { + matchingRegions += region + } + } + + dataMap.foreach { case (range, region) => + if (negateCondition) { + if (range.size != size) { + matchingRegions += region + } + } else if (range.size == size) { + matchingRegions += region + } + } + + matchingRegions.toSet.map(returnRegion) + } + + def getAllocsPerProcedure: Map[String, Set[StackRegion]] = { + localStacks.map((name, stackRegions) => (name, stackRegions.toSet.map(returnRegion))).toMap + } + + def getAllStackRegions: Set[StackRegion] = { + localStacks.values.toSet.flatten.map(returnRegion) + } + + def getAllDataRegions: Set[DataRegion] = { + dataMap.values.toSet.map(returnRegion) + } + + def getAllHeapRegions: Set[HeapRegion] = { + heapMap.values.toSet.map(returnRegion) + } + + def getAllRegions: Set[MemoryRegion] = { + getAllStackRegions ++ getAllDataRegions ++ getAllHeapRegions + } + + def getEnd(memoryRegion: MemoryRegion): BigInt = { // TODO: This would return a list of ends + val range = memoryRegion match { + case stackRegion: StackRegion => + stackMap.find((_, obj) => obj == stackRegion).map((range, _) => range).getOrElse(RangeKey(0, 0)) + case heapRegion: HeapRegion => + heapMap.find((_, obj) => obj == heapRegion).map((range, _) => range).getOrElse(RangeKey(0, 0)) + case dataRegion: DataRegion => + dataMap.find((_, obj) => obj == dataRegion).map((range, _) => range).getOrElse(RangeKey(0, 0)) + } + range.end + } + + /* All regions that start at value and are exactly of length size */ + def findStackFullAccessesOnly(value: BigInt, size: BigInt): Option[StackRegion] = { + stackMap.find((range, _) => range.start == value && range.size == size).map((range, obj) => returnRegion(obj)) + } def findStackObject(value: BigInt): Option[StackRegion] = - stackMap.find((range, _) => range.start <= value && value <= range.end).map((range, obj) => obj) + stackMap.find((range, _) => range.start <= value && value <= range.end).map((range, obj) => returnRegion(obj)) def findSharedStackObject(value: BigInt): Set[StackRegion] = - sharedStackMap.values.flatMap(_.find((range, _) => range.start <= value && value <= range.end).map((range, obj) => obj)).toSet + sharedStackMap.values.flatMap(_.find((range, _) => range.start <= value && value <= range.end).map((range, obj) => returnRegion(obj))).toSet def findDataObject(value: BigInt): Option[DataRegion] = - dataMap.find((range, _) => range.start <= value && value <= range.end).map((range, obj) => obj) + dataMap.find((range, _) => range.start <= value && value <= range.end).map((range, obj) => returnRegion(obj)) override def toString: String = s"Stack: $stackMap\n Heap: $heapMap\n Data: $dataMap\n" @@ -254,6 +405,29 @@ class MemoryModelMap { logRegion(range, region) } } + + def mergeRegions(regions: Set[MemoryRegion]): MemoryRegion = { + // assert regions are of the same type + regions.foreach(uf.makeSet) + regions.foreach(uf.union(regions.head, _)) + uf.find(regions.head) + } + + private def returnRegion(region: MemoryRegion): MemoryRegion = { + uf.find(region) + } + + private def returnRegion(region: StackRegion): StackRegion = { + uf.find(region.asInstanceOf[MemoryRegion]).asInstanceOf[StackRegion] + } + + private def returnRegion(region: DataRegion): DataRegion = { + uf.find(region.asInstanceOf[MemoryRegion]).asInstanceOf[DataRegion] + } + + private def returnRegion(region: HeapRegion): HeapRegion = { + uf.find(region.asInstanceOf[MemoryRegion]).asInstanceOf[HeapRegion] + } } trait MemoryRegion { @@ -271,3 +445,50 @@ case class HeapRegion(override val regionIdentifier: String, size: BitVecLiteral case class DataRegion(override val regionIdentifier: String, start: BitVecLiteral) extends MemoryRegion { override def toString: String = s"Data($regionIdentifier, $start)" } + +class UnionFind { + // Map to store the parent of each region + private val parent: mutable.Map[MemoryRegion, MemoryRegion] = mutable.Map() + + // Map to store the size of each set, used for union by rank + private val size: mutable.Map[MemoryRegion, Int] = mutable.Map() + + // Initialise each region to be its own parent and set size to 1 + def makeSet(region: MemoryRegion): Unit = { + parent(region) = region + size(region) = 1 + } + + // Find operation with path compression + def find(region: MemoryRegion): MemoryRegion = { + if (!parent.contains(region)) { + makeSet(region) + } + + if (parent(region) != region) { + parent(region) = find(parent(region)) // Path compression + } + parent(region) + } + + // Union operation with union by rank + def union(region1: MemoryRegion, region2: MemoryRegion): Unit = { + val root1 = find(region1) + val root2 = find(region2) + + if (root1 != root2) { + if (size(root1) < size(root2)) { + parent(root1) = root2 + size(root2) += size(root1) + } else { + parent(root2) = root1 + size(root1) += size(root2) + } + } + } + + // Check if two regions are in the same set + def connected(region1: MemoryRegion, region2: MemoryRegion): Boolean = { + find(region1) == find(region2) + } +} \ No newline at end of file diff --git a/src/main/scala/analysis/RegionInjector.scala b/src/main/scala/analysis/RegionInjector.scala new file mode 100644 index 000000000..1d6ad0781 --- /dev/null +++ b/src/main/scala/analysis/RegionInjector.scala @@ -0,0 +1,301 @@ +package analysis + +import ir.* +import util.Logger +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +/** + * Replaces the region access with the calculated memory region. + */ +class RegionInjector(domain: mutable.Set[CFGPosition], + program: Program, + constantProp: Map[CFGPosition, Map[RegisterWrapperEqualSets, Set[BitVecLiteral]]], + mmm: MemoryModelMap, + reachingDefs: Map[CFGPosition, (Map[Variable, Set[Assign]], Map[Variable, Set[Assign]])], + globalOffsets: Map[BigInt, BigInt]) { + private val stackPointer = Register("R31", 64) + + def nodeVisitor(): Unit = { + for (elem <- domain) {localTransfer(elem)} + program.initialMemory = transformMemorySections(program.initialMemory) + program.readOnlyMemory = transformMemorySections(program.readOnlyMemory) + } + + /** + * In expressions that have accesses within a region, we need to relocate + * the base address to the actual address using the relocation table. + * MUST RELOCATE because MMM iterate to find the lowest address + * TODO: May need to iterate over the relocation table to find the actual address + * + * @param address + * @param globalOffsets + * @return BitVecLiteral: the relocated address + */ + def relocatedBase(address: BitVecLiteral, globalOffsets: Map[BigInt, BigInt]): BitVecLiteral = { + val tableAddress = globalOffsets.getOrElse(address.value, address.value) + // this condition checks if the address is not layered and returns if it is not + if (tableAddress != address.value && !globalOffsets.contains(tableAddress)) { + return address + } + BitVecLiteral(tableAddress, address.size) + } + + /** + * Used to reduce an expression that may be a sub-region of a memory region. + * Pointer reduction example: + * R2 = R31 + 20 + * Mem[R2 + 8] <- R1 + * + * Steps: + * 1) R2 = R31 + 20 <- ie. stack access (assume R31 = stackPointer) + * ↓ + * R2 = StackRegion("stack_1", 20) + * + * 2) Mem[R2 + 8] <- R1 <- ie. memStore + * ↓ + * (StackRegion("stack_1", 20) + 8) <- R1 + * ↓ + * MMM.get(20 + 8) <- R1 + * + * @param binExpr + * @param n + * @return Set[MemoryRegion]: a set of regions that the expression may be pointing to + */ + def reducibleToRegion(binExpr: BinaryExpr, n: Command): Set[MemoryRegion] = { + var reducedRegions = Set.empty[MemoryRegion] + binExpr.arg1 match { + case variable: Variable => + evaluateExpressionWithSSA(binExpr, constantProp(n), n, reachingDefs).foreach { b => + val region = mmm.findDataObject(b.value) + reducedRegions = reducedRegions ++ region + } + if (reducedRegions.nonEmpty) { + return reducedRegions + } + val ctx = getUse(variable, n, reachingDefs) + for (i <- ctx) { + if (i != n) { // handles loops (ie. R19 = R19 + 1) %00000662 in jumptable2 + val regions = i.rhs match { + case loadL: MemoryLoad => + val foundRegions = exprToRegion(loadL.index, i) + val toReturn = mutable.Set[MemoryRegion]().addAll(foundRegions) + for { + f <- foundRegions + } { + // TODO: Must enable this (probably need to calculate those contents beforehand) +// if (memoryRegionContents.contains(f)) { +// memoryRegionContents(f).foreach { +// case b: BitVecLiteral => +// // val region = mmm.findDataObject(b.value) +// // if (region.isDefined) { +// // toReturn.addOne(region.get) +// // } +// case r: MemoryRegion => +// toReturn.addOne(r) +// toReturn.remove(f) +// } +// } + } + toReturn.toSet + case _: BitVecLiteral => + Set.empty[MemoryRegion] + case _ => + println(s"Unknown expression: ${i}") + println(ctx) + exprToRegion(i.rhs, i) + } + val results = evaluateExpressionWithSSA(binExpr.arg2, constantProp(n), n, reachingDefs) + for { + b <- results + r <- regions + } { + r match { + case stackRegion: StackRegion => + println(s"StackRegion: ${stackRegion.start}") + println(s"BitVecLiteral: ${b}") + if (b.size == stackRegion.start.size) { + val nextOffset = BinaryExpr(binExpr.op, stackRegion.start, b) + evaluateExpressionWithSSA(nextOffset, constantProp(n), n, reachingDefs).foreach { b2 => + reducedRegions ++= exprToRegion(BinaryExpr(binExpr.op, stackPointer, b2), n) + } + } + case dataRegion: DataRegion => + val nextOffset = BinaryExpr(binExpr.op, relocatedBase(dataRegion.start, globalOffsets), b) + evaluateExpressionWithSSA(nextOffset, constantProp(n), n, reachingDefs).foreach { b2 => + reducedRegions ++= exprToRegion(b2, n) + } + case _ => + } + } + } + } + case _ => + } + reducedRegions + } + + /** + * Finds a region for a given expression using MMM results + * + * @param expr + * @param n + * @return Set[MemoryRegion]: a set of regions that the expression may be pointing to + */ + def exprToRegion(expr: Expr, n: Command): Set[MemoryRegion] = { + var res = Set[MemoryRegion]() + mmm.popContext() + mmm.pushContext(IRWalk.procedure(n).name) + expr match { // TODO: Stack detection here should be done in a better way or just merged with data + case binOp: BinaryExpr if binOp.arg1 == stackPointer => + evaluateExpressionWithSSA(binOp.arg2, constantProp(n), n, reachingDefs).foreach { b => + if binOp.arg2.variables.exists { v => v.sharedVariable } then { + Logger.debug("Shared stack object: " + b) + Logger.debug("Shared in: " + expr) + val regions = mmm.findSharedStackObject(b.value) + Logger.debug("found: " + regions) + res ++= regions + } else { + val region = mmm.findStackObject(b.value) + if (region.isDefined) { + res = res + region.get + } + } + } + res + case binaryExpr: BinaryExpr => + res ++= reducibleToRegion(binaryExpr, n) + res + case v: Variable if v == stackPointer => + res ++= mmm.findStackObject(0) + res + case v: Variable => + evaluateExpressionWithSSA(expr, constantProp(n), n, reachingDefs).foreach { b => + Logger.debug("BitVecLiteral: " + b) + val region = mmm.findDataObject(b.value) + if (region.isDefined) { + res += region.get + } + } + if (res.isEmpty) { // may be passed as param + val ctx = getUse(v, n, reachingDefs) + for (i <- ctx) { + i.rhs match { + case load: MemoryLoad => // treat as a region + res ++= exprToRegion(load.index, i) + case binaryExpr: BinaryExpr => + res ++= reducibleToRegion(binaryExpr, i) + case _ => // also treat as a region (for now) even if just Base + Offset without memLoad + res ++= exprToRegion(i.rhs, i) + } + } + } + res + case _ => + evaluateExpressionWithSSA(expr, constantProp(n), n, reachingDefs).foreach { b => + Logger.debug("BitVecLiteral: " + b) + val region = mmm.findDataObject(b.value) + if (region.isDefined) { + res += region.get + } + } + res + } + } + + /** Default implementation of eval. + */ + def eval(expr: Expr, cmd: Command): Expr = { + expr match + case literal: Literal => literal // ignore literals + case Extract(end, start, body) => + Extract(end, start, eval(body, cmd)) + case Repeat(repeats, body) => + Repeat(repeats, eval(body, cmd)) + case ZeroExtend(extension, body) => + ZeroExtend(extension, eval(body, cmd)) + case SignExtend(extension, body) => + SignExtend(extension, eval(body, cmd)) + case UnaryExpr(op, arg) => + UnaryExpr(op, eval(arg, cmd)) + case BinaryExpr(op, arg1, arg2) => + BinaryExpr(op, eval(arg1, cmd), eval(arg2, cmd)) + case MemoryLoad(mem, index, endian, size) => + // TODO: index should be replaced region + MemoryLoad(renameMemory(mem, index, cmd), eval(index, cmd), endian, size) + case variable: Variable => variable // ignore variables + } + + def renameMemory(mem: Memory, expr: Expr, cmd : Command): Memory = { + val regions = exprToRegion(eval(expr, cmd), cmd) + if (regions.size == 1) { + Logger.warn(s"Mem CMD is: ${cmd}") + Logger.warn(s"Region found for mem: ${regions.head}") + regions.head match { + case stackRegion: StackRegion => + StackMemory(stackRegion.regionIdentifier, mem.addressSize, mem.valueSize) + case dataRegion: DataRegion => + SharedMemory(dataRegion.regionIdentifier, mem.addressSize, mem.valueSize) + case _ => + } + } else if (regions.size > 1) { + Logger.warn(s"Mem CMD is: ${cmd}") + Logger.warn(s"Multiple regions found for mem: ${regions}") + mmm.mergeRegions(regions) match { + case stackRegion: StackRegion => + StackMemory(stackRegion.regionIdentifier, mem.addressSize, mem.valueSize) + case dataRegion: DataRegion => + SharedMemory(dataRegion.regionIdentifier, mem.addressSize, mem.valueSize) + case _ => + } + } else { + Logger.warn(s"Mem CMD is: ${cmd}") + Logger.warn(s"No region found for mem") + } + mem + } + + /** Transfer function for state lattice elements. + */ + def localTransfer(n: CFGPosition): Unit = n match { + case cmd: Command => + cmd match + case statement: Statement => statement match + case assign: Assign => + assign.rhs = eval(assign.rhs, cmd) + case mAssign: MemoryAssign => + mAssign.mem = renameMemory(mAssign.mem, mAssign.index, cmd) + mAssign.index = eval(mAssign.index, cmd) + mAssign.value = eval(mAssign.value, cmd) + case nop: NOP => // ignore NOP + case assert: Assert => + assert.body = eval(assert.body, cmd) + case assume: Assume => + assume.body = eval(assume.body, cmd) + case jump: Jump => jump match + case to: GoTo => // ignore GoTo + case call: Call => call match + case call: DirectCall => // ignore DirectCall + case call: IndirectCall => // ignore IndirectCall + case _ => // ignore other kinds of nodes + } + + def transformMemorySections(memorySegment: ArrayBuffer[MemorySection]): ArrayBuffer[MemorySection] = { + val newArrayBuffer = ArrayBuffer.empty[MemorySection] + for (elem <- memorySegment) { + elem match { + case mem: MemorySection => + val regions = mmm.findDataObject(mem.address) + if (regions.size == 1) { + newArrayBuffer += MemorySection(regions.head.regionIdentifier, mem.address, mem.size, mem.bytes) + Logger.warn(s"Region ${regions.get.regionIdentifier} found for memory section ${mem.address}") + } else { + newArrayBuffer += mem + Logger.warn(s"No region found for memory section ${mem.address}") + } + case _ => + } + } + newArrayBuffer + } +} \ No newline at end of file diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index da353ac21..12138a67e 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -677,6 +677,10 @@ object StaticAnalysis { mmm.convertMemoryRegions(mraResult, mergedSubroutines, globalOffsets, mraSolver.procedureToSharedRegions) mmm.logRegions() + Logger.info("[!] Injecting regions") + val regionInjector = RegionInjector(domain, IRProgram, constPropResultWithSSA, mmm, reachingDefinitionsAnalysisResults, globalOffsets) + regionInjector.nodeVisitor() + Logger.info("[!] Running Steensgaard") val steensgaardSolver = InterprocSteensgaardAnalysis(IRProgram, constPropResultWithSSA, regionAccessesAnalysisResults, mmm, reachingDefinitionsAnalysisResults, globalOffsets) steensgaardSolver.analyze()