Skip to content

Commit

Permalink
Fixed PR issues
Browse files Browse the repository at this point in the history
  • Loading branch information
ziggyfish committed Oct 13, 2023
1 parent 9c5d4b8 commit f9e55b8
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 29 deletions.
31 changes: 20 additions & 11 deletions src/main/scala/analysis/NonReturningFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.parallel.CollectionConverters.*

class NonReturningFunctions {

def transform(procedures: ArrayBuffer[Procedure]): Unit = {
private val knownNonReturningFunctions = List("exit", "_exit", "abort", "__stack_chk_fail", "__assert_fail", "longjump")
def transform(procedures: ArrayBuffer[Procedure], externalFunctions: Set[ExternalFunction]): Unit = {
val blocksToRemove: Queue[String] = Queue()
val mapJumpsToBlocks: Map[String, ArrayBuffer[(Jump, Block)]] = Map()
val mapBlocksToProcedure: Map[String, (Procedure, Integer)] = Map()


val externalFunctionNames = externalFunctions.map(func => func.name)
// Check if the goto is part of a endless loop by checking to see if the index its jumping to is earlier on in the method,
// and that any jumps between the index its jumping to, and the index of the instruction, don't jump outside of that range.
// This method checks for both continue (which jumps back to the while loop, and break (which jumps outside of loops).
Expand Down Expand Up @@ -51,10 +51,11 @@ class NonReturningFunctions {
// look into each procedure, and calculate the number of return statements in each block
// and create maps between jumps and blocks, and blocks and procedures.
// this also looks at endless loops, and removes unreachable code after endless blocks


for (proc <- procedures) {

for ((block, index) <- proc.blocks.zipWithIndex) {

mapBlocksToProcedure.addOne(block.label, (proc, index))
for (jump <- block.jumps) {

Expand All @@ -64,10 +65,16 @@ class NonReturningFunctions {
block.countOfReturnStatements += 1
}
case directCall: DirectCall =>

if (knownNonReturningFunctions.contains(directCall.target.name)) {
directCall.returnTarget = None
}

mapJumpsToBlocks.put(directCall.target.name, mapJumpsToBlocks.getOrElse(directCall.target.name, ArrayBuffer()).addOne((directCall, block)))

case goTo: GoTo =>
mapJumpsToBlocks.put(goTo.target.label, mapJumpsToBlocks.getOrElse(goTo.target.label, ArrayBuffer()).addOne((goTo, block)))
if (proc.blocks.length > index && isEndlessLoop(proc, goTo, index)) {
if (proc.blocks.length > index+1 && isEndlessLoop(proc, goTo, index)) {
blocksToRemove.enqueue(proc.blocks(index + 1).label)
}
case _ =>
Expand All @@ -84,7 +91,7 @@ class NonReturningFunctions {
// find all direct calls that are non-returning, not an external function and currently have a return target.
// add the return targets to the queue for removal, and mark the function as non-returning
for (proc <- procedures) {
if (!proc.externalFunction && proc.calculateReturnCount() == 0) {
if (!externalFunctionNames.contains(proc.name) && proc.calculateReturnCount() == 0) {
mapJumpsToBlocks.get(proc.name) match {
case Some(v) => for (block <- v) {
val (_, containingBlock) = block
Expand Down Expand Up @@ -114,10 +121,10 @@ class NonReturningFunctions {
val (procedure, _) = mapBlocksToProcedure(label)
if (!mapJumpsToBlocks.contains(label) || mapJumpsToBlocks(label).length <= 1) {

var procedureBlock: Integer = null
var procedureBlock: Option[Integer] = None
for ((block, index) <- procedure.blocks.zipWithIndex) {
if (block.label == label) {
procedureBlock = index
procedureBlock = Some(index)
for (jump <- block.jumps) {
jump match {
case goTo: GoTo =>
Expand All @@ -127,9 +134,11 @@ class NonReturningFunctions {
}
}
}
if (procedureBlock != null) {
procedure.blocks.remove(procedureBlock)
blocksDeleted = true
procedureBlock match {
case Some(x) =>
procedure.blocks.remove(procedureBlock.get)
blocksDeleted = true
case _ =>
}
}
}
Expand Down
8 changes: 3 additions & 5 deletions src/main/scala/ir/Program.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,8 @@ class Procedure(
var address: Option[Int],
var blocks: ArrayBuffer[Block],
var in: ArrayBuffer[Parameter],
var out: ArrayBuffer[Parameter],
val externalFunction: Boolean
var out: ArrayBuffer[Parameter]
) {

def calls: Set[Procedure] = blocks.flatMap(_.calls).toSet
override def toString: String = {
s"Procedure $name at ${address.getOrElse("None")} with ${blocks.size} blocks and ${in.size} in and ${out.size} out parameters"
Expand Down Expand Up @@ -177,9 +175,9 @@ class Block(
var label: String,
var address: Option[Int],
var statements: ArrayBuffer[Statement],
var jumps: ArrayBuffer[Jump],
var countOfReturnStatements: Int
var jumps: ArrayBuffer[Jump]
) {
var countOfReturnStatements: Int = 0
def calls: Set[Procedure] = jumps.flatMap(_.calls).toSet
def modifies: Set[Global] = statements.flatMap(_.modifies).toSet
//def locals: Set[Variable] = statements.flatMap(_.locals).toSet ++ jumps.flatMap(_.locals).toSet
Expand Down
6 changes: 0 additions & 6 deletions src/main/scala/translating/BAPLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ import scala.jdk.CollectionConverters._

object BAPLoader {

private val knownNonReturningFunctions = List("exit", "_exit", "abort", "__stack_chk_fail", "__assert_fail", "longjump")

def isNonReturning(s: String): Boolean = knownNonReturningFunctions.contains(s)

def visitProject(ctx: ProjectContext): BAPProgram = {
val memorySections = visitSections(ctx.sections)
Expand Down Expand Up @@ -129,9 +126,6 @@ object BAPLoader {
val line = visitQuoteString(ctx.tid.name)
val insn = parseFromAttrs(ctx.attrs, "insn").getOrElse("")
val function = visitQuoteString(ctx.callee.tid.name).stripPrefix("@")
if (knownNonReturningFunctions.contains(function))
BAPDirectCall(parseAllowed(function), visitExp(ctx.cond), None, line, insn)
else
BAPDirectCall(parseAllowed(function), visitExp(ctx.cond), returnTarget, line, insn)
}

Expand Down
7 changes: 3 additions & 4 deletions src/main/scala/translating/BAPToIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,18 @@ import scala.collection.mutable
import scala.collection.mutable.Map
import scala.collection.mutable.ArrayBuffer

class BAPToIR(var program: BAPProgram, mainAddress: Int, externalFunctions: Set[ExternalFunction]) {
class BAPToIR(var program: BAPProgram, mainAddress: Int) {

private val nameToProcedure: mutable.Map[String, Procedure] = mutable.Map()
private val labelToBlock: mutable.Map[String, Block] = mutable.Map()

def translate: Program = {
var mainProcedure: Option[Procedure] = None
val procedures: ArrayBuffer[Procedure] = ArrayBuffer()
val externalFunctionNames = externalFunctions.map(func => func.name)
for (s <- program.subroutines) {
val blocks: ArrayBuffer[Block] = ArrayBuffer()
for (b <- s.blocks) {
val block = Block(b.label, b.address, ArrayBuffer(), ArrayBuffer(), 0)
val block = Block(b.label, b.address, ArrayBuffer(), ArrayBuffer())
blocks.append(block)
labelToBlock.addOne(b.label, block)
}
Expand All @@ -33,7 +32,7 @@ class BAPToIR(var program: BAPProgram, mainAddress: Int, externalFunctions: Set[
for (p <- s.out) {
out.append(p.toIR)
}
val procedure = Procedure(s.name, Some(s.address), blocks, in, out, externalFunctionNames.contains(s.name))
val procedure = Procedure(s.name, Some(s.address), blocks, in, out)
if (s.address == mainAddress) {
mainProcedure = Some(procedure)
}
Expand Down
6 changes: 3 additions & 3 deletions src/main/scala/util/RunUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ object RunUtils {

val (externalFunctions, globals, globalOffsets, mainAddress) = loadReadELF(readELFFileName)

val IRTranslator = BAPToIR(bapProgram, mainAddress, externalFunctions)
val IRTranslator = BAPToIR(bapProgram, mainAddress)
var IRProgram = IRTranslator.translate
NonReturningFunctions().transform(IRProgram.procedures)
NonReturningFunctions().transform(IRProgram.procedures, externalFunctions)

val specification = loadSpecification(specFileName, IRProgram, globals)

Expand Down Expand Up @@ -316,7 +316,7 @@ object RunUtils {
}

def addFakeProcedure(name: String): Unit = {
IRProgram.procedures += Procedure(name, None, ArrayBuffer(), ArrayBuffer(), ArrayBuffer(), true)
IRProgram.procedures += Procedure(name, None, ArrayBuffer(), ArrayBuffer(), ArrayBuffer())
}

def resolveAddresses(valueSet: Set[Value]): Set[AddressValue] = {
Expand Down

0 comments on commit f9e55b8

Please sign in to comment.