Skip to content

Commit

Permalink
frontend fixes to support CAS
Browse files Browse the repository at this point in the history
- fix bug: requiring all symbol table entries to have an address
- frontend: allow loads in binary expressions and branch conditions to support CAS instruction
- frontend (gtirb): support AtomicStart and AtomicEnd intrinsics: split blocks containing calls
                    into multiple that each may end in a call
  • Loading branch information
ailrst committed Dec 10, 2024
1 parent 4cb64c0 commit 2860aee
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 74 deletions.
68 changes: 43 additions & 25 deletions src/main/scala/translating/BAPToIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import bap.*
import boogie.UnaryBExpr
import ir.{UnaryExpr, BinaryExpr, *}
import specification.*
import util.Logger

import scala.collection.mutable
import scala.collection.mutable.Map
Expand Down Expand Up @@ -37,7 +38,7 @@ class BAPToIR(var program: BAPProgram, mainAddress: BigInt) {
for (p <- s.out) {
procedure.out.append(translateParameter(p))
}
if (s.address.get == mainAddress) {
if (s.address.contains(mainAddress)) {
mainProcedure = Some(procedure)
}
procedures.append(procedure)
Expand Down Expand Up @@ -102,7 +103,18 @@ class BAPToIR(var program: BAPProgram, mainAddress: BigInt) {
}
}

private def translateExpr(e: BAPExpr): (Expr, Option[MemoryLoad]) = e match {

private def translateExpr(e: BAPExpr): (Expr, Option[MemoryLoad]) =
try {
_translateExpr(e)
} catch {
case exp: Exception => {
Logger.error(s"Error translating $e:\n\t" + e)
throw exp
}
}

private def _translateExpr(e: BAPExpr): (Expr, Option[MemoryLoad]) = e match {
case b @ BAPConcat(left, right) =>
val (arg0, load0) = translateExpr(left)
val (arg1, load1) = translateExpr(right)
Expand Down Expand Up @@ -146,43 +158,49 @@ class BAPToIR(var program: BAPProgram, mainAddress: BigInt) {
case NOT => (UnaryExpr(BVNOT, translateExprOnly(exp)), None)
case NEG => (UnaryExpr(BVNEG, translateExprOnly(exp)), None)
}
case BAPBinOp(operator, lhs, rhs) => operator match {
case PLUS => (BinaryExpr(BVADD, translateExprOnly(lhs), translateExprOnly(rhs)), None)
case MINUS => (BinaryExpr(BVSUB, translateExprOnly(lhs), translateExprOnly(rhs)), None)
case TIMES => (BinaryExpr(BVMUL, translateExprOnly(lhs), translateExprOnly(rhs)), None)
case DIVIDE => (BinaryExpr(BVUDIV, translateExprOnly(lhs), translateExprOnly(rhs)), None)
case SDIVIDE => (BinaryExpr(BVSDIV, translateExprOnly(lhs), translateExprOnly(rhs)), None)
case BAPBinOp(operator, lhs, rhs) =>
val (tlhs, l1) = translateExpr(lhs)
val (trhs, l2) = translateExpr(rhs)
assert(!(l1.isDefined && l2.isDefined), "Don't expect two loads in an expression")
val load : Option[MemoryLoad] = l1.orElse(l2)

operator match {
case PLUS => (BinaryExpr(BVADD, tlhs, trhs), None)
case MINUS => (BinaryExpr(BVSUB, tlhs, trhs), None)
case TIMES => (BinaryExpr(BVMUL, tlhs, trhs), None)
case DIVIDE => (BinaryExpr(BVUDIV, tlhs, trhs), None)
case SDIVIDE => (BinaryExpr(BVSDIV, tlhs, trhs), None)
// counterintuitive but correct according to BAP source
case MOD => (BinaryExpr(BVSREM, translateExprOnly(lhs), translateExprOnly(rhs)), None)
case MOD => (BinaryExpr(BVSREM, tlhs, trhs), None)
// counterintuitive but correct according to BAP source
case SMOD => (BinaryExpr(BVUREM, translateExprOnly(lhs), translateExprOnly(rhs)), None)
case SMOD => (BinaryExpr(BVUREM, tlhs, trhs), None)
case LSHIFT => // BAP says caring about this case is necessary?
if (lhs.size == rhs.size) {
(BinaryExpr(BVSHL, translateExprOnly(lhs), translateExprOnly(rhs)), None)
(BinaryExpr(BVSHL, tlhs, trhs), None)
} else {
(BinaryExpr(BVSHL, translateExprOnly(lhs), ZeroExtend(lhs.size - rhs.size, translateExprOnly(rhs))), None)
(BinaryExpr(BVSHL, tlhs, ZeroExtend(lhs.size - rhs.size, trhs)), None)
}
case RSHIFT =>
if (lhs.size == rhs.size) {
(BinaryExpr(BVLSHR, translateExprOnly(lhs), translateExprOnly(rhs)), None)
(BinaryExpr(BVLSHR, tlhs, trhs), None)
} else {
(BinaryExpr(BVLSHR, translateExprOnly(lhs), ZeroExtend(lhs.size - rhs.size, translateExprOnly(rhs))), None)
(BinaryExpr(BVLSHR, tlhs, ZeroExtend(lhs.size - rhs.size, trhs)), None)
}
case ARSHIFT =>
if (lhs.size == rhs.size) {
(BinaryExpr(BVASHR, translateExprOnly(lhs), translateExprOnly(rhs)), None)
(BinaryExpr(BVASHR, tlhs, trhs), None)
} else {
(BinaryExpr(BVASHR, translateExprOnly(lhs), ZeroExtend(lhs.size - rhs.size, translateExprOnly(rhs))), None)
(BinaryExpr(BVASHR, tlhs, ZeroExtend(lhs.size - rhs.size, trhs)), None)
}
case AND => (BinaryExpr(BVAND, translateExprOnly(lhs), translateExprOnly(rhs)), None)
case OR => (BinaryExpr(BVOR, translateExprOnly(lhs), translateExprOnly(rhs)), None)
case XOR => (BinaryExpr(BVXOR, translateExprOnly(lhs), translateExprOnly(rhs)), None)
case EQ => (BinaryExpr(BVCOMP, translateExprOnly(lhs), translateExprOnly(rhs)), None)
case NEQ => (UnaryExpr(BVNOT, BinaryExpr(BVCOMP, translateExprOnly(lhs), translateExprOnly(rhs))), None)
case LT => (BinaryExpr(BVULT, translateExprOnly(lhs), translateExprOnly(rhs)), None)
case LE => (BinaryExpr(BVULE, translateExprOnly(lhs), translateExprOnly(rhs)), None)
case SLT => (BinaryExpr(BVSLT, translateExprOnly(lhs), translateExprOnly(rhs)), None)
case SLE => (BinaryExpr(BVSLE, translateExprOnly(lhs), translateExprOnly(rhs)), None)
case AND => (BinaryExpr(BVAND, tlhs, trhs), None)
case OR => (BinaryExpr(BVOR, tlhs, trhs), None)
case XOR => (BinaryExpr(BVXOR, tlhs, trhs), None)
case EQ => (BinaryExpr(BVCOMP, tlhs, trhs), None)
case NEQ => (UnaryExpr(BVNOT, BinaryExpr(BVCOMP, tlhs, trhs)), None)
case LT => (BinaryExpr(BVULT, tlhs, trhs), None)
case LE => (BinaryExpr(BVULE, tlhs, trhs), None)
case SLT => (BinaryExpr(BVSLT, tlhs, trhs), None)
case SLE => (BinaryExpr(BVSLE, tlhs, trhs), None)
}
case b: BAPVar => (translateVar(b), None)
case BAPMemAccess(memory, index, endian, size) =>
Expand Down
83 changes: 49 additions & 34 deletions src/main/scala/translating/GTIRBLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ class GTIRBLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) {
Seq()
case a: AssertContext => visitAssert(a, label).toSeq
case t: TCallContext => visitTCall(t, label).toSeq
case i: IfContext => visitIf(i, label).toSeq
case i: IfContext => {
val (tif, load) = visitIf(i, label)
Seq() ++ load ++ tif
}
case t: ThrowContext => Seq(visitThrow(t, label))
}
}
Expand All @@ -80,6 +83,8 @@ class GTIRBLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) {
private def visitTCall(ctx: TCallContext, label: Option[String] = None): Option[Statement] = {
val function = visitIdent(ctx.ident)

val atomicStart = Procedure("intrinsic$AtomicStart")
val atomicEnd = Procedure("intrinsic$AtomicEnd")
val typeArgs = Option(ctx.tes).toList.flatMap(_.expr.asScala)
val args = Option(ctx.args).toList.flatMap(_.expr.asScala)

Expand All @@ -102,6 +107,8 @@ class GTIRBLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) {
} else {
None
}
case "AtomicStart.0" => Some(DirectCall(atomicStart))
case "AtomicEnd.0" => Some(DirectCall(atomicEnd))
case "unsupported_opcode.0" => {
val op = args.headOption.flatMap(visitExprOnly) match {
case Some(IntLiteral(s)) => Some("%08x".format(s))
Expand All @@ -127,8 +134,8 @@ class GTIRBLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) {
case _ => throw Exception(s"expected ${ctx.getText} to be an integer literal")
}

private def visitIf(ctx: IfContext, label: Option[String] = None): Option[TempIf] = {
val condition = visitExprOnly(ctx.cond)
private def visitIf(ctx: IfContext, label: Option[String] = None): (Option[TempIf], Option[MemoryLoad]) = {
val (condition, load) = visitExpr(ctx.cond)
val thenStmts = ctx.thenStmts.stmt.asScala.flatMap(visitStmt(_, label))

val elseStmts = Option(ctx.elseStmts) match {
Expand All @@ -137,9 +144,9 @@ class GTIRBLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) {
}

if (condition.isDefined) {
Some(TempIf(condition.get, thenStmts, elseStmts, label))
(Some(TempIf(condition.get, thenStmts, elseStmts, label)), load)
} else {
None
(None, None)
}
}

Expand Down Expand Up @@ -222,21 +229,27 @@ class GTIRBLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) {
}

private def visitExpr(ctx: ExprContext): (Option[Expr], Option[MemoryLoad]) = {
ctx match {
case e: ExprVarContext => (visitExprVar(e), None)
case e: ExprTApplyContext => visitExprTApply(e)
case e: ExprSlicesContext => visitExprSlices(e)
case e: ExprFieldContext => (Some(visitExprField(e)), None)
case e: ExprArrayContext => (Some(visitExprArray(e)), None)
case e: ExprLitIntContext => (Some(IntLiteral(parseInt(e))), None)
case e: ExprLitBitsContext => (Some(visitExprLitBits(e)), None)
try {
ctx match {
case e: ExprVarContext => (visitExprVar(e), None)
case e: ExprTApplyContext => visitExprTApply(e)
case e: ExprSlicesContext => visitExprSlices(e)
case e: ExprFieldContext => (Some(visitExprField(e)), None)
case e: ExprArrayContext => (Some(visitExprArray(e)), None)
case e: ExprLitIntContext => (Some(IntLiteral(parseInt(e))), None)
case e: ExprLitBitsContext => (Some(visitExprLitBits(e)), None)
}
} catch {
case ex: Exception =>
Logger.error(s"Error in : $ctx :" + ex)
throw ex
}
}

private def visitExprOnly(ctx: ExprContext): Option[Expr] = {
val (expr, load) = visitExpr(ctx)
if (load.isDefined) {
throw Exception("")
throw Exception(s"Unexpected load $load ; $expr ")
} else {
expr
}
Expand Down Expand Up @@ -309,29 +322,29 @@ class GTIRBLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) {
(result, None)

case "not_bool.0" => (resolveUnaryOp(BoolNOT, function, 0, typeArgs, args, ctx.getText), None)
case "eq_enum.0" => (resolveBinaryOp(BoolEQ, function, 0, typeArgs, args, ctx.getText), None)
case "or_bool.0" => (resolveBinaryOp(BoolOR, function, 0, typeArgs, args, ctx.getText), None)
case "and_bool.0" => (resolveBinaryOp(BoolAND, function, 0, typeArgs, args, ctx.getText), None)
case "eq_enum.0" => (resolveBinaryOp(BoolEQ, function, 0, typeArgs, args, ctx.getText))
case "or_bool.0" => (resolveBinaryOp(BoolOR, function, 0, typeArgs, args, ctx.getText))
case "and_bool.0" => (resolveBinaryOp(BoolAND, function, 0, typeArgs, args, ctx.getText))

case "not_bits.0" => (resolveUnaryOp(BVNOT, function, 1, typeArgs, args, ctx.getText), None)
case "or_bits.0" => (resolveBinaryOp(BVOR, function, 1, typeArgs, args, ctx.getText), None)
case "and_bits.0" => (resolveBinaryOp(BVAND, function, 1, typeArgs, args, ctx.getText), None)
case "eor_bits.0" => (resolveBinaryOp(BVXOR, function, 1, typeArgs, args, ctx.getText), None)
case "eq_bits.0" => (resolveBinaryOp(BVEQ, function, 1, typeArgs, args, ctx.getText), None)
case "add_bits.0" => (resolveBinaryOp(BVADD, function, 1, typeArgs, args, ctx.getText), None)
case "sub_bits.0" => (resolveBinaryOp(BVSUB, function, 1, typeArgs, args, ctx.getText), None)
case "mul_bits.0" => (resolveBinaryOp(BVMUL, function, 1, typeArgs, args, ctx.getText), None)
case "sdiv_bits.0" => (resolveBinaryOp(BVSDIV, function, 1, typeArgs, args, ctx.getText), None)

case "slt_bits.0" => (resolveBinaryOp(BVSLT, function, 1, typeArgs, args, ctx.getText), None)
case "sle_bits.0" => (resolveBinaryOp(BVSLE, function, 1, typeArgs, args, ctx.getText), None)
case "or_bits.0" => (resolveBinaryOp(BVOR, function, 1, typeArgs, args, ctx.getText))
case "and_bits.0" => (resolveBinaryOp(BVAND, function, 1, typeArgs, args, ctx.getText))
case "eor_bits.0" => (resolveBinaryOp(BVXOR, function, 1, typeArgs, args, ctx.getText))
case "eq_bits.0" => (resolveBinaryOp(BVEQ, function, 1, typeArgs, args, ctx.getText))
case "add_bits.0" => (resolveBinaryOp(BVADD, function, 1, typeArgs, args, ctx.getText))
case "sub_bits.0" => (resolveBinaryOp(BVSUB, function, 1, typeArgs, args, ctx.getText))
case "mul_bits.0" => (resolveBinaryOp(BVMUL, function, 1, typeArgs, args, ctx.getText))
case "sdiv_bits.0" => (resolveBinaryOp(BVSDIV, function, 1, typeArgs, args, ctx.getText))

case "slt_bits.0" => (resolveBinaryOp(BVSLT, function, 1, typeArgs, args, ctx.getText))
case "sle_bits.0" => (resolveBinaryOp(BVSLE, function, 1, typeArgs, args, ctx.getText))

case "lsl_bits.0" => (resolveBitShiftOp(BVSHL, function, typeArgs, args, ctx.getText), None)
case "lsr_bits.0" => (resolveBitShiftOp(BVLSHR, function, typeArgs, args, ctx.getText), None)
case "asr_bits.0" => (resolveBitShiftOp(BVASHR, function, typeArgs, args, ctx.getText), None)

case "append_bits.0" =>
(resolveBinaryOp(BVCONCAT, function, 2, typeArgs, args, ctx.getText), None)
(resolveBinaryOp(BVCONCAT, function, 2, typeArgs, args, ctx.getText))

case "replicate_bits.0" =>
checkArgs(function, 2, 2, typeArgs.size, args.size, ctx.getText)
Expand Down Expand Up @@ -482,16 +495,18 @@ class GTIRBLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) {
typeArgs: mutable.Buffer[ExprContext],
args: mutable.Buffer[ExprContext],
token: String
): Option[BinaryExpr] = {
): (Option[BinaryExpr], Option[MemoryLoad]) = {
checkArgs(function, typeArgsExpected, 2, typeArgs.size, args.size, token)
// we don't currently check the size for BV ops which is the type arg
// memory loads shouldn't appear inside binary operations?
val arg0 = visitExprOnly(args(0))
val arg1 = visitExprOnly(args(1))
val (arg0, l0) = visitExpr(args(0))
val (arg1, l1) = visitExpr(args(1))
assert(!(l0.isDefined && l1.isDefined), "Do not expect two loads in an expression")
val load = l0.orElse(l1)
if (arg0.isDefined && arg1.isDefined) {
Some(BinaryExpr(operator, arg0.get, arg1.get))
(Some(BinaryExpr(operator, arg0.get, arg1.get)), load)
} else {
None
(None, load)
}
}

Expand Down
66 changes: 54 additions & 12 deletions src/main/scala/translating/GTIRBToIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -162,25 +162,67 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[
val procedure = uuidToProcedure(functionUUID)
var blockCount = 0
for (blockUUID <- blockUUIDs) {
val block = uuidToBlock(blockUUID)
var block = uuidToBlock(blockUUID)

val statements = semanticsLoader.visitBlock(blockUUID, blockCount, block.address)
blockCount += 1
block.statements.addAll(statements)
val _statements = semanticsLoader.visitBlock(blockUUID, blockCount, block.address)
val blockNonEmpty = _statements.nonEmpty

if (block.statements.isEmpty && !blockOutgoingEdges.contains(blockUUID)) {
for (s <- _statements) {
// add intrinsics
s match {
case d: DirectCall if !procedures.exists(d.target.name == _.name) => procedures += d.target
case _ => ()
}
}

val blocks = ArrayBuffer[List[Statement]]()

while (_statements.nonEmpty) {
val blockStatements = _statements.takeWhile(s => !(s.isInstanceOf[Call]))
_statements.remove(0, blockStatements.size)

val maybeCall = _statements.headOption
if (maybeCall.isDefined) {
_statements.remove(0)
}

blocks.addOne(blockStatements.toList ++ maybeCall)
}

blocks.toList match {
case Nil => ()
case statements::Nil => {
blockCount += 1
block.statements.addAll(statements)
}
case statements::tl => {
blockCount += 1
block.statements.addAll(statements)

val labels = (1 to tl.size).map(i => block.label + "_" + i)
val newBlocks = labels.zip(tl).map((l, statements) => Block(l, statements=statements))

newBlocks.foldLeft(block)((l, r) => {
l.replaceJump(GoTo(r))
r
})

block.parent.addBlocks(newBlocks)

block = newBlocks.last
}
}

if (blockNonEmpty && !blockOutgoingEdges.contains(blockUUID)) {
// remove blocks that are just nop padding
// TODO cleanup blocks that are entirely nop but have fallthrough edges?
Logger.debug(s"removing block ${block.label}")
procedure.removeBlocks(block)
} else if (!blockOutgoingEdges.contains(blockUUID) || blockOutgoingEdges(blockUUID).isEmpty) {
block.replaceJump(Unreachable())
Logger.debug(s"block ${block.label} in subroutine ${procedure.name} has no outgoing edges")
} else {
if (!blockOutgoingEdges.contains(blockUUID)) {
throw Exception (s"block ${block.label} in subroutine ${procedure.name} has no outgoing edges")
}
val outgoingEdges = blockOutgoingEdges(blockUUID)
if (outgoingEdges.isEmpty) {
throw Exception(s"block ${block.label} in subroutine ${procedure.name} has no outgoing edges")
}

val (calls, jump) = if (outgoingEdges.size == 1) {
val edge = outgoingEdges.head
Expand Down Expand Up @@ -285,7 +327,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[
val blockLabel = convertLabel(procedure, blockUUID, blockCount)

val blockAddress = blockUUIDToAddress.get(blockUUID)
val block = Block(blockLabel, blockAddress)
val block = Block(blockLabel, blockAddress, jump=Unreachable())
procedure.addBlocks(block)
if (uuidToBlock.contains(blockUUID)) {
// TODO this is a case that requires special consideration
Expand Down
6 changes: 3 additions & 3 deletions src/main/scala/util/RunUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ object IRTransform {
val renamer = Renamer(boogieReserved)
renamer.visitProgram(ctx.program)

assert(invariant.singleCallBlockEnd(ctx.program))
assert(invariant.singleCallBlockEnd(ctx.program), "Single call at block end")
}

def generateProcedureSummaries(
Expand Down Expand Up @@ -562,7 +562,7 @@ object RunUtils {
val boogieTranslator = IRToBoogie(ctx.program, ctx.specification, None, q.outputPrefix, regionInjector, q.boogieTranslation)
ArrayBuffer(boogieTranslator.translate)
}
assert(invariant.singleCallBlockEnd(ctx.program))
assert(invariant.singleCallBlockEnd(ctx.program), "Single call at block end")

BASILResult(ctx, analysis, boogiePrograms)
}
Expand Down Expand Up @@ -643,7 +643,7 @@ object RunUtils {
None
}

assert(invariant.singleCallBlockEnd(ctx.program))
assert(invariant.singleCallBlockEnd(ctx.program), "Single call at block end")
Logger.debug(s"[!] Finished indirect call resolution after $iteration iterations")
analysisResult.last.copy(
symbolicAddresses = symResults,
Expand Down

0 comments on commit 2860aee

Please sign in to comment.