Skip to content

Commit

Permalink
cntlm test, fixed lots of little errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Megatomato committed Feb 5, 2024
1 parent d111b60 commit 0d344f0
Show file tree
Hide file tree
Showing 8 changed files with 720 additions and 41 deletions.
Binary file added examples/cntlm_gtirb/cntlm
Binary file not shown.
Binary file added examples/cntlm_gtirb/cntlm.gts
Binary file not shown.
648 changes: 648 additions & 0 deletions examples/cntlm_gtirb/cntlm.relf

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions src/main/antlr4/Semantics.g4
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ assignment_stmt:
'Stmt_Assign' OPEN_PAREN lexpr COMMA expr CLOSE_PAREN # Assign
| 'Stmt_ConstDecl' OPEN_PAREN type COMMA METHOD COMMA expr CLOSE_PAREN # ConstDecl
| 'Stmt_VarDecl' OPEN_PAREN type COMMA METHOD COMMA expr CLOSE_PAREN # VarDecl
| 'Stmt_VarDeclsNoInit' OPEN_PAREN type COMMA OPEN_BRACKET OPEN_PAREN METHOD (COMMA METHOD)* CLOSE_PAREN CLOSE_BRACKET CLOSE_PAREN # VarDeclsNoInit;
| 'Stmt_VarDeclsNoInit' OPEN_PAREN type COMMA OPEN_BRACKET OPEN_PAREN METHOD (COMMA METHOD)* CLOSE_PAREN CLOSE_BRACKET CLOSE_PAREN # VarDeclsNoInit
| 'Stmt_Assert' OPEN_PAREN expr CLOSE_PAREN # Assert;

