Skip to content

Commit

Permalink
Checkpointing IR.
Browse files Browse the repository at this point in the history
  • Loading branch information
robby-phd committed Feb 4, 2025
1 parent ecf7abe commit c996b39
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 44 deletions.
26 changes: 20 additions & 6 deletions ast/shared/src/main/scala/org/sireum/lang/ast/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -279,29 +279,43 @@ object IR {

object Decl {

@datatype class Local(val undecl: B, val isVal: B, val tipe: Typed, val id: String, val pos: Position) extends Decl {
@strictpure def undeclare: Decl = {
@datatype trait Ground extends Decl {
@strictpure def undeclare: Ground
}

@datatype class Local(val undecl: B, val isVal: B, val tipe: Typed, val id: String, val pos: Position) extends Ground {
@strictpure def undeclare: Ground = {
val thiz = this
thiz(undecl = T)
}
@strictpure def prettyST: ST = st"${if (undecl) "de" else ""}local $id: $tipe"
}

@datatype class Register(val undecl: B, val tipe: Typed, val n: Z, val pos: Position) extends Decl {
@strictpure def undeclare: Decl = {
@datatype class Register(val undecl: B, val tipe: Typed, val n: Z, val pos: Position) extends Ground {
@strictpure def undeclare: Ground = {
val thiz = this
thiz(undecl = T)
}
@strictpure def prettyST: ST = st"${if (undecl) "de" else ""}register $$$n: $tipe"
}

@datatype class Multiple(val undecl: B, val decls: ISZ[Decl]) extends Decl {
@datatype class Multiple(val undecl: B, val decls: ISZ[Ground]) extends Decl {
@strictpure def pos: Position = decls(0).pos.to(decls(decls.size - 1).pos)
@strictpure def undeclare: Decl = {
val thiz = this
thiz(undecl = T, decls = for (d <- decls) yield d.undeclare)
}
@strictpure def prettyST: ST = st"${(for (d <- decls) yield d.prettyST, "\n")}"
@pure def prettyST: ST = {
var ds = ISZ[ST]()
for (d <- decls) {
d match {
case d: Local => ds = ds :+ (if (undecl) st"${d.id}" else st"${d.id}: ${d.tipe}")
case d: Register => ds = ds :+ (if (undecl) st"$$${d.n}" else st"$$${d.n}: ${d.tipe}")
}
}
val r = st"${if (undecl) "un" else ""}decls ${(ds, ", ")}"
return r
}
}

}
Expand Down
93 changes: 84 additions & 9 deletions ast/shared/src/main/scala/org/sireum/lang/ast/IRTransformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -322,18 +322,37 @@ object IRTransformer {
}

@pure def preIRStmtDecl(ctx: Context, o: IR.Stmt.Decl): PreResult[Context, IR.Stmt.Decl] = {
o match {
case o: IR.Stmt.Decl.Local =>
val r: PreResult[Context, IR.Stmt.Decl] = preIRStmtDeclLocal(ctx, o) match {
case PreResult(preCtx, continu, Some(r: IR.Stmt.Decl)) => PreResult(preCtx, continu, Some[IR.Stmt.Decl](r))
case PreResult(_, _, Some(_)) => halt("Can only produce object of type IR.Stmt.Decl")
case PreResult(preCtx, continu, _) => PreResult(preCtx, continu, None[IR.Stmt.Decl]())
}
return r
case o: IR.Stmt.Decl.Register =>
val r: PreResult[Context, IR.Stmt.Decl] = preIRStmtDeclRegister(ctx, o) match {
case PreResult(preCtx, continu, Some(r: IR.Stmt.Decl)) => PreResult(preCtx, continu, Some[IR.Stmt.Decl](r))
case PreResult(_, _, Some(_)) => halt("Can only produce object of type IR.Stmt.Decl")
case PreResult(preCtx, continu, _) => PreResult(preCtx, continu, None[IR.Stmt.Decl]())
}
return r
case o: IR.Stmt.Decl.Multiple => return preIRStmtDeclMultiple(ctx, o)
}
}

@pure def preIRStmtDeclGround(ctx: Context, o: IR.Stmt.Decl.Ground): PreResult[Context, IR.Stmt.Decl.Ground] = {
o match {
case o: IR.Stmt.Decl.Local => return preIRStmtDeclLocal(ctx, o)
case o: IR.Stmt.Decl.Register => return preIRStmtDeclRegister(ctx, o)
case o: IR.Stmt.Decl.Multiple => return preIRStmtDeclMultiple(ctx, o)
}
}

@pure def preIRStmtDeclLocal(ctx: Context, o: IR.Stmt.Decl.Local): PreResult[Context, IR.Stmt.Decl] = {
@pure def preIRStmtDeclLocal(ctx: Context, o: IR.Stmt.Decl.Local): PreResult[Context, IR.Stmt.Decl.Ground] = {
return PreResult(ctx, T, None())
}

@pure def preIRStmtDeclRegister(ctx: Context, o: IR.Stmt.Decl.Register): PreResult[Context, IR.Stmt.Decl] = {
@pure def preIRStmtDeclRegister(ctx: Context, o: IR.Stmt.Decl.Register): PreResult[Context, IR.Stmt.Decl.Ground] = {
return PreResult(ctx, T, None())
}

Expand Down Expand Up @@ -734,18 +753,37 @@ object IRTransformer {
}

@pure def postIRStmtDecl(ctx: Context, o: IR.Stmt.Decl): TPostResult[Context, IR.Stmt.Decl] = {
o match {
case o: IR.Stmt.Decl.Local =>
val r: TPostResult[Context, IR.Stmt.Decl] = postIRStmtDeclLocal(ctx, o) match {
case TPostResult(postCtx, Some(result: IR.Stmt.Decl)) => TPostResult(postCtx, Some[IR.Stmt.Decl](result))
case TPostResult(_, Some(_)) => halt("Can only produce object of type IR.Stmt.Decl")
case TPostResult(postCtx, _) => TPostResult(postCtx, None[IR.Stmt.Decl]())
}
return r
case o: IR.Stmt.Decl.Register =>
val r: TPostResult[Context, IR.Stmt.Decl] = postIRStmtDeclRegister(ctx, o) match {
case TPostResult(postCtx, Some(result: IR.Stmt.Decl)) => TPostResult(postCtx, Some[IR.Stmt.Decl](result))
case TPostResult(_, Some(_)) => halt("Can only produce object of type IR.Stmt.Decl")
case TPostResult(postCtx, _) => TPostResult(postCtx, None[IR.Stmt.Decl]())
}
return r
case o: IR.Stmt.Decl.Multiple => return postIRStmtDeclMultiple(ctx, o)
}
}

@pure def postIRStmtDeclGround(ctx: Context, o: IR.Stmt.Decl.Ground): TPostResult[Context, IR.Stmt.Decl.Ground] = {
o match {
case o: IR.Stmt.Decl.Local => return postIRStmtDeclLocal(ctx, o)
case o: IR.Stmt.Decl.Register => return postIRStmtDeclRegister(ctx, o)
case o: IR.Stmt.Decl.Multiple => return postIRStmtDeclMultiple(ctx, o)
}
}

@pure def postIRStmtDeclLocal(ctx: Context, o: IR.Stmt.Decl.Local): TPostResult[Context, IR.Stmt.Decl] = {
@pure def postIRStmtDeclLocal(ctx: Context, o: IR.Stmt.Decl.Local): TPostResult[Context, IR.Stmt.Decl.Ground] = {
return TPostResult(ctx, None())
}

@pure def postIRStmtDeclRegister(ctx: Context, o: IR.Stmt.Decl.Register): TPostResult[Context, IR.Stmt.Decl] = {
@pure def postIRStmtDeclRegister(ctx: Context, o: IR.Stmt.Decl.Register): TPostResult[Context, IR.Stmt.Decl.Ground] = {
return TPostResult(ctx, None())
}

Expand Down Expand Up @@ -1165,7 +1203,7 @@ import IRTransformer._
else
TPostResult(r0.ctx, None())
case o2: IR.Stmt.Decl.Multiple =>
val r0: TPostResult[Context, IS[Z, IR.Stmt.Decl]] = transformISZ(preR.ctx, o2.decls, transformIRStmtDecl _)
val r0: TPostResult[Context, IS[Z, IR.Stmt.Decl.Ground]] = transformISZ(preR.ctx, o2.decls, transformIRStmtDeclGround _)
if (hasChanged || r0.resultOpt.nonEmpty)
TPostResult(r0.ctx, Some(o2(decls = r0.resultOpt.getOrElse(o2.decls))))
else
Expand Down Expand Up @@ -1244,7 +1282,7 @@ import IRTransformer._
else
TPostResult(r0.ctx, None())
case o2: IR.Stmt.Decl.Multiple =>
val r0: TPostResult[Context, IS[Z, IR.Stmt.Decl]] = transformISZ(preR.ctx, o2.decls, transformIRStmtDecl _)
val r0: TPostResult[Context, IS[Z, IR.Stmt.Decl.Ground]] = transformISZ(preR.ctx, o2.decls, transformIRStmtDeclGround _)
if (hasChanged || r0.resultOpt.nonEmpty)
TPostResult(r0.ctx, Some(o2(decls = r0.resultOpt.getOrElse(o2.decls))))
else
Expand Down Expand Up @@ -1348,7 +1386,7 @@ import IRTransformer._
else
TPostResult(r0.ctx, None())
case o2: IR.Stmt.Decl.Multiple =>
val r0: TPostResult[Context, IS[Z, IR.Stmt.Decl]] = transformISZ(preR.ctx, o2.decls, transformIRStmtDecl _)
val r0: TPostResult[Context, IS[Z, IR.Stmt.Decl.Ground]] = transformISZ(preR.ctx, o2.decls, transformIRStmtDeclGround _)
if (hasChanged || r0.resultOpt.nonEmpty)
TPostResult(r0.ctx, Some(o2(decls = r0.resultOpt.getOrElse(o2.decls))))
else
Expand All @@ -1372,6 +1410,43 @@ import IRTransformer._
}
}

@pure def transformIRStmtDeclGround(ctx: Context, o: IR.Stmt.Decl.Ground): TPostResult[Context, IR.Stmt.Decl.Ground] = {
val preR: PreResult[Context, IR.Stmt.Decl.Ground] = pp.preIRStmtDeclGround(ctx, o)
val r: TPostResult[Context, IR.Stmt.Decl.Ground] = if (preR.continu) {
val o2: IR.Stmt.Decl.Ground = preR.resultOpt.getOrElse(o)
val hasChanged: B = preR.resultOpt.nonEmpty
val rOpt: TPostResult[Context, IR.Stmt.Decl.Ground] = o2 match {
case o2: IR.Stmt.Decl.Local =>
val r0: TPostResult[Context, Typed] = transformTyped(preR.ctx, o2.tipe)
if (hasChanged || r0.resultOpt.nonEmpty)
TPostResult(r0.ctx, Some(o2(tipe = r0.resultOpt.getOrElse(o2.tipe))))
else
TPostResult(r0.ctx, None())
case o2: IR.Stmt.Decl.Register =>
val r0: TPostResult[Context, Typed] = transformTyped(preR.ctx, o2.tipe)
if (hasChanged || r0.resultOpt.nonEmpty)
TPostResult(r0.ctx, Some(o2(tipe = r0.resultOpt.getOrElse(o2.tipe))))
else
TPostResult(r0.ctx, None())
}
rOpt
} else if (preR.resultOpt.nonEmpty) {
TPostResult(preR.ctx, Some(preR.resultOpt.getOrElse(o)))
} else {
TPostResult(preR.ctx, None())
}
val hasChanged: B = r.resultOpt.nonEmpty
val o2: IR.Stmt.Decl.Ground = r.resultOpt.getOrElse(o)
val postR: TPostResult[Context, IR.Stmt.Decl.Ground] = pp.postIRStmtDeclGround(r.ctx, o2)
if (postR.resultOpt.nonEmpty) {
return postR
} else if (hasChanged) {
return TPostResult(postR.ctx, Some(o2))
} else {
return TPostResult(postR.ctx, None())
}
}

@pure def transformIRJump(ctx: Context, o: IR.Jump): TPostResult[Context, IR.Jump] = {
val preR: PreResult[Context, IR.Jump] = pp.preIRJump(ctx, o)
val r: TPostResult[Context, IR.Jump] = if (preR.continu) {
Expand Down
101 changes: 88 additions & 13 deletions ast/shared/src/main/scala/org/sireum/lang/ast/MIRTransformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,13 @@ object MIRTransformer {

val PostResultIRStmtReturn: MOption[IR.Stmt] = MNone()

val PreResultIRStmtDeclLocal: PreResult[IR.Stmt.Decl] = PreResult(T, MNone())
val PreResultIRStmtDeclLocal: PreResult[IR.Stmt.Decl.Ground] = PreResult(T, MNone())

val PostResultIRStmtDeclLocal: MOption[IR.Stmt.Decl] = MNone()
val PostResultIRStmtDeclLocal: MOption[IR.Stmt.Decl.Ground] = MNone()

val PreResultIRStmtDeclRegister: PreResult[IR.Stmt.Decl] = PreResult(T, MNone())
val PreResultIRStmtDeclRegister: PreResult[IR.Stmt.Decl.Ground] = PreResult(T, MNone())

val PostResultIRStmtDeclRegister: MOption[IR.Stmt.Decl] = MNone()
val PostResultIRStmtDeclRegister: MOption[IR.Stmt.Decl.Ground] = MNone()

val PreResultIRStmtDeclMultiple: PreResult[IR.Stmt.Decl] = PreResult(T, MNone())

Expand Down Expand Up @@ -560,18 +560,37 @@ import MIRTransformer._
}

def preIRStmtDecl(o: IR.Stmt.Decl): PreResult[IR.Stmt.Decl] = {
o match {
case o: IR.Stmt.Decl.Local =>
val r: PreResult[IR.Stmt.Decl] = preIRStmtDeclLocal(o) match {
case PreResult(continu, MSome(r: IR.Stmt.Decl)) => PreResult(continu, MSome[IR.Stmt.Decl](r))
case PreResult(_, MSome(_)) => halt("Can only produce object of type IR.Stmt.Decl")
case PreResult(continu, _) => PreResult(continu, MNone[IR.Stmt.Decl]())
}
return r
case o: IR.Stmt.Decl.Register =>
val r: PreResult[IR.Stmt.Decl] = preIRStmtDeclRegister(o) match {
case PreResult(continu, MSome(r: IR.Stmt.Decl)) => PreResult(continu, MSome[IR.Stmt.Decl](r))
case PreResult(_, MSome(_)) => halt("Can only produce object of type IR.Stmt.Decl")
case PreResult(continu, _) => PreResult(continu, MNone[IR.Stmt.Decl]())
}
return r
case o: IR.Stmt.Decl.Multiple => return preIRStmtDeclMultiple(o)
}
}

def preIRStmtDeclGround(o: IR.Stmt.Decl.Ground): PreResult[IR.Stmt.Decl.Ground] = {
o match {
case o: IR.Stmt.Decl.Local => return preIRStmtDeclLocal(o)
case o: IR.Stmt.Decl.Register => return preIRStmtDeclRegister(o)
case o: IR.Stmt.Decl.Multiple => return preIRStmtDeclMultiple(o)
}
}

def preIRStmtDeclLocal(o: IR.Stmt.Decl.Local): PreResult[IR.Stmt.Decl] = {
def preIRStmtDeclLocal(o: IR.Stmt.Decl.Local): PreResult[IR.Stmt.Decl.Ground] = {
return PreResultIRStmtDeclLocal
}

def preIRStmtDeclRegister(o: IR.Stmt.Decl.Register): PreResult[IR.Stmt.Decl] = {
def preIRStmtDeclRegister(o: IR.Stmt.Decl.Register): PreResult[IR.Stmt.Decl.Ground] = {
return PreResultIRStmtDeclRegister
}

Expand Down Expand Up @@ -972,18 +991,37 @@ import MIRTransformer._
}

def postIRStmtDecl(o: IR.Stmt.Decl): MOption[IR.Stmt.Decl] = {
o match {
case o: IR.Stmt.Decl.Local =>
val r: MOption[IR.Stmt.Decl] = postIRStmtDeclLocal(o) match {
case MSome(result: IR.Stmt.Decl) => MSome[IR.Stmt.Decl](result)
case MSome(_) => halt("Can only produce object of type IR.Stmt.Decl")
case _ => MNone[IR.Stmt.Decl]()
}
return r
case o: IR.Stmt.Decl.Register =>
val r: MOption[IR.Stmt.Decl] = postIRStmtDeclRegister(o) match {
case MSome(result: IR.Stmt.Decl) => MSome[IR.Stmt.Decl](result)
case MSome(_) => halt("Can only produce object of type IR.Stmt.Decl")
case _ => MNone[IR.Stmt.Decl]()
}
return r
case o: IR.Stmt.Decl.Multiple => return postIRStmtDeclMultiple(o)
}
}

def postIRStmtDeclGround(o: IR.Stmt.Decl.Ground): MOption[IR.Stmt.Decl.Ground] = {
o match {
case o: IR.Stmt.Decl.Local => return postIRStmtDeclLocal(o)
case o: IR.Stmt.Decl.Register => return postIRStmtDeclRegister(o)
case o: IR.Stmt.Decl.Multiple => return postIRStmtDeclMultiple(o)
}
}

def postIRStmtDeclLocal(o: IR.Stmt.Decl.Local): MOption[IR.Stmt.Decl] = {
def postIRStmtDeclLocal(o: IR.Stmt.Decl.Local): MOption[IR.Stmt.Decl.Ground] = {
return PostResultIRStmtDeclLocal
}

def postIRStmtDeclRegister(o: IR.Stmt.Decl.Register): MOption[IR.Stmt.Decl] = {
def postIRStmtDeclRegister(o: IR.Stmt.Decl.Register): MOption[IR.Stmt.Decl.Ground] = {
return PostResultIRStmtDeclRegister
}

Expand Down Expand Up @@ -1365,7 +1403,7 @@ import MIRTransformer._
else
MNone()
case o2: IR.Stmt.Decl.Multiple =>
val r0: MOption[IS[Z, IR.Stmt.Decl]] = transformISZ(o2.decls, transformIRStmtDecl _)
val r0: MOption[IS[Z, IR.Stmt.Decl.Ground]] = transformISZ(o2.decls, transformIRStmtDeclGround _)
if (hasChanged || r0.nonEmpty)
MSome(o2(decls = r0.getOrElse(o2.decls)))
else
Expand Down Expand Up @@ -1444,7 +1482,7 @@ import MIRTransformer._
else
MNone()
case o2: IR.Stmt.Decl.Multiple =>
val r0: MOption[IS[Z, IR.Stmt.Decl]] = transformISZ(o2.decls, transformIRStmtDecl _)
val r0: MOption[IS[Z, IR.Stmt.Decl.Ground]] = transformISZ(o2.decls, transformIRStmtDeclGround _)
if (hasChanged || r0.nonEmpty)
MSome(o2(decls = r0.getOrElse(o2.decls)))
else
Expand Down Expand Up @@ -1548,7 +1586,7 @@ import MIRTransformer._
else
MNone()
case o2: IR.Stmt.Decl.Multiple =>
val r0: MOption[IS[Z, IR.Stmt.Decl]] = transformISZ(o2.decls, transformIRStmtDecl _)
val r0: MOption[IS[Z, IR.Stmt.Decl.Ground]] = transformISZ(o2.decls, transformIRStmtDeclGround _)
if (hasChanged || r0.nonEmpty)
MSome(o2(decls = r0.getOrElse(o2.decls)))
else
Expand All @@ -1572,6 +1610,43 @@ import MIRTransformer._
}
}

def transformIRStmtDeclGround(o: IR.Stmt.Decl.Ground): MOption[IR.Stmt.Decl.Ground] = {
val preR: PreResult[IR.Stmt.Decl.Ground] = preIRStmtDeclGround(o)
val r: MOption[IR.Stmt.Decl.Ground] = if (preR.continu) {
val o2: IR.Stmt.Decl.Ground = preR.resultOpt.getOrElse(o)
val hasChanged: B = preR.resultOpt.nonEmpty
val rOpt: MOption[IR.Stmt.Decl.Ground] = o2 match {
case o2: IR.Stmt.Decl.Local =>
val r0: MOption[Typed] = transformTyped(o2.tipe)
if (hasChanged || r0.nonEmpty)
MSome(o2(tipe = r0.getOrElse(o2.tipe)))
else
MNone()
case o2: IR.Stmt.Decl.Register =>
val r0: MOption[Typed] = transformTyped(o2.tipe)
if (hasChanged || r0.nonEmpty)
MSome(o2(tipe = r0.getOrElse(o2.tipe)))
else
MNone()
}
rOpt
} else if (preR.resultOpt.nonEmpty) {
MSome(preR.resultOpt.getOrElse(o))
} else {
MNone()
}
val hasChanged: B = r.nonEmpty
val o2: IR.Stmt.Decl.Ground = r.getOrElse(o)
val postR: MOption[IR.Stmt.Decl.Ground] = postIRStmtDeclGround(o2)
if (postR.nonEmpty) {
return postR
} else if (hasChanged) {
return MSome(o2)
} else {
return MNone()
}
}

def transformIRJump(o: IR.Jump): MOption[IR.Jump] = {
val preR: PreResult[IR.Jump] = preIRJump(o)
val r: MOption[IR.Jump] = if (preR.continu) {
Expand Down
Loading

0 comments on commit c996b39

Please sign in to comment.