Skip to content

Commit

Permalink
Checkpointing IR.
Browse files Browse the repository at this point in the history
  • Loading branch information
robby-phd committed Feb 7, 2025
1 parent f848fd1 commit 4a7f350
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 196 deletions.
32 changes: 9 additions & 23 deletions ast/shared/src/main/scala/org/sireum/lang/ast/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ object IR {

@datatype trait Ground extends Stmt

@datatype class Expr(val exp: Exp.Apply) extends Ground {
@strictpure def pos: Position = exp.pos
@strictpure def prettyST: ST = exp.prettyST
}

@datatype trait Assign extends Ground {
@strictpure def rhs: Exp
}
Expand Down Expand Up @@ -301,25 +306,6 @@ object IR {
@strictpure def prettyST: ST = st"${if (undecl) "un" else ""}decl $$$n: $tipe"
}

@datatype class Multiple(val undecl: B, val decls: ISZ[Single]) extends Decl {
@strictpure def pos: Position = decls(0).pos.to(decls(decls.size - 1).pos)
@strictpure def undeclare: Decl = {
val thiz = this
thiz(undecl = T, decls = for (d <- decls) yield d.undeclare)
}
@pure def prettyST: ST = {
var ds = ISZ[ST]()
for (d <- decls) {
d match {
case d: Local => ds = ds :+ (if (undecl) st"${d.id}" else st"${d.id}: ${d.tipe}")
case d: Temp => ds = ds :+ (if (undecl) st"$$${d.n}" else st"$$${d.n}: ${d.tipe}")
}
}
val r = st"${if (undecl) "un" else ""}decls ${(ds, ", ")}"
return r
}
}

}

}
Expand Down Expand Up @@ -404,13 +390,13 @@ object IR {
}
}

