Skip to content

Commit

Permalink
fixed regression in the handling of try/catch statement with singleton
Browse files Browse the repository at this point in the history
cases
  • Loading branch information
rssh committed Oct 1, 2023
1 parent 190a31a commit 5ff3f7f
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 22 deletions.
17 changes: 11 additions & 6 deletions compiler-plugin/src/main/scala/cps/plugin/forest/CpsTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,10 @@ case class MapCpsTree(
mapFun.body.originType

override def castOriginType(ntpe: Type)(using Context, CpsTopLevelContext): CpsTree = {
copy(mapFun = mapFun.copy(body = mapFun.body.castOriginType(ntpe)))
if (ntpe =:= originType) then
this
else
copy(mapFun = mapFun.copy(body = mapFun.body.castOriginType(ntpe)))
}

override def internalAsyncKind(using Context, CpsTopLevelContext) =
Expand Down Expand Up @@ -736,7 +739,9 @@ case class LambdaCpsTree(
}

override def castOriginType(ntpe: Type)(using Context, CpsTopLevelContext): CpsTree =
if (defn.isFunctionType(ntpe) || defn.isContextFunctionType(ntpe)) then
if (ntpe =:= originType) then
this
else if (defn.isFunctionType(ntpe) || defn.isContextFunctionType(ntpe)) then
ntpe match
case AppliedType(ntpfun, ntpargs) =>
// TODO: check that params are the same.
Expand Down Expand Up @@ -853,9 +858,6 @@ case class LambdaCpsTree(
tctx.cpsNonDirectContext
)
)
if (tss.isEmpty) then
println(s"tss.isEmpty, originParans=${originParams}, tss=${tss}")
println(s"originType=${originType.show}, originClosureType=${originClosureType.show}")
TransformUtil.substParams(nBody, originParams, tss.head).changeOwner(cpsBody.owner, meth)
})
CpsTree.pure(origin,owner,closure)
Expand Down Expand Up @@ -921,7 +923,10 @@ case class OpaqueAsyncLambdaTermCpsTree(
override def normalizeAsyncKind(using Context, CpsTopLevelContext) = this

override def castOriginType(ntpe: Type)(using Context, CpsTopLevelContext): CpsTree = {
typed(Typed(origin,TypeTree(ntpe)))
if (origin.tpe =:= ntpe) then
this
else
typed(Typed(origin,TypeTree(ntpe)))
}

override def unpure(using Context, CpsTopLevelContext): Option[Tree] = None
Expand Down
35 changes: 22 additions & 13 deletions compiler-plugin/src/main/scala/cps/plugin/forest/TryTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,19 @@ object TryTransform {
val newTree = Try(cpsExpr.unpure.get, cases.unpureCaseDefs, EmptyTree)
CpsTree.pure(origin, owner, newTree)
case (AsyncKind.Sync, AsyncKind.Async(ik)) =>
val retval = generateWithAsyncCasesWithTry(origin, owner, cpsExpr, cases, casesAsyncKind, nesting)
Log.trace(s"TryTransform:applyNoFinalizer return ${retval.show}", nesting)
retval
val unwrapedType = origin.tpe.widenUnion
val castedCpsExpr = cpsExpr.castOriginType(unwrapedType)
val retval = generateWithAsyncCasesWithTry(origin, owner, castedCpsExpr, cases, casesAsyncKind, nesting)
Log.trace(s"TryTransform:applyNoFinalizer return ${retval.show}", nesting)
retval
case _ =>
val targetKind = cpsExpr.asyncKind unify casesAsyncKind match
val targetKind = cpsExpr.asyncKind unify casesAsyncKind match
case Left((k1,k2)) =>
throw CpsTransformException("Incompatible async kinds of try expr and cases", origin.srcPos)
case Right(k) => k
generateWithAsyncCases(origin, owner, cpsExpr, cases, targetKind, nesting)
val unwrapedType = origin.tpe.widenUnion
val castedCpsExpr = cpsExpr.castOriginType(unwrapedType)
generateWithAsyncCases(origin, owner, castedCpsExpr, cases, targetKind, nesting)
}
}

