Skip to content

Commit

Permalink
add labels to statements from gtirb
Browse files Browse the repository at this point in the history
  • Loading branch information
l-kent committed Feb 20, 2024
1 parent f84c7f1 commit 5477e6d
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 57 deletions.
7 changes: 5 additions & 2 deletions src/main/scala/translating/GTIRBToIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,16 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[

// maybe good to sort blocks by address around here?

val semanticsLoader = SemanticsLoader(parserMap)

for ((functionUUID, blockUUIDs) <- functionBlocks) {
val procedure = uuidToProcedure(functionUUID)
var blockCount = 0
for (blockUUID <- blockUUIDs) {
val block = uuidToBlock(blockUUID)
val semanticsLoader = SemanticsLoader(blockUUID, parserMap, blockCount)

val statements = semanticsLoader.visitBlock(blockUUID, blockCount, block.address)
blockCount += 1
val statements = semanticsLoader.createStatements()
block.statements.addAll(statements)

if (block.statements.isEmpty && !blockOutgoingEdges.contains(blockUUID)) {
Expand Down Expand Up @@ -230,6 +232,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[
uuidToProcedure += (functionUUID -> procedure)
entranceUUIDtoProcedure += (entranceUUID -> procedure)

// sort blocks by address to give a more practical order
val blockUUIDs = functionBlocks(functionUUID)
val blockUUIDsSorted = blockUUIDs.toSeq.sortBy(addresses(_))
// should probably check if empty?
Expand Down
118 changes: 64 additions & 54 deletions src/main/scala/translating/SemanticsLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,63 +13,73 @@ import scala.collection.mutable.ArrayBuffer
import com.grammatech.gtirb.proto.Module.ByteOrder.LittleEndian
import util.Logger

class SemanticsLoader(targetuuid: ByteString, parserMap: immutable.Map[String, Array[Array[StmtContext]]], blkCount: Int) extends SemanticsBaseVisitor[Any] {
class SemanticsLoader(parserMap: immutable.Map[String, Array[Array[StmtContext]]]) {

private var constMap = mutable.Map[String, IRType]()
private var varMap = mutable.Map[String, IRType]()
private val constMap = mutable.Map[String, IRType]()
private val varMap = mutable.Map[String, IRType]()
private var instructionCount = 0
private var blockCount = 0

def createStatements(): ArrayBuffer[Statement] = {
visitInstructions(parserMap(Base64.getEncoder.encodeToString(targetuuid.toByteArray)))
}
val opcodeSize = 4

def visitBlock(blockUUID: ByteString, blockCountIn: Int, blockAddress: Option[Int]): ArrayBuffer[Statement] = {
blockCount = blockCountIn
instructionCount = 0
val instructions = parserMap(Base64.getEncoder.encodeToString(blockUUID.toByteArray))

def visitInstructions(stmts: Array[Array[StmtContext]]): ArrayBuffer[Statement] = {
val statements: ArrayBuffer[Statement] = ArrayBuffer()

for (insn <- stmts) {
constMap = constMap.empty
varMap = varMap.empty
for (s <- insn) {
val s2 = visitStmt(s)
if (s2.isDefined) {
statements.append(s2.get)
for (instruction <- instructions) {
constMap.clear
varMap.clear

for ((s, i) <- instruction.zipWithIndex) {

val label = blockAddress.map {(a: Int) =>
val instructionAddress = a + (opcodeSize * instructionCount)
instructionAddress.toString + "$" + i
}

val statement = visitStmt(s, label)
if (statement.isDefined) {
statements.append(statement.get)
}
}
instructionCount += 1
}
statements
}

def visitStmt(ctx: StmtContext): Option[Statement] = {
private def visitStmt(ctx: StmtContext, label: Option[String] = None): Option[Statement] = {
ctx match {
case a: AssignContext => visitAssign(a)
case c: ConstDeclContext => visitConstDecl(c)
case v: VarDeclContext => visitVarDecl(v)
case a: AssignContext => visitAssign(a, label)
case c: ConstDeclContext => visitConstDecl(c, label)
case v: VarDeclContext => visitVarDecl(v, label)
case v: VarDeclsNoInitContext =>
visitVarDeclsNoInit(v)
None
case a: AssertContext => visitAssert(a)
case t: TCallContext => visitTCall(t)
case i: IfContext => visitIf(i)
case t: ThrowContext => Some(visitThrow(t))
case a: AssertContext => visitAssert(a, label)
case t: TCallContext => visitTCall(t, label)
case i: IfContext => visitIf(i, label)
case t: ThrowContext => Some(visitThrow(t, label))
}
}

override def visitAssert(ctx: AssertContext): Option[Assert] = {
private def visitAssert(ctx: AssertContext, label: Option[String] = None): Option[Assert] = {
val expr = visitExpr(ctx.expr)
if (expr.isDefined) {
Some(Assert(expr.get))
Some(Assert(expr.get, None, label))
} else {
None
}
}

override def visitThrow(ctx: ThrowContext): Assert = {
private def visitThrow(ctx: ThrowContext, label: Option[String] = None): Assert = {
val message = ctx.ID().asScala.map(_.getText).mkString(" ,")
Assert(FalseLiteral, Some(message))
Assert(FalseLiteral, Some(message), label)
}

override def visitTCall(ctx: TCallContext): Option[Statement] = {
private def visitTCall(ctx: TCallContext, label: Option[String] = None): Option[Statement] = {
val function = ctx.ID.getText

val typeArgs = Option(ctx.tes) match {
Expand Down Expand Up @@ -100,7 +110,7 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: immutable.Map[String, A
// LittleEndian is an assumption
if (index.isDefined && value.isDefined) {
val memstore = MemoryStore(mem, index.get, value.get, Endian.LittleEndian, size)
Some(MemoryAssign(mem, memstore))
Some(MemoryAssign(mem, memstore, label))
} else {
None
}
Expand All @@ -122,63 +132,63 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: immutable.Map[String, A
case _ => throw Exception(s"expected ${ctx.getText} to be an integer literal")
}

override def visitIf(ctx: IfContext): Option[TempIf] = {
private def visitIf(ctx: IfContext, label: Option[String] = None): Option[TempIf] = {
val condition = visitExpr(ctx.cond)
val thenStmts = ctx.stmt().asScala.flatMap(visitStmt)
val thenStmts = ctx.stmt().asScala.flatMap(visitStmt(_, label))

val elseStmts = Option(ctx.elseStmt) match {
case Some(_) => ctx.elseStmt.stmt().asScala.flatMap(visitStmt)
case Some(_) => ctx.elseStmt.stmt().asScala.flatMap(visitStmt(_, label))
case None => mutable.Buffer()
}

if (condition.isDefined) {
Some(TempIf(condition.get, thenStmts, elseStmts))
Some(TempIf(condition.get, thenStmts, elseStmts, label))
} else {
None
}
}

override def visitVarDeclsNoInit(ctx: VarDeclsNoInitContext): Unit = {
private def visitVarDeclsNoInit(ctx: VarDeclsNoInitContext): Unit = {
val ty = visitType(ctx.`type`())
ctx.lvars.ID().asScala.foreach(lvar => varMap += (lvar.getText -> ty))
}

override def visitVarDecl(ctx: VarDeclContext): Option[LocalAssign] = {
private def visitVarDecl(ctx: VarDeclContext, label: Option[String] = None): Option[LocalAssign] = {
val ty = visitType(ctx.`type`())
val name = ctx.lvar.getText
varMap += (name -> ty)

val expr = visitExpr(ctx.expr())
if (expr.isDefined) {
Some(LocalAssign(LocalVar(name, ty), expr.get))
Some(LocalAssign(LocalVar(name, ty), expr.get, label))
} else {
None
}
}

override def visitAssign(ctx: AssignContext): Option[LocalAssign] = {
private def visitAssign(ctx: AssignContext, label: Option[String] = None): Option[LocalAssign] = {
val lhs = visitLexpr(ctx.lexpr)
val rhs = visitExpr(ctx.expr)
if (lhs.isDefined && rhs.isDefined) {
Some(LocalAssign(lhs.get, rhs.get))
Some(LocalAssign(lhs.get, rhs.get, label))
} else {
None
}
}

override def visitConstDecl(ctx: ConstDeclContext): Option[LocalAssign] = {
private def visitConstDecl(ctx: ConstDeclContext, label: Option[String] = None): Option[LocalAssign] = {
val ty = visitType(ctx.`type`())
val name = ctx.lvar.getText
constMap += (name -> ty)
val expr = visitExpr(ctx.expr)
if (expr.isDefined) {
Some(LocalAssign(LocalVar(name + "_" + blkCount + "_" + instructionCount, ty), expr.get))
Some(LocalAssign(LocalVar(name + "$" + blockCount + "$" + instructionCount, ty), expr.get, label))
} else {
None
}
}

def visitType(ctx: TypeContext): IRType = {
private def visitType(ctx: TypeContext): IRType = {
ctx match
case e: TypeBitsContext => BitVecType(parseInt(e.size))
case r: TypeRegisterContext =>
Expand All @@ -192,7 +202,7 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: immutable.Map[String, A
case _ => throw Exception(s"unknown type ${ctx.getText}")
}

def visitExpr(ctx: ExprContext): Option[Expr] = {
private def visitExpr(ctx: ExprContext): Option[Expr] = {
ctx match {
case e: ExprVarContext => visitExprVar(e)
case e: ExprTApplyContext => visitExprTApply(e)
Expand All @@ -204,10 +214,10 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: immutable.Map[String, A
}
}

override def visitExprVar(ctx: ExprVarContext): Option[Expr] = {
private def visitExprVar(ctx: ExprVarContext): Option[Expr] = {
val name = ctx.ID.getText
name match {
case n if constMap.contains(n) => Some(LocalVar(n + "_" + blkCount + "_" + instructionCount, constMap(n)))
case n if constMap.contains(n) => Some(LocalVar(n + "$" + blockCount + "$" + instructionCount, constMap(n)))
case v if varMap.contains(v) => Some(LocalVar(v, varMap(v)))
case "SP_EL0" => Some(Register("R31", BitVecType(64)))
case "_PC" => Some(Register("_PC", BitVecType(64)))
Expand All @@ -222,7 +232,7 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: immutable.Map[String, A
}
}

override def visitExprTApply(ctx: ExprTApplyContext): Option[Expr] = {
private def visitExprTApply(ctx: ExprTApplyContext): Option[Expr] = {

val function = ctx.ID.getText

Expand Down Expand Up @@ -391,7 +401,7 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: immutable.Map[String, A
}
}

override def visitExprSlices(ctx: ExprSlicesContext): Option[Extract] = {
private def visitExprSlices(ctx: ExprSlicesContext): Option[Extract] = {
val slices = ctx.slices.slice().asScala
if (slices.size != 1) {
// need to determine the semantics for this case
Expand All @@ -406,7 +416,7 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: immutable.Map[String, A
}
}

def visitSliceContext(ctx: SliceContext): (Int, Int) = {
private def visitSliceContext(ctx: SliceContext): (Int, Int) = {
ctx match {
case s: Slice_HiLoContext =>
val hi = parseInt(s.hi)
Expand All @@ -419,7 +429,7 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: immutable.Map[String, A
}
}

override def visitExprField(ctx: ExprFieldContext): LocalVar = {
private def visitExprField(ctx: ExprFieldContext): LocalVar = {
val name = ctx.expr match {
case e: ExprVarContext => e.ID.getText
case _ => throw Exception(s"expected ${ctx.getText} to have an Expr_Var as first parameter")
Expand All @@ -429,7 +439,7 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: immutable.Map[String, A
resolveFieldExpr(name, field)
}

override def visitExprArray(ctx: ExprArrayContext): Register = {
private def visitExprArray(ctx: ExprArrayContext): Register = {
val name = ctx.array match {
case e: ExprVarContext => e.ID.getText
case _ => throw Exception(s"expected ${ctx.getText} to have an Expr_Var as first parameter")
Expand All @@ -439,7 +449,7 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: immutable.Map[String, A
resolveArrayExpr(name, index)
}

override def visitExprLitBits(ctx: ExprLitBitsContext): BitVecLiteral = {
private def visitExprLitBits(ctx: ExprLitBitsContext): BitVecLiteral = {
var num = BigInt(ctx.value.getText, 2)
val len = ctx.value.getText.length
if (num < 0) {
Expand All @@ -448,18 +458,18 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: immutable.Map[String, A
BitVecLiteral(num, len)
}

def visitLexpr(ctx: LexprContext): Option[Variable] = {
private def visitLexpr(ctx: LexprContext): Option[Variable] = {
ctx match {
case l: LExprVarContext => visitLExprVar(l)
case l: LExprFieldContext => Some(visitLExprField(l))
case l: LExprArrayContext => Some(visitLExprArray(l))
}
}

override def visitLExprVar(ctx: LExprVarContext): Option[Variable] = {
private def visitLExprVar(ctx: LExprVarContext): Option[Variable] = {
val name = ctx.ID.getText
name match {
case n if constMap.contains(n) => Some(LocalVar(n + "_" + blkCount + "_" + instructionCount, constMap(n)))
case n if constMap.contains(n) => Some(LocalVar(n + "$" + blockCount + "$" + instructionCount, constMap(n)))
case v if varMap.contains(v) => Some(LocalVar(v, varMap(v)))
case "SP_EL0" => Some(Register("R31", BitVecType(64)))
case "_PC" => Some(Register("_PC", BitVecType(64)))
Expand All @@ -473,7 +483,7 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: immutable.Map[String, A
}
}

override def visitLExprField(ctx: LExprFieldContext): LocalVar = {
private def visitLExprField(ctx: LExprFieldContext): LocalVar = {
val name = ctx.lexpr match {
case l: LExprVarContext => l.ID.getText
case _ => throw Exception(s"expected ${ctx.getText} to have an LExpr_Var as first parameter")
Expand All @@ -483,7 +493,7 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: immutable.Map[String, A
resolveFieldExpr(name, field)
}

override def visitLExprArray(ctx: LExprArrayContext): Register = {
private def visitLExprArray(ctx: LExprArrayContext): Register = {
val name = ctx.lexpr match {
case l: LExprVarContext => l.ID.getText
case _ => throw Exception(s"expected ${ctx.getText} to have an LExpr_Var as first parameter")
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/util/RunUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.antlr.v4.runtime.{CharStreams, CommonTokenStream}
import translating.*
import util.Logger
import java.util.Base64
import spray.json.DefaultJsonProtocol._
import spray.json.DefaultJsonProtocol.*
import intrusivelist.IntrusiveList
import analysis.CfgCommandNode

Expand Down

0 comments on commit 5477e6d

Please sign in to comment.