@datatype class Program(val globals: ISZ[Global], val procedures: ISZ[Procedure], val stmts: ISZ[Stmt]) {
@datatype class Program(val threeAddressCode: B,
val globals: ISZ[Global],
val procedures: ISZ[Procedure]) {
@strictpure def prettyST: ST =
st"""${(for (g <- globals) yield g.prettyST, "\n")}
|
|${(for (p <- procedures) yield p.prettyST, "\n\n")}
|
|${(for (stmt <- stmts) yield stmt.prettyST, "\n")}"""
|${(for (p <- procedures) yield p.prettyST, "\n\n")}"""
@pure override def string: String = {
return prettyST.render
}
Expand Down
136 changes: 76 additions & 60 deletions ast/shared/src/main/scala/org/sireum/lang/ast/IRTransformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ object IRTransformer {

@pure def preIRStmt(ctx: Context, o: IR.Stmt): PreResult[Context, IR.Stmt] = {
o match {
case o: IR.Stmt.Expr =>
val r: PreResult[Context, IR.Stmt] = preIRStmtExpr(ctx, o) match {
case PreResult(preCtx, continu, Some(r: IR.Stmt)) => PreResult(preCtx, continu, Some[IR.Stmt](r))
case PreResult(_, _, Some(_)) => halt("Can only produce object of type IR.Stmt")
case PreResult(preCtx, continu, _) => PreResult(preCtx, continu, None[IR.Stmt]())
}
return r
case o: IR.Stmt.Assign.Local =>
val r: PreResult[Context, IR.Stmt] = preIRStmtAssignLocal(ctx, o) match {
case PreResult(preCtx, continu, Some(r: IR.Stmt)) => PreResult(preCtx, continu, Some[IR.Stmt](r))
Expand Down Expand Up @@ -204,18 +211,12 @@ object IRTransformer {
case PreResult(preCtx, continu, _) => PreResult(preCtx, continu, None[IR.Stmt]())
}
return r
case o: IR.Stmt.Decl.Multiple =>
val r: PreResult[Context, IR.Stmt] = preIRStmtDeclMultiple(ctx, o) match {
case PreResult(preCtx, continu, Some(r: IR.Stmt)) => PreResult(preCtx, continu, Some[IR.Stmt](r))
case PreResult(_, _, Some(_)) => halt("Can only produce object of type IR.Stmt")
case PreResult(preCtx, continu, _) => PreResult(preCtx, continu, None[IR.Stmt]())
}
return r
}
}

@pure def preIRStmtGround(ctx: Context, o: IR.Stmt.Ground): PreResult[Context, IR.Stmt.Ground] = {
o match {
case o: IR.Stmt.Expr => return preIRStmtExpr(ctx, o)
case o: IR.Stmt.Assign.Local =>
val r: PreResult[Context, IR.Stmt.Ground] = preIRStmtAssignLocal(ctx, o) match {
case PreResult(preCtx, continu, Some(r: IR.Stmt.Ground)) => PreResult(preCtx, continu, Some[IR.Stmt.Ground](r))
Expand Down Expand Up @@ -265,16 +266,13 @@ object IRTransformer {
case PreResult(preCtx, continu, _) => PreResult(preCtx, continu, None[IR.Stmt.Ground]())
}
return r
case o: IR.Stmt.Decl.Multiple =>
val r: PreResult[Context, IR.Stmt.Ground] = preIRStmtDeclMultiple(ctx, o) match {
case PreResult(preCtx, continu, Some(r: IR.Stmt.Ground)) => PreResult(preCtx, continu, Some[IR.Stmt.Ground](r))
case PreResult(_, _, Some(_)) => halt("Can only produce object of type IR.Stmt.Ground")
case PreResult(preCtx, continu, _) => PreResult(preCtx, continu, None[IR.Stmt.Ground]())
}
return r
}
}

@pure def preIRStmtExpr(ctx: Context, o: IR.Stmt.Expr): PreResult[Context, IR.Stmt.Ground] = {
return PreResult(ctx, T, None())
}

@pure def preIRStmtAssign(ctx: Context, o: IR.Stmt.Assign): PreResult[Context, IR.Stmt.Assign] = {
o match {
case o: IR.Stmt.Assign.Local => return preIRStmtAssignLocal(ctx, o)
Expand Down Expand Up @@ -337,7 +335,6 @@ object IRTransformer {
case PreResult(preCtx, continu, _) => PreResult(preCtx, continu, None[IR.Stmt.Decl]())
}
return r
case o: IR.Stmt.Decl.Multiple => return preIRStmtDeclMultiple(ctx, o)
}
}

Expand All @@ -356,10 +353,6 @@ object IRTransformer {
return PreResult(ctx, T, None())
}

@pure def preIRStmtDeclMultiple(ctx: Context, o: IR.Stmt.Decl.Multiple): PreResult[Context, IR.Stmt.Decl] = {
return PreResult(ctx, T, None())
}

@pure def preIRJump(ctx: Context, o: IR.Jump): PreResult[Context, IR.Jump] = {
o match {
case o: IR.Jump.Goto => return preIRJumpGoto(ctx, o)
Expand Down Expand Up @@ -582,6 +575,13 @@ object IRTransformer {

@pure def postIRStmt(ctx: Context, o: IR.Stmt): TPostResult[Context, IR.Stmt] = {
o match {
case o: IR.Stmt.Expr =>
val r: TPostResult[Context, IR.Stmt] = postIRStmtExpr(ctx, o) match {
case TPostResult(postCtx, Some(result: IR.Stmt)) => TPostResult(postCtx, Some[IR.Stmt](result))
case TPostResult(_, Some(_)) => halt("Can only produce object of type IR.Stmt")
case TPostResult(postCtx, _) => TPostResult(postCtx, None[IR.Stmt]())
}
return r
case o: IR.Stmt.Assign.Local =>
val r: TPostResult[Context, IR.Stmt] = postIRStmtAssignLocal(ctx, o) match {
case TPostResult(postCtx, Some(result: IR.Stmt)) => TPostResult(postCtx, Some[IR.Stmt](result))
Expand Down Expand Up @@ -635,18 +635,12 @@ object IRTransformer {
case TPostResult(postCtx, _) => TPostResult(postCtx, None[IR.Stmt]())
}
return r
case o: IR.Stmt.Decl.Multiple =>
val r: TPostResult[Context, IR.Stmt] = postIRStmtDeclMultiple(ctx, o) match {
case TPostResult(postCtx, Some(result: IR.Stmt)) => TPostResult(postCtx, Some[IR.Stmt](result))
case TPostResult(_, Some(_)) => halt("Can only produce object of type IR.Stmt")
case TPostResult(postCtx, _) => TPostResult(postCtx, None[IR.Stmt]())
}
return r
}
}

@pure def postIRStmtGround(ctx: Context, o: IR.Stmt.Ground): TPostResult[Context, IR.Stmt.Ground] = {
o match {
case o: IR.Stmt.Expr => return postIRStmtExpr(ctx, o)
case o: IR.Stmt.Assign.Local =>
val r: TPostResult[Context, IR.Stmt.Ground] = postIRStmtAssignLocal(ctx, o) match {
case TPostResult(postCtx, Some(result: IR.Stmt.Ground)) => TPostResult(postCtx, Some[IR.Stmt.Ground](result))
Expand Down Expand Up @@ -696,16 +690,13 @@ object IRTransformer {
case TPostResult(postCtx, _) => TPostResult(postCtx, None[IR.Stmt.Ground]())
}
return r
case o: IR.Stmt.Decl.Multiple =>
val r: TPostResult[Context, IR.Stmt.Ground] = postIRStmtDeclMultiple(ctx, o) match {
case TPostResult(postCtx, Some(result: IR.Stmt.Ground)) => TPostResult(postCtx, Some[IR.Stmt.Ground](result))
case TPostResult(_, Some(_)) => halt("Can only produce object of type IR.Stmt.Ground")
case TPostResult(postCtx, _) => TPostResult(postCtx, None[IR.Stmt.Ground]())
}
return r
}
}

@pure def postIRStmtExpr(ctx: Context, o: IR.Stmt.Expr): TPostResult[Context, IR.Stmt.Ground] = {
return TPostResult(ctx, None())
}

@pure def postIRStmtAssign(ctx: Context, o: IR.Stmt.Assign): TPostResult[Context, IR.Stmt.Assign] = {
o match {
case o: IR.Stmt.Assign.Local => return postIRStmtAssignLocal(ctx, o)
Expand Down Expand Up @@ -768,7 +759,6 @@ object IRTransformer {
case TPostResult(postCtx, _) => TPostResult(postCtx, None[IR.Stmt.Decl]())
}
return r
case o: IR.Stmt.Decl.Multiple => return postIRStmtDeclMultiple(ctx, o)
}
}

Expand All @@ -787,10 +777,6 @@ object IRTransformer {
return TPostResult(ctx, None())
}

@pure def postIRStmtDeclMultiple(ctx: Context, o: IR.Stmt.Decl.Multiple): TPostResult[Context, IR.Stmt.Decl] = {
return TPostResult(ctx, None())
}

@pure def postIRJump(ctx: Context, o: IR.Jump): TPostResult[Context, IR.Jump] = {
o match {
case o: IR.Jump.Goto => return postIRJumpGoto(ctx, o)
Expand Down Expand Up @@ -1127,6 +1113,12 @@ import IRTransformer._
val o2: IR.Stmt = preR.resultOpt.getOrElse(o)
val hasChanged: B = preR.resultOpt.nonEmpty
val rOpt: TPostResult[Context, IR.Stmt] = o2 match {
case o2: IR.Stmt.Expr =>
val r0: TPostResult[Context, IR.Exp.Apply] = transformIRExpApply(preR.ctx, o2.exp)
if (hasChanged || r0.resultOpt.nonEmpty)
TPostResult(r0.ctx, Some(o2(exp = r0.resultOpt.getOrElse(o2.exp))))
else
TPostResult(r0.ctx, None())
case o2: IR.Stmt.Assign.Local =>
val r0: TPostResult[Context, IR.MethodContext] = transformIRMethodContext(preR.ctx, o2.context)
val r1: TPostResult[Context, IR.Exp] = transformIRExp(r0.ctx, o2.rhs)
Expand Down Expand Up @@ -1203,12 +1195,6 @@ import IRTransformer._
TPostResult(r0.ctx, Some(o2(tipe = r0.resultOpt.getOrElse(o2.tipe))))
else
TPostResult(r0.ctx, None())
case o2: IR.Stmt.Decl.Multiple =>
val r0: TPostResult[Context, IS[Z, IR.Stmt.Decl.Single]] = transformISZ(preR.ctx, o2.decls, transformIRStmtDeclSingle _)
if (hasChanged || r0.resultOpt.nonEmpty)
TPostResult(r0.ctx, Some(o2(decls = r0.resultOpt.getOrElse(o2.decls))))
else
TPostResult(r0.ctx, None())
}
rOpt
} else if (preR.resultOpt.nonEmpty) {
Expand All @@ -1234,6 +1220,12 @@ import IRTransformer._
val o2: IR.Stmt.Ground = preR.resultOpt.getOrElse(o)
val hasChanged: B = preR.resultOpt.nonEmpty
val rOpt: TPostResult[Context, IR.Stmt.Ground] = o2 match {
case o2: IR.Stmt.Expr =>
val r0: TPostResult[Context, IR.Exp.Apply] = transformIRExpApply(preR.ctx, o2.exp)
if (hasChanged || r0.resultOpt.nonEmpty)
TPostResult(r0.ctx, Some(o2(exp = r0.resultOpt.getOrElse(o2.exp))))
else
TPostResult(r0.ctx, None())
case o2: IR.Stmt.Assign.Local =>
val r0: TPostResult[Context, IR.MethodContext] = transformIRMethodContext(preR.ctx, o2.context)
val r1: TPostResult[Context, IR.Exp] = transformIRExp(r0.ctx, o2.rhs)
Expand Down Expand Up @@ -1282,12 +1274,6 @@ import IRTransformer._
TPostResult(r0.ctx, Some(o2(tipe = r0.resultOpt.getOrElse(o2.tipe))))
else
TPostResult(r0.ctx, None())
case o2: IR.Stmt.Decl.Multiple =>
val r0: TPostResult[Context, IS[Z, IR.Stmt.Decl.Single]] = transformISZ(preR.ctx, o2.decls, transformIRStmtDeclSingle _)
if (hasChanged || r0.resultOpt.nonEmpty)
TPostResult(r0.ctx, Some(o2(decls = r0.resultOpt.getOrElse(o2.decls))))
else
TPostResult(r0.ctx, None())
}
rOpt
} else if (preR.resultOpt.nonEmpty) {
Expand Down Expand Up @@ -1386,12 +1372,6 @@ import IRTransformer._
TPostResult(r0.ctx, Some(o2(tipe = r0.resultOpt.getOrElse(o2.tipe))))
else
TPostResult(r0.ctx, None())
case o2: IR.Stmt.Decl.Multiple =>
val r0: TPostResult[Context, IS[Z, IR.Stmt.Decl.Single]] = transformISZ(preR.ctx, o2.decls, transformIRStmtDeclSingle _)
if (hasChanged || r0.resultOpt.nonEmpty)
TPostResult(r0.ctx, Some(o2(decls = r0.resultOpt.getOrElse(o2.decls))))
else
TPostResult(r0.ctx, None())
}
rOpt
} else if (preR.resultOpt.nonEmpty) {
Expand Down Expand Up @@ -1617,11 +1597,10 @@ import IRTransformer._
val hasChanged: B = preR.resultOpt.nonEmpty
val r0: TPostResult[Context, IS[Z, IR.Global]] = transformISZ(preR.ctx, o2.globals, transformIRGlobal _)
val r1: TPostResult[Context, IS[Z, IR.Procedure]] = transformISZ(r0.ctx, o2.procedures, transformIRProcedure _)
val r2: TPostResult[Context, IS[Z, IR.Stmt]] = transformISZ(r1.ctx, o2.stmts, transformIRStmt _)
if (hasChanged || r0.resultOpt.nonEmpty || r1.resultOpt.nonEmpty || r2.resultOpt.nonEmpty)
TPostResult(r2.ctx, Some(o2(globals = r0.resultOpt.getOrElse(o2.globals), procedures = r1.resultOpt.getOrElse(o2.procedures), stmts = r2.resultOpt.getOrElse(o2.stmts))))
if (hasChanged || r0.resultOpt.nonEmpty || r1.resultOpt.nonEmpty)
TPostResult(r1.ctx, Some(o2(globals = r0.resultOpt.getOrElse(o2.globals), procedures = r1.resultOpt.getOrElse(o2.procedures))))
else
TPostResult(r2.ctx, None())
TPostResult(r1.ctx, None())
} else if (preR.resultOpt.nonEmpty) {
TPostResult(preR.ctx, Some(preR.resultOpt.getOrElse(o)))
} else {
Expand Down Expand Up @@ -1766,6 +1745,43 @@ import IRTransformer._
}
}

@pure def transformIRExpApply(ctx: Context, o: IR.Exp.Apply): TPostResult[Context, IR.Exp.Apply] = {
val preR: PreResult[Context, IR.Exp.Apply] = pp.preIRExpApply(ctx, o) match {
case PreResult(preCtx, continu, Some(r: IR.Exp.Apply)) => PreResult(preCtx, continu, Some[IR.Exp.Apply](r))
case PreResult(_, _, Some(_)) => halt("Can only produce object of type IR.Exp.Apply")
case PreResult(preCtx, continu, _) => PreResult(preCtx, continu, None[IR.Exp.Apply]())
}
val r: TPostResult[Context, IR.Exp.Apply] = if (preR.continu) {
val o2: IR.Exp.Apply = preR.resultOpt.getOrElse(o)
val hasChanged: B = preR.resultOpt.nonEmpty
val r0: TPostResult[Context, IS[Z, IR.Exp]] = transformISZ(preR.ctx, o2.args, transformIRExp _)
val r1: TPostResult[Context, Typed.Fun] = transformTypedFun(r0.ctx, o2.methodType)
val r2: TPostResult[Context, Typed] = transformTyped(r1.ctx, o2.tipe)
if (hasChanged || r0.resultOpt.nonEmpty || r1.resultOpt.nonEmpty || r2.resultOpt.nonEmpty)
TPostResult(r2.ctx, Some(o2(args = r0.resultOpt.getOrElse(o2.args), methodType = r1.resultOpt.getOrElse(o2.methodType), tipe = r2.resultOpt.getOrElse(o2.tipe))))
else
TPostResult(r2.ctx, None())
} else if (preR.resultOpt.nonEmpty) {
TPostResult(preR.ctx, Some(preR.resultOpt.getOrElse(o)))
} else {
TPostResult(preR.ctx, None())
}
val hasChanged: B = r.resultOpt.nonEmpty
val o2: IR.Exp.Apply = r.resultOpt.getOrElse(o)
val postR: TPostResult[Context, IR.Exp.Apply] = pp.postIRExpApply(r.ctx, o2) match {
case TPostResult(postCtx, Some(result: IR.Exp.Apply)) => TPostResult(postCtx, Some[IR.Exp.Apply](result))
case TPostResult(_, Some(_)) => halt("Can only produce object of type IR.Exp.Apply")
case TPostResult(postCtx, _) => TPostResult(postCtx, None[IR.Exp.Apply]())
}
if (postR.resultOpt.nonEmpty) {
return postR
} else if (hasChanged) {
return TPostResult(postR.ctx, Some(o2))
} else {
return TPostResult(postR.ctx, None())
}
}

@pure def transformTypedName(ctx: Context, o: Typed.Name): TPostResult[Context, Typed.Name] = {
val preR: PreResult[Context, Typed.Name] = pp.preTypedName(ctx, o) match {
case PreResult(preCtx, continu, Some(r: Typed.Name)) => PreResult(preCtx, continu, Some[Typed.Name](r))
Expand Down
Loading

0 comments on commit 4a7f350

Please sign in to comment.