diff --git a/src/main/scala/translating/BAPToIR.scala b/src/main/scala/translating/BAPToIR.scala index 9793022cb..4a0d89fd8 100644 --- a/src/main/scala/translating/BAPToIR.scala +++ b/src/main/scala/translating/BAPToIR.scala @@ -4,6 +4,7 @@ import bap.* import boogie.UnaryBExpr import ir.{UnaryExpr, BinaryExpr, *} import specification.* +import util.Logger import scala.collection.mutable import scala.collection.mutable.Map @@ -37,7 +38,7 @@ class BAPToIR(var program: BAPProgram, mainAddress: BigInt) { for (p <- s.out) { procedure.out.append(translateParameter(p)) } - if (s.address.get == mainAddress) { + if (s.address.contains(mainAddress)) { mainProcedure = Some(procedure) } procedures.append(procedure) @@ -102,7 +103,18 @@ class BAPToIR(var program: BAPProgram, mainAddress: BigInt) { } } - private def translateExpr(e: BAPExpr): (Expr, Option[MemoryLoad]) = e match { + + private def translateExpr(e: BAPExpr): (Expr, Option[MemoryLoad]) = + try { + _translateExpr(e) + } catch { + case exp: Exception => { + Logger.error(s"Error translating $e:\n\t" + e) + throw exp + } + } + + private def _translateExpr(e: BAPExpr): (Expr, Option[MemoryLoad]) = e match { case b @ BAPConcat(left, right) => val (arg0, load0) = translateExpr(left) val (arg1, load1) = translateExpr(right) @@ -146,43 +158,49 @@ class BAPToIR(var program: BAPProgram, mainAddress: BigInt) { case NOT => (UnaryExpr(BVNOT, translateExprOnly(exp)), None) case NEG => (UnaryExpr(BVNEG, translateExprOnly(exp)), None) } - case BAPBinOp(operator, lhs, rhs) => operator match { - case PLUS => (BinaryExpr(BVADD, translateExprOnly(lhs), translateExprOnly(rhs)), None) - case MINUS => (BinaryExpr(BVSUB, translateExprOnly(lhs), translateExprOnly(rhs)), None) - case TIMES => (BinaryExpr(BVMUL, translateExprOnly(lhs), translateExprOnly(rhs)), None) - case DIVIDE => (BinaryExpr(BVUDIV, translateExprOnly(lhs), translateExprOnly(rhs)), None) - case SDIVIDE => (BinaryExpr(BVSDIV, translateExprOnly(lhs), translateExprOnly(rhs)), None) + case BAPBinOp(operator, lhs, rhs) => + val (tlhs, l1) = translateExpr(lhs) + val (trhs, l2) = translateExpr(rhs) + assert(!(l1.isDefined && l2.isDefined), "Don't expect two loads in an expression") + val load : Option[MemoryLoad] = l1.orElse(l2) + + operator match { + case PLUS => (BinaryExpr(BVADD, tlhs, trhs), None) + case MINUS => (BinaryExpr(BVSUB, tlhs, trhs), None) + case TIMES => (BinaryExpr(BVMUL, tlhs, trhs), None) + case DIVIDE => (BinaryExpr(BVUDIV, tlhs, trhs), None) + case SDIVIDE => (BinaryExpr(BVSDIV, tlhs, trhs), None) // counterintuitive but correct according to BAP source - case MOD => (BinaryExpr(BVSREM, translateExprOnly(lhs), translateExprOnly(rhs)), None) + case MOD => (BinaryExpr(BVSREM, tlhs, trhs), None) // counterintuitive but correct according to BAP source - case SMOD => (BinaryExpr(BVUREM, translateExprOnly(lhs), translateExprOnly(rhs)), None) + case SMOD => (BinaryExpr(BVUREM, tlhs, trhs), None) case LSHIFT => // BAP says caring about this case is necessary? if (lhs.size == rhs.size) { - (BinaryExpr(BVSHL, translateExprOnly(lhs), translateExprOnly(rhs)), None) + (BinaryExpr(BVSHL, tlhs, trhs), None) } else { - (BinaryExpr(BVSHL, translateExprOnly(lhs), ZeroExtend(lhs.size - rhs.size, translateExprOnly(rhs))), None) + (BinaryExpr(BVSHL, tlhs, ZeroExtend(lhs.size - rhs.size, trhs)), None) } case RSHIFT => if (lhs.size == rhs.size) { - (BinaryExpr(BVLSHR, translateExprOnly(lhs), translateExprOnly(rhs)), None) + (BinaryExpr(BVLSHR, tlhs, trhs), None) } else { - (BinaryExpr(BVLSHR, translateExprOnly(lhs), ZeroExtend(lhs.size - rhs.size, translateExprOnly(rhs))), None) + (BinaryExpr(BVLSHR, tlhs, ZeroExtend(lhs.size - rhs.size, trhs)), None) } case ARSHIFT => if (lhs.size == rhs.size) { - (BinaryExpr(BVASHR, translateExprOnly(lhs), translateExprOnly(rhs)), None) + (BinaryExpr(BVASHR, tlhs, trhs), None) } else { - (BinaryExpr(BVASHR, translateExprOnly(lhs), ZeroExtend(lhs.size - rhs.size, translateExprOnly(rhs))), None) + (BinaryExpr(BVASHR, tlhs, ZeroExtend(lhs.size - rhs.size, trhs)), None) } - case AND => (BinaryExpr(BVAND, translateExprOnly(lhs), translateExprOnly(rhs)), None) - case OR => (BinaryExpr(BVOR, translateExprOnly(lhs), translateExprOnly(rhs)), None) - case XOR => (BinaryExpr(BVXOR, translateExprOnly(lhs), translateExprOnly(rhs)), None) - case EQ => (BinaryExpr(BVCOMP, translateExprOnly(lhs), translateExprOnly(rhs)), None) - case NEQ => (UnaryExpr(BVNOT, BinaryExpr(BVCOMP, translateExprOnly(lhs), translateExprOnly(rhs))), None) - case LT => (BinaryExpr(BVULT, translateExprOnly(lhs), translateExprOnly(rhs)), None) - case LE => (BinaryExpr(BVULE, translateExprOnly(lhs), translateExprOnly(rhs)), None) - case SLT => (BinaryExpr(BVSLT, translateExprOnly(lhs), translateExprOnly(rhs)), None) - case SLE => (BinaryExpr(BVSLE, translateExprOnly(lhs), translateExprOnly(rhs)), None) + case AND => (BinaryExpr(BVAND, tlhs, trhs), None) + case OR => (BinaryExpr(BVOR, tlhs, trhs), None) + case XOR => (BinaryExpr(BVXOR, tlhs, trhs), None) + case EQ => (BinaryExpr(BVCOMP, tlhs, trhs), None) + case NEQ => (UnaryExpr(BVNOT, BinaryExpr(BVCOMP, tlhs, trhs)), None) + case LT => (BinaryExpr(BVULT, tlhs, trhs), None) + case LE => (BinaryExpr(BVULE, tlhs, trhs), None) + case SLT => (BinaryExpr(BVSLT, tlhs, trhs), None) + case SLE => (BinaryExpr(BVSLE, tlhs, trhs), None) } case b: BAPVar => (translateVar(b), None) case BAPMemAccess(memory, index, endian, size) => diff --git a/src/main/scala/translating/GTIRBLoader.scala b/src/main/scala/translating/GTIRBLoader.scala index 836ed5de0..4ab2a9ff8 100644 --- a/src/main/scala/translating/GTIRBLoader.scala +++ b/src/main/scala/translating/GTIRBLoader.scala @@ -58,7 +58,10 @@ class GTIRBLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) { Seq() case a: AssertContext => visitAssert(a, label).toSeq case t: TCallContext => visitTCall(t, label).toSeq - case i: IfContext => visitIf(i, label).toSeq + case i: IfContext => { + val (tif, load) = visitIf(i, label) + Seq() ++ load ++ tif + } case t: ThrowContext => Seq(visitThrow(t, label)) } } @@ -80,6 +83,8 @@ class GTIRBLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) { private def visitTCall(ctx: TCallContext, label: Option[String] = None): Option[Statement] = { val function = visitIdent(ctx.ident) + val atomicStart = Procedure("intrinsic$AtomicStart") + val atomicEnd = Procedure("intrinsic$AtomicEnd") val typeArgs = Option(ctx.tes).toList.flatMap(_.expr.asScala) val args = Option(ctx.args).toList.flatMap(_.expr.asScala) @@ -102,6 +107,8 @@ class GTIRBLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) { } else { None } + case "AtomicStart.0" => Some(DirectCall(atomicStart)) + case "AtomicEnd.0" => Some(DirectCall(atomicEnd)) case "unsupported_opcode.0" => { val op = args.headOption.flatMap(visitExprOnly) match { case Some(IntLiteral(s)) => Some("%08x".format(s)) @@ -127,8 +134,8 @@ class GTIRBLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) { case _ => throw Exception(s"expected ${ctx.getText} to be an integer literal") } - private def visitIf(ctx: IfContext, label: Option[String] = None): Option[TempIf] = { - val condition = visitExprOnly(ctx.cond) + private def visitIf(ctx: IfContext, label: Option[String] = None): (Option[TempIf], Option[MemoryLoad]) = { + val (condition, load) = visitExpr(ctx.cond) val thenStmts = ctx.thenStmts.stmt.asScala.flatMap(visitStmt(_, label)) val elseStmts = Option(ctx.elseStmts) match { @@ -137,9 +144,9 @@ class GTIRBLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) { } if (condition.isDefined) { - Some(TempIf(condition.get, thenStmts, elseStmts, label)) + (Some(TempIf(condition.get, thenStmts, elseStmts, label)), load) } else { - None + (None, None) } } @@ -222,21 +229,27 @@ class GTIRBLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) { } private def visitExpr(ctx: ExprContext): (Option[Expr], Option[MemoryLoad]) = { - ctx match { - case e: ExprVarContext => (visitExprVar(e), None) - case e: ExprTApplyContext => visitExprTApply(e) - case e: ExprSlicesContext => visitExprSlices(e) - case e: ExprFieldContext => (Some(visitExprField(e)), None) - case e: ExprArrayContext => (Some(visitExprArray(e)), None) - case e: ExprLitIntContext => (Some(IntLiteral(parseInt(e))), None) - case e: ExprLitBitsContext => (Some(visitExprLitBits(e)), None) + try { + ctx match { + case e: ExprVarContext => (visitExprVar(e), None) + case e: ExprTApplyContext => visitExprTApply(e) + case e: ExprSlicesContext => visitExprSlices(e) + case e: ExprFieldContext => (Some(visitExprField(e)), None) + case e: ExprArrayContext => (Some(visitExprArray(e)), None) + case e: ExprLitIntContext => (Some(IntLiteral(parseInt(e))), None) + case e: ExprLitBitsContext => (Some(visitExprLitBits(e)), None) + } + } catch { + case ex: Exception => + Logger.error(s"Error in : $ctx :" + ex) + throw ex } } private def visitExprOnly(ctx: ExprContext): Option[Expr] = { val (expr, load) = visitExpr(ctx) if (load.isDefined) { - throw Exception("") + throw Exception(s"Unexpected load $load ; $expr ") } else { expr } @@ -309,29 +322,29 @@ class GTIRBLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) { (result, None) case "not_bool.0" => (resolveUnaryOp(BoolNOT, function, 0, typeArgs, args, ctx.getText), None) - case "eq_enum.0" => (resolveBinaryOp(BoolEQ, function, 0, typeArgs, args, ctx.getText), None) - case "or_bool.0" => (resolveBinaryOp(BoolOR, function, 0, typeArgs, args, ctx.getText), None) - case "and_bool.0" => (resolveBinaryOp(BoolAND, function, 0, typeArgs, args, ctx.getText), None) + case "eq_enum.0" => (resolveBinaryOp(BoolEQ, function, 0, typeArgs, args, ctx.getText)) + case "or_bool.0" => (resolveBinaryOp(BoolOR, function, 0, typeArgs, args, ctx.getText)) + case "and_bool.0" => (resolveBinaryOp(BoolAND, function, 0, typeArgs, args, ctx.getText)) case "not_bits.0" => (resolveUnaryOp(BVNOT, function, 1, typeArgs, args, ctx.getText), None) - case "or_bits.0" => (resolveBinaryOp(BVOR, function, 1, typeArgs, args, ctx.getText), None) - case "and_bits.0" => (resolveBinaryOp(BVAND, function, 1, typeArgs, args, ctx.getText), None) - case "eor_bits.0" => (resolveBinaryOp(BVXOR, function, 1, typeArgs, args, ctx.getText), None) - case "eq_bits.0" => (resolveBinaryOp(BVEQ, function, 1, typeArgs, args, ctx.getText), None) - case "add_bits.0" => (resolveBinaryOp(BVADD, function, 1, typeArgs, args, ctx.getText), None) - case "sub_bits.0" => (resolveBinaryOp(BVSUB, function, 1, typeArgs, args, ctx.getText), None) - case "mul_bits.0" => (resolveBinaryOp(BVMUL, function, 1, typeArgs, args, ctx.getText), None) - case "sdiv_bits.0" => (resolveBinaryOp(BVSDIV, function, 1, typeArgs, args, ctx.getText), None) - - case "slt_bits.0" => (resolveBinaryOp(BVSLT, function, 1, typeArgs, args, ctx.getText), None) - case "sle_bits.0" => (resolveBinaryOp(BVSLE, function, 1, typeArgs, args, ctx.getText), None) + case "or_bits.0" => (resolveBinaryOp(BVOR, function, 1, typeArgs, args, ctx.getText)) + case "and_bits.0" => (resolveBinaryOp(BVAND, function, 1, typeArgs, args, ctx.getText)) + case "eor_bits.0" => (resolveBinaryOp(BVXOR, function, 1, typeArgs, args, ctx.getText)) + case "eq_bits.0" => (resolveBinaryOp(BVEQ, function, 1, typeArgs, args, ctx.getText)) + case "add_bits.0" => (resolveBinaryOp(BVADD, function, 1, typeArgs, args, ctx.getText)) + case "sub_bits.0" => (resolveBinaryOp(BVSUB, function, 1, typeArgs, args, ctx.getText)) + case "mul_bits.0" => (resolveBinaryOp(BVMUL, function, 1, typeArgs, args, ctx.getText)) + case "sdiv_bits.0" => (resolveBinaryOp(BVSDIV, function, 1, typeArgs, args, ctx.getText)) + + case "slt_bits.0" => (resolveBinaryOp(BVSLT, function, 1, typeArgs, args, ctx.getText)) + case "sle_bits.0" => (resolveBinaryOp(BVSLE, function, 1, typeArgs, args, ctx.getText)) case "lsl_bits.0" => (resolveBitShiftOp(BVSHL, function, typeArgs, args, ctx.getText), None) case "lsr_bits.0" => (resolveBitShiftOp(BVLSHR, function, typeArgs, args, ctx.getText), None) case "asr_bits.0" => (resolveBitShiftOp(BVASHR, function, typeArgs, args, ctx.getText), None) case "append_bits.0" => - (resolveBinaryOp(BVCONCAT, function, 2, typeArgs, args, ctx.getText), None) + (resolveBinaryOp(BVCONCAT, function, 2, typeArgs, args, ctx.getText)) case "replicate_bits.0" => checkArgs(function, 2, 2, typeArgs.size, args.size, ctx.getText) @@ -482,16 +495,18 @@ class GTIRBLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) { typeArgs: mutable.Buffer[ExprContext], args: mutable.Buffer[ExprContext], token: String - ): Option[BinaryExpr] = { + ): (Option[BinaryExpr], Option[MemoryLoad]) = { checkArgs(function, typeArgsExpected, 2, typeArgs.size, args.size, token) // we don't currently check the size for BV ops which is the type arg // memory loads shouldn't appear inside binary operations? - val arg0 = visitExprOnly(args(0)) - val arg1 = visitExprOnly(args(1)) + val (arg0, l0) = visitExpr(args(0)) + val (arg1, l1) = visitExpr(args(1)) + assert(!(l0.isDefined && l1.isDefined), "Do not expect two loads in an expression") + val load = l0.orElse(l1) if (arg0.isDefined && arg1.isDefined) { - Some(BinaryExpr(operator, arg0.get, arg1.get)) + (Some(BinaryExpr(operator, arg0.get, arg1.get)), load) } else { - None + (None, load) } } diff --git a/src/main/scala/translating/GTIRBToIR.scala b/src/main/scala/translating/GTIRBToIR.scala index 31049706c..a7226b74f 100644 --- a/src/main/scala/translating/GTIRBToIR.scala +++ b/src/main/scala/translating/GTIRBToIR.scala @@ -162,25 +162,67 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ val procedure = uuidToProcedure(functionUUID) var blockCount = 0 for (blockUUID <- blockUUIDs) { - val block = uuidToBlock(blockUUID) + var block = uuidToBlock(blockUUID) - val statements = semanticsLoader.visitBlock(blockUUID, blockCount, block.address) - blockCount += 1 - block.statements.addAll(statements) + val _statements = semanticsLoader.visitBlock(blockUUID, blockCount, block.address) + val blockNonEmpty = _statements.nonEmpty - if (block.statements.isEmpty && !blockOutgoingEdges.contains(blockUUID)) { + for (s <- _statements) { + // add intrinsics + s match { + case d: DirectCall if !procedures.exists(d.target.name == _.name) => procedures += d.target + case _ => () + } + } + + val blocks = ArrayBuffer[List[Statement]]() + + while (_statements.nonEmpty) { + val blockStatements = _statements.takeWhile(s => !(s.isInstanceOf[Call])) + _statements.remove(0, blockStatements.size) + + val maybeCall = _statements.headOption + if (maybeCall.isDefined) { + _statements.remove(0) + } + + blocks.addOne(blockStatements.toList ++ maybeCall) + } + + blocks.toList match { + case Nil => () + case statements::Nil => { + blockCount += 1 + block.statements.addAll(statements) + } + case statements::tl => { + blockCount += 1 + block.statements.addAll(statements) + + val labels = (1 to tl.size).map(i => block.label + "_" + i) + val newBlocks = labels.zip(tl).map((l, statements) => Block(l, statements=statements)) + + newBlocks.foldLeft(block)((l, r) => { + l.replaceJump(GoTo(r)) + r + }) + + block.parent.addBlocks(newBlocks) + + block = newBlocks.last + } + } + + if (blockNonEmpty && !blockOutgoingEdges.contains(blockUUID)) { // remove blocks that are just nop padding // TODO cleanup blocks that are entirely nop but have fallthrough edges? Logger.debug(s"removing block ${block.label}") procedure.removeBlocks(block) + } else if (!blockOutgoingEdges.contains(blockUUID) || blockOutgoingEdges(blockUUID).isEmpty) { + block.replaceJump(Unreachable()) + Logger.debug(s"block ${block.label} in subroutine ${procedure.name} has no outgoing edges") } else { - if (!blockOutgoingEdges.contains(blockUUID)) { - throw Exception (s"block ${block.label} in subroutine ${procedure.name} has no outgoing edges") - } val outgoingEdges = blockOutgoingEdges(blockUUID) - if (outgoingEdges.isEmpty) { - throw Exception(s"block ${block.label} in subroutine ${procedure.name} has no outgoing edges") - } val (calls, jump) = if (outgoingEdges.size == 1) { val edge = outgoingEdges.head @@ -285,7 +327,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[ val blockLabel = convertLabel(procedure, blockUUID, blockCount) val blockAddress = blockUUIDToAddress.get(blockUUID) - val block = Block(blockLabel, blockAddress) + val block = Block(blockLabel, blockAddress, jump=Unreachable()) procedure.addBlocks(block) if (uuidToBlock.contains(blockUUID)) { // TODO this is a case that requires special consideration diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index 08cb17f53..9ef79fcfe 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -246,7 +246,7 @@ object IRTransform { val renamer = Renamer(boogieReserved) renamer.visitProgram(ctx.program) - assert(invariant.singleCallBlockEnd(ctx.program)) + assert(invariant.singleCallBlockEnd(ctx.program), "Single call at block end") } def generateProcedureSummaries( @@ -562,7 +562,7 @@ object RunUtils { val boogieTranslator = IRToBoogie(ctx.program, ctx.specification, None, q.outputPrefix, regionInjector, q.boogieTranslation) ArrayBuffer(boogieTranslator.translate) } - assert(invariant.singleCallBlockEnd(ctx.program)) + assert(invariant.singleCallBlockEnd(ctx.program), "Single call at block end") BASILResult(ctx, analysis, boogiePrograms) } @@ -643,7 +643,7 @@ object RunUtils { None } - assert(invariant.singleCallBlockEnd(ctx.program)) + assert(invariant.singleCallBlockEnd(ctx.program), "Single call at block end") Logger.debug(s"[!] Finished indirect call resolution after $iteration iterations") analysisResult.last.copy( symbolicAddresses = symResults,