diff --git a/src/main/scala/analysis/NonReturningFunctions.scala b/src/main/scala/analysis/NonReturningFunctions.scala index 38b0b48bc..524583e92 100644 --- a/src/main/scala/analysis/NonReturningFunctions.scala +++ b/src/main/scala/analysis/NonReturningFunctions.scala @@ -11,13 +11,13 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.parallel.CollectionConverters.* class NonReturningFunctions { - - def transform(procedures: ArrayBuffer[Procedure]): Unit = { + private val knownNonReturningFunctions = List("exit", "_exit", "abort", "__stack_chk_fail", "__assert_fail", "longjump") + def transform(procedures: ArrayBuffer[Procedure], externalFunctions: Set[ExternalFunction]): Unit = { val blocksToRemove: Queue[String] = Queue() val mapJumpsToBlocks: Map[String, ArrayBuffer[(Jump, Block)]] = Map() val mapBlocksToProcedure: Map[String, (Procedure, Integer)] = Map() - + val externalFunctionNames = externalFunctions.map(func => func.name) // Check if the goto is part of a endless loop by checking to see if the index its jumping to is earlier on in the method, // and that any jumps between the index its jumping to, and the index of the instruction, don't jump outside of that range. // This method checks for both continue (which jumps back to the while loop, and break (which jumps outside of loops). @@ -51,10 +51,11 @@ class NonReturningFunctions { // look into each procedure, and calculate the number of return statements in each block // and create maps between jumps and blocks, and blocks and procedures. // this also looks at endless loops, and removes unreachable code after endless blocks + + for (proc <- procedures) { for ((block, index) <- proc.blocks.zipWithIndex) { - mapBlocksToProcedure.addOne(block.label, (proc, index)) for (jump <- block.jumps) { @@ -64,10 +65,16 @@ class NonReturningFunctions { block.countOfReturnStatements += 1 } case directCall: DirectCall => + + if (knownNonReturningFunctions.contains(directCall.target.name)) { + directCall.returnTarget = None + } + mapJumpsToBlocks.put(directCall.target.name, mapJumpsToBlocks.getOrElse(directCall.target.name, ArrayBuffer()).addOne((directCall, block))) + case goTo: GoTo => mapJumpsToBlocks.put(goTo.target.label, mapJumpsToBlocks.getOrElse(goTo.target.label, ArrayBuffer()).addOne((goTo, block))) - if (proc.blocks.length > index && isEndlessLoop(proc, goTo, index)) { + if (proc.blocks.length > index+1 && isEndlessLoop(proc, goTo, index)) { blocksToRemove.enqueue(proc.blocks(index + 1).label) } case _ => @@ -84,7 +91,7 @@ class NonReturningFunctions { // find all direct calls that are non-returning, not an external function and currently have a return target. // add the return targets to the queue for removal, and mark the function as non-returning for (proc <- procedures) { - if (!proc.externalFunction && proc.calculateReturnCount() == 0) { + if (!externalFunctionNames.contains(proc.name) && proc.calculateReturnCount() == 0) { mapJumpsToBlocks.get(proc.name) match { case Some(v) => for (block <- v) { val (_, containingBlock) = block @@ -114,10 +121,10 @@ class NonReturningFunctions { val (procedure, _) = mapBlocksToProcedure(label) if (!mapJumpsToBlocks.contains(label) || mapJumpsToBlocks(label).length <= 1) { - var procedureBlock: Integer = null + var procedureBlock: Option[Integer] = None for ((block, index) <- procedure.blocks.zipWithIndex) { if (block.label == label) { - procedureBlock = index + procedureBlock = Some(index) for (jump <- block.jumps) { jump match { case goTo: GoTo => @@ -127,9 +134,11 @@ class NonReturningFunctions { } } } - if (procedureBlock != null) { - procedure.blocks.remove(procedureBlock) - blocksDeleted = true + procedureBlock match { + case Some(x) => + procedure.blocks.remove(procedureBlock.get) + blocksDeleted = true + case _ => } } } diff --git a/src/main/scala/ir/Program.scala b/src/main/scala/ir/Program.scala index 70dae7e9c..6f59ae1ba 100644 --- a/src/main/scala/ir/Program.scala +++ b/src/main/scala/ir/Program.scala @@ -106,10 +106,8 @@ class Procedure( var address: Option[Int], var blocks: ArrayBuffer[Block], var in: ArrayBuffer[Parameter], - var out: ArrayBuffer[Parameter], - val externalFunction: Boolean + var out: ArrayBuffer[Parameter] ) { - def calls: Set[Procedure] = blocks.flatMap(_.calls).toSet override def toString: String = { s"Procedure $name at ${address.getOrElse("None")} with ${blocks.size} blocks and ${in.size} in and ${out.size} out parameters" @@ -177,9 +175,9 @@ class Block( var label: String, var address: Option[Int], var statements: ArrayBuffer[Statement], - var jumps: ArrayBuffer[Jump], - var countOfReturnStatements: Int + var jumps: ArrayBuffer[Jump] ) { + var countOfReturnStatements: Int = 0 def calls: Set[Procedure] = jumps.flatMap(_.calls).toSet def modifies: Set[Global] = statements.flatMap(_.modifies).toSet //def locals: Set[Variable] = statements.flatMap(_.locals).toSet ++ jumps.flatMap(_.locals).toSet diff --git a/src/main/scala/translating/BAPLoader.scala b/src/main/scala/translating/BAPLoader.scala index 805b56435..be917f08b 100644 --- a/src/main/scala/translating/BAPLoader.scala +++ b/src/main/scala/translating/BAPLoader.scala @@ -10,9 +10,6 @@ import scala.jdk.CollectionConverters._ object BAPLoader { - private val knownNonReturningFunctions = List("exit", "_exit", "abort", "__stack_chk_fail", "__assert_fail", "longjump") - - def isNonReturning(s: String): Boolean = knownNonReturningFunctions.contains(s) def visitProject(ctx: ProjectContext): BAPProgram = { val memorySections = visitSections(ctx.sections) @@ -129,9 +126,6 @@ object BAPLoader { val line = visitQuoteString(ctx.tid.name) val insn = parseFromAttrs(ctx.attrs, "insn").getOrElse("") val function = visitQuoteString(ctx.callee.tid.name).stripPrefix("@") - if (knownNonReturningFunctions.contains(function)) - BAPDirectCall(parseAllowed(function), visitExp(ctx.cond), None, line, insn) - else BAPDirectCall(parseAllowed(function), visitExp(ctx.cond), returnTarget, line, insn) } diff --git a/src/main/scala/translating/BAPToIR.scala b/src/main/scala/translating/BAPToIR.scala index c976c44f2..4d2a8b418 100644 --- a/src/main/scala/translating/BAPToIR.scala +++ b/src/main/scala/translating/BAPToIR.scala @@ -9,7 +9,7 @@ import scala.collection.mutable import scala.collection.mutable.Map import scala.collection.mutable.ArrayBuffer -class BAPToIR(var program: BAPProgram, mainAddress: Int, externalFunctions: Set[ExternalFunction]) { +class BAPToIR(var program: BAPProgram, mainAddress: Int) { private val nameToProcedure: mutable.Map[String, Procedure] = mutable.Map() private val labelToBlock: mutable.Map[String, Block] = mutable.Map() @@ -17,11 +17,10 @@ class BAPToIR(var program: BAPProgram, mainAddress: Int, externalFunctions: Set[ def translate: Program = { var mainProcedure: Option[Procedure] = None val procedures: ArrayBuffer[Procedure] = ArrayBuffer() - val externalFunctionNames = externalFunctions.map(func => func.name) for (s <- program.subroutines) { val blocks: ArrayBuffer[Block] = ArrayBuffer() for (b <- s.blocks) { - val block = Block(b.label, b.address, ArrayBuffer(), ArrayBuffer(), 0) + val block = Block(b.label, b.address, ArrayBuffer(), ArrayBuffer()) blocks.append(block) labelToBlock.addOne(b.label, block) } @@ -33,7 +32,7 @@ class BAPToIR(var program: BAPProgram, mainAddress: Int, externalFunctions: Set[ for (p <- s.out) { out.append(p.toIR) } - val procedure = Procedure(s.name, Some(s.address), blocks, in, out, externalFunctionNames.contains(s.name)) + val procedure = Procedure(s.name, Some(s.address), blocks, in, out) if (s.address == mainAddress) { mainProcedure = Some(procedure) } diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index b3504cb66..f1e01eb65 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -73,9 +73,9 @@ object RunUtils { val (externalFunctions, globals, globalOffsets, mainAddress) = loadReadELF(readELFFileName) - val IRTranslator = BAPToIR(bapProgram, mainAddress, externalFunctions) + val IRTranslator = BAPToIR(bapProgram, mainAddress) var IRProgram = IRTranslator.translate - NonReturningFunctions().transform(IRProgram.procedures) + NonReturningFunctions().transform(IRProgram.procedures, externalFunctions) val specification = loadSpecification(specFileName, IRProgram, globals) @@ -316,7 +316,7 @@ object RunUtils { } def addFakeProcedure(name: String): Unit = { - IRProgram.procedures += Procedure(name, None, ArrayBuffer(), ArrayBuffer(), ArrayBuffer(), true) + IRProgram.procedures += Procedure(name, None, ArrayBuffer(), ArrayBuffer(), ArrayBuffer()) } def resolveAddresses(valueSet: Set[Value]): Set[AddressValue] = {