Skip to content

Commit

Permalink
Merge pull request #258 from UQ-PAC/points_to_analysis_overlapping_ac…
Browse files Browse the repository at this point in the history
…cess

added functionality to track overlapping accesses
  • Loading branch information
l-kent authored Nov 25, 2024
2 parents 4586ce2 + 3baa6e3 commit 7ef566f
Show file tree
Hide file tree
Showing 31 changed files with 5,717 additions and 109 deletions.
Empty file modified scripts/lift.sh
100644 → 100755
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -124,27 +124,30 @@ class DataStructureAnalysis(program: Program,
calleeGraph.globalMapping.foreach {
case (range: AddressRange, Field(node: Node, offset: BigInt)) =>
val field = calleeGraph.find(node)
buGraph.mergeCells(
val res = buGraph.mergeCells(
buGraph.globalMapping(range).node.getCell(buGraph.globalMapping(range).offset),
field.node.getCell(field.offset + offset)
)
buGraph.handleOverlapping(res)
}

if (buGraph.varToCell.contains(callee)) {
buGraph.varToCell(callee).keys.foreach { variable =>
if (!ignoreRegisters.contains(variable)) {
val formal = buGraph.varToCell(callee)(variable)
buGraph.mergeCells(buGraph.adjust(formal), buGraph.adjust(callSite.paramCells(variable)))
val res = buGraph.mergeCells(buGraph.adjust(formal), buGraph.adjust(callSite.paramCells(variable)))
buGraph.handleOverlapping(res)
}
}
}

writesTo(callee).foreach { reg =>
val returnCells = buGraph.getCells(IRWalk.lastInProc(callee).get, reg)
// assert(returnCells.nonEmpty)
returnCells.foldLeft(buGraph.adjust(callSite.returnCells(reg))) { (c, ret) =>
val res = returnCells.foldLeft(buGraph.adjust(callSite.returnCells(reg))) { (c, ret) =>
buGraph.mergeCells(c, buGraph.adjust(ret))
}
buGraph.handleOverlapping(res)
}
}
buGraph.collectNodes()
Expand Down Expand Up @@ -181,32 +184,37 @@ class DataStructureAnalysis(program: Program,

callSite.returnCells.values.foreach { slice =>
val node = callersGraph.find(slice).node
node.cloneNode(callersGraph, callersGraph)
node.cloneNode(callersGraph, calleesGraph)
}

callersGraph.globalMapping.foreach { case (range: AddressRange, Field(oldNode, internal)) =>
// val node = callersGraph
val field = callersGraph.find(oldNode)
calleesGraph.mergeCells(
val res = calleesGraph.mergeCells(
calleesGraph.globalMapping(range).node.getCell(calleesGraph.globalMapping(range).offset),
field.node.getCell(field.offset + internal)
)
calleesGraph.handleOverlapping(res)
}

callSite.paramCells.keySet.foreach { variable =>
val paramCells = calleesGraph.getCells(callSite.call, variable) // wrong param offset
paramCells.foldLeft(calleesGraph.adjust(calleesGraph.formals(variable))) {
val res = paramCells.foldLeft(calleesGraph.adjust(calleesGraph.formals(variable))) {
(cell, slice) => calleesGraph.mergeCells(calleesGraph.adjust(slice), cell)
}
calleesGraph.handleOverlapping(res)

}

if (calleesGraph.varToCell.contains(callSite.call)) {
calleesGraph.varToCell(callSite.call).foreach { (variable, oldSlice) =>
val slice = callersGraph.find(oldSlice)
val returnCells = calleesGraph.getCells(IRWalk.lastInProc(callee).get, variable)
returnCells.foldLeft(calleesGraph.adjust(slice)) {
val res = returnCells.foldLeft(calleesGraph.adjust(slice)) {
(c, retCell) => calleesGraph.mergeCells(c, calleesGraph.adjust(retCell))
}

calleesGraph.handleOverlapping(res)
}
}
}
Expand Down
218 changes: 195 additions & 23 deletions src/main/scala/analysis/data_structure_analysis/Graph.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import analysis.solvers.DSAUnionFindSolver
import analysis.evaluateExpression
import cfg_visualiser.*
import ir.*
import specification.{ExternalFunction, SymbolTableEntry}
import specification.{ExternalFunction, FuncEntry, SpecGlobal, SymbolTableEntry}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
Expand Down Expand Up @@ -125,16 +125,94 @@ class Graph(val proc: Procedure,
nextValidOffset = offset + byteSize
}


/**
* Takes a cell and returns all corresponding stack offsets to it if any
*/
def getStackOffsets(cell: Cell): Set[BigInt] = { // TODO replace with tracking through merges
stackMapping.foldLeft(Set[BigInt]()) {
(s, f) =>
f match
case (offset: BigInt, node: Node) =>
s ++ node.cells.foldLeft(Set[BigInt]()) {
(se, g) =>
g match
case (internal: BigInt, stackCell: Cell) =>
if cell == find(stackCell) then
se + (offset + internal)
else
se
}
}
}

def getStack(offset: BigInt, size: Int): Cell = {
var last: BigInt = 0
var headNodeOffset: BigInt = -1

val head: Cell =
if stackMapping.contains(offset) then
headNodeOffset = offset
stackMapping(offset).cells(0)
else
breakable {
stackMapping.keys.toSeq.sorted.foreach(
elementOffset =>
if offset < elementOffset then
break
else
last = elementOffset
)
}
val diff = offset - last
headNodeOffset = last
assert(stackMapping.contains(last))
stackMapping(last).getCell(diff)

// DSA grows cell size with size
// selfCollapse at the end to merge all the overlapping accessed size
// However, the above approach prevents distinct multi loads
// find(head).growSize(size)
val headOffset = headNodeOffset + head.offset
stackMapping.keys.toSeq.filter(off => off > headOffset && off < headOffset + size).sorted.foreach {
off =>
val stackDiff = off - headOffset
val updatedHead = find(head)
val newHeadOffset = updatedHead.offset
val headNode = updatedHead.node.get
mergeCells(headNode.addCell(newHeadOffset + stackDiff, 0), find(stackMapping(off).cells(0)))
}
// selfCollapse(head.node.get)
find(head)
}


private val swappedOffsets = globalOffsets.map(_.swap)

// creates the globals from the symbol tables
val globalMapping = mutable.Map[AddressRange, Field]()
globals.foreach { global =>
val node = Node(Some(this), global.size)
node.allocationRegions.add(DataLocation(global.name, global.address, global.size / 8))
node.flags.global = true
node.flags.incomplete = true
globalMapping.update(AddressRange(global.address, global.address + global.size / 8), Field(node, 0))
globals.foreach {
case FuncEntry(name, size, address) =>
val func = Node(Some(this), size)
func.allocationRegions.add(Function(name))
func.flags.global = true
func.flags.incomplete = true
func.flags.function = true
// globalMapping.update(AddressRange(address, address + (size / 8)), Field(func, 0))

val pointer = Node(Some(this), size)
pointer.allocationRegions.add(DataPointer(name, address, size / 8)) // todo check that size 0 is correct
pointer.flags.global = true
pointer.flags.incomplete = true
pointer.cells(0).pointee = Some(Slice(func.cells(0), 0))
globalMapping.update(AddressRange(address, address + (size / 8)), Field(pointer, 0))
case SpecGlobal(name, size, arraySize, address) =>
val node = Node(Some(this), size)
node.allocationRegions.add(DataPointer(name, address, size / 8))
node.flags.global = true
node.flags.incomplete = true
globalMapping.update(AddressRange(address, address + size / 8), Field(node, 0))
case _ => ???
}

// creates a global for each relocation entry in the symbol table
Expand All @@ -157,7 +235,7 @@ class Graph(val proc: Procedure,

case None =>
val node = Node(Some(this))
node.allocationRegions.add(DataLocation(s"Relocated_$relocatedAddress", relocatedAddress, 8))
node.allocationRegions.add(DataPointer(s"Relocated_$relocatedAddress", relocatedAddress, 8))
node.flags.global = true
node.flags.incomplete = true
globalMapping.update(AddressRange(relocatedAddress, relocatedAddress + 8), Field(node, 0))
Expand All @@ -171,7 +249,7 @@ class Graph(val proc: Procedure,

externalFunctions.foreach { external =>
val node = Node(Some(this))
node.allocationRegions.add(DataLocation(external.name, external.offset, 0))
node.allocationRegions.add(DataPointer(external.name, external.offset, 0))
node.flags.global = true
node.flags.incomplete = true
globalMapping.update(AddressRange(external.offset, external.offset), Field(node, 0))
Expand All @@ -190,6 +268,59 @@ class Graph(val proc: Procedure,
global
}

// determine if an address is a global and return the corresponding global(s) if it is.
private def getGlobals(address: BigInt, size: Int): Seq[DSAGlobal] =
var global: Seq[DSAGlobal] = Seq.empty
for ((range, field) <- globalMapping) {
if (address < range.end && range.start < address + size) ||
(address + size > range.end && address < range.end) ||
(address >= range.start && (address < range.end || (range.start == range.end && range.end == address))) then
global = global ++ Seq(DSAGlobal(range, field))

}
global.sortBy(f => f.addressRange.start)



def getGlobal(address: BigInt, size: Int): Option[Cell] = {
val globals = getGlobals(address, size)
if globals.nonEmpty then
val head = globals.head
val DSAGlobal(range: AddressRange, Field(node, internal)) = head
val headOffset: BigInt = if address > range.start then address - range.start + internal else internal
val headNode = node
val headCell: Cell = node.addCell(headOffset, 0) // DSA has the size of the added cell should as size with
// selfCollapse at the end to merge all the overlapping accessed size
// However, the above approach prevent distinct multi loads
// graph.selfCollapse(headNode)
val tail = globals.tail
tail.foreach {
g =>
val DSAGlobal(range: AddressRange, Field(node, internal)) = g
val offset: BigInt = if address > range.start then address - range.start + internal else internal
node.addCell(offset, 0)
selfCollapse(node)
assert(range.start >= address)
mergeCells(find(headNode.addCell(range.start - address, 0)), find(node.getCell(offset)))
}
selfCollapse(find(headCell).node.get)
Some(find(headCell))
else
None
}

def getGlobalAddresses(cell: Cell): Set[BigInt] = {
globalMapping.foldLeft(Set[BigInt]()) {
(s, g) =>
g match
case (range: AddressRange, field: Field) =>
if cell == find(field.node.getCell(field.offset)) then
s + range.start
else
s
}
}

def getCells(pos: CFGPosition, arg: Variable): Set[Slice] = {
if (reachingDefs(pos).contains(arg)) {
reachingDefs(pos)(arg).map(definition => varToCell(definition)(arg))
Expand Down Expand Up @@ -284,7 +415,7 @@ class Graph(val proc: Procedure,
val offset = field.offset + find(field.node).offset
val cellOffset = node.getCell(offset).offset
val internalOffset = offset - cellOffset
arrows.append(StructArrow(DotStructElement(s"Global_${range.start}_${range.end}", None), DotStructElement(node.id.toString, Some(cellOffset.toString)), internalOffset.toString))
// arrows.append(StructArrow(DotStructElement(s"Global_${range.start}_${range.end}", None), DotStructElement(node.id.toString, Some(cellOffset.toString)), internalOffset.toString))
}

stackMapping.foreach { (offset, dsn) =>
Expand Down Expand Up @@ -509,12 +640,12 @@ class Graph(val proc: Procedure,
resultNode.children(k) = node2.children(k) + delta
}
resultNode.children += (node2 -> delta)
if node2.flags.global then // node 2 may have been adjusted depending on cell1 and cell2 offsets
globalMapping.foreach { // update global mapping if node 2 was global
case (range: AddressRange, Field(node, offset)) =>
if node.equals(node2) then
globalMapping.update(range, Field(node, offset + delta))
}
// if node2.flags.global then // node 2 may have been adjusted depending on cell1 and cell2 offsets
// globalMapping.foreach { // update global mapping if node 2 was global
// case (range: AddressRange, Field(node, offset)) =>
// if node.equals(node2) then
// globalMapping.update(range, Field(node, offset + delta))
// }

// compute the cells present in the resulting unified node
// a mapping from offsets to the set of old cells which are merged to form a cell in the new unified node
Expand Down Expand Up @@ -542,7 +673,8 @@ class Graph(val proc: Procedure,

resultCells.keys.foreach { offset =>
val collapsedCell = resultNode.addCell(offset, resultLargestAccesses(offset))
val outgoing: Set[Slice] = cells.flatMap { (_, cell) =>
val cells = resultCells(offset)
val outgoing: Set[Slice] = cells.flatMap { cell =>
if (cell.pointee.isDefined) {
Some(cell.getPointee)
} else {
Expand Down Expand Up @@ -590,6 +722,29 @@ class Graph(val proc: Procedure,
Slice(newCell, offset - newCell.offset)
}

def handleOverlapping(cell: Cell): Cell = {
val size = cell.node.get.getSize - cell.offset // if it's stack the size is rest of the node
val result =
if cell.node.get.flags.stack then
getStackOffsets(cell).foldLeft(cell) {
(res, offset) =>
val stack = getStack(offset, size.toInt)
mergeCells(res, stack)
}
else
cell


// size = result.largestAccessedSize
if result.node.get.flags.global then
getGlobalAddresses(result).foldLeft(result) {
(res, offset) =>
mergeCells(res, getGlobal(offset, size.toInt).get)
}
else
result
}

private def isFormal(pos: CFGPosition, variable: Variable): Boolean = !reachingDefs(pos).contains(variable)

// formal arguments to this function
Expand Down Expand Up @@ -650,6 +805,22 @@ class Graph(val proc: Procedure,
varToCell
}

def SSAVar(posLabel:String, varName: String): Slice = {
assert(posLabel.matches("%[0-9]{8}?\\$\\d"))

val res = varToCell.keys.filter(pos => pos.toShortString.startsWith(posLabel))
assert(res.size == 1)
val key = res.head

val map = varToCell(key).toMap

val temp = map.keys.filter(variable => variable.name == varName)
assert(temp.size == 1)
val variable = temp.head
map(variable)
}


def cloneSelf(): Graph = {
val newGraph = Graph(proc, constProp, varToSym, globals, globalOffsets, externalFunctions, reachingDefs, writesTo, params)
assert(formals.size == newGraph.formals.size)
Expand Down Expand Up @@ -695,12 +866,13 @@ class Graph(val proc: Procedure,

globalMapping.foreach { case (range: AddressRange, Field(node, offset)) =>
assert(newGraph.globalMapping.contains(range))
val field = find(node)
nodes.add(field.node)
if !idToNode.contains(field.node.id) then
val newNode = node.cloneSelf(newGraph)
idToNode.update(field.node.id, newNode)
newGraph.globalMapping.update(range, Field(idToNode(field.node.id), field.offset + offset))
val cell: Cell = find(node.getCell(offset))
val finalNode: Node = cell.node.get
nodes.add(finalNode)
if !idToNode.contains(finalNode.id) then
val newNode = finalNode.cloneSelf(newGraph)
idToNode.update(finalNode.id, newNode)
newGraph.globalMapping.update(range, Field(idToNode(finalNode.id), cell.offset + (node.getCell(offset).offset - offset)))
}

val queue = mutable.Queue[Node]()
Expand Down
Loading

0 comments on commit 7ef566f

Please sign in to comment.