Skip to content

Commit

Permalink
Checkpointing IR.
Browse files Browse the repository at this point in the history
  • Loading branch information
robby-phd committed Feb 13, 2025
1 parent 1d621da commit e176141
Showing 1 changed file with 67 additions and 29 deletions.
96 changes: 67 additions & 29 deletions frontend/shared/src/main/scala/org/sireum/lang/IRTranslator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.sireum.lang.ast.IR
import org.sireum.lang.symbol.TypeInfo
import org.sireum.lang.tipe.TypeHierarchy
import org.sireum.lang.{ast => AST}
import org.sireum.U32._

object IRTranslator {
@msig trait Fresh {
Expand Down Expand Up @@ -228,7 +229,7 @@ object IRTranslator {
return IR.Body.Basic(blocks)
}

def translateStmt(stmt: AST.Stmt, registerOpt: Option[Z]): Unit = {
def translateStmt(stmt: AST.Stmt, localOpt: Option[(String, AST.Typed)]): Unit = {
val pos = stmt.posOpt.get
stmt match {
case stmt: AST.Stmt.Var =>
Expand All @@ -237,16 +238,18 @@ object IRTranslator {
stmts = ISZ()
val t = stmt.attr.typedOpt.get
val varRhs: IR.Exp = init match {
case init: AST.Stmt.Expr =>
translateExp(init.exp)
case init: AST.Stmt.Expr => translateExp(init.exp)
case _ =>
val n = fresh.temp()
translateAssignExp(init, n)
IR.Exp.Temp(n, t, init.asStmt.posOpt.get)
val aePos = init.asStmt.posOpt.get
val id = assignExpId(Some(stmt.id.value), aePos)
stmts = stmts :+ AST.IR.Stmt.Decl(F, T, methodContext, ISZ(AST.IR.Stmt.Decl.Local(id, t)), aePos)
translateAssignExp(init, (id, t))
IR.Exp.LocalVarRef(T, methodContext, id, t, aePos)
}
stmts = stmts :+ IR.Stmt.Assign.Local(shouldCopy(t), methodContext, stmt.id.value, t, varRhs, pos)
oldStmts = oldStmts :+ IR.Stmt.Decl(F, stmt.isVal, methodContext, ISZ(IR.Stmt.Decl.Local(stmt.id.value, t)), pos)
stmts = oldStmts ++ stmts
fresh.setTemp(0)
case stmt: AST.Stmt.Assign =>
val copy = shouldCopy(stmt.rhs.typedOpt.get)
val oldStmts = stmts
Expand All @@ -255,10 +258,12 @@ object IRTranslator {
stmt.rhs match {
case rhs: AST.Stmt.Expr => return translateExp(rhs.exp)
case _ =>
val n = fresh.temp()
val t = stmt.rhs.typedOpt.get
translateAssignExp(stmt.rhs, n)
return IR.Exp.Temp(n, t, stmt.rhs.asStmt.posOpt.get)
val aePos = stmt.rhs.asStmt.posOpt.get
val id = assignExpId(None(), aePos)
val t = stmt.lhs.typedOpt.get
stmts = stmts :+ AST.IR.Stmt.Decl(F, T, methodContext, ISZ(AST.IR.Stmt.Decl.Local(id, t)), aePos)
translateAssignExp(stmt.rhs, (id, t))
return IR.Exp.LocalVarRef(T, methodContext, id, t, aePos)
}
}
stmt.lhs match {
Expand Down Expand Up @@ -291,11 +296,12 @@ object IRTranslator {
stmt.rhs match {
case rhs: AST.Stmt.Expr => return translateExp(rhs.exp)
case _ =>
val n = fresh.temp()
val t = stmt.rhs.typedOpt.get
val rhsPos = stmt.rhs.asStmt.posOpt.get
translateAssignExp(stmt.rhs, n)
return IR.Exp.Temp(n, t, rhsPos)
val aePos = stmt.rhs.asStmt.posOpt.get
val id = assignExpId(None(), aePos)
val t = stmt.lhs.typedOpt.get
stmts = stmts :+ AST.IR.Stmt.Decl(F, T, methodContext, ISZ(AST.IR.Stmt.Decl.Local(id, t)), aePos)
translateAssignExp(stmt.rhs, (id, t))
return IR.Exp.LocalVarRef(T, methodContext, id, t, aePos)
}
}

Expand All @@ -314,45 +320,53 @@ object IRTranslator {
val invokeRhs: IR.Exp = stmt.rhs match {
case rhs: AST.Stmt.Expr => translateExp(rhs.exp)
case _ =>
val n = fresh.temp()
val rhsPos = stmt.rhs.asStmt.posOpt.get
val t = stmt.rhs.typedOpt.get
translateAssignExp(stmt.rhs, n)
IR.Exp.Temp(n, t, rhsPos)
val aePos = stmt.rhs.asStmt.posOpt.get
val id = assignExpId(None(), aePos)
val t = stmt.lhs.typedOpt.get
stmts = stmts :+ AST.IR.Stmt.Decl(F, T, methodContext, ISZ(AST.IR.Stmt.Decl.Local(id, t)), aePos)
translateAssignExp(stmt.rhs, (id, t))
IR.Exp.LocalVarRef(T, methodContext, id, t, aePos)
}
stmts = stmts :+ IR.Stmt.Assign.Index(copy, receiver, index, invokeRhs, pos)
case _ => halt("Infeasible")
}
stmts = oldStmts ++ stmts
fresh.setTemp(0)
case stmt: AST.Stmt.If =>
val oldStmts = stmts
stmts = ISZ()
val cond = translateExp(stmt.cond)
val condStmts = stmts
fresh.setTemp(0)
stmts = ISZ()
translateBody(stmt.thenBody, registerOpt)
translateBody(stmt.thenBody, localOpt)
val thenPos = bodyPos(stmt.thenBody, pos)
val thenStmts = stmts
fresh.setTemp(0)
stmts = ISZ()
translateBody(stmt.elseBody, registerOpt)
translateBody(stmt.elseBody, localOpt)
val elsePos = bodyPos(stmt.elseBody, pos)
val elseStmts = stmts
stmts = oldStmts ++ condStmts :+
IR.Stmt.If(cond, IR.Stmt.Block(thenStmts, thenPos), IR.Stmt.Block(elseStmts, elsePos), pos)
fresh.setTemp(0)
case stmt: AST.Stmt.While =>
val oldStmts = stmts
stmts = ISZ()
val cond = translateExp(stmt.cond)
val condStmts = stmts
stmts = ISZ()
fresh.setTemp(0)
translateBody(stmt.body, None())
val bPos = bodyPos(stmt.body, pos)
stmts = oldStmts :+ IR.Stmt.While(IR.Stmt.Block(condStmts, cond.pos), cond, IR.Stmt.Block(stmts, bPos), pos)
fresh.setTemp(0)
case stmt: AST.Stmt.Expr =>
val e = translateExp(stmt.exp)
if (e.tipe == AST.Typed.unit || e.tipe == AST.Typed.nothing) {
stmts = stmts :+ IR.Stmt.Expr(e.asInstanceOf[IR.Exp.Apply])
}
fresh.setTemp(0)
case stmt: AST.Stmt.Return =>
stmt.expOpt match {
case Some(exp) =>
Expand All @@ -361,7 +375,10 @@ object IRTranslator {
case _ =>
stmts = stmts :+ IR.Stmt.Return(None(), pos)
}
case stmt: AST.Stmt.Block => translateBody(stmt.body, registerOpt)
fresh.setTemp(0)
case stmt: AST.Stmt.Block =>
translateBody(stmt.body, localOpt)
fresh.setTemp(0)
case stmt: AST.Stmt.Match => halt(s"TODO: $stmt")
case stmt: AST.Stmt.For => halt(s"TODO: $stmt")
case stmt: AST.Stmt.VarPattern => halt(s"TODO: $stmt")
Expand All @@ -386,19 +403,28 @@ object IRTranslator {
return body.stmts(0).posOpt.get.to(body.stmts(body.stmts.size - 1).posOpt.get)
}

def translateBody(body: AST.Body, registerOpt: Option[Z]): Unit = {
for (stmt <- body.stmts) {
translateStmt(stmt, registerOpt)
def translateBody(body: AST.Body, localOpt: Option[(String, AST.Typed)]): Unit = {
val stmts = body.stmts
localOpt match {
case Some((_, _)) =>
for (i <- 0 until stmts.size - 1) {
translateStmt(stmts(i), None())
}
translateAssignExp(stmts(stmts.size - 1).asAssignExp, localOpt.get)
case _ =>
for (stmt <- body.stmts) {
translateStmt(stmt, None())
}
}
}

def translateAssignExp(stmt: AST.AssignExp, register: Z): Unit = {
def translateAssignExp(stmt: AST.AssignExp, local: (String, AST.Typed)): Unit = {
val pos = stmt.asStmt.posOpt.get
stmt match {
case stmt: AST.Stmt.Expr =>
val exp = translateExp(stmt.exp)
stmts = stmts :+ IR.Stmt.Assign.Temp(register, exp, pos)
case _ => translateStmt(stmt.asStmt, Some(register))
stmts = stmts :+ IR.Stmt.Assign.Local(F, methodContext, local._1, local._2, exp, pos)
case _ => translateStmt(stmt.asStmt, Some(local))
}
}

Expand Down Expand Up @@ -684,4 +710,16 @@ object IRTranslator {
case exp: AST.ProofAst.StepId => halt(s"Infeasible: $exp")
}
}

@pure def sha3(s: String): U32 = {
val sha = crypto.SHA3.init512
sha.update(conversions.String.toU8is(s))
val bs = sha.finalise()
return conversions.U8.toU32(bs(0)) << u32"24" | conversions.U8.toU32(bs(1)) << u32"16" |
conversions.U8.toU32(bs(2)) << u32"8" | conversions.U8.toU32(bs(3))
}

@strictpure def assignExpId(idOpt: Option[String], pos: message.Position): String = {
st"${idOpt.getOrElse("$ae")}.${pos.beginLine}.${pos.beginColumn}.${sha3(pos.string)}".render
}
}

0 comments on commit e176141

Please sign in to comment.