Skip to content

Commit

Permalink
booltobv1 and copyprop heuristic
Browse files Browse the repository at this point in the history
  • Loading branch information
ailrst committed Nov 1, 2024
1 parent f313ed3 commit 95744f7
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 74 deletions.
8 changes: 8 additions & 0 deletions src/main/scala/boogie/BExpr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ case class BFunctionCall(name: String, args: List[BExpr], outType: BType, uninte

case class UnaryBExpr(op: UnOp, arg: BExpr) extends BExpr {
override def getType: BType = (op, arg.getType) match {
case (BoolToBV1, BoolBType) => BitVecBType(1)
case (_: BoolUnOp, BoolBType) => BoolBType
case (_: BVUnOp, bv: BitVecBType) => bv
case (_: IntUnOp, IntBType) => IntBType
Expand All @@ -282,13 +283,15 @@ case class UnaryBExpr(op: UnOp, arg: BExpr) extends BExpr {
}

override def toString: String = op match {
case BoolToBV1 => s"$op($arg)"
case uOp: BoolUnOp => s"($uOp$arg)"
case uOp: BVUnOp => s"bv$uOp$inSize($arg)"
case uOp: IntUnOp => s"($uOp$arg)"
}

override def functionOps: Set[FunctionOp] = {
val thisFn = op match {
case b @ BoolToBV1 => Set(BoolToBV1Op(arg))
case b: BVUnOp =>
Set(BVFunctionOp(s"bv$b$inSize", s"bv$b", List(BParam(arg.getType)), BParam(getType)))
case _ => Set()
Expand Down Expand Up @@ -680,6 +683,11 @@ case class BInBounds(base: BExpr, len: BExpr, endian: Endian, i: BExpr) extends
override def loads: Set[BExpr] = base.loads ++ len.loads ++ i.loads
}


case class BoolToBV1Op(arg: BExpr) extends FunctionOp {
val fnName: String = "bool2bv1"
}

case class BMemoryLoad(memory: BMapVar, index: BExpr, endian: Endian, bits: Int) extends BExpr {
override def toString: String = s"$fnName($memory, $index)"

Expand Down
28 changes: 16 additions & 12 deletions src/main/scala/ir/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ sealed trait Expr {
lazy val variablesCached = variables
}


def size(e: Expr) = {
e.getType match {
case BitVecType(s) => Some(s)
Expand Down Expand Up @@ -67,9 +66,10 @@ case class IntLiteral(value: BigInt) extends Literal {
override def toString: String = value.toString
}

/**
* @param end : high bit exclusive
* @param start : low bit inclusive
/** @param end
* : high bit exclusive
* @param start
* : low bit inclusive
* @param body
*/
case class Extract(end: Int, start: Int, body: Expr) extends Expr {
Expand Down Expand Up @@ -130,6 +130,7 @@ case class UnaryExpr(op: UnOp, arg: Expr) extends Expr {
override def variables: Set[Variable] = arg.variables
override def loads: Set[MemoryLoad] = arg.loads
override def getType: IRType = (op, arg.getType) match {
case (BoolToBV1, BoolType) => BitVecType(1)
case (_: BoolUnOp, BoolType) => BoolType
case (_: BVUnOp, bv: BitVecType) => bv
case (_: IntUnOp, IntType) => IntType
Expand Down Expand Up @@ -157,6 +158,7 @@ sealed trait BoolUnOp(op: String) extends UnOp {
}

case object BoolNOT extends BoolUnOp("!")
case object BoolToBV1 extends BoolUnOp("bool2bv1")

sealed trait IntUnOp(op: String) extends UnOp {
override def toString: String = op
Expand All @@ -165,7 +167,6 @@ sealed trait IntUnOp(op: String) extends UnOp {

case object IntNEG extends IntUnOp("-")


sealed trait BVUnOp(op: String) extends UnOp {
override def toString: String = op
}
Expand Down Expand Up @@ -212,7 +213,9 @@ case class BinaryExpr(op: BinOp, arg1: Expr, arg2: Expr) extends Expr {
case IntEQ | IntNEQ | IntLT | IntLE | IntGT | IntGE => BoolType
}
case _ =>
throw new Exception("type mismatch, operator " + op.getClass.getSimpleName + s" type doesn't match args: (" + arg1 + ", " + arg2 + ")")
throw new Exception(
"type mismatch, operator " + op.getClass.getSimpleName + s" type doesn't match args: (" + arg1 + ", " + arg2 + ")"
)
}

private def inSize = arg1.getType match {
Expand All @@ -237,7 +240,7 @@ case class BinaryExpr(op: BinOp, arg1: Expr, arg2: Expr) extends Expr {
}

trait BinOp {
def opName : String
def opName: String
}

sealed trait BoolBinOp(op: String) extends BinOp {
Expand Down Expand Up @@ -386,7 +389,7 @@ case class Register(override val name: String, size: Int) extends Variable with
}

// Variable with scope local to the procedure, typically a temporary variable created in the lifting process
case class LocalVar(varName: String, override val irType: IRType, val index : Int = 0) extends Variable {
case class LocalVar(varName: String, override val irType: IRType, val index: Int = 0) extends Variable {
override val name = varName + (if (index > 0) then s"_$index" else "")
override def toGamma: BVar = BVariable(s"Gamma_$name", BoolBType, Scope.Local)
override def toBoogie: BVar = BVariable(s"$name", irType.toBoogie, Scope.Local)
Expand All @@ -396,11 +399,10 @@ case class LocalVar(varName: String, override val irType: IRType, val index : In

object LocalVar {

def unapply(l: LocalVar) : Option[(String, IRType)] = Some((s"${l.name}_${l.index}", l.irType))
def unapply(l: LocalVar): Option[(String, IRType)] = Some((s"${l.name}_${l.index}", l.irType))

}


// A memory section
sealed trait Memory extends Global {
val name: String
Expand All @@ -416,11 +418,13 @@ sealed trait Memory extends Global {
}

// A stack section of memory, which is local to a thread
case class StackMemory(override val name: String, override val addressSize: Int, override val valueSize: Int) extends Memory {
case class StackMemory(override val name: String, override val addressSize: Int, override val valueSize: Int)
extends Memory {
override def acceptVisit(visitor: Visitor): Memory = visitor.visitStackMemory(this)
}

// A non-stack region of memory, which is shared between threads
case class SharedMemory(override val name: String, override val addressSize: Int, override val valueSize: Int) extends Memory {
case class SharedMemory(override val name: String, override val addressSize: Int, override val valueSize: Int)
extends Memory {
override def acceptVisit(visitor: Visitor): Memory = visitor.visitSharedMemory(this)
}
23 changes: 12 additions & 11 deletions src/main/scala/ir/eval/ExprEval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def evalUnOp(op: UnOp, body: Literal): Expr = {
case (i: IntLiteral, IntNEG) => IntLiteral(-i.value)
case (FalseLiteral, BoolNOT) => TrueLiteral
case (TrueLiteral, BoolNOT) => FalseLiteral
case (TrueLiteral, BoolToBV1) => BitVecLiteral(1, 0)
case (FalseLiteral, BoolToBV1) => BitVecLiteral(0, 1)
case (_, _) => throw Exception(s"Unreachable ${(body, op)}")
}
}
Expand Down Expand Up @@ -154,15 +156,15 @@ def fastPartialEvalExpr(exp: Expr): Expr = {
exp match {
case f: UninterpretedFunction => f
case unOp: UnaryExpr => {
unOp.arg match {
case l: Literal => evalUnOp(unOp.op, l)
case o => UnaryExpr(unOp.op, o)
}
unOp.arg match {
case l: Literal => evalUnOp(unOp.op, l)
case o => UnaryExpr(unOp.op, o)
}
}
case binOp: BinaryExpr =>
val lhs = binOp.arg1
val rhs = binOp.arg2
binOp.getType match {
val lhs = binOp.arg1
val rhs = binOp.arg2
binOp.getType match {
case m: MapType => binOp
case b: BitVecType => {
(binOp.op, lhs, rhs) match {
Expand Down Expand Up @@ -217,14 +219,13 @@ def fastPartialEvalExpr(exp: Expr): Expr = {
case o => r.copy(body = o)
}
case variable: Variable => variable
case ml: MemoryLoad =>
case ml: MemoryLoad =>
val addr = ml.index
ml.copy(index= addr)
case b: Literal => b
ml.copy(index = addr)
case b: Literal => b
}
}


def statePartialEvalExpr[S](l: Loader[S, InterpreterError])(exp: Expr): State[S, Expr, InterpreterError] = {
val eval = statePartialEvalExpr(l)
val ns = exp match {
Expand Down
31 changes: 15 additions & 16 deletions src/main/scala/ir/eval/SimplifyExpr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def simplifyExpr(e: Expr): (Expr, Boolean) = {
def bool2bv1(e: Expr) = {
e.getType match {
case BitVecType(1) => e
case BoolType => UninterpretedFunction("bool2bv1", Seq(e), BitVecType(1))
case BoolType => UnaryExpr(BoolToBV1, e)
case _ => ???
}

Expand Down Expand Up @@ -221,28 +221,27 @@ def simplifyExpr(e: Expr): (Expr, Boolean) = {
/* push bool2bv upwards */
case BinaryExpr(
bop,
UninterpretedFunction("bool2bv1", Seq(l), _),
UninterpretedFunction("bool2bv1", Seq(r), _)
UnaryExpr(BoolToBV1, l),
UnaryExpr(BoolToBV1, r),
) if bvLogOpToBoolOp.contains(bop) => {
bool2bv1(BinaryExpr(bvLogOpToBoolOp(bop), (l), (r)))
}
case BinaryExpr(
bop,
UninterpretedFunction("bool2bv1", Seq(l), _),
UninterpretedFunction("bool2bv1", Seq(r), _)
UnaryExpr(BoolToBV1, l),
UnaryExpr(BoolToBV1, r),
) if bvLogOpToBoolOp.contains(bop) => {
bool2bv1(BinaryExpr(bvLogOpToBoolOp(bop), (l), (r)))
}

case UnaryExpr(BVNOT, UninterpretedFunction("bool2bv1", Seq(arg), BitVecType(1))) =>
case UnaryExpr(BVNOT, UnaryExpr(BoolToBV1, arg)) =>
bool2bv1(UnaryExpr(BoolNOT, arg))

/* remove bool2bv in boolean context */
case BinaryExpr(BVEQ, UninterpretedFunction("bool2bv1", Seq(body), _), BitVecLiteral(1, 1)) => (body)
case BinaryExpr(BVEQ, UninterpretedFunction("bool2bv1", Seq(l), _), UninterpretedFunction("bool2bv1", Seq(r), _)) =>
BinaryExpr(BoolEQ, (l), (r))
case UninterpretedFunction("bool2bv1", Seq(FalseLiteral), _) => BitVecLiteral(0, 1)
case UninterpretedFunction("bool2bv1", Seq(TrueLiteral), _) => BitVecLiteral(1, 1)
case BinaryExpr(BVEQ, UnaryExpr(BoolToBV1, body), BitVecLiteral(1, 1)) => body
case BinaryExpr(BVEQ, UnaryExpr(BoolToBV1, l), UnaryExpr(BoolToBV1, r)) => BinaryExpr(BoolEQ, (l), (r))
case UnaryExpr(BoolToBV1, FalseLiteral) => BitVecLiteral(0, 1)
case UnaryExpr(BoolToBV1, TrueLiteral) => BitVecLiteral(1, 1)

case BinaryExpr(BoolAND, x, TrueLiteral) => x
case BinaryExpr(BoolAND, x, FalseLiteral) => FalseLiteral
Expand Down Expand Up @@ -288,7 +287,7 @@ def simplifyExpr(e: Expr): (Expr, Boolean) = {
) // high precision op
)
)
if (o1 == o2) && o1 == BVADD && (lhs) == (orig)
if sz > 1 && (o1 == o2) && o1 == BVADD && (lhs) == (orig)
&& AlgebraicSimplifications(SignExtend(exts, x1)) == x2
&& AlgebraicSimplifications(SignExtend(exts, y1)) == y2 => {
BinaryExpr(BVSGE, x1, UnaryExpr(BVNEG, y1))
Expand Down Expand Up @@ -333,7 +332,7 @@ def simplifyExpr(e: Expr): (Expr, Boolean) = {
)
)
)
if (o1 == o2) && o2 == o4 && o1 == BVADD && (lhs) == (orig)
if sz > 1 && (o1 == o2) && o2 == o4 && o1 == BVADD && (lhs) == (orig)
&& AlgebraicSimplifications(x2) == AlgebraicSimplifications(SignExtend(exts, x1))
&& AlgebraicSimplifications(y2) == AlgebraicSimplifications(SignExtend(exts, y1))
&& AlgebraicSimplifications(z2) == AlgebraicSimplifications(SignExtend(exts, z1)) => {
Expand All @@ -346,7 +345,7 @@ def simplifyExpr(e: Expr): (Expr, Boolean) = {
ZeroExtend(exts, orig @ BinaryExpr(o1, x1, y1)),
compar @ BinaryExpr(o2, x2, y2)
)
if (o1 == o2) && o1 == BVADD
if size(x1).get > 1 && (o1 == o2) && o1 == BVADD
&& AlgebraicSimplifications(x2) == AlgebraicSimplifications(ZeroExtend(exts, x1))
&& AlgebraicSimplifications(y2) == AlgebraicSimplifications(ZeroExtend(exts, y1)) => {
// C not Set
Expand All @@ -358,7 +357,7 @@ def simplifyExpr(e: Expr): (Expr, Boolean) = {
ZeroExtend(exts, orig @ BinaryExpr(o1, BinaryExpr(o3, x1, y1), z1)),
BinaryExpr(o2, compar @ BinaryExpr(o4, x2, y2), z2) // high precision op
)
if (o1 == o2) && o2 == o4 && o1 == BVADD
if size(x1).get > 1 && (o1 == o2) && o2 == o4 && o1 == BVADD
&& (x2) == (ZeroExtend(exts, x1))
&& (y2) == (ZeroExtend(exts, y1))
&& (z2) == (ZeroExtend(exts, z1)) => {
Expand All @@ -371,7 +370,7 @@ def simplifyExpr(e: Expr): (Expr, Boolean) = {
ZeroExtend(exts, orig @ BinaryExpr(o1, x1, UnaryExpr(BVNEG, y1))),
BinaryExpr(o2, compar @ BinaryExpr(o4, ZeroExtend(ext1, x2), ZeroExtend(ext2, UnaryExpr(BVNOT, y2))), BitVecLiteral(1, _)) // high precision op
)
if (o1 == o2) && o2 == o4 && o1 == BVADD
if size(x1).get > 1 && (o1 == o2) && o2 == o4 && o1 == BVADD
&& exts == ext1 && exts == ext2
&& x1 == x2 && y1 == y2 => {
// C not Set
Expand Down
20 changes: 20 additions & 0 deletions src/main/scala/ir/invariant/BlocksUniqueToProcedure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@ private class BVis extends CILVisitor {
}
}


def blockUniqueLabels(p: Program) : Boolean = {
p.procedures.forall(blockUniqueLabels)
}

def blockUniqueLabels(p: Procedure) : Boolean = {
val blockNames = mutable.Set[String]()
var passed = true

for (block <- p.blocks) {
if (blockNames.contains(block.label)) {
passed = false
Logger.error("Duplicate block name: " + block.label)
}
blockNames.add(block.label)
}
passed
}


def blocksUniqueToEachProcedure(p: Program) : Boolean = {
val v = BVis()
visit_prog(v, p)
Expand Down
Loading

0 comments on commit 95744f7

Please sign in to comment.