call_stmt:
'Stmt_TCall' OPEN_PAREN (SSYMBOL | METHOD) (
Expand All @@ -26,7 +27,7 @@ call_stmt:

conditional_stmt:
'Stmt_If' OPEN_PAREN expr COMMA OPEN_BRACKET stmt* COMMA? CLOSE_BRACKET COMMA
OPEN_BRACKET CLOSE_BRACKET COMMA (OPEN_PAREN 'else' conditional_stmt CLOSE_PAREN)? (OPEN_PAREN 'else' else_stmt CLOSE_PAREN)? CLOSE_PAREN;
OPEN_BRACKET CLOSE_BRACKET COMMA (OPEN_PAREN 'else' conditional_stmt CLOSE_PAREN)? (OPEN_PAREN 'else' else_stmt* CLOSE_PAREN)? CLOSE_PAREN;
else_stmt: stmt;

type : 'Type_Bits' OPEN_PAREN expr CLOSE_PAREN # TypeBits;
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/boogie/BExpr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ case class BinaryBExpr(op: BinOp, arg1: BExpr, arg2: BExpr) extends BExpr {
if (bv1.size == bv2.size) {
bv1
} else {
println(s"$arg1, $arg2")
throw new Exception("bitvector size mismatch")
}
case BVCOMP =>
Expand Down
23 changes: 10 additions & 13 deletions src/main/scala/translating/GtirbToIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ import intrusivelist.{IntrusiveList, IntrusiveListElement}
*
*/
class TempIf(val isLongIf: Boolean, val conds: ArrayBuffer[Expr], val stmts: ArrayBuffer[ArrayBuffer[Statement]],
val elseStatement: Option[Statement] = None, override val label: Option[String] = None) extends Assert(conds.head) {}
val elseStatement: ArrayBuffer[Statement], override val label: Option[String] = None) extends Assert(conds.head) {}

object TempIf {
def unapply(tempIf: TempIf): Option[(Boolean, ArrayBuffer[Expr], ArrayBuffer[ArrayBuffer[Statement]], Option[Statement], Option[String])] = {
def unapply(tempIf: TempIf): Option[(Boolean, ArrayBuffer[Expr], ArrayBuffer[ArrayBuffer[Statement]], ArrayBuffer[Statement], Option[String])] = {
Some((tempIf.isLongIf, tempIf.conds, tempIf.stmts, tempIf.elseStatement, tempIf.label))
}
}
Expand Down Expand Up @@ -251,7 +251,7 @@ class GtirbToIR (mods: Seq[com.grammatech.gtirb.proto.Module.Module], parserMap:

blks ++= funcblks


extraBlkCount = 2
blks = blks.flatMap(handle_long_if)

return blks.flatMap(handle_unlifted_indirects)
Expand Down Expand Up @@ -354,7 +354,7 @@ class GtirbToIR (mods: Seq[com.grammatech.gtirb.proto.Module.Module], parserMap:
*/
def handle_long_if(blk: Block): ArrayBuffer[Block] = {

val create_TempIf: Expr => TempIf = conds => TempIf(false, ArrayBuffer(conds), ArrayBuffer[ArrayBuffer[Statement]]())
val create_TempIf: Expr => TempIf = conds => TempIf(false, ArrayBuffer(conds), ArrayBuffer[ArrayBuffer[Statement]](), ArrayBuffer[Statement]())

val create_endEdge: (String, String) => proto.CFG.Edge = (uuid, endUuid) =>
proto.CFG.Edge(get_ByteString(uuid), get_ByteString(endUuid), Option(proto.CFG.EdgeLabel(false, true, proto.CFG.EdgeType.Type_Fallthrough)))
Expand All @@ -364,7 +364,6 @@ class GtirbToIR (mods: Seq[com.grammatech.gtirb.proto.Module.Module], parserMap:

while (block.statements.exists {elem => elem.isInstanceOf[TempIf] && elem.asInstanceOf[TempIf].isLongIf}) {
extraBlkCount += 1

// Split Blk in two and remove IfStmt
val ifStmt = block.statements.find {elem => elem.isInstanceOf[TempIf] && elem.asInstanceOf[TempIf].isLongIf}.get.asInstanceOf[TempIf]

Expand All @@ -373,7 +372,6 @@ class GtirbToIR (mods: Seq[com.grammatech.gtirb.proto.Module.Module], parserMap:
startStmts.remove(ifStmt)
startStmts.append(create_TempIf(ifStmt.conds.remove(0)))


val startBlk = Block(block.label, block.address, startStmts)
val endBlk = Block(get_blkLabel(block.label, extraBlkCount), block.address, endStmts)
edgeMap += (endBlk.label -> edgeMap.get(startBlk.label).get)
Expand All @@ -382,20 +380,19 @@ class GtirbToIR (mods: Seq[com.grammatech.gtirb.proto.Module.Module], parserMap:
blkMap(get_ByteString(startBlk.label)) = startBlk
blkMap += (get_ByteString(endBlk.label) -> endBlk)


// falseBlock creation
val tempFalseBlks: ArrayBuffer[Block] = ifStmt.conds.map { stmts =>
val label = get_blkLabel(block.label, extraBlkCount)
val falseBlock = Block(label, block.address, ArrayBuffer(create_TempIf(stmts)))
extraBlkCount += 1
falseBlock
}


// trueBlock creation
val trueBlks: ArrayBuffer[Block] = ifStmt.stmts.map { stmts =>
val label = get_blkLabel(block.label, extraBlkCount)
val trueBlock = Block(label, block.address, stmts)

val edge = create_endEdge(label, endBlk.label)
edgeMap += (label -> ArrayBuffer(edge))
blkMap += (get_ByteString(label) -> trueBlock)
Expand All @@ -405,10 +402,9 @@ class GtirbToIR (mods: Seq[com.grammatech.gtirb.proto.Module.Module], parserMap:

// Adds Else to FalseBlocks
val label = get_blkLabel(block.label, extraBlkCount)
val elseBlk = Block(label, block.address, ArrayBuffer(ifStmt.elseStatement.get))

val elseBlk = Block(label, block.address, ifStmt.elseStatement)
val edge = create_endEdge(label, endBlk.label)
edgeMap += (label -> ArrayBuffer(edge))
edgeMap += (elseBlk.label -> ArrayBuffer(edge))
blkMap += (get_ByteString(elseBlk.label) -> elseBlk)
val falseBlks = (startBlk +: tempFalseBlks.toList).to(ArrayBuffer) += elseBlk
extraBlkCount += 1
Expand Down Expand Up @@ -479,7 +475,7 @@ class GtirbToIR (mods: Seq[com.grammatech.gtirb.proto.Module.Module], parserMap:

// Ditto above, but handles the case where a block has more than one outgoing edge
def multiJump(procedures: ArrayBuffer[Procedure], block: Block, edges: ArrayBuffer[proto.CFG.Edge],
entries: List[ByteString], ifStmt: Statement): Either[Jump, ArrayBuffer[Block]] = {
entries: List[ByteString]): Either[Jump, ArrayBuffer[Block]] = {

val types = edges.map(_.label.get.`type`)

Expand Down Expand Up @@ -510,6 +506,7 @@ class GtirbToIR (mods: Seq[com.grammatech.gtirb.proto.Module.Module], parserMap:

case (false, true, _) => //If statement, need to create TRUE and FALSE blocks that contain asserts
val blks = ArrayBuffer[Block]()
val ifStmt = block.statements.lastElem.get
val cond: Statement = Assume(ifStmt.asInstanceOf[TempIf].conds(0), checkSecurity=true)
val notCond = Assume(UnaryExpr(BoolNOT, ifStmt.asInstanceOf[TempIf].conds(0)), checkSecurity=true) // Inverted Condition
edges.foreach { elem =>
Expand Down Expand Up @@ -550,7 +547,7 @@ class GtirbToIR (mods: Seq[com.grammatech.gtirb.proto.Module.Module], parserMap:
val edges = edgeMap(b.label)

if (edges.size > 1) {
multiJump(cpy, b, edges, entries, b.statements(b.statements.size - 1)) match {
multiJump(cpy, b, edges, entries) match {
case Left(jump) =>
b.replaceJump(jump)
case Right(blks) =>
Expand Down
66 changes: 45 additions & 21 deletions src/main/scala/translating/SemanticsLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: Map[String, Array[Array
statements.to(ArrayBuffer)
}

def visitAssignment_stmt(ctx: Assignment_stmtContext): Option[LocalAssign] = {
def visitAssignment_stmt(ctx: Assignment_stmtContext): Option[Statement] = {
ctx match
case a: AssignContext => return Option(visitAssign(a))

Expand All @@ -55,6 +55,13 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: Map[String, Array[Array
case v: VarDeclContext => return Option(visitVarDecl(v))

case v: VarDeclsNoInitContext => return visitVarDeclsNoInit(v)

case a: AssertContext => return Option(visitAssert(a))
}

override def visitAssert(ctx: AssertContext): Assert = {
val expr = visitExpr(ctx.expr)
return Assert(expr)
}


Expand All @@ -69,8 +76,6 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: Map[String, Array[Array
}

override def visitConditional_stmt(ctx: Conditional_stmtContext): TempIf = {


val totalStmts = ArrayBuffer.newBuilder[ArrayBuffer[Statement]]
val conds = ArrayBuffer.newBuilder[Expr]

Expand All @@ -93,20 +98,21 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: Map[String, Array[Array
}

val elseBranch = prevContext.else_stmt()
if (elseBranch != null) {


val elseStmt = elseBranch.stmt() match {
case a if (a.assignment_stmt() != null) =>
visitAssignment_stmt(a.assignment_stmt())
case c if (c.call_stmt() != null) =>
Option(visitCall_stmt(c.call_stmt()))
case _ => None
}
if (!elseBranch.isEmpty) {

val elseStmt: ArrayBuffer[Statement] = elseBranch.asScala.flatMap { elem =>
elem.stmt() match {
case a if (a.assignment_stmt() != null) =>
visitAssignment_stmt(a.assignment_stmt())
case c if (c.call_stmt() != null) =>
Option(visitCall_stmt(c.call_stmt()))
case _ => None
}
}.to(ArrayBuffer)
return TempIf(true, conds.result(), totalStmts.result(), elseStmt)

} else {
return TempIf(false, conds.result(), totalStmts.result())
return TempIf(false, conds.result(), totalStmts.result(), ArrayBuffer[Statement]())
}

}
Expand Down Expand Up @@ -145,9 +151,7 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: Map[String, Array[Array
val ty = visitType(ctx.`type`())
val name = ctx.METHOD().getText()

if (name.startsWith("Cse")) {
cseMap += (name -> ty)
}
cseMap += (name -> ty)

val expr = visitExpr(ctx.expr())
if (expr != null) {
Expand Down Expand Up @@ -208,16 +212,32 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: Map[String, Array[Array
def fix_size(expr1: Expr, expr2: Expr): Expr = {

val size1 = expr1 match {
case e: Extract => e.body.asInstanceOf[Register].irType.asInstanceOf[BitVecType].size
case e: Extract => e.end - e.start
case r: Register => r.irType.asInstanceOf[BitVecType].size
case b: BitVecLiteral => b.size
case z: ZeroExtend =>
val innerSize = z.body match {
case e: Extract => e.end - e.start
case r: Register => r.irType.asInstanceOf[BitVecType].size
case b: BitVecLiteral => b.size
case _ => ???
}
innerSize + z.extension
case _ => ???
}

val size2 = expr2 match {
case e: Extract => e.body.asInstanceOf[Register].irType.asInstanceOf[BitVecType].size
case e: Extract => e.end - e.start
case r: Register => r.irType.asInstanceOf[BitVecType].size
case b: BitVecLiteral => b.size
case z: ZeroExtend =>
val innerSize = z.body match {
case e: Extract => e.end - e.start
case r: Register => r.irType.asInstanceOf[BitVecType].size
case b: BitVecLiteral => b.size
case _ => ???
}
innerSize + z.extension
case _ => ???
}

Expand Down Expand Up @@ -263,7 +283,10 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: Map[String, Array[Array
case "sub_bits" => return BinaryExpr(BVSUB, visitExpr(ctx.expr(0)), visitExpr(ctx.expr(1)))
case "mul_bits" => return BinaryExpr(BVMUL, visitExpr(ctx.expr(0)), visitExpr(ctx.expr(1)))
case "sdiv_bits" => return BinaryExpr(BVSDIV, visitExpr(ctx.expr(0)), visitExpr(ctx.expr(1)))
case "lsl_bits" => return ??? // can't find logical shift left binop?
case "lsl_bits" =>
val expr1 = visitExpr(ctx.expr(0))
val expr2 = fix_size(expr1, visitExpr(ctx.expr(1)))
return BinaryExpr(BVSHL, expr1, expr2) // need to fix size here?
case "lsr_bits" =>
val expr1 = visitExpr(ctx.expr(0))
val expr2 = fix_size(expr1, visitExpr(ctx.expr(1)))
Expand Down Expand Up @@ -413,7 +436,7 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: Map[String, Array[Array

def createExprVar(name: String): Expr = {
name match
case n if n.startsWith("Cse") => return LocalVar(n.dropRight(3) + "_" + blkCount + instructionCount, cseMap.get(n).get)
case n if cseMap.contains(n) => return LocalVar(n.dropRight(3) + "_" + blkCount + instructionCount, cseMap.get(n).get)
case v if varMap.contains(v) => return LocalVar(v + "_" + blkCount + instructionCount, varMap.get(v).get)
case "TRUE" => return TrueLiteral
case "FALSE" => return FalseLiteral
Expand All @@ -425,6 +448,7 @@ class SemanticsLoader(targetuuid: ByteString, parserMap: Map[String, Array[Array
) // "_PC" flag, useful for jumps later on
case "__BranchTaken" => null
case "BTypeNext" => null
case "BTypeCompatible" => null
}

def createExprVarArray(v: ArrayBuffer[String]): Variable = {
Expand Down
18 changes: 13 additions & 5 deletions src/main/scala/util/RunUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import specification.*
import Parsers.*
import Parsers.SemanticsParser.*
import org.antlr.v4.runtime.tree.ParseTreeWalker
import org.antlr.v4.runtime.BailErrorStrategy
import org.antlr.v4.runtime.{CharStreams, CommonTokenStream}
import translating.*
import util.Logger
Expand Down Expand Up @@ -63,11 +64,18 @@ object RunUtils {
val semantics = mods.map(_.auxData("ast").data.toStringUtf8.parseJson.convertTo[Map[String, Array[Array[String]]]]);

def parse_insn (f: String) : StmtContext = {
val semanticsLexer = SemanticsLexer(CharStreams.fromString(f))
val tokens = CommonTokenStream(semanticsLexer)
val parser = SemanticsParser(tokens)
parser.setBuildParseTree(true)
return parser.stmt()
try {
val semanticsLexer = SemanticsLexer(CharStreams.fromString(f))
val tokens = CommonTokenStream(semanticsLexer)
val parser = SemanticsParser(tokens)
parser.setErrorHandler(new BailErrorStrategy())
parser.setBuildParseTree(true)
return parser.stmt()
} catch {
case e: org.antlr.v4.runtime.misc.ParseCancellationException =>
println(f)
throw new RuntimeException(e)
}
}

val parserMap = semantics.map(_.map(((k: String,v: Array[Array[String]]) => (k, v.map(_.map(parse_insn(_)))))))
Expand Down

0 comments on commit 0d344f0

Please sign in to comment.