From 5ff3f7fafc96a5938d9807baf1b4303926ea5127 Mon Sep 17 00:00:00 2001 From: Ruslan Shevchenko Date: Sun, 1 Oct 2023 20:51:00 +0300 Subject: [PATCH] fixed regression in the handling of try/catch statement with singleton cases --- .../scala/cps/plugin/forest/CpsTree.scala | 17 +++++---- .../cps/plugin/forest/TryTransform.scala | 35 ++++++++++++------- .../src/test/scala/cc/DotcInvocations.scala | 4 +-- .../src/test/scala/cc/Test9Try.scala | 2 +- .../testdata/set9Try/m6_100/Test9m6_100.scala | 2 ++ 5 files changed, 38 insertions(+), 22 deletions(-) diff --git a/compiler-plugin/src/main/scala/cps/plugin/forest/CpsTree.scala b/compiler-plugin/src/main/scala/cps/plugin/forest/CpsTree.scala index 7a1a41478..73469c90e 100644 --- a/compiler-plugin/src/main/scala/cps/plugin/forest/CpsTree.scala +++ b/compiler-plugin/src/main/scala/cps/plugin/forest/CpsTree.scala @@ -514,7 +514,10 @@ case class MapCpsTree( mapFun.body.originType override def castOriginType(ntpe: Type)(using Context, CpsTopLevelContext): CpsTree = { - copy(mapFun = mapFun.copy(body = mapFun.body.castOriginType(ntpe))) + if (ntpe =:= originType) then + this + else + copy(mapFun = mapFun.copy(body = mapFun.body.castOriginType(ntpe))) } override def internalAsyncKind(using Context, CpsTopLevelContext) = @@ -736,7 +739,9 @@ case class LambdaCpsTree( } override def castOriginType(ntpe: Type)(using Context, CpsTopLevelContext): CpsTree = - if (defn.isFunctionType(ntpe) || defn.isContextFunctionType(ntpe)) then + if (ntpe =:= originType) then + this + else if (defn.isFunctionType(ntpe) || defn.isContextFunctionType(ntpe)) then ntpe match case AppliedType(ntpfun, ntpargs) => // TODO: check that params are the same. @@ -853,9 +858,6 @@ case class LambdaCpsTree( tctx.cpsNonDirectContext ) ) - if (tss.isEmpty) then - println(s"tss.isEmpty, originParans=${originParams}, tss=${tss}") - println(s"originType=${originType.show}, originClosureType=${originClosureType.show}") TransformUtil.substParams(nBody, originParams, tss.head).changeOwner(cpsBody.owner, meth) }) CpsTree.pure(origin,owner,closure) @@ -921,7 +923,10 @@ case class OpaqueAsyncLambdaTermCpsTree( override def normalizeAsyncKind(using Context, CpsTopLevelContext) = this override def castOriginType(ntpe: Type)(using Context, CpsTopLevelContext): CpsTree = { - typed(Typed(origin,TypeTree(ntpe))) + if (origin.tpe =:= ntpe) then + this + else + typed(Typed(origin,TypeTree(ntpe))) } override def unpure(using Context, CpsTopLevelContext): Option[Tree] = None diff --git a/compiler-plugin/src/main/scala/cps/plugin/forest/TryTransform.scala b/compiler-plugin/src/main/scala/cps/plugin/forest/TryTransform.scala index 91d274be2..b0215f956 100644 --- a/compiler-plugin/src/main/scala/cps/plugin/forest/TryTransform.scala +++ b/compiler-plugin/src/main/scala/cps/plugin/forest/TryTransform.scala @@ -65,15 +65,19 @@ object TryTransform { val newTree = Try(cpsExpr.unpure.get, cases.unpureCaseDefs, EmptyTree) CpsTree.pure(origin, owner, newTree) case (AsyncKind.Sync, AsyncKind.Async(ik)) => - val retval = generateWithAsyncCasesWithTry(origin, owner, cpsExpr, cases, casesAsyncKind, nesting) - Log.trace(s"TryTransform:applyNoFinalizer return ${retval.show}", nesting) - retval + val unwrapedType = origin.tpe.widenUnion + val castedCpsExpr = cpsExpr.castOriginType(unwrapedType) + val retval = generateWithAsyncCasesWithTry(origin, owner, castedCpsExpr, cases, casesAsyncKind, nesting) + Log.trace(s"TryTransform:applyNoFinalizer return ${retval.show}", nesting) + retval case _ => - val targetKind = cpsExpr.asyncKind unify casesAsyncKind match + val targetKind = cpsExpr.asyncKind unify casesAsyncKind match case Left((k1,k2)) => throw CpsTransformException("Incompatible async kinds of try expr and cases", origin.srcPos) case Right(k) => k - generateWithAsyncCases(origin, owner, cpsExpr, cases, targetKind, nesting) + val unwrapedType = origin.tpe.widenUnion + val castedCpsExpr = cpsExpr.castOriginType(unwrapedType) + generateWithAsyncCases(origin, owner, castedCpsExpr, cases, targetKind, nesting) } } @@ -122,7 +126,8 @@ object TryTransform { case Left(p) => throw CpsTransformException(s"Non-compatible async shape in try exppression and handlers ${p}", origin.srcPos) case Right(k) => - val expr1 = generateWithAsyncCases(origin, owner, cpsExpr, cases, k, nesting) + val castedCpsExpr = cpsExpr.castOriginType(cpsExpr.originType.widenUnion) + val expr1 = generateWithAsyncCases(origin, owner, castedCpsExpr, cases, k, nesting) val expr2 = generateWithAsyncFinalizer(origin, owner, expr1, cpsFinalizer) expr2 } @@ -134,10 +139,13 @@ object TryTransform { exprCpsTree: CpsTree, finalizerCpsTree: CpsTree, )(using Context, CpsTopLevelContext): CpsTree = { + val castedExprCpsTree = if (!(origin.tpe.widenUnion =:= exprCpsTree.originType)) then { + exprCpsTree.castOriginType(origin.tpe.widenUnion) + } else exprCpsTree generateWithAsyncFinalizerTree( origin, owner, - wrapPureCpsTreeInTry(origin, exprCpsTree), - exprCpsTree.originType.widen, - exprCpsTree.asyncKind, + wrapPureCpsTreeInTry(origin, castedExprCpsTree), + castedExprCpsTree.originType.widen, + castedExprCpsTree.asyncKind, finalizerCpsTree ) } @@ -194,7 +202,7 @@ object TryTransform { val mt = MethodType(List("ex".toTermName), List(defn.ThrowableType), lambdaResultType) val lambdaSym = newAnonFun(owner,mt) val lambda = Closure(lambdaSym, tss => { - val defaultCase = generateDefaultCaseDef(origin, origin.tpe.widen)(using summon[Context].withOwner(lambdaSym), summon[CpsTopLevelContext]) + val defaultCase = generateDefaultCaseDef(origin, unwrappedTpe)(using summon[Context].withOwner(lambdaSym), summon[CpsTopLevelContext]) Match(tss.head.head, transformedCases :+ defaultCase).changeOwner(owner,lambdaSym) } ) @@ -208,13 +216,14 @@ object TryTransform { ), List(lambda) ).withSpan(origin.span) + val typedOrigin = if (unwrappedTpe =:= origin.tpe.widen) then origin else Typed(origin, TypeTree(unwrappedTpe)).withSpan(origin.span) targetKind match case AsyncKind.Sync => - CpsTree.impure(origin,owner,tree,AsyncKind.Sync) + CpsTree.impure(typedOrigin,owner,tree,AsyncKind.Sync) case AsyncKind.Async(ik) => - CpsTree.impure(origin,owner,tree,ik) + CpsTree.impure(typedOrigin,owner,tree,ik) case AsyncKind.AsyncLambda(bodyKind) => - CpsTree.opaqueAsyncLambda(origin,owner,tree,bodyKind) + CpsTree.opaqueAsyncLambda(typedOrigin,owner,tree,bodyKind) } diff --git a/compiler-plugin/src/test/scala/cc/DotcInvocations.scala b/compiler-plugin/src/test/scala/cc/DotcInvocations.scala index ba8a8b391..46421d164 100644 --- a/compiler-plugin/src/test/scala/cc/DotcInvocations.scala +++ b/compiler-plugin/src/test/scala/cc/DotcInvocations.scala @@ -110,7 +110,7 @@ class DotcInvocations(silent: Boolean = false) { } -case class TestRun(inputDir: String, mainClass: String, expectedOutput: String = "Ok\n") +case class TestRun(inputDir: String, mainClass: String, expectedOutput: String = "Ok\n", extraDotcArgs:List[String] = List.empty) @@ -194,7 +194,7 @@ object DotcInvocations { ): Unit = { for(r <- runs) { if (selection.matches(r.inputDir)) { - compileAndRunFilesInDirAndCheckResult(r.inputDir,r.mainClass,r.expectedOutput,dotcArgs) + compileAndRunFilesInDirAndCheckResult(r.inputDir,r.mainClass,r.expectedOutput,dotcArgs.copy(extraDotcArgs = dotcArgs.extraDotcArgs ++ r.extraDotcArgs)) } } } diff --git a/compiler-plugin/src/test/scala/cc/Test9Try.scala b/compiler-plugin/src/test/scala/cc/Test9Try.scala index 7b5e7a487..8655c4ee6 100644 --- a/compiler-plugin/src/test/scala/cc/Test9Try.scala +++ b/compiler-plugin/src/test/scala/cc/Test9Try.scala @@ -8,7 +8,7 @@ class Test9Try { @Test def testCompileAndRun(): Unit = { - DotcInvocations.checkRuns(selection = (".*".r))( + DotcInvocations.checkRuns(selection = (".*".r),dotcArgs = DotcInvocationArgs(extraDotcArgs = List()))( TestRun("testdata/set9Try/m1", "cpstest.Test9m1", "Right(10)\n"), TestRun("testdata/set9Try/m2", "cpstest.Test9m2"), TestRun("testdata/set9Try/m3", "cpstest.Test9m3"), diff --git a/compiler-plugin/testdata/set9Try/m6_100/Test9m6_100.scala b/compiler-plugin/testdata/set9Try/m6_100/Test9m6_100.scala index a572ff57f..03120868c 100644 --- a/compiler-plugin/testdata/set9Try/m6_100/Test9m6_100.scala +++ b/compiler-plugin/testdata/set9Try/m6_100/Test9m6_100.scala @@ -3,10 +3,12 @@ package cpstest import scala.annotation.experimental import cps.* import cps.monads.{*,given} +import cps.plugin.annotation.CpsDebugLevel import testUtil.* @experimental +@CpsDebugLevel(20) object Test9m6_100 { var finallyWasRun=false