Skip to content

Commit

Permalink
make exprs immutable again
Browse files Browse the repository at this point in the history
  • Loading branch information
l-kent committed Nov 14, 2023
1 parent 87b911e commit fc3e754
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 45 deletions.
21 changes: 10 additions & 11 deletions src/main/scala/ir/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ package ir

import boogie._

trait Expr {
var ssa_id: Int = 0
sealed trait Expr {
def toBoogie: BExpr
def toGamma: BExpr = {
val gammaVars: Set[BExpr] = gammas.map(_.toGamma)
Expand All @@ -24,7 +23,7 @@ trait Expr {
def acceptVisit(visitor: Visitor): Expr = throw new Exception("visitor " + visitor + " unimplemented for: " + this)
}

trait Literal extends Expr {
sealed trait Literal extends Expr {
override def acceptVisit(visitor: Visitor): Literal = visitor.visitLiteral(this)
}

Expand Down Expand Up @@ -54,7 +53,7 @@ case class IntLiteral(value: BigInt) extends Literal {
override def toString: String = value.toString
}

class Extract(var end: Int, var start: Int, var body: Expr) extends Expr {
case class Extract(end: Int, start: Int, body: Expr) extends Expr {
override def toBoogie: BExpr = BVExtract(end, start, body.toBoogie)
override def gammas: Set[Expr] = body.gammas
override def variables: Set[Variable] = body.variables
Expand All @@ -64,7 +63,7 @@ class Extract(var end: Int, var start: Int, var body: Expr) extends Expr {
override def loads: Set[MemoryLoad] = body.loads
}

class Repeat(var repeats: Int, var body: Expr) extends Expr {
case class Repeat(repeats: Int, body: Expr) extends Expr {
override def toBoogie: BExpr = BVRepeat(repeats, body.toBoogie)
override def gammas: Set[Expr] = body.gammas
override def variables: Set[Variable] = body.variables
Expand All @@ -78,7 +77,7 @@ class Repeat(var repeats: Int, var body: Expr) extends Expr {
override def loads: Set[MemoryLoad] = body.loads
}

class ZeroExtend(var extension: Int, var body: Expr) extends Expr {
case class ZeroExtend(extension: Int, body: Expr) extends Expr {
override def toBoogie: BExpr = BVZeroExtend(extension, body.toBoogie)
override def gammas: Set[Expr] = body.gammas
override def variables: Set[Variable] = body.variables
Expand All @@ -92,7 +91,7 @@ class ZeroExtend(var extension: Int, var body: Expr) extends Expr {
override def loads: Set[MemoryLoad] = body.loads
}

class SignExtend(var extension: Int, var body: Expr) extends Expr {
case class SignExtend(extension: Int, body: Expr) extends Expr {
override def toBoogie: BExpr = BVSignExtend(extension, body.toBoogie)
override def gammas: Set[Expr] = body.gammas
override def variables: Set[Variable] = body.variables
Expand All @@ -106,7 +105,7 @@ class SignExtend(var extension: Int, var body: Expr) extends Expr {
override def loads: Set[MemoryLoad] = body.loads
}

class UnaryExpr(var op: UnOp, var arg: Expr) extends Expr {
case class UnaryExpr(op: UnOp, arg: Expr) extends Expr {
override def toBoogie: BExpr = UnaryBExpr(op, arg.toBoogie)
override def gammas: Set[Expr] = arg.gammas
override def variables: Set[Variable] = arg.variables
Expand Down Expand Up @@ -154,7 +153,7 @@ sealed trait BVUnOp(op: String) extends UnOp {
case object BVNOT extends BVUnOp("not")
case object BVNEG extends BVUnOp("neg")

class BinaryExpr(var op: BinOp, var arg1: Expr, var arg2: Expr) extends Expr {
case class BinaryExpr(op: BinOp, arg1: Expr, arg2: Expr) extends Expr {
override def toBoogie: BExpr = BinaryBExpr(op, arg1.toBoogie, arg2.toBoogie)
override def gammas: Set[Expr] = arg1.gammas ++ arg2.gammas
override def variables: Set[Variable] = arg1.variables ++ arg2.variables
Expand Down Expand Up @@ -298,7 +297,7 @@ enum Endian {
case BigEndian
}

class MemoryStore(var mem: Memory, var index: Expr, var value: Expr, var endian: Endian, var size: Int) extends Expr {
case class MemoryStore(mem: Memory, index: Expr, value: Expr, endian: Endian, size: Int) extends Expr {
override def toBoogie: BMemoryStore = BMemoryStore(mem.toBoogie, index.toBoogie, value.toBoogie, endian, size)
override def toGamma: GammaStore =
GammaStore(mem.toGamma, index.toBoogie, value.toGamma, size, size / mem.valueSize)
Expand All @@ -312,7 +311,7 @@ class MemoryStore(var mem: Memory, var index: Expr, var value: Expr, var endian:
override def acceptVisit(visitor: Visitor): Expr = visitor.visitMemoryStore(this)
}

class MemoryLoad(var mem: Memory, var index: Expr, var endian: Endian, var size: Int) extends Expr {
case class MemoryLoad(mem: Memory, index: Expr, endian: Endian, size: Int) extends Expr {
override def toBoogie: BMemoryLoad = BMemoryLoad(mem.toBoogie, index.toBoogie, endian, size)
override def toGamma: BExpr = if (mem.name == "stack") {
GammaLoad(mem.toGamma, index.toBoogie, size, size / mem.valueSize)
Expand Down
23 changes: 9 additions & 14 deletions src/main/scala/ir/Program.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ class Procedure(

def stackIdentification(): Unit = {
val stackPointer = Register("R31", BitVecType(64))
val stackRefs: mutable.Set[Variable] = mutable.Set(stackPointer)
val stackSubstituter = StackSubstituter()
stackSubstituter.stackRefs.add(stackPointer)
val visitedBlocks: mutable.Set[Block] = mutable.Set()
val stackMemory = Memory("stack", 64, 8)
val firstBlock = blocks.headOption
Expand All @@ -141,30 +142,24 @@ class Procedure(
s match {
case l: LocalAssign =>
// replace mem with stack in loads if index contains stack references
val loads = l.rhs.loads
for (load <- loads) {
val loadStackRefs = load.index.variables.intersect(stackRefs)
if (loadStackRefs.nonEmpty) {
load.mem = stackMemory
}
}
stackSubstituter.visitLocalAssign(l)

// update stack references
val variableVisitor = VariablesWithoutStoresLoads()
variableVisitor.visitExpr(l.rhs)

val rhsStackRefs = variableVisitor.variables.toSet.intersect(stackRefs)
val rhsStackRefs = variableVisitor.variables.toSet.intersect(stackSubstituter.stackRefs)
if (rhsStackRefs.nonEmpty) {
stackRefs.add(l.lhs)
} else if (stackRefs.contains(l.lhs) && l.lhs != stackPointer) {
stackRefs.remove(l.lhs)
stackSubstituter.stackRefs.add(l.lhs)
} else if (stackSubstituter.stackRefs.contains(l.lhs) && l.lhs != stackPointer) {
stackSubstituter.stackRefs.remove(l.lhs)
}
case m: MemoryAssign =>
// replace mem with stack if index contains stack reference
val indexStackRefs = m.rhs.index.variables.intersect(stackRefs)
val indexStackRefs = m.rhs.index.variables.intersect(stackSubstituter.stackRefs)
if (indexStackRefs.nonEmpty) {
m.lhs = stackMemory
m.rhs.mem = stackMemory
m.rhs = m.rhs.copy(mem = stackMemory)
}
case _ =>
}
Expand Down
43 changes: 23 additions & 20 deletions src/main/scala/ir/Visitor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,47 +85,35 @@ abstract class Visitor {
}

def visitExtract(node: Extract): Expr = {
node.body = visitExpr(node.body)
node
node.copy(body = visitExpr(node.body))
}

def visitRepeat(node: Repeat): Expr = {
node.body = visitExpr(node.body)
node
node.copy(body = visitExpr(node.body))
}

def visitZeroExtend(node: ZeroExtend): Expr = {
node.body = visitExpr(node.body)
node
node.copy(body = visitExpr(node.body))
}

def visitSignExtend(node: SignExtend): Expr = {
node.body = visitExpr(node.body)
node
node.copy(body = visitExpr(node.body))
}

def visitUnaryExpr(node: UnaryExpr): Expr = {
node.arg = visitExpr(node.arg)
node
node.copy(arg = visitExpr(node.arg))
}

def visitBinaryExpr(node: BinaryExpr): Expr = {
node.arg1 = visitExpr(node.arg1)
node.arg2 = visitExpr(node.arg2)
node
node.copy(arg1 = visitExpr(node.arg1), arg2 = visitExpr(node.arg2))
}

def visitMemoryStore(node: MemoryStore): MemoryStore = {
node.mem = visitMemory(node.mem)
node.index = visitExpr(node.index)
node.value = visitExpr(node.value)
node
node.copy(mem = visitMemory(node.mem), index = visitExpr(node.index), value = visitExpr(node.value))
}

def visitMemoryLoad(node: MemoryLoad): Expr = {
node.mem = visitMemory(node.mem)
node.index = visitExpr(node.index)
node
node.copy(mem = visitMemory(node.mem), index = visitExpr(node.index))
}

def visitMemory(node: Memory): Memory = node
Expand Down Expand Up @@ -255,6 +243,21 @@ abstract class ReadOnlyVisitor extends Visitor {

}

class StackSubstituter extends Visitor {
val stackRefs: mutable.Set[Variable] = mutable.Set()
val stackMemory: Memory = Memory("stack", 64, 8)

override def visitMemoryLoad(node: MemoryLoad): MemoryLoad = {
val loadStackRefs = node.index.variables.intersect(stackRefs)
if (loadStackRefs.nonEmpty) {
node.copy(mem = stackMemory)
} else {
node
}
}

}

class Substituter(variables: Map[Variable, Variable] = Map(), memories: Map[Memory, Memory] = Map()) extends Visitor {
override def visitVariable(node: Variable): Variable = variables.get(node) match {
case Some(v: Variable) => v
Expand Down

0 comments on commit fc3e754

Please sign in to comment.