Skip to content

Commit

Permalink
Checkpointing IR.
Browse files Browse the repository at this point in the history
  • Loading branch information
robby-phd committed Feb 4, 2025
1 parent 62369a4 commit 88c6873
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 56 deletions.
19 changes: 10 additions & 9 deletions ast/shared/src/main/scala/org/sireum/lang/ast/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -279,27 +279,27 @@ object IR {

object Decl {

@datatype trait Ground extends Decl {
@strictpure def undeclare: Ground
@datatype trait Single extends Decl {
@strictpure def undeclare: Single
}

@datatype class Local(val undecl: B, val isVal: B, val tipe: Typed, val id: String, val pos: Position) extends Ground {
@strictpure def undeclare: Ground = {
@datatype class Local(val undecl: B, val isVal: B, val tipe: Typed, val id: String, val pos: Position) extends Single {
@strictpure def undeclare: Single = {
val thiz = this
thiz(undecl = T)
}
@strictpure def prettyST: ST = st"${if (undecl) "de" else ""}local $id: $tipe"
}

@datatype class Temp(val undecl: B, val tipe: Typed, val n: Z, val pos: Position) extends Ground {
@strictpure def undeclare: Ground = {
@datatype class Temp(val undecl: B, val tipe: Typed, val n: Z, val pos: Position) extends Single {
@strictpure def undeclare: Single = {
val thiz = this
thiz(undecl = T)
}
@strictpure def prettyST: ST = st"${if (undecl) "de" else ""}register $$$n: $tipe"
@strictpure def prettyST: ST = st"${if (undecl) "un" else ""}decl $$$n: $tipe"
}

@datatype class Multiple(val undecl: B, val decls: ISZ[Ground]) extends Decl {
@datatype class Multiple(val undecl: B, val decls: ISZ[Single]) extends Decl {
@strictpure def pos: Position = decls(0).pos.to(decls(decls.size - 1).pos)
@strictpure def undeclare: Decl = {
val thiz = this
Expand Down Expand Up @@ -387,7 +387,8 @@ object IR {
val pos: Position) {
@strictpure def prettyST: ST = {
val pt: ST = if (typeParams.isEmpty) st"" else st"[${(typeParams, ", ")}]"
st"procedure ${(owner, ".")}${if (isInObject) "." else "#"}$id$pt(${(for (p <- ops.ISZOps(paramNames).zip(tipe.args)) yield st"${p._1}: ${p._2}", ", ")}): ${tipe.ret} ${body.prettyST}"
val ownerOpt: Option[ST] = if (owner.isEmpty) None() else Some(st"${(owner, ".")}${if (isInObject) "." else "#"}")
st"procedure $ownerOpt$id$pt(${(for (p <- ops.ISZOps(paramNames).zip(tipe.args)) yield st"${p._1}: ${p._2}", ", ")}): ${tipe.ret} ${body.prettyST}"
}
@pure override def string: String = {
return prettyST.render
Expand Down
32 changes: 16 additions & 16 deletions ast/shared/src/main/scala/org/sireum/lang/ast/IRTransformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -341,18 +341,18 @@ object IRTransformer {
}
}

@pure def preIRStmtDeclGround(ctx: Context, o: IR.Stmt.Decl.Ground): PreResult[Context, IR.Stmt.Decl.Ground] = {
@pure def preIRStmtDeclSingle(ctx: Context, o: IR.Stmt.Decl.Single): PreResult[Context, IR.Stmt.Decl.Single] = {
o match {
case o: IR.Stmt.Decl.Local => return preIRStmtDeclLocal(ctx, o)
case o: IR.Stmt.Decl.Temp => return preIRStmtDeclTemp(ctx, o)
}
}

@pure def preIRStmtDeclLocal(ctx: Context, o: IR.Stmt.Decl.Local): PreResult[Context, IR.Stmt.Decl.Ground] = {
@pure def preIRStmtDeclLocal(ctx: Context, o: IR.Stmt.Decl.Local): PreResult[Context, IR.Stmt.Decl.Single] = {
return PreResult(ctx, T, None())
}

@pure def preIRStmtDeclTemp(ctx: Context, o: IR.Stmt.Decl.Temp): PreResult[Context, IR.Stmt.Decl.Ground] = {
@pure def preIRStmtDeclTemp(ctx: Context, o: IR.Stmt.Decl.Temp): PreResult[Context, IR.Stmt.Decl.Single] = {
return PreResult(ctx, T, None())
}

Expand Down Expand Up @@ -772,18 +772,18 @@ object IRTransformer {
}
}

@pure def postIRStmtDeclGround(ctx: Context, o: IR.Stmt.Decl.Ground): TPostResult[Context, IR.Stmt.Decl.Ground] = {
@pure def postIRStmtDeclSingle(ctx: Context, o: IR.Stmt.Decl.Single): TPostResult[Context, IR.Stmt.Decl.Single] = {
o match {
case o: IR.Stmt.Decl.Local => return postIRStmtDeclLocal(ctx, o)
case o: IR.Stmt.Decl.Temp => return postIRStmtDeclTemp(ctx, o)
}
}

@pure def postIRStmtDeclLocal(ctx: Context, o: IR.Stmt.Decl.Local): TPostResult[Context, IR.Stmt.Decl.Ground] = {
@pure def postIRStmtDeclLocal(ctx: Context, o: IR.Stmt.Decl.Local): TPostResult[Context, IR.Stmt.Decl.Single] = {
return TPostResult(ctx, None())
}

@pure def postIRStmtDeclTemp(ctx: Context, o: IR.Stmt.Decl.Temp): TPostResult[Context, IR.Stmt.Decl.Ground] = {
@pure def postIRStmtDeclTemp(ctx: Context, o: IR.Stmt.Decl.Temp): TPostResult[Context, IR.Stmt.Decl.Single] = {
return TPostResult(ctx, None())
}

Expand Down Expand Up @@ -1203,7 +1203,7 @@ import IRTransformer._
else
TPostResult(r0.ctx, None())
case o2: IR.Stmt.Decl.Multiple =>
val r0: TPostResult[Context, IS[Z, IR.Stmt.Decl.Ground]] = transformISZ(preR.ctx, o2.decls, transformIRStmtDeclGround _)
val r0: TPostResult[Context, IS[Z, IR.Stmt.Decl.Single]] = transformISZ(preR.ctx, o2.decls, transformIRStmtDeclSingle _)
if (hasChanged || r0.resultOpt.nonEmpty)
TPostResult(r0.ctx, Some(o2(decls = r0.resultOpt.getOrElse(o2.decls))))
else
Expand Down Expand Up @@ -1282,7 +1282,7 @@ import IRTransformer._
else
TPostResult(r0.ctx, None())
case o2: IR.Stmt.Decl.Multiple =>
val r0: TPostResult[Context, IS[Z, IR.Stmt.Decl.Ground]] = transformISZ(preR.ctx, o2.decls, transformIRStmtDeclGround _)
val r0: TPostResult[Context, IS[Z, IR.Stmt.Decl.Single]] = transformISZ(preR.ctx, o2.decls, transformIRStmtDeclSingle _)
if (hasChanged || r0.resultOpt.nonEmpty)
TPostResult(r0.ctx, Some(o2(decls = r0.resultOpt.getOrElse(o2.decls))))
else
Expand Down Expand Up @@ -1386,7 +1386,7 @@ import IRTransformer._
else
TPostResult(r0.ctx, None())
case o2: IR.Stmt.Decl.Multiple =>
val r0: TPostResult[Context, IS[Z, IR.Stmt.Decl.Ground]] = transformISZ(preR.ctx, o2.decls, transformIRStmtDeclGround _)
val r0: TPostResult[Context, IS[Z, IR.Stmt.Decl.Single]] = transformISZ(preR.ctx, o2.decls, transformIRStmtDeclSingle _)
if (hasChanged || r0.resultOpt.nonEmpty)
TPostResult(r0.ctx, Some(o2(decls = r0.resultOpt.getOrElse(o2.decls))))
else
Expand All @@ -1410,12 +1410,12 @@ import IRTransformer._
}
}

@pure def transformIRStmtDeclGround(ctx: Context, o: IR.Stmt.Decl.Ground): TPostResult[Context, IR.Stmt.Decl.Ground] = {
val preR: PreResult[Context, IR.Stmt.Decl.Ground] = pp.preIRStmtDeclGround(ctx, o)
val r: TPostResult[Context, IR.Stmt.Decl.Ground] = if (preR.continu) {
val o2: IR.Stmt.Decl.Ground = preR.resultOpt.getOrElse(o)
@pure def transformIRStmtDeclSingle(ctx: Context, o: IR.Stmt.Decl.Single): TPostResult[Context, IR.Stmt.Decl.Single] = {
val preR: PreResult[Context, IR.Stmt.Decl.Single] = pp.preIRStmtDeclSingle(ctx, o)
val r: TPostResult[Context, IR.Stmt.Decl.Single] = if (preR.continu) {
val o2: IR.Stmt.Decl.Single = preR.resultOpt.getOrElse(o)
val hasChanged: B = preR.resultOpt.nonEmpty
val rOpt: TPostResult[Context, IR.Stmt.Decl.Ground] = o2 match {
val rOpt: TPostResult[Context, IR.Stmt.Decl.Single] = o2 match {
case o2: IR.Stmt.Decl.Local =>
val r0: TPostResult[Context, Typed] = transformTyped(preR.ctx, o2.tipe)
if (hasChanged || r0.resultOpt.nonEmpty)
Expand All @@ -1436,8 +1436,8 @@ import IRTransformer._
TPostResult(preR.ctx, None())
}
val hasChanged: B = r.resultOpt.nonEmpty
val o2: IR.Stmt.Decl.Ground = r.resultOpt.getOrElse(o)
val postR: TPostResult[Context, IR.Stmt.Decl.Ground] = pp.postIRStmtDeclGround(r.ctx, o2)
val o2: IR.Stmt.Decl.Single = r.resultOpt.getOrElse(o)
val postR: TPostResult[Context, IR.Stmt.Decl.Single] = pp.postIRStmtDeclSingle(r.ctx, o2)
if (postR.resultOpt.nonEmpty) {
return postR
} else if (hasChanged) {
Expand Down
40 changes: 20 additions & 20 deletions ast/shared/src/main/scala/org/sireum/lang/ast/MIRTransformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,13 @@ object MIRTransformer {

val PostResultIRStmtReturn: MOption[IR.Stmt] = MNone()

val PreResultIRStmtDeclLocal: PreResult[IR.Stmt.Decl.Ground] = PreResult(T, MNone())
val PreResultIRStmtDeclLocal: PreResult[IR.Stmt.Decl.Single] = PreResult(T, MNone())

val PostResultIRStmtDeclLocal: MOption[IR.Stmt.Decl.Ground] = MNone()
val PostResultIRStmtDeclLocal: MOption[IR.Stmt.Decl.Single] = MNone()

val PreResultIRStmtDeclTemp: PreResult[IR.Stmt.Decl.Ground] = PreResult(T, MNone())
val PreResultIRStmtDeclTemp: PreResult[IR.Stmt.Decl.Single] = PreResult(T, MNone())

val PostResultIRStmtDeclTemp: MOption[IR.Stmt.Decl.Ground] = MNone()
val PostResultIRStmtDeclTemp: MOption[IR.Stmt.Decl.Single] = MNone()

val PreResultIRStmtDeclMultiple: PreResult[IR.Stmt.Decl] = PreResult(T, MNone())

Expand Down Expand Up @@ -579,18 +579,18 @@ import MIRTransformer._
}
}

def preIRStmtDeclGround(o: IR.Stmt.Decl.Ground): PreResult[IR.Stmt.Decl.Ground] = {
def preIRStmtDeclSingle(o: IR.Stmt.Decl.Single): PreResult[IR.Stmt.Decl.Single] = {
o match {
case o: IR.Stmt.Decl.Local => return preIRStmtDeclLocal(o)
case o: IR.Stmt.Decl.Temp => return preIRStmtDeclTemp(o)
}
}

def preIRStmtDeclLocal(o: IR.Stmt.Decl.Local): PreResult[IR.Stmt.Decl.Ground] = {
def preIRStmtDeclLocal(o: IR.Stmt.Decl.Local): PreResult[IR.Stmt.Decl.Single] = {
return PreResultIRStmtDeclLocal
}

def preIRStmtDeclTemp(o: IR.Stmt.Decl.Temp): PreResult[IR.Stmt.Decl.Ground] = {
def preIRStmtDeclTemp(o: IR.Stmt.Decl.Temp): PreResult[IR.Stmt.Decl.Single] = {
return PreResultIRStmtDeclTemp
}

Expand Down Expand Up @@ -1010,18 +1010,18 @@ import MIRTransformer._
}
}

def postIRStmtDeclGround(o: IR.Stmt.Decl.Ground): MOption[IR.Stmt.Decl.Ground] = {
def postIRStmtDeclSingle(o: IR.Stmt.Decl.Single): MOption[IR.Stmt.Decl.Single] = {
o match {
case o: IR.Stmt.Decl.Local => return postIRStmtDeclLocal(o)
case o: IR.Stmt.Decl.Temp => return postIRStmtDeclTemp(o)
}
}

def postIRStmtDeclLocal(o: IR.Stmt.Decl.Local): MOption[IR.Stmt.Decl.Ground] = {
def postIRStmtDeclLocal(o: IR.Stmt.Decl.Local): MOption[IR.Stmt.Decl.Single] = {
return PostResultIRStmtDeclLocal
}

def postIRStmtDeclTemp(o: IR.Stmt.Decl.Temp): MOption[IR.Stmt.Decl.Ground] = {
def postIRStmtDeclTemp(o: IR.Stmt.Decl.Temp): MOption[IR.Stmt.Decl.Single] = {
return PostResultIRStmtDeclTemp
}

Expand Down Expand Up @@ -1403,7 +1403,7 @@ import MIRTransformer._
else
MNone()
case o2: IR.Stmt.Decl.Multiple =>
val r0: MOption[IS[Z, IR.Stmt.Decl.Ground]] = transformISZ(o2.decls, transformIRStmtDeclGround _)
val r0: MOption[IS[Z, IR.Stmt.Decl.Single]] = transformISZ(o2.decls, transformIRStmtDeclSingle _)
if (hasChanged || r0.nonEmpty)
MSome(o2(decls = r0.getOrElse(o2.decls)))
else
Expand Down Expand Up @@ -1482,7 +1482,7 @@ import MIRTransformer._
else
MNone()
case o2: IR.Stmt.Decl.Multiple =>
val r0: MOption[IS[Z, IR.Stmt.Decl.Ground]] = transformISZ(o2.decls, transformIRStmtDeclGround _)
val r0: MOption[IS[Z, IR.Stmt.Decl.Single]] = transformISZ(o2.decls, transformIRStmtDeclSingle _)
if (hasChanged || r0.nonEmpty)
MSome(o2(decls = r0.getOrElse(o2.decls)))
else
Expand Down Expand Up @@ -1586,7 +1586,7 @@ import MIRTransformer._
else
MNone()
case o2: IR.Stmt.Decl.Multiple =>
val r0: MOption[IS[Z, IR.Stmt.Decl.Ground]] = transformISZ(o2.decls, transformIRStmtDeclGround _)
val r0: MOption[IS[Z, IR.Stmt.Decl.Single]] = transformISZ(o2.decls, transformIRStmtDeclSingle _)
if (hasChanged || r0.nonEmpty)
MSome(o2(decls = r0.getOrElse(o2.decls)))
else
Expand All @@ -1610,12 +1610,12 @@ import MIRTransformer._
}
}

def transformIRStmtDeclGround(o: IR.Stmt.Decl.Ground): MOption[IR.Stmt.Decl.Ground] = {
val preR: PreResult[IR.Stmt.Decl.Ground] = preIRStmtDeclGround(o)
val r: MOption[IR.Stmt.Decl.Ground] = if (preR.continu) {
val o2: IR.Stmt.Decl.Ground = preR.resultOpt.getOrElse(o)
def transformIRStmtDeclSingle(o: IR.Stmt.Decl.Single): MOption[IR.Stmt.Decl.Single] = {
val preR: PreResult[IR.Stmt.Decl.Single] = preIRStmtDeclSingle(o)
val r: MOption[IR.Stmt.Decl.Single] = if (preR.continu) {
val o2: IR.Stmt.Decl.Single = preR.resultOpt.getOrElse(o)
val hasChanged: B = preR.resultOpt.nonEmpty
val rOpt: MOption[IR.Stmt.Decl.Ground] = o2 match {
val rOpt: MOption[IR.Stmt.Decl.Single] = o2 match {
case o2: IR.Stmt.Decl.Local =>
val r0: MOption[Typed] = transformTyped(o2.tipe)
if (hasChanged || r0.nonEmpty)
Expand All @@ -1636,8 +1636,8 @@ import MIRTransformer._
MNone()
}
val hasChanged: B = r.nonEmpty
val o2: IR.Stmt.Decl.Ground = r.getOrElse(o)
val postR: MOption[IR.Stmt.Decl.Ground] = postIRStmtDeclGround(o2)
val o2: IR.Stmt.Decl.Single = r.getOrElse(o)
val postR: MOption[IR.Stmt.Decl.Single] = postIRStmtDeclSingle(o2)
if (postR.nonEmpty) {
return postR
} else if (hasChanged) {
Expand Down
26 changes: 15 additions & 11 deletions frontend/shared/src/main/scala/org/sireum/lang/IRTranslator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ import org.sireum.lang.{ast => AST}
object IRTranslator {
@record class BlockDeclPreamble extends MIRTransformer {
override def postIRStmtBlock(o: Stmt.Block): MOption[IR.Stmt] = {
var decls = ISZ[IR.Stmt.Decl.Ground]()
var decls = ISZ[IR.Stmt.Decl.Single]()
var nonDecls = ISZ[IR.Stmt]()
for (stmt <- o.stmts) {
stmt match {
case stmt: IR.Stmt.Decl.Ground => decls = decls :+ stmt
case stmt: IR.Stmt.Decl.Single => decls = decls :+ stmt
case _ => nonDecls = nonDecls :+ stmt
}
}
Expand All @@ -49,7 +49,7 @@ object IRTranslator {
}
}

@record class IRTranslator(val threeAddressCode: B, val th: TypeHierarchy) {
@record class IRTranslator(val threeAddressCode: B, val undeclare: B, val mergeDecls: B, val th: TypeHierarchy) {

var methodContext: IR.MethodContext = IR.MethodContext.empty
var _freshTemp: Z = 0
Expand Down Expand Up @@ -95,14 +95,14 @@ object IRTranslator {
var grounds = ISZ[IR.Stmt.Ground]()
var decls = ISZ[IR.Stmt.Decl]()

def mergeDecls(stmts: ISZ[IR.Stmt.Ground]): ISZ[IR.Stmt.Ground] = {
def mergeMultipleDecls(stmts: ISZ[IR.Stmt.Ground]): ISZ[IR.Stmt.Ground] = {
var r = ISZ[IR.Stmt.Ground]()
var i = 0
while (i < stmts.size) {
stmts(i) match {
case stmt: IR.Stmt.Decl.Multiple =>
var j = i
var mdecls = ISZ[IR.Stmt.Decl.Ground]()
var mdecls = ISZ[IR.Stmt.Decl.Single]()
while (j < stmts.size && stmts(j).isInstanceOf[IR.Stmt.Decl.Multiple] && stmts(j).asInstanceOf[IR.Stmt.Decl.Multiple].undecl == stmt.undecl) {
mdecls = mdecls ++ stmts(j).asInstanceOf[IR.Stmt.Decl.Multiple].decls
j = j + 1
Expand All @@ -118,7 +118,7 @@ object IRTranslator {
}

@pure def basicBlock(label:Z, stmts: ISZ[IR.Stmt.Ground], jump: IR.Jump): IR.BasicBlock = {
return IR.BasicBlock(label, mergeDecls(stmts), jump)
return if (this.mergeDecls) IR.BasicBlock(label, stmts, jump) else IR.BasicBlock(label, stmts, jump)
}

def stmtToBasic(label: Z, stmt: IR.Stmt): Option[Z] = {
Expand All @@ -135,8 +135,10 @@ object IRTranslator {
Some(IR.Exp.LocalVarRef(F, methodContext, "Res", exp.tipe, exp.pos))
case _ => None()
}
for (d <- decls) {
grounds = grounds :+ d.undeclare
if (undeclare) {
for (d <- decls) {
grounds = grounds :+ d.undeclare
}
}
blocks = blocks :+ basicBlock(label, grounds, IR.Jump.Return(expOpt, stmt.pos))
grounds = ISZ()
Expand Down Expand Up @@ -199,8 +201,10 @@ object IRTranslator {
case _ => return None()
}
}
for (d <- decls) {
grounds = grounds :+ d.undeclare
if (undeclare) {
for (d <- decls) {
grounds = grounds :+ d.undeclare
}
}
decls = oldDecls
return Some(l)
Expand All @@ -211,7 +215,7 @@ object IRTranslator {
case _ =>
}
if (methodContext.t.ret != AST.Typed.unit) {
blocks = blocks(0 ~> blocks(0)(grounds = mergeDecls(
blocks = blocks(0 ~> blocks(0)(grounds = mergeMultipleDecls(
IR.Stmt.Decl.Multiple(F, ISZ(IR.Stmt.Decl.Local(F, F, methodContext.t.ret, "Res", pos))) +: blocks(0).grounds)))
}
return IR.Body.Basic(blocks)
Expand Down

0 comments on commit 88c6873

Please sign in to comment.