Skip to content

Commit

Permalink
Be more careful computing underlying types of reach capabilities
Browse files Browse the repository at this point in the history
We can use the dcs only if there are no type variables.
  • Loading branch information
odersky committed Nov 5, 2024
1 parent ce3c01d commit ac06cb5
Show file tree
Hide file tree
Showing 11 changed files with 165 additions and 46 deletions.
23 changes: 18 additions & 5 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ extension (tp: Type)
* a singleton capability `x` or a reach capability `x*`, the deep capture
* set can be narrowed to`{x*}`.
*/
def deepCaptureSet(using Context): CaptureSet =
val dcs = CaptureSet.ofTypeDeeply(tp.widen.stripCapturing)
def deepCaptureSet(includeTypevars: Boolean)(using Context): CaptureSet =
val dcs = CaptureSet.ofTypeDeeply(tp.widen.stripCapturing, includeTypevars)
if dcs.isAlwaysEmpty then tp.captureSet
else tp match
case tp @ ReachCapability(_) =>
Expand All @@ -231,6 +231,9 @@ extension (tp: Type)
case _ =>
tp.captureSet ++ dcs

def deepCaptureSet(using Context): CaptureSet =
deepCaptureSet(includeTypevars = false)

/** A type capturing `ref` */
def capturing(ref: CaptureRef)(using Context): Type =
if tp.captureSet.accountsFor(ref) then tp
Expand Down Expand Up @@ -593,16 +596,26 @@ extension (sym: Symbol)
def isRefiningParamAccessor(using Context): Boolean =
sym.is(ParamAccessor)
&& {
val param = sym.owner.primaryConstructor.paramSymss
.nestedFind(_.name == sym.name)
.getOrElse(NoSymbol)
val param = sym.owner.primaryConstructor.paramNamed(sym.name)
!param.hasAnnotation(defn.ConstructorOnlyAnnot)
&& !param.hasAnnotation(defn.UntrackedCapturesAnnot)
}

def hasTrackedParts(using Context): Boolean =
!CaptureSet.ofTypeDeeply(sym.info).isAlwaysEmpty

/** `sym` is annotated @use or it is a type parameter with a matching
* @use-annotated term parameter that contains `sym` in its deep capture set.
*/
def isUseParam(using Context): Boolean =
sym.hasAnnotation(defn.UseAnnot)
|| sym.is(TypeParam)
&& sym.owner.rawParamss.nestedExists: param =>
param.is(TermParam) && param.hasAnnotation(defn.UseAnnot)
&& param.info.deepCaptureSet.elems.exists:
case c: TypeRef => c.symbol == sym
case _ => false

extension (tp: AnnotatedType)
/** Is this a boxed capturing type? */
def isBoxed(using Context): Boolean = tp.annot match
Expand Down
11 changes: 7 additions & 4 deletions compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1064,8 +1064,9 @@ object CaptureSet:
case ref: (TermRef | TermParamRef) if ref.isMaxCapability =>
if ref.isTrackableRef then ref.singletonCaptureSet
else CaptureSet.universal
case ReachCapability(ref1) => ref1.widen.deepCaptureSet
.showing(i"Deep capture set of $ref: ${ref1.widen} = $result", capt)
case ReachCapability(ref1) =>
ref1.widen.deepCaptureSet(includeTypevars = true)
.showing(i"Deep capture set of $ref: ${ref1.widen} = ${result}", capt)
case _ => ofType(ref.underlying, followResult = true)

/** Capture set of a type */
Expand Down Expand Up @@ -1120,7 +1121,7 @@ object CaptureSet:
* arguments. This have to be included to be conservative in dcs but must be
* excluded in narrowCaps.
*/
def ofTypeDeeply(tp: Type)(using Context): CaptureSet =
def ofTypeDeeply(tp: Type, includeTypevars: Boolean = false)(using Context): CaptureSet =
val collect = new TypeAccumulator[CaptureSet]:
val seen = util.HashSet[Symbol]()
def apply(cs: CaptureSet, t: Type) =
Expand All @@ -1132,7 +1133,9 @@ object CaptureSet:
this(cs, parent)
case t: TypeRef if t.symbol.isAbstractOrParamType && !seen.contains(t.symbol) =>
seen += t.symbol
this(cs, t.info.bounds.hi)
val upper = t.info.bounds.hi
if includeTypevars && upper.isExactlyAny then CaptureSet.universal
else this(cs, t.info.bounds.hi)
case t @ FunctionOrMethod(args, res @ Existential(_, _))
if args.forall(_.isAlwaysPure) =>
this(cs, Existential.toCap(res))
Expand Down
74 changes: 40 additions & 34 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -182,34 +182,6 @@ object CheckCaptures:
if ccConfig.useSealed then check.traverse(tp)
end disallowRootCapabilitiesIn

