Skip to content

Commit

Permalink
remove automatic coloring
Browse files Browse the repository at this point in the history
  • Loading branch information
rssh committed Oct 17, 2023
1 parent 699b2e2 commit 456179c
Show file tree
Hide file tree
Showing 51 changed files with 227 additions and 1,511 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package cpstest

import cps.*

import cps.{ComputationBound, Thunk}
import org.junit.{Ignore, Test}
import org.junit.Assert.*

import scala.annotation.experimental
import scala.util.*
import java.util.concurrent.atomic.*
import cps.testconfig.given

// rewrited with CpsDirect instead automatic coloring
@experimental
class TestCBS2ACCntDirect:

def qqq:Int = 0

val LOG_MOD = 10
val LOG_TRESHOLD = 100

def createCounter(n:Int) = new AtomicInteger(n)

// disable loom, to prevent compiler crash
given noLoom1: cps.macros.flags.UseLoomAwait.type = ???
given noLoom2: cps.macros.flags.UseLoomAwait.type = ???


//implicit val printCode: cps.macroFlags.PrintCode.type = cps.macroFlags.PrintCode
//implicit val printTree = cps.macroFlags.PrintTree
//implicit val debugLevel = cps.macroFlags.DebugLevel(20)

def increment(cnt: AtomicInteger)(using CpsDirect[ComputationBound]): Int =
val cb: ComputationBound[Int] = Thunk( () => ComputationBound.pure(cnt.incrementAndGet()) )
await(cb)

class Log:
private var lines = Vector[String]()

def log(msg:String): Unit =
lines = lines :+ msg

def all: Vector[String] = lines


def cntDirect(counter: AtomicInteger): ComputationBound[Log] = async[ComputationBound]{
val log = new Log
val value = increment(counter)
if value % LOG_MOD == 0 then
log.log(s"counter value = ${value}")
if (value - 1 == LOG_TRESHOLD) then
// Conversion will not be appliyed for == . For this example we want automatic conversion, so -1
log.log("counter TRESHOLD")
log
}


@Test def cnt_direct(): Unit =
val counter = createCounter(9)
val c = cntDirect(counter)
val r: Try[Log] = c.run()
//println(s"cn_automatic_coloring, r=$r, r.get.all=${r.get.all} counter.get()=${counter.get()} ")
assert(r.isSuccess, "r should be success")
assert(r.get.all.size == 1, "r.get.all.size==1")
assert(counter.get() == 10, "counter.get() == 10")


Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ case class CpsTopLevelContext(
val optTrySupport: Option[Tree],
val debugSettings: DebugSettings,
val pluginSettings: CpsPluginSettings,
val isBeforeInliner: Boolean,
val automaticColoring: Option[CpsAutomaticColoring],
val customValueDiscard: Boolean
val isBeforeInliner: Boolean
) {

def isAfterInliner = !isBeforeInliner
Expand Down
11 changes: 0 additions & 11 deletions compiler-plugin/src/main/scala/cps/plugin/CpsTransformHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -251,17 +251,6 @@ object CpsTransformHelper {
findWrapperForMonad("cps.CpsTrySupport", monadType, span)
}

def findAutomaticColoringTag(monadType: Type, span: Span)(using ctx: Context): Option[Tree] = {
findWrapperForMonad("cps.automaticColoring.AutomaticColoringTag", monadType, span)
}

def findCpsMonadMemoization(monadType: Type, span: Span)(using ctx: Context): Option[Tree] = {
findWrapperForMonad("cps.CpsMonadMemoization", monadType, span)
}

def findCustomValueDiscardTag(span: Span)(using ctx: Context): Option[Tree] = {
findImplicitInstance(Symbols.requiredClassRef("cps.ValueDiscard.CustomTag"), span)
}


}
15 changes: 1 addition & 14 deletions compiler-plugin/src/main/scala/cps/plugin/PhaseCps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ class PhaseCps(settings: CpsPluginSettings,
case _ => throw CpsTransformException(s"excepted that cpsMonadContext is ValDef, but we have ${cpsMonadContext.show}", asyncCallTree.srcPos)
val (tctx, monadValDef) = makeCpsTopLevelContext(contextParam, ddef.symbol, asyncCallTree.srcPos, DebugSettings.make(asyncCallTree), CpsTransformHelper.cpsMonadContextClassSymbol)
val ddefCtx = ctx.withOwner(ddef.symbol)
tctx.automaticColoring.foreach(_.analyzer.observe(ddef.rhs)(using ddefCtx))
val nRhsCps = RootTransform(ddef.rhs, ddef.symbol, 0)(using ddefCtx, tctx)
val nRhsTerm = wrapTopLevelCpsTree(nRhsCps)(using ddefCtx, tctx)
val nRhsType = nRhsTerm.tpe.widen
Expand Down Expand Up @@ -390,22 +389,10 @@ class PhaseCps(settings: CpsPluginSettings,
else if (runsAfter.contains(Inlining.name)) { false }
else
throw new CpsTransformException("plugins runsBefore/After Inlining not found", srcPos)
val automaticColoringTag = CpsTransformHelper.findAutomaticColoringTag(monadType, srcPos.span)
val automaticColoring = if (automaticColoringTag.isDefined) {
val memoization = CpsTransformHelper.findCpsMonadMemoization(monadType, srcPos.span)
if (memoization.isDefined) {
val analyzer = new AutomaticColoringAnalyzer()
Some(CpsAutomaticColoring(memoization.get,analyzer))
} else {
throw CpsTransformException(s"Can't find instance of cps.CpsMemoization for ${monadType.show}", srcPos)
}
} else None
val customValueDiscard = automaticColoring.isDefined || CpsTransformHelper.findCustomValueDiscardTag(srcPos.span).isDefined
val tc = CpsTopLevelContext(monadType, monadRef, cpsDirectOrSimpleContext,
optRuntimeAwait, optRuntimeAwaitProvider,
optThrowSupport, optTrySupport,
debugSettings, settings, isBeforeInliner,
automaticColoring, customValueDiscard)
debugSettings, settings, isBeforeInliner)
(tc, monadValDef)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ object BlockTransform {
// imports here ?
s
case _ =>
val cpsE0 = RootTransform(e, owner, nesting+1)
val cpsE = maybeApplyCustomDiscard(cpsE0, owner, nesting)
val cpsE = RootTransform(e, owner, nesting+1)
val r = s.appendInBlock(cpsE)
r
}
Expand All @@ -75,88 +74,4 @@ object BlockTransform {
retval
}

// TODO: remove after elimination of automatic coloring
def maybeApplyCustomDiscard(cpsTree:CpsTree, owner:Symbol, nesting:Int)(using Context, CpsTopLevelContext): CpsTree = {
val tctx = summon[CpsTopLevelContext]
Log.trace(s"BlockTransform.maybeApplyCustomDiscard: customValueDiscard=${tctx.customValueDiscard}", nesting)
if (!tctx.customValueDiscard) then
cpsTree
else
if (cpsTree.originType != defn.UnitType && cpsTree.originType != defn.NothingType) then
val valueDiscardType = Symbols.requiredClassRef("cps.ValueDiscard").appliedTo(cpsTree.originType.widen)
CpsTransformHelper.findImplicitInstance(valueDiscardType, cpsTree.origin.span) match
case Some(discard) =>
applyImplicitDiscard(cpsTree, owner, discard)
case None =>
report.warning(s"custom discard is enablde, but no implicit instance for ${valueDiscardType.show} found", cpsTree.origin.srcPos)
cpsTree
else
cpsTree

}

def applyImplicitDiscard(cpsTree:CpsTree, owner:Symbol, inDiscard:Tree)(using Context, CpsTopLevelContext): CpsTree = {
import inlines.Inlines


def genDiscardApply(discard: Tree, arg: Tree): Tree = {
val discardApply = Select(discard, "apply".toTermName)
val call = discardApply.appliedTo(arg)
val retval = if (discardApply.symbol.flags.is(Flags.Inline)) then {
Inlines.inlineCall(call)
} else
call
retval
}

val discard = if (inDiscard.symbol.flags.is(Flags.Inline)) then
Inlines.inlineCall(inDiscard)
else
inDiscard

if (discard.tpe.baseType(Symbols.requiredClass("cps.AwaitValueDiscard")) != NoType) then
if !(cpsTree.originType <:< summon[CpsTopLevelContext].monadType.appliedTo(Types.WildcardType)) then
throw CpsTransformException(s"await discard is not applicable to ${cpsTree.originType.show}", cpsTree.origin.srcPos)
cpsTree.unpure match
case Some(stat) =>
CpsTree.impure(cpsTree.origin, cpsTree.owner, stat, AsyncKind.Sync)
case None =>
cpsTree.asyncKind match
case AsyncKind.Sync =>
throw CpsTransformException(s"impossible: sync tree with empty unpure: ${cpsTree}", cpsTree.origin.srcPos)
case AsyncKind.Async(internalKind) =>
if (internalKind != AsyncKind.Sync) then
throw CpsTransformException(s"impossible: async tree with non-sync internal kind: ${cpsTree}", cpsTree.origin.srcPos)
val untpdTree = untpd.Apply(
untpd.Select(untpd.TypedSplice(summon[CpsTopLevelContext].cpsMonadRef), "flatten".toTermName),
List(untpd.TypedSplice(cpsTree.transformed))
)
val typedTree = ctx.typer.typed(untpdTree, summon[CpsTopLevelContext].monadType.appliedTo(Types.WildcardType))
val fakeOrigin = Apply(Select(discard, "apply".toTermName),List(cpsTree.origin) )
CpsTree.impure(fakeOrigin, owner, typedTree, internalKind)
case AsyncKind.AsyncLambda(bodyKind) =>
throw CpsTransformException(s"discarede lambda expression: ${cpsTree}", cpsTree.origin.srcPos)
else
cpsTree.unpure match
case Some(stat) =>
val tree = genDiscardApply(discard, stat)
val fakeOrigin = Typed(cpsTree.origin, TypeTree(defn.UnitType)).withSpan(cpsTree.origin.span)
CpsTree.pure(fakeOrigin, cpsTree.owner, tree)
case None =>
cpsTree.asyncKind match
case AsyncKind.Sync =>
throw CpsTransformException(s"impossible: sync tree with empty unpure: ${cpsTree}", cpsTree.origin.srcPos)
case AsyncKind.Async(ik) =>
val toDiscardSym = Symbols.newSymbol(owner, "toDiscard".toTermName, Flags.Synthetic, cpsTree.originType)
val toDiscardRef = ref(toDiscardSym)
val toDiscardValDef = ValDef(toDiscardSym, EmptyTree)
val discardBody = genDiscardApply(discard, toDiscardRef)
val fakeOrigin0 = Apply(Select(discard, "apply".toTermName),List(cpsTree.origin) ).withSpan(cpsTree.origin.span)
val fakeOrigin1 = Apply(Select(discard, "apply".toTermName),List(toDiscardRef) ).withSpan(cpsTree.origin.span)
MapCpsTree(fakeOrigin0, owner, cpsTree,
MapCpsTreeArgument(Some(toDiscardValDef), CpsTree.pure(fakeOrigin1 , owner, discardBody)))
case AsyncKind.AsyncLambda(bodyKind) =>
throw CpsTransformException(s"discarede lambda expression: ${cpsTree}", cpsTree.origin.srcPos)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,7 @@ object InlinedTransform {
optTrySupport = CpsTransformHelper.findCpsTrySupport(gMonadType, inlinedTerm.span),
debugSettings = summon[CpsTopLevelContext].debugSettings,
pluginSettings = summon[CpsTopLevelContext].pluginSettings,
isBeforeInliner = summon[CpsTopLevelContext].isBeforeInliner,
automaticColoring = summon[CpsTopLevelContext].automaticColoring,
customValueDiscard = summon[CpsTopLevelContext].customValueDiscard
isBeforeInliner = summon[CpsTopLevelContext].isBeforeInliner
)
val List(ctx) = tss.head
val newContext = summon[Context].withOwner(inclusionLambdaSym)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,7 @@ object ValDefTransform {
val tctx = summon[CpsTopLevelContext]
if (term.rhs.isEmpty) then
throw CpsTransformException(s"ValDef without right part: $term", term.srcPos)
val cpsRhs0 = RootTransform(term.rhs,term.symbol,nesting+1)
val cpsRhs = tctx.automaticColoring match
case Some(c) if (cpsRhs0.originType <:< tctx.monadType.appliedTo(Types.WildcardType)) =>
c.analyzer.usageRecords.get(term.symbol) match
case Some(record) =>
record.reportCases()
if (record.nInAwaits > 0 && record.nWithoutAwaits == 0) {
applyMemoization(cpsRhs0, owner, c.memoization, term)
} else if (record.nInAwaits >0 && record.nWithoutAwaits > 0) {
record.reportCases()
throw CpsTransformException(s"val ${term.name} used in both sync and async way with automatic coloring", term.srcPos)
} else {
cpsRhs0
}
case None => cpsRhs0
case _ => cpsRhs0
val cpsRhs = RootTransform(term.rhs,term.symbol,nesting+1)

cpsRhs.asyncKind match
case AsyncKind.Sync =>
Expand Down
2 changes: 2 additions & 0 deletions compiler-plugin/src/test/scala/cc/Test14Run.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ class Test14Run {
}




@Test
def testCBS2Dynamic(): Unit = {
val dirname = "testdata/set14runtests/m11"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ object ComputationBoundLoomUsage {

def useLoomFastImpl(using Quotes): Expr[Boolean] = {
val r = Expr.summon[cps.macros.flags.UseLoomAwait.type].isDefined &&
Expr.summon[CpsFastRuntimeAwait[ComputationBound]].isDefined
Expr.summon[CpsRuntimeAsyncAwait[ComputationBound]].isDefined
Expr(r)
}

Expand All @@ -23,7 +23,7 @@ object ComputationBoundLoomUsage {

def useLoomHybridImpl(using Quotes): Expr[Boolean] = {
val r = Expr.summon[cps.macros.flags.UseLoomAwait.type].isDefined &&
!Expr.summon[CpsFastRuntimeAwait[ComputationBound]].isDefined
!Expr.summon[CpsRuntimeAsyncAwait[ComputationBound]].isDefined
Expr(r)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,5 @@ given PureEffectCpsMonad: CpsConcurrentEffectMonad[PureEffect] with CpsTryMonadI



given CpsMonadMemoization.Pure[PureEffect] with

def apply[T](ft:PureEffect[T]): PureEffect[PureEffect[T]] =
ft.memoize()


inline transparent given ValueDiscard[PureEffect[Unit]] = AwaitValueDiscard[PureEffect,Unit]


12 changes: 6 additions & 6 deletions compiler-plugin/testdata/set14runtests/m10/TestFizzBuzz.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ import scala.concurrent.*
import scala.concurrent.duration.*

import cps.*
import cps.automaticColoring.given
import cps.plugin.annotation.CpsDebugLevel
import scala.language.implicitConversions


import cps.testconfig.given

import scala.annotation.experimental
import scala.concurrent.ExecutionContext.Implicits.global

// This test will be deleted after disabling of automatic coloring.
@CpsDebugLevel(20)
@experimental
class TestFizzBuzz:


Expand All @@ -32,18 +32,18 @@ class TestFizzBuzz:
@Test def testFizBuzz =
val c = async[PureEffect] {
val logger = PEToyLogger.make()
val counter = PEIntRef.make(-1)
val counter = await(PEIntRef.make(-1))
//println(s"crrate counter, value=${counter.value}")
while {
val v = counter.increment()
logger.log(await(v).toString)
val v = await(counter.increment())
logger.log(v.toString)
if (v % 3 == 0) then
logger.log("Fizz")
if (v % 5 == 0) then
logger.log("Buzz")
v < 10
} do ()
await(logger.all())
logger.all()
}
println(s"PE:fizbuzz, c=${c} ")
val future = c.unsafeRunFuture().map{ log =>
Expand Down
18 changes: 11 additions & 7 deletions compiler-plugin/testdata/set14runtests/m10/ToyLogger.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package cps.pe

import scala.annotation.experimental

import cps.*
import cps.testconfig.given

class ToyLogger{
Expand All @@ -12,25 +15,26 @@ class ToyLogger{

}


@experimental
class PEToyLogger{

private val logger = new ToyLogger()

def log(msg:String): PureEffect[Unit] =
PureEffect.delay(logger.log(msg))
def log(msg:String)(using CpsDirect[PureEffect]): Unit =
logger.log(msg)

def all(): PureEffect[Vector[String]] =
PureEffect.delay(logger.lines)
def all()(using CpsDirect[PureEffect]): Vector[String] =
logger.lines

def __all(): Vector[String] =
logger.lines

}

object PEToyLogger {

def make(): PureEffect[PEToyLogger] =
PureEffect.delay(new PEToyLogger)
def make()(using CpsDirect[PureEffect]): PEToyLogger =
new PEToyLogger

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import org.junit.Assert.*

import scala.util.*

import cps.automaticColoring
import scala.language.implicitConversions
import scala.language.dynamics

import cps.testconfig.given
Expand Down
Loading

0 comments on commit 456179c

Please sign in to comment.