diff --git a/src/main/scala/ir/Program.scala b/src/main/scala/ir/Program.scala index ffd9f8e5e..8b4796615 100644 --- a/src/main/scala/ir/Program.scala +++ b/src/main/scala/ir/Program.scala @@ -73,12 +73,6 @@ class Program(var procedures: ArrayBuffer[Procedure], var mainProcedure: Procedu } } - def stackIdentification(): Unit = { - for (p <- procedures) { - p.stackIdentification() - } - } - /** * Takes all the memory sections we get from the ADT (previously in initialMemory) and restricts initialMemory to * just the .data section (which contains things such as global variables which are mutable) and puts the .rodata @@ -124,55 +118,6 @@ class Procedure( } var modifies: mutable.Set[Global] = mutable.Set() - def stackIdentification(): Unit = { - val stackPointer = Register("R31", BitVecType(64)) - val stackSubstituter = StackSubstituter() - stackSubstituter.stackRefs.add(stackPointer) - val visitedBlocks: mutable.Set[Block] = mutable.Set() - val stackMemory = Memory("stack", 64, 8) - val firstBlock = blocks.headOption - firstBlock.foreach(visitBlock) - - // does not handle loops but we do not currently support loops in block CFG so this should do for now anyway - def visitBlock(b: Block): Unit = { - if (visitedBlocks.contains(b)) { - return - } - for (s <- b.statements) { - s match { - case l: LocalAssign => - // replace mem with stack in loads if index contains stack references - stackSubstituter.visitLocalAssign(l) - - // update stack references - val variableVisitor = VariablesWithoutStoresLoads() - variableVisitor.visitExpr(l.rhs) - - val rhsStackRefs = variableVisitor.variables.toSet.intersect(stackSubstituter.stackRefs) - if (rhsStackRefs.nonEmpty) { - stackSubstituter.stackRefs.add(l.lhs) - } else if (stackSubstituter.stackRefs.contains(l.lhs) && l.lhs != stackPointer) { - stackSubstituter.stackRefs.remove(l.lhs) - } - case m: MemoryAssign => - // replace mem with stack if index contains stack reference - val indexStackRefs = m.rhs.index.variables.intersect(stackSubstituter.stackRefs) - if (indexStackRefs.nonEmpty) { - m.lhs = stackMemory - m.rhs = m.rhs.copy(mem = stackMemory) - } - case _ => - } - } - visitedBlocks.add(b) - b.jump match { - case g: GoTo => g.targets.foreach(visitBlock) - case d: DirectCall => d.returnTarget.foreach(visitBlock) - case i: IndirectCall => i.returnTarget.foreach(visitBlock) - } - } - } - } class Block( diff --git a/src/main/scala/ir/Visitor.scala b/src/main/scala/ir/Visitor.scala index 97281f6a0..0b6d79339 100644 --- a/src/main/scala/ir/Visitor.scala +++ b/src/main/scala/ir/Visitor.scala @@ -243,11 +243,64 @@ abstract class ReadOnlyVisitor extends Visitor { } -class StackSubstituter extends Visitor { - val stackRefs: mutable.Set[Variable] = mutable.Set() - val stackMemory: Memory = Memory("stack", 64, 8) +/** + * Visits all reachable blocks in a procedure, depth-first, in the order they are reachable from the start of the + * procedure. + * Does not jump to other procedures. + * Only modifies statements and jumps. + * */ +abstract class IntraproceduralControlFlowVisitor extends Visitor { + private val visitedBlocks: mutable.Set[Block] = mutable.Set() + + override def visitProcedure(node: Procedure): Procedure = { + node.blocks.headOption.foreach(visitBlock) + node + } + + override def visitBlock(node: Block): Block = { + if (visitedBlocks.contains(node)) { + return node + } + for (i <- node.statements.indices) { + node.statements(i) = visitStatement(node.statements(i)) + } + visitedBlocks.add(node) + node.jump = visitJump(node.jump) + node + } + + override def visitGoTo(node: GoTo): Jump = { + node.targets.foreach(visitBlock) + node + } + + override def visitDirectCall(node: DirectCall): Jump = { + node.returnTarget.foreach(visitBlock) + node + } + + override def visitIndirectCall(node: IndirectCall): Jump = { + node.target = visitVariable(node.target) + node.returnTarget.foreach(visitBlock) + node + } +} + +// TODO: does this break for programs with loops? need to calculate a fixed-point? +class StackSubstituter extends IntraproceduralControlFlowVisitor { + private val stackPointer = Register("R31", BitVecType(64)) + private val stackMemory = Memory("stack", 64, 8) + val stackRefs: mutable.Set[Variable] = mutable.Set(stackPointer) + + override def visitProcedure(node: Procedure): Procedure = { + // reset for each procedure + stackRefs.clear() + stackRefs.add(stackPointer) + super.visitProcedure(node) + } override def visitMemoryLoad(node: MemoryLoad): MemoryLoad = { + // replace mem with stack in load if index contains stack references val loadStackRefs = node.index.variables.intersect(stackRefs) if (loadStackRefs.nonEmpty) { node.copy(mem = stackMemory) @@ -256,6 +309,32 @@ class StackSubstituter extends Visitor { } } + override def visitLocalAssign(node: LocalAssign): Statement = { + node.lhs = visitVariable(node.lhs) + node.rhs = visitExpr(node.rhs) + + // update stack references + val variableVisitor = VariablesWithoutStoresLoads() + variableVisitor.visitExpr(node.rhs) + + val rhsStackRefs = variableVisitor.variables.toSet.intersect(stackRefs) + if (rhsStackRefs.nonEmpty) { + stackRefs.add(node.lhs) + } else if (stackRefs.contains(node.lhs) && node.lhs != stackPointer) { + stackRefs.remove(node.lhs) + } + node + } + + override def visitMemoryAssign(node: MemoryAssign): Statement = { + val indexStackRefs = node.rhs.index.variables.intersect(stackRefs) + if (indexStackRefs.nonEmpty) { + node.lhs = stackMemory + node.rhs = node.rhs.copy(mem = stackMemory) + } + node + } + } class Substituter(variables: Map[Variable, Variable] = Map(), memories: Map[Memory, Memory] = Map()) extends Visitor { diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index bcc90b7c7..f7321aa74 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -99,7 +99,8 @@ object RunUtils { IRProgram.determineRelevantMemory(globalOffsets) IRProgram.stripUnreachableFunctions() - IRProgram.stackIdentification() + val stackIdentification = StackSubstituter() + stackIdentification.visitProgram(IRProgram) val specModifies = specification.subroutines.map(s => s.name -> s.modifies).toMap IRProgram.setModifies(specModifies) diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index d448b132e..dc30b32a6 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -21,7 +21,8 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { IRProgram = ExternalRemover(externalFunctions.map(e => e.name)).visitProgram(IRProgram) IRProgram = Renamer(Set("free")).visitProgram(IRProgram) IRProgram.stripUnreachableFunctions() - IRProgram.stackIdentification() + val stackIdentification = StackSubstituter() + stackIdentification.visitProgram(IRProgram) IRProgram.setModifies(Map()) (IRProgram, globals)