/** Under the sealed policy, disallow the root capability in type arguments.
* Type arguments come either from a TypeApply node or from an AppliedType
* which represents a trait parent in a template.
* @param fn the type application, of type TypeApply or TypeTree
* @param sym the constructor symbol (could be a method or a val or a class)
* @param args the type arguments
*/
private def disallowCapInTypeArgs(fn: Tree, sym: Symbol, args: List[Tree], thisPhase: Phase)(using Context): Unit =
def isExempt = sym.isTypeTestOrCast || sym == defn.Compiletime_erasedValue
if ccConfig.useSealed && !isExempt then
val paramNames = atPhase(thisPhase.prev):
fn.tpe.widenDealias match
case tl: TypeLambda => tl.paramNames
case ref: AppliedType if ref.typeSymbol.isClass => ref.typeSymbol.typeParams.map(_.name)
case t =>
println(i"parent type: $t")
args.map(_ => EmptyTypeName)
for case (arg: TypeTree, pname) <- args.lazyZip(paramNames) do
def where = if sym.exists then i" in an argument of $sym" else ""
val (addendum, pos) =
if arg.isInferred
then ("\nThis is often caused by a local capability$where\nleaking as part of its result.", fn.srcPos)
else if arg.span.exists then ("", arg.srcPos)
else ("", fn.srcPos)
disallowRootCapabilitiesIn(arg.knownType, NoSymbol,
i"Type variable $pname of $sym", "be instantiated to", addendum, pos)
end disallowCapInTypeArgs

/** If we are not under the sealed policy, and a tree is an application that unboxes
* its result or is a try, check that the tree's type does not have covariant universal
* capabilities.
Expand Down Expand Up @@ -404,14 +376,14 @@ class CheckCaptures extends Recheck, SymTransformer:
if lastEnv != null && env.nestedClosure.exists && env.nestedClosure == lastEnv.owner then
() // access is from a nested closure, so it's OK
else c.pathRoot match
case ref: NamedType if !ref.symbol.hasAnnotation(defn.UseAnnot) =>
case ref: NamedType if !ref.symbol.isUseParam =>
val what = if ref.isType then "Capture set parameter" else "Local reach capability"
report.error(
em"""$what $c leaks into capture scope of ${env.ownerString}.
|To allow this, the ${ref.symbol} should be declared with a @use annotation""", pos)
case _ =>

def recur(cs: CaptureSet, env: Env, lastEnv: Env | Null)(using Context): Unit =
def recur(cs: CaptureSet, env: Env, lastEnv: Env | Null): Unit =
if env.isOpen && !env.owner.isStaticOwner && !cs.isAlwaysEmpty then
// Only captured references that are visible from the environment
// should be included.
Expand Down Expand Up @@ -475,6 +447,40 @@ class CheckCaptures extends Recheck, SymTransformer:
case _ =>
if sym.exists && curEnv.isOpen then markFree(capturedVars(sym), pos)

/** Under the sealed policy, disallow the root capability in type arguments.
* Type arguments come either from a TypeApply node or from an AppliedType
* which represents a trait parent in a template. Also, if a corresponding
* formal type parameter is declared or implied @use, charge the deep capture
* set of the argument to the environent.
* @param fn the type application, of type TypeApply or TypeTree
* @param sym the constructor symbol (could be a method or a val or a class)
* @param args the type arguments
*/
def disallowCapInTypeArgs(fn: Tree, sym: Symbol, args: List[Tree])(using Context): Unit =
def isExempt = sym.isTypeTestOrCast || sym == defn.Compiletime_erasedValue
if ccConfig.useSealed && !isExempt then
val paramNames = atPhase(thisPhase.prev):
fn.tpe.widenDealias match
case tl: TypeLambda => tl.paramNames
case ref: AppliedType if ref.typeSymbol.isClass => ref.typeSymbol.typeParams.map(_.name)
case t =>
println(i"parent type: $t")
args.map(_ => EmptyTypeName)

