Skip to content

Commit

Permalink
Merge pull request #142 from UQ-PAC/immutable-expr
Browse files Browse the repository at this point in the history
Make Exprs Immutable
  • Loading branch information
l-kent authored Jan 12, 2024
2 parents 2a7d05f + 72dde92 commit 40a0b0f
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 93 deletions.
21 changes: 10 additions & 11 deletions src/main/scala/ir/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ package ir

import boogie._

trait Expr {
var ssa_id: Int = 0
sealed trait Expr {
def toBoogie: BExpr
def toGamma: BExpr = {
val gammaVars: Set[BExpr] = gammas.map(_.toGamma)
Expand All @@ -24,7 +23,7 @@ trait Expr {
def acceptVisit(visitor: Visitor): Expr = throw new Exception("visitor " + visitor + " unimplemented for: " + this)
}

trait Literal extends Expr {
sealed trait Literal extends Expr {
override def acceptVisit(visitor: Visitor): Literal = visitor.visitLiteral(this)
}

Expand Down Expand Up @@ -54,7 +53,7 @@ case class IntLiteral(value: BigInt) extends Literal {
override def toString: String = value.toString
}

class Extract(var end: Int, var start: Int, var body: Expr) extends Expr {
case class Extract(end: Int, start: Int, body: Expr) extends Expr {
override def toBoogie: BExpr = BVExtract(end, start, body.toBoogie)
override def gammas: Set[Expr] = body.gammas
override def variables: Set[Variable] = body.variables
Expand All @@ -64,7 +63,7 @@ class Extract(var end: Int, var start: Int, var body: Expr) extends Expr {
override def loads: Set[MemoryLoad] = body.loads
}

class Repeat(var repeats: Int, var body: Expr) extends Expr {
case class Repeat(repeats: Int, body: Expr) extends Expr {
override def toBoogie: BExpr = BVRepeat(repeats, body.toBoogie)
override def gammas: Set[Expr] = body.gammas
override def variables: Set[Variable] = body.variables
Expand All @@ -78,7 +77,7 @@ class Repeat(var repeats: Int, var body: Expr) extends Expr {
override def loads: Set[MemoryLoad] = body.loads
}

class ZeroExtend(var extension: Int, var body: Expr) extends Expr {
case class ZeroExtend(extension: Int, body: Expr) extends Expr {
override def toBoogie: BExpr = BVZeroExtend(extension, body.toBoogie)
override def gammas: Set[Expr] = body.gammas
override def variables: Set[Variable] = body.variables
Expand All @@ -92,7 +91,7 @@ class ZeroExtend(var extension: Int, var body: Expr) extends Expr {
override def loads: Set[MemoryLoad] = body.loads
}

class SignExtend(var extension: Int, var body: Expr) extends Expr {
case class SignExtend(extension: Int, body: Expr) extends Expr {
override def toBoogie: BExpr = BVSignExtend(extension, body.toBoogie)
override def gammas: Set[Expr] = body.gammas
override def variables: Set[Variable] = body.variables
Expand All @@ -106,7 +105,7 @@ class SignExtend(var extension: Int, var body: Expr) extends Expr {
override def loads: Set[MemoryLoad] = body.loads
}

class UnaryExpr(var op: UnOp, var arg: Expr) extends Expr {
case class UnaryExpr(op: UnOp, arg: Expr) extends Expr {
override def toBoogie: BExpr = UnaryBExpr(op, arg.toBoogie)
override def gammas: Set[Expr] = arg.gammas
override def variables: Set[Variable] = arg.variables
Expand Down Expand Up @@ -154,7 +153,7 @@ sealed trait BVUnOp(op: String) extends UnOp {
case object BVNOT extends BVUnOp("not")
case object BVNEG extends BVUnOp("neg")

class BinaryExpr(var op: BinOp, var arg1: Expr, var arg2: Expr) extends Expr {
case class BinaryExpr(op: BinOp, arg1: Expr, arg2: Expr) extends Expr {
override def toBoogie: BExpr = BinaryBExpr(op, arg1.toBoogie, arg2.toBoogie)
override def gammas: Set[Expr] = arg1.gammas ++ arg2.gammas
override def variables: Set[Variable] = arg1.variables ++ arg2.variables
Expand Down Expand Up @@ -298,7 +297,7 @@ enum Endian {
case BigEndian
}

class MemoryStore(var mem: Memory, var index: Expr, var value: Expr, var endian: Endian, var size: Int) extends Expr {
case class MemoryStore(mem: Memory, index: Expr, value: Expr, endian: Endian, size: Int) extends Expr {
override def toBoogie: BMemoryStore = BMemoryStore(mem.toBoogie, index.toBoogie, value.toBoogie, endian, size)
override def toGamma: GammaStore =
GammaStore(mem.toGamma, index.toBoogie, value.toGamma, size, size / mem.valueSize)
Expand All @@ -312,7 +311,7 @@ class MemoryStore(var mem: Memory, var index: Expr, var value: Expr, var endian:
override def acceptVisit(visitor: Visitor): Expr = visitor.visitMemoryStore(this)
}

class MemoryLoad(var mem: Memory, var index: Expr, var endian: Endian, var size: Int) extends Expr {
case class MemoryLoad(mem: Memory, index: Expr, endian: Endian, size: Int) extends Expr {
override def toBoogie: BMemoryLoad = BMemoryLoad(mem.toBoogie, index.toBoogie, endian, size)
override def toGamma: BExpr = if (mem.name == "stack") {
GammaLoad(mem.toGamma, index.toBoogie, size, size / mem.valueSize)
Expand Down
60 changes: 0 additions & 60 deletions src/main/scala/ir/Program.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,6 @@ class Program(var procedures: ArrayBuffer[Procedure], var mainProcedure: Procedu
}
}

def stackIdentification(): Unit = {
for (p <- procedures) {
p.stackIdentification()
}
}

/**
* Takes all the memory sections we get from the ADT (previously in initialMemory) and restricts initialMemory to
* just the .data section (which contains things such as global variables which are mutable) and puts the .rodata
Expand Down Expand Up @@ -124,60 +118,6 @@ class Procedure(
}
var modifies: mutable.Set[Global] = mutable.Set()

def stackIdentification(): Unit = {
val stackPointer = Register("R31", BitVecType(64))
val stackRefs: mutable.Set[Variable] = mutable.Set(stackPointer)
val visitedBlocks: mutable.Set[Block] = mutable.Set()
val stackMemory = Memory("stack", 64, 8)
val firstBlock = blocks.headOption
firstBlock.foreach(visitBlock)

// does not handle loops but we do not currently support loops in block CFG so this should do for now anyway
def visitBlock(b: Block): Unit = {
if (visitedBlocks.contains(b)) {
return
}
for (s <- b.statements) {
s match {
case l: LocalAssign =>
// replace mem with stack in loads if index contains stack references
val loads = l.rhs.loads
for (load <- loads) {
val loadStackRefs = load.index.variables.intersect(stackRefs)
if (loadStackRefs.nonEmpty) {
load.mem = stackMemory
}
}

// update stack references
val variableVisitor = VariablesWithoutStoresLoads()
variableVisitor.visitExpr(l.rhs)

val rhsStackRefs = variableVisitor.variables.toSet.intersect(stackRefs)
if (rhsStackRefs.nonEmpty) {
stackRefs.add(l.lhs)
} else if (stackRefs.contains(l.lhs) && l.lhs != stackPointer) {
stackRefs.remove(l.lhs)
}
case m: MemoryAssign =>
// replace mem with stack if index contains stack reference
val indexStackRefs = m.rhs.index.variables.intersect(stackRefs)
if (indexStackRefs.nonEmpty) {
m.lhs = stackMemory
m.rhs.mem = stackMemory
}
case _ =>
}
}
visitedBlocks.add(b)
b.jump match {
case g: GoTo => g.targets.foreach(visitBlock)
case d: DirectCall => d.returnTarget.foreach(visitBlock)
case i: IndirectCall => i.returnTarget.foreach(visitBlock)
}
}
}

}

class Block(
Expand Down
122 changes: 102 additions & 20 deletions src/main/scala/ir/Visitor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,47 +85,35 @@ abstract class Visitor {
}

def visitExtract(node: Extract): Expr = {
node.body = visitExpr(node.body)
node
node.copy(body = visitExpr(node.body))
}

def visitRepeat(node: Repeat): Expr = {
node.body = visitExpr(node.body)
node
node.copy(body = visitExpr(node.body))
}

def visitZeroExtend(node: ZeroExtend): Expr = {
node.body = visitExpr(node.body)
node
node.copy(body = visitExpr(node.body))
}

def visitSignExtend(node: SignExtend): Expr = {
node.body = visitExpr(node.body)
node
node.copy(body = visitExpr(node.body))
}

def visitUnaryExpr(node: UnaryExpr): Expr = {
node.arg = visitExpr(node.arg)
node
node.copy(arg = visitExpr(node.arg))
}

def visitBinaryExpr(node: BinaryExpr): Expr = {
node.arg1 = visitExpr(node.arg1)
node.arg2 = visitExpr(node.arg2)
node
node.copy(arg1 = visitExpr(node.arg1), arg2 = visitExpr(node.arg2))
}

def visitMemoryStore(node: MemoryStore): MemoryStore = {
node.mem = visitMemory(node.mem)
node.index = visitExpr(node.index)
node.value = visitExpr(node.value)
node
node.copy(mem = visitMemory(node.mem), index = visitExpr(node.index), value = visitExpr(node.value))
}

def visitMemoryLoad(node: MemoryLoad): Expr = {
node.mem = visitMemory(node.mem)
node.index = visitExpr(node.index)
node
node.copy(mem = visitMemory(node.mem), index = visitExpr(node.index))
}

def visitMemory(node: Memory): Memory = node
Expand Down Expand Up @@ -255,6 +243,100 @@ abstract class ReadOnlyVisitor extends Visitor {

}

/**
* Visits all reachable blocks in a procedure, depth-first, in the order they are reachable from the start of the
* procedure.
* Does not jump to other procedures.
* Only modifies statements and jumps.
* */
abstract class IntraproceduralControlFlowVisitor extends Visitor {
private val visitedBlocks: mutable.Set[Block] = mutable.Set()

override def visitProcedure(node: Procedure): Procedure = {
node.blocks.headOption.foreach(visitBlock)
node
}

override def visitBlock(node: Block): Block = {
if (visitedBlocks.contains(node)) {
return node
}
for (i <- node.statements.indices) {
node.statements(i) = visitStatement(node.statements(i))
}
visitedBlocks.add(node)
node.jump = visitJump(node.jump)
node
}

override def visitGoTo(node: GoTo): Jump = {
node.targets.foreach(visitBlock)
node
}

override def visitDirectCall(node: DirectCall): Jump = {
node.returnTarget.foreach(visitBlock)
node
}

override def visitIndirectCall(node: IndirectCall): Jump = {
node.target = visitVariable(node.target)
node.returnTarget.foreach(visitBlock)
node
}
}

// TODO: does this break for programs with loops? need to calculate a fixed-point?
class StackSubstituter extends IntraproceduralControlFlowVisitor {
private val stackPointer = Register("R31", BitVecType(64))
private val stackMemory = Memory("stack", 64, 8)
val stackRefs: mutable.Set[Variable] = mutable.Set(stackPointer)

override def visitProcedure(node: Procedure): Procedure = {
// reset for each procedure
stackRefs.clear()
stackRefs.add(stackPointer)
super.visitProcedure(node)
}

override def visitMemoryLoad(node: MemoryLoad): MemoryLoad = {
// replace mem with stack in load if index contains stack references
val loadStackRefs = node.index.variables.intersect(stackRefs)
if (loadStackRefs.nonEmpty) {
node.copy(mem = stackMemory)
} else {
node
}
}

override def visitLocalAssign(node: LocalAssign): Statement = {
node.lhs = visitVariable(node.lhs)
node.rhs = visitExpr(node.rhs)

// update stack references
val variableVisitor = VariablesWithoutStoresLoads()
variableVisitor.visitExpr(node.rhs)

val rhsStackRefs = variableVisitor.variables.toSet.intersect(stackRefs)
if (rhsStackRefs.nonEmpty) {
stackRefs.add(node.lhs)
} else if (stackRefs.contains(node.lhs) && node.lhs != stackPointer) {
stackRefs.remove(node.lhs)
}
node
}

override def visitMemoryAssign(node: MemoryAssign): Statement = {
val indexStackRefs = node.rhs.index.variables.intersect(stackRefs)
if (indexStackRefs.nonEmpty) {
node.lhs = stackMemory
node.rhs = node.rhs.copy(mem = stackMemory)
}
node
}

}

class Substituter(variables: Map[Variable, Variable] = Map(), memories: Map[Memory, Memory] = Map()) extends Visitor {
override def visitVariable(node: Variable): Variable = variables.get(node) match {
case Some(v: Variable) => v
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/util/RunUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ object RunUtils {

IRProgram.determineRelevantMemory(globalOffsets)
IRProgram.stripUnreachableFunctions()
IRProgram.stackIdentification()
val stackIdentification = StackSubstituter()
stackIdentification.visitProgram(IRProgram)

val specModifies = specification.subroutines.map(s => s.name -> s.modifies).toMap
IRProgram.setModifies(specModifies)
Expand Down
3 changes: 2 additions & 1 deletion src/test/scala/ir/InterpreterTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter {
IRProgram = ExternalRemover(externalFunctions.map(e => e.name)).visitProgram(IRProgram)
IRProgram = Renamer(Set("free")).visitProgram(IRProgram)
IRProgram.stripUnreachableFunctions()
IRProgram.stackIdentification()
val stackIdentification = StackSubstituter()
stackIdentification.visitProgram(IRProgram)
IRProgram.setModifies(Map())

(IRProgram, globals)
Expand Down

0 comments on commit 40a0b0f

Please sign in to comment.