Skip to content

Commit

Permalink
move trimUnreachable to transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
ailrst committed Aug 9, 2024
1 parent f713fae commit b7378b6
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 39 deletions.
37 changes: 2 additions & 35 deletions src/main/scala/ir/Program.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,40 +18,6 @@ class Program(var procedures: ArrayBuffer[Procedure],
serialiseIL(this)
}

// This shouldn't be run before indirect calls are resolved
def stripUnreachableFunctions(depth: Int = Int.MaxValue): Unit = {
val procedureCalleeNames = procedures.map(f => f.name -> f.calls.map(_.name)).toMap

val toVisit: mutable.LinkedHashSet[(Int, String)] = mutable.LinkedHashSet((0, mainProcedure.name))
var reachableFound = true
val reachableNames = mutable.HashMap[String, Int]()
while (toVisit.nonEmpty) {
val next = toVisit.head
toVisit.remove(next)

if (next._1 <= depth) {

def addName(depth: Int, name: String): Unit = {
val oldDepth = reachableNames.getOrElse(name, Integer.MAX_VALUE)
reachableNames.put(next._2, if depth < oldDepth then depth else oldDepth)
}
addName(next._1, next._2)

val callees = procedureCalleeNames(next._2)

toVisit.addAll(callees.diff(reachableNames.keySet).map(c => (next._1 + 1, c)))
callees.foreach(c => addName(next._1 + 1, c))
}
}
procedures = procedures.filter(f => reachableNames.keySet.contains(f.name))

for (elem <- procedures.filter(c => c.calls.exists(s => !procedures.contains(s)))) {
// last layer is analysed only as specifications so we remove the body for anything that calls
// a function we have removed

elem.clearBlocks()
}
}

def setModifies(specModifies: Map[String, List[String]]): Unit = {
val procToCalls: mutable.Map[Procedure, Set[Procedure]] = mutable.Map()
Expand Down Expand Up @@ -318,7 +284,7 @@ class Procedure private (

def clearBlocks(): Unit = {
// O(n) because we are careful to unlink the parents etc.
removeBlocks(_blocks)
removeBlocksDisconnect(_blocks)
}

def callers(): Iterable[Procedure] = _callers.map(_.parent.parent).toSet[Procedure]
Expand Down Expand Up @@ -369,6 +335,7 @@ class Block private (
this(label, address, IntrusiveList().addAll(statements), jump, mutable.HashSet.empty)
}

def isReturn: Boolean = parent.returnBlock.contains(this)
def isEntry: Boolean = parent.entryBlock.contains(this)

def jump: Jump = _jump
Expand Down
2 changes: 0 additions & 2 deletions src/main/scala/ir/dsl/DSL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,6 @@ def block(label: String, sl: (Statement | EventuallyStatement | EventuallyJump)*
val statements : Seq[EventuallyStatement] = sl.flatMap {
case s: Statement => Some(ResolvableStatement(s))
case o: EventuallyStatement => Some(o)
case o: EventuallyCall => Some(o)
case o: EventuallyIndirectCall => Some(o)
case g: EventuallyJump => None
}
val jump = sl.collectFirst {
Expand Down
38 changes: 38 additions & 0 deletions src/main/scala/ir/transforms/StripUnreachableFunctions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package ir.transforms
import ir._
import collection.mutable

// This shouldn't be run before indirect calls are resolved
def stripUnreachableFunctions(p: Program, depth: Int = Int.MaxValue): Unit = {
val procedureCalleeNames = p.procedures.map(f => f.name -> f.calls.map(_.name)).toMap

val toVisit: mutable.LinkedHashSet[(Int, String)] = mutable.LinkedHashSet((0, p.mainProcedure.name))
var reachableFound = true
val reachableNames = mutable.HashMap[String, Int]()
while (toVisit.nonEmpty) {
val next = toVisit.head
toVisit.remove(next)

if (next._1 <= depth) {

def addName(depth: Int, name: String): Unit = {
val oldDepth = reachableNames.getOrElse(name, Integer.MAX_VALUE)
reachableNames.put(next._2, if depth < oldDepth then depth else oldDepth)
}
addName(next._1, next._2)

val callees = procedureCalleeNames(next._2)

toVisit.addAll(callees.diff(reachableNames.keySet).map(c => (next._1 + 1, c)))
callees.foreach(c => addName(next._1 + 1, c))
}
}
p.procedures = p.procedures.filter(f => reachableNames.keySet.contains(f.name))

for (elem <- p.procedures.filter(c => c.calls.exists(s => !p.procedures.contains(s)))) {
// last layer is analysed only as specifications so we remove the body for anything that calls
// a function we have removed

elem.clearBlocks()
}
}
2 changes: 1 addition & 1 deletion src/main/scala/util/RunUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ object IRTransform {

Logger.info("[!] Stripping unreachable")
val before = ctx.program.procedures.size
ctx.program.stripUnreachableFunctions(config.procedureTrimDepth)
transforms.stripUnreachableFunctions(ctx.program, config.procedureTrimDepth)
Logger.info(
s"[!] Removed ${before - ctx.program.procedures.size} functions (${ctx.program.procedures.size} remaining)"
)
Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/ir/InterpreterTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter {
var IRProgram = IRTranslator.translate
IRProgram = ExternalRemover(externalFunctions.map(e => e.name)).visitProgram(IRProgram)
IRProgram = Renamer(Set("free")).visitProgram(IRProgram)
IRProgram.stripUnreachableFunctions()
transforms.stripUnreachableFunctions(IRProgram)
val stackIdentification = StackSubstituter()
stackIdentification.visitProgram(IRProgram)
IRProgram.setModifies(Map())
Expand Down

0 comments on commit b7378b6

Please sign in to comment.