diff --git a/src/main/scala/ir/Expr.scala b/src/main/scala/ir/Expr.scala index c86741d2d..304db4a16 100644 --- a/src/main/scala/ir/Expr.scala +++ b/src/main/scala/ir/Expr.scala @@ -2,8 +2,7 @@ package ir import boogie._ -trait Expr { - var ssa_id: Int = 0 +sealed trait Expr { def toBoogie: BExpr def toGamma: BExpr = { val gammaVars: Set[BExpr] = gammas.map(_.toGamma) @@ -24,7 +23,7 @@ trait Expr { def acceptVisit(visitor: Visitor): Expr = throw new Exception("visitor " + visitor + " unimplemented for: " + this) } -trait Literal extends Expr { +sealed trait Literal extends Expr { override def acceptVisit(visitor: Visitor): Literal = visitor.visitLiteral(this) } @@ -54,7 +53,7 @@ case class IntLiteral(value: BigInt) extends Literal { override def toString: String = value.toString } -class Extract(var end: Int, var start: Int, var body: Expr) extends Expr { +case class Extract(end: Int, start: Int, body: Expr) extends Expr { override def toBoogie: BExpr = BVExtract(end, start, body.toBoogie) override def gammas: Set[Expr] = body.gammas override def variables: Set[Variable] = body.variables @@ -64,7 +63,7 @@ class Extract(var end: Int, var start: Int, var body: Expr) extends Expr { override def loads: Set[MemoryLoad] = body.loads } -class Repeat(var repeats: Int, var body: Expr) extends Expr { +case class Repeat(repeats: Int, body: Expr) extends Expr { override def toBoogie: BExpr = BVRepeat(repeats, body.toBoogie) override def gammas: Set[Expr] = body.gammas override def variables: Set[Variable] = body.variables @@ -78,7 +77,7 @@ class Repeat(var repeats: Int, var body: Expr) extends Expr { override def loads: Set[MemoryLoad] = body.loads } -class ZeroExtend(var extension: Int, var body: Expr) extends Expr { +case class ZeroExtend(extension: Int, body: Expr) extends Expr { override def toBoogie: BExpr = BVZeroExtend(extension, body.toBoogie) override def gammas: Set[Expr] = body.gammas override def variables: Set[Variable] = body.variables @@ -92,7 +91,7 @@ class ZeroExtend(var extension: Int, var body: Expr) extends Expr { override def loads: Set[MemoryLoad] = body.loads } -class SignExtend(var extension: Int, var body: Expr) extends Expr { +case class SignExtend(extension: Int, body: Expr) extends Expr { override def toBoogie: BExpr = BVSignExtend(extension, body.toBoogie) override def gammas: Set[Expr] = body.gammas override def variables: Set[Variable] = body.variables @@ -106,7 +105,7 @@ class SignExtend(var extension: Int, var body: Expr) extends Expr { override def loads: Set[MemoryLoad] = body.loads } -class UnaryExpr(var op: UnOp, var arg: Expr) extends Expr { +case class UnaryExpr(op: UnOp, arg: Expr) extends Expr { override def toBoogie: BExpr = UnaryBExpr(op, arg.toBoogie) override def gammas: Set[Expr] = arg.gammas override def variables: Set[Variable] = arg.variables @@ -154,7 +153,7 @@ sealed trait BVUnOp(op: String) extends UnOp { case object BVNOT extends BVUnOp("not") case object BVNEG extends BVUnOp("neg") -class BinaryExpr(var op: BinOp, var arg1: Expr, var arg2: Expr) extends Expr { +case class BinaryExpr(op: BinOp, arg1: Expr, arg2: Expr) extends Expr { override def toBoogie: BExpr = BinaryBExpr(op, arg1.toBoogie, arg2.toBoogie) override def gammas: Set[Expr] = arg1.gammas ++ arg2.gammas override def variables: Set[Variable] = arg1.variables ++ arg2.variables @@ -298,7 +297,7 @@ enum Endian { case BigEndian } -class MemoryStore(var mem: Memory, var index: Expr, var value: Expr, var endian: Endian, var size: Int) extends Expr { +case class MemoryStore(mem: Memory, index: Expr, value: Expr, endian: Endian, size: Int) extends Expr { override def toBoogie: BMemoryStore = BMemoryStore(mem.toBoogie, index.toBoogie, value.toBoogie, endian, size) override def toGamma: GammaStore = GammaStore(mem.toGamma, index.toBoogie, value.toGamma, size, size / mem.valueSize) @@ -312,7 +311,7 @@ class MemoryStore(var mem: Memory, var index: Expr, var value: Expr, var endian: override def acceptVisit(visitor: Visitor): Expr = visitor.visitMemoryStore(this) } -class MemoryLoad(var mem: Memory, var index: Expr, var endian: Endian, var size: Int) extends Expr { +case class MemoryLoad(mem: Memory, index: Expr, endian: Endian, size: Int) extends Expr { override def toBoogie: BMemoryLoad = BMemoryLoad(mem.toBoogie, index.toBoogie, endian, size) override def toGamma: BExpr = if (mem.name == "stack") { GammaLoad(mem.toGamma, index.toBoogie, size, size / mem.valueSize) diff --git a/src/main/scala/ir/Program.scala b/src/main/scala/ir/Program.scala index 9d8e6f7aa..8b4796615 100644 --- a/src/main/scala/ir/Program.scala +++ b/src/main/scala/ir/Program.scala @@ -73,12 +73,6 @@ class Program(var procedures: ArrayBuffer[Procedure], var mainProcedure: Procedu } } - def stackIdentification(): Unit = { - for (p <- procedures) { - p.stackIdentification() - } - } - /** * Takes all the memory sections we get from the ADT (previously in initialMemory) and restricts initialMemory to * just the .data section (which contains things such as global variables which are mutable) and puts the .rodata @@ -124,60 +118,6 @@ class Procedure( } var modifies: mutable.Set[Global] = mutable.Set() - def stackIdentification(): Unit = { - val stackPointer = Register("R31", BitVecType(64)) - val stackRefs: mutable.Set[Variable] = mutable.Set(stackPointer) - val visitedBlocks: mutable.Set[Block] = mutable.Set() - val stackMemory = Memory("stack", 64, 8) - val firstBlock = blocks.headOption - firstBlock.foreach(visitBlock) - - // does not handle loops but we do not currently support loops in block CFG so this should do for now anyway - def visitBlock(b: Block): Unit = { - if (visitedBlocks.contains(b)) { - return - } - for (s <- b.statements) { - s match { - case l: LocalAssign => - // replace mem with stack in loads if index contains stack references - val loads = l.rhs.loads - for (load <- loads) { - val loadStackRefs = load.index.variables.intersect(stackRefs) - if (loadStackRefs.nonEmpty) { - load.mem = stackMemory - } - } - - // update stack references - val variableVisitor = VariablesWithoutStoresLoads() - variableVisitor.visitExpr(l.rhs) - - val rhsStackRefs = variableVisitor.variables.toSet.intersect(stackRefs) - if (rhsStackRefs.nonEmpty) { - stackRefs.add(l.lhs) - } else if (stackRefs.contains(l.lhs) && l.lhs != stackPointer) { - stackRefs.remove(l.lhs) - } - case m: MemoryAssign => - // replace mem with stack if index contains stack reference - val indexStackRefs = m.rhs.index.variables.intersect(stackRefs) - if (indexStackRefs.nonEmpty) { - m.lhs = stackMemory - m.rhs.mem = stackMemory - } - case _ => - } - } - visitedBlocks.add(b) - b.jump match { - case g: GoTo => g.targets.foreach(visitBlock) - case d: DirectCall => d.returnTarget.foreach(visitBlock) - case i: IndirectCall => i.returnTarget.foreach(visitBlock) - } - } - } - } class Block( diff --git a/src/main/scala/ir/Visitor.scala b/src/main/scala/ir/Visitor.scala index 37768d1bd..0b6d79339 100644 --- a/src/main/scala/ir/Visitor.scala +++ b/src/main/scala/ir/Visitor.scala @@ -85,47 +85,35 @@ abstract class Visitor { } def visitExtract(node: Extract): Expr = { - node.body = visitExpr(node.body) - node + node.copy(body = visitExpr(node.body)) } def visitRepeat(node: Repeat): Expr = { - node.body = visitExpr(node.body) - node + node.copy(body = visitExpr(node.body)) } def visitZeroExtend(node: ZeroExtend): Expr = { - node.body = visitExpr(node.body) - node + node.copy(body = visitExpr(node.body)) } def visitSignExtend(node: SignExtend): Expr = { - node.body = visitExpr(node.body) - node + node.copy(body = visitExpr(node.body)) } def visitUnaryExpr(node: UnaryExpr): Expr = { - node.arg = visitExpr(node.arg) - node + node.copy(arg = visitExpr(node.arg)) } def visitBinaryExpr(node: BinaryExpr): Expr = { - node.arg1 = visitExpr(node.arg1) - node.arg2 = visitExpr(node.arg2) - node + node.copy(arg1 = visitExpr(node.arg1), arg2 = visitExpr(node.arg2)) } def visitMemoryStore(node: MemoryStore): MemoryStore = { - node.mem = visitMemory(node.mem) - node.index = visitExpr(node.index) - node.value = visitExpr(node.value) - node + node.copy(mem = visitMemory(node.mem), index = visitExpr(node.index), value = visitExpr(node.value)) } def visitMemoryLoad(node: MemoryLoad): Expr = { - node.mem = visitMemory(node.mem) - node.index = visitExpr(node.index) - node + node.copy(mem = visitMemory(node.mem), index = visitExpr(node.index)) } def visitMemory(node: Memory): Memory = node @@ -255,6 +243,100 @@ abstract class ReadOnlyVisitor extends Visitor { } +/** + * Visits all reachable blocks in a procedure, depth-first, in the order they are reachable from the start of the + * procedure. + * Does not jump to other procedures. + * Only modifies statements and jumps. + * */ +abstract class IntraproceduralControlFlowVisitor extends Visitor { + private val visitedBlocks: mutable.Set[Block] = mutable.Set() + + override def visitProcedure(node: Procedure): Procedure = { + node.blocks.headOption.foreach(visitBlock) + node + } + + override def visitBlock(node: Block): Block = { + if (visitedBlocks.contains(node)) { + return node + } + for (i <- node.statements.indices) { + node.statements(i) = visitStatement(node.statements(i)) + } + visitedBlocks.add(node) + node.jump = visitJump(node.jump) + node + } + + override def visitGoTo(node: GoTo): Jump = { + node.targets.foreach(visitBlock) + node + } + + override def visitDirectCall(node: DirectCall): Jump = { + node.returnTarget.foreach(visitBlock) + node + } + + override def visitIndirectCall(node: IndirectCall): Jump = { + node.target = visitVariable(node.target) + node.returnTarget.foreach(visitBlock) + node + } +} + +// TODO: does this break for programs with loops? need to calculate a fixed-point? +class StackSubstituter extends IntraproceduralControlFlowVisitor { + private val stackPointer = Register("R31", BitVecType(64)) + private val stackMemory = Memory("stack", 64, 8) + val stackRefs: mutable.Set[Variable] = mutable.Set(stackPointer) + + override def visitProcedure(node: Procedure): Procedure = { + // reset for each procedure + stackRefs.clear() + stackRefs.add(stackPointer) + super.visitProcedure(node) + } + + override def visitMemoryLoad(node: MemoryLoad): MemoryLoad = { + // replace mem with stack in load if index contains stack references + val loadStackRefs = node.index.variables.intersect(stackRefs) + if (loadStackRefs.nonEmpty) { + node.copy(mem = stackMemory) + } else { + node + } + } + + override def visitLocalAssign(node: LocalAssign): Statement = { + node.lhs = visitVariable(node.lhs) + node.rhs = visitExpr(node.rhs) + + // update stack references + val variableVisitor = VariablesWithoutStoresLoads() + variableVisitor.visitExpr(node.rhs) + + val rhsStackRefs = variableVisitor.variables.toSet.intersect(stackRefs) + if (rhsStackRefs.nonEmpty) { + stackRefs.add(node.lhs) + } else if (stackRefs.contains(node.lhs) && node.lhs != stackPointer) { + stackRefs.remove(node.lhs) + } + node + } + + override def visitMemoryAssign(node: MemoryAssign): Statement = { + val indexStackRefs = node.rhs.index.variables.intersect(stackRefs) + if (indexStackRefs.nonEmpty) { + node.lhs = stackMemory + node.rhs = node.rhs.copy(mem = stackMemory) + } + node + } + +} + class Substituter(variables: Map[Variable, Variable] = Map(), memories: Map[Memory, Memory] = Map()) extends Visitor { override def visitVariable(node: Variable): Variable = variables.get(node) match { case Some(v: Variable) => v diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index dd4173871..6ec84a96e 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -99,7 +99,8 @@ object RunUtils { IRProgram.determineRelevantMemory(globalOffsets) IRProgram.stripUnreachableFunctions() - IRProgram.stackIdentification() + val stackIdentification = StackSubstituter() + stackIdentification.visitProgram(IRProgram) val specModifies = specification.subroutines.map(s => s.name -> s.modifies).toMap IRProgram.setModifies(specModifies) diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index d448b132e..dc30b32a6 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -21,7 +21,8 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { IRProgram = ExternalRemover(externalFunctions.map(e => e.name)).visitProgram(IRProgram) IRProgram = Renamer(Set("free")).visitProgram(IRProgram) IRProgram.stripUnreachableFunctions() - IRProgram.stackIdentification() + val stackIdentification = StackSubstituter() + stackIdentification.visitProgram(IRProgram) IRProgram.setModifies(Map()) (IRProgram, globals)