Expand Down Expand Up @@ -122,7 +126,8 @@ object TryTransform {
case Left(p) =>
throw CpsTransformException(s"Non-compatible async shape in try exppression and handlers ${p}", origin.srcPos)
case Right(k) =>
val expr1 = generateWithAsyncCases(origin, owner, cpsExpr, cases, k, nesting)
val castedCpsExpr = cpsExpr.castOriginType(cpsExpr.originType.widenUnion)
val expr1 = generateWithAsyncCases(origin, owner, castedCpsExpr, cases, k, nesting)
val expr2 = generateWithAsyncFinalizer(origin, owner, expr1, cpsFinalizer)
expr2
}
Expand All @@ -134,10 +139,13 @@ object TryTransform {
exprCpsTree: CpsTree,
finalizerCpsTree: CpsTree,
)(using Context, CpsTopLevelContext): CpsTree = {
val castedExprCpsTree = if (!(origin.tpe.widenUnion =:= exprCpsTree.originType)) then {
exprCpsTree.castOriginType(origin.tpe.widenUnion)
} else exprCpsTree
generateWithAsyncFinalizerTree( origin, owner,
wrapPureCpsTreeInTry(origin, exprCpsTree),
exprCpsTree.originType.widen,
exprCpsTree.asyncKind,
wrapPureCpsTreeInTry(origin, castedExprCpsTree),
castedExprCpsTree.originType.widen,
castedExprCpsTree.asyncKind,
finalizerCpsTree
)
}
Expand Down Expand Up @@ -194,7 +202,7 @@ object TryTransform {
val mt = MethodType(List("ex".toTermName), List(defn.ThrowableType), lambdaResultType)
val lambdaSym = newAnonFun(owner,mt)
val lambda = Closure(lambdaSym, tss => {
val defaultCase = generateDefaultCaseDef(origin, origin.tpe.widen)(using summon[Context].withOwner(lambdaSym), summon[CpsTopLevelContext])
val defaultCase = generateDefaultCaseDef(origin, unwrappedTpe)(using summon[Context].withOwner(lambdaSym), summon[CpsTopLevelContext])
Match(tss.head.head, transformedCases :+ defaultCase).changeOwner(owner,lambdaSym)
}
)
Expand All @@ -208,13 +216,14 @@ object TryTransform {
),
List(lambda)
).withSpan(origin.span)
val typedOrigin = if (unwrappedTpe =:= origin.tpe.widen) then origin else Typed(origin, TypeTree(unwrappedTpe)).withSpan(origin.span)
targetKind match
case AsyncKind.Sync =>
CpsTree.impure(origin,owner,tree,AsyncKind.Sync)
CpsTree.impure(typedOrigin,owner,tree,AsyncKind.Sync)
case AsyncKind.Async(ik) =>
CpsTree.impure(origin,owner,tree,ik)
CpsTree.impure(typedOrigin,owner,tree,ik)
case AsyncKind.AsyncLambda(bodyKind) =>
CpsTree.opaqueAsyncLambda(origin,owner,tree,bodyKind)
CpsTree.opaqueAsyncLambda(typedOrigin,owner,tree,bodyKind)
}


Expand Down
4 changes: 2 additions & 2 deletions compiler-plugin/src/test/scala/cc/DotcInvocations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class DotcInvocations(silent: Boolean = false) {
}


case class TestRun(inputDir: String, mainClass: String, expectedOutput: String = "Ok\n")
case class TestRun(inputDir: String, mainClass: String, expectedOutput: String = "Ok\n", extraDotcArgs:List[String] = List.empty)



Expand Down Expand Up @@ -194,7 +194,7 @@ object DotcInvocations {
): Unit = {
for(r <- runs) {
if (selection.matches(r.inputDir)) {
compileAndRunFilesInDirAndCheckResult(r.inputDir,r.mainClass,r.expectedOutput,dotcArgs)
compileAndRunFilesInDirAndCheckResult(r.inputDir,r.mainClass,r.expectedOutput,dotcArgs.copy(extraDotcArgs = dotcArgs.extraDotcArgs ++ r.extraDotcArgs))
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler-plugin/src/test/scala/cc/Test9Try.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class Test9Try {

@Test
def testCompileAndRun(): Unit = {
DotcInvocations.checkRuns(selection = (".*".r))(
DotcInvocations.checkRuns(selection = (".*".r),dotcArgs = DotcInvocationArgs(extraDotcArgs = List()))(
TestRun("testdata/set9Try/m1", "cpstest.Test9m1", "Right(10)\n"),
TestRun("testdata/set9Try/m2", "cpstest.Test9m2"),
TestRun("testdata/set9Try/m3", "cpstest.Test9m3"),
Expand Down
2 changes: 2 additions & 0 deletions compiler-plugin/testdata/set9Try/m6_100/Test9m6_100.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package cpstest
import scala.annotation.experimental
import cps.*
import cps.monads.{*,given}
import cps.plugin.annotation.CpsDebugLevel

import testUtil.*

@experimental
@CpsDebugLevel(20)
object Test9m6_100 {

var finallyWasRun=false
Expand Down

0 comments on commit 5ff3f7f

Please sign in to comment.