for case (arg: TypeTree, pname) <- args.lazyZip(paramNames) do
def where = if sym.exists then i" in an argument of $sym" else ""
val (addendum, pos) =
if arg.isInferred
then ("\nThis is often caused by a local capability$where\nleaking as part of its result.", fn.srcPos)
else if arg.span.exists then ("", arg.srcPos)
else ("", fn.srcPos)
disallowRootCapabilitiesIn(arg.knownType, NoSymbol,
i"Type variable $pname of $sym", "be instantiated to", addendum, pos)

val param = fn.symbol.paramNamed(pname)
if param.isUseParam then markFree(arg.knownType.deepCaptureSet, pos)
end disallowCapInTypeArgs

override def recheckIdent(tree: Ident, pt: Type)(using Context): Type =
val sym = tree.symbol
if sym.is(Method) then
Expand Down Expand Up @@ -553,8 +559,8 @@ class CheckCaptures extends Recheck, SymTransformer:
*/
override def prepareFunction(funtpe: MethodType, meth: Symbol)(using Context): MethodType =
val paramInfosWithUses = funtpe.paramInfos.zipWithConserve(funtpe.paramNames): (formal, pname) =>
val paramOpt = meth.rawParamss.nestedFind(_.name == pname)
paramOpt.flatMap(_.getAnnotation(defn.UseAnnot)) match
val param = meth.paramNamed(pname)
param.getAnnotation(defn.UseAnnot) match
case Some(ann) => AnnotatedType(formal, ann)
case _ => formal
funtpe.derivedLambdaType(paramInfos = paramInfosWithUses)
Expand Down Expand Up @@ -720,7 +726,7 @@ class CheckCaptures extends Recheck, SymTransformer:
val meth = tree.fun match
case fun @ Select(qual, nme.apply) => qual.symbol.orElse(fun.symbol)
case fun => fun.symbol
disallowCapInTypeArgs(tree.fun, meth, tree.args, thisPhase)
disallowCapInTypeArgs(tree.fun, meth, tree.args)
val res = Existential.toCap(super.recheckTypeApply(tree, pt))
includeCallCaptures(tree.symbol, res, tree.srcPos)
checkContains(tree)
Expand Down Expand Up @@ -951,7 +957,7 @@ class CheckCaptures extends Recheck, SymTransformer:
for case tpt: TypeTree <- impl.parents do
tpt.tpe match
case AppliedType(fn, args) =>
disallowCapInTypeArgs(tpt, fn.typeSymbol, args.map(TypeTree(_)), thisPhase)
disallowCapInTypeArgs(tpt, fn.typeSymbol, args.map(TypeTree(_)))
case _ =>
inNestedLevelUnless(cls.is(Module)):
super.recheckClassDef(tree, impl, cls)
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/SymUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ class SymUtils:
self.owner.info.decl(fieldName).suchThat(!_.is(Method)).symbol
}

def paramNamed(name: Name)(using Context): Symbol =
self.rawParamss.nestedFind(_.name == name).getOrElse(NoSymbol)

/** Is this symbol a constant expression final val?
*
* This is the case if all of the following are true:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ extension [T](@use fs: Seq[Future[T]^])
val collector//: Collector[T]{val futures: Seq[Future[T]^{fs*}]}
= Collector(fs)
// val ch = collector.results // also errors
val fut: Future[T]^{fs*} = collector.results.read().get // found ...^{caps.cap}
val fut: Future[T]^{fs*} = collector.results.read().get // error
15 changes: 15 additions & 0 deletions tests/neg-custom-args/captures/gears-problem.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/gears-problem.scala:19:62 --------------------------------
19 | val fut: Future[T]^{fs*} = collector.results.read().right.get // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| Found: Future[T]^{collector.futures*}
| Required: Future[T]^{fs*}
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/gears-problem.scala:24:34 --------------------------------
24 | val fut2: Future[T]^{fs*} = r.get // error
| ^^^^^
| Found: Future[box T^?]^{collector.futures*}
| Required: Future[T]^{fs*}
|
| longer explanation available when compiling with `-explain`
there were 4 deprecation warnings; re-run with -deprecation for details
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ extension [T](@use fs: Seq[Future[T]^])
val collector: Collector[T]{val futures: Seq[Future[T]^{fs*}]}
= Collector(fs)
// val ch = collector.results // also errors
val fut: Future[T]^{fs*} = collector.results.read().right.get // found ...^{caps.cap}
val fut: Future[T]^{fs*} = collector.results.read().right.get // error

val ch = collector.results
val item = ch.read()
val r = item.right
val fut2: Future[T]^{fs*} = r.get
val fut2: Future[T]^{fs*} = r.get // error
15 changes: 15 additions & 0 deletions tests/neg-custom-args/captures/unsound-reach-7.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import language.experimental.captureChecking
import caps.{cap, use}

trait IO
trait Async

def main(io: IO^, async: Async^) =
def bad[X](ops: List[(X, () ->{io} Unit)])(f: () ->{ops*} Unit): () ->{io} Unit = f // error
def runOps(@use ops: List[(() => Unit, () => Unit)]): () ->{ops*} Unit =
() => ops.foreach((f1, f2) => { f1(); f2() })
def delayOps(@use ops: List[(() ->{async} Unit, () ->{io} Unit)]): () ->{io} Unit =
val runner: () ->{ops*} Unit = runOps(ops)
val badRunner: () ->{io} Unit = bad[() ->{async} Unit](ops)(runner)
// it uses both async and io, but we losed track of async.
badRunner
19 changes: 19 additions & 0 deletions tests/neg-custom-args/captures/use-capset.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
-- Error: tests/neg-custom-args/captures/use-capset.scala:7:50 ---------------------------------------------------------
7 |private def g[C^] = (xs: List[Object^{C^}]) => xs.head // error
| ^^^^^^^
| Capture set parameter C leaks into capture scope of method g.
| To allow this, the type C should be declared with a @use annotation
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/use-capset.scala:13:22 -----------------------------------
13 | val _: () -> Unit = h // error: should be ->{io}
| ^
| Found: (h : () ->{io} Unit)
| Required: () -> Unit
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/use-capset.scala:15:50 -----------------------------------
15 | val _: () -> List[Object^{io}] -> Object^{io} = h2 // error, should be ->{io}
| ^^
| Found: () ->? (x$0: List[box Object^{io}]^{}) ->{io} (ex$13: caps.Exists) -> Object^{io}
| Required: () -> List[box Object^{io}] -> Object^{io}
|
| longer explanation available when compiling with `-explain`
16 changes: 16 additions & 0 deletions tests/neg-custom-args/captures/use-capset.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import caps.{use, CapSet}



def f[C^](@use xs: List[Object^{C^}]): Unit = ???

private def g[C^] = (xs: List[Object^{C^}]) => xs.head // error

private def g2[@use C^] = (xs: List[Object^{C^}]) => xs.head // ok

def test(io: Object^)(@use xs: List[Object^{io}]): Unit =
val h = () => f(xs)
val _: () -> Unit = h // error: should be ->{io}
val h2 = () => g[CapSet^{io}]
val _: () -> List[Object^{io}] -> Object^{io} = h2 // error, should be ->{io}

29 changes: 29 additions & 0 deletions tests/pos-custom-args/captures/gears-problem-poly.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import language.experimental.captureChecking
import caps.{use, CapSet}

trait Future[+T]:
def await: T

trait Channel[+T]:
def read(): Ok[T]

class Collector[T, C^](val futures: Seq[Future[T]^{C^}]):
val results: Channel[Future[T]^{C^}] = ???
end Collector

class Result[+T, +E]:
def get: T = ???

case class Err[+E](e: E) extends Result[Nothing, E]
case class Ok[+T](x: T) extends Result[T, Nothing]

extension [T, C^](@use fs: Seq[Future[T]^{C^}])
def awaitAllPoly =
val collector = Collector(fs)
val fut: Future[T]^{C^} = collector.results.read().get

extension [T](@use fs: Seq[Future[T]^])
def awaitAll = fs.awaitAllPoly

def awaitExplicit[T](@use fs: Seq[Future[T]^]): Unit =
awaitAllPoly[T, CapSet^{fs*}](fs)

0 comments on commit ac06cb5

Please sign in to comment.