Skip to content

Commit

Permalink
Fix #21619: Refactor NotNullInfo to record every reference which is r…
Browse files Browse the repository at this point in the history
…etracted once. (#21624)

This PR improves the flow typing for returning and exceptions.

The `NotNullInfo` is defined as following now:

```scala
case class NotNullInfo(asserted: Set[TermRef] | Null, retracted: Set[TermRef]):
```

* `retracted` contains variable references that are ever assigned to
null;
* if `asserted` is not `null`, it contains `val` or `var` references
that are known to be not null, after the tree finishes executing
normally (non-exceptionally);
* if `asserted` is `null`, the tree is know to terminate, by throwing,
returning, or calling a function with `Nothing` type. Hence, it acts
like a universal set.

`alt` is defined as `<a1,r1>.alt(<a2,r2>) = <a1 intersect a2, r1 union
r2>`.

The difficult part is the `try ... catch ... finally ...`. We don't know
at which point an exception is thrown in the body, and the catch cases
may be not exhaustive, we have to collect any reference that is once
retracted.

Fix #21619
  • Loading branch information
noti0na1 authored Dec 10, 2024
2 parents e6b4222 + 200c038 commit ee0dd7a
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 78 deletions.
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -777,13 +777,13 @@ object Contexts {

extension (c: Context)
def addNotNullInfo(info: NotNullInfo) =
c.withNotNullInfos(c.notNullInfos.extendWith(info))
if c.explicitNulls then c.withNotNullInfos(c.notNullInfos.extendWith(info)) else c

def addNotNullRefs(refs: Set[TermRef]) =
c.addNotNullInfo(NotNullInfo(refs, Set()))
if c.explicitNulls then c.addNotNullInfo(NotNullInfo(refs, Set())) else c

def withNotNullInfos(infos: List[NotNullInfo]): Context =
if c.notNullInfos eq infos then c else c.fresh.setNotNullInfos(infos)
if !c.explicitNulls || (c.notNullInfos eq infos) then c else c.fresh.setNotNullInfos(infos)

def relaxedOverrideContext: Context =
c.withModeBits(c.mode &~ Mode.SafeNulls | Mode.RelaxedOverriding)
Expand Down
7 changes: 6 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,7 @@ trait Applications extends Compatibility {
case _ => ()
else ()

fun1.tpe match {
val result = fun1.tpe match {
case err: ErrorType => cpy.Apply(tree)(fun1, proto.typedArgs()).withType(err)
case TryDynamicCallType =>
val isInsertedApply = fun1 match {
Expand Down Expand Up @@ -1208,6 +1208,11 @@ trait Applications extends Compatibility {
else tryWithImplicitOnQualifier(fun1, proto).getOrElse(fail))
}
}

if result.tpe.isNothingType then
val nnInfo = result.notNullInfo
result.withNotNullInfo(nnInfo.terminatedInfo)
else result
}

/** Convert expression like
Expand Down
116 changes: 76 additions & 40 deletions compiler/src/dotty/tools/dotc/typer/Nullables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,34 +52,46 @@ object Nullables:
val hiTree = if(hiTpe eq hi.typeOpt) hi else TypeTree(hiTpe)
TypeBoundsTree(lo, hiTree, alias)

/** A set of val or var references that are known to be not null, plus a set of
* variable references that are not known (anymore) to be not null
/** A set of val or var references that are known to be not null
* after the tree finishes executing normally (non-exceptionally),
* plus a set of variable references that are ever assigned to null,
* and may therefore be null if execution of the tree is interrupted
* by an exception.
*/
case class NotNullInfo(asserted: Set[TermRef], retracted: Set[TermRef]):
assert((asserted & retracted).isEmpty)

case class NotNullInfo(asserted: Set[TermRef] | Null, retracted: Set[TermRef]):
def isEmpty = this eq NotNullInfo.empty

def retractedInfo = NotNullInfo(Set(), retracted)

def terminatedInfo = NotNullInfo(null, retracted)

/** The sequential combination with another not-null info */
def seq(that: NotNullInfo): NotNullInfo =
if this.isEmpty then that
else if that.isEmpty then this
else NotNullInfo(
this.asserted.union(that.asserted).diff(that.retracted),
this.retracted.union(that.retracted).diff(that.asserted))
else
val newAsserted =
if this.asserted == null || that.asserted == null then null
else this.asserted.diff(that.retracted).union(that.asserted)
val newRetracted = this.retracted.union(that.retracted)
NotNullInfo(newAsserted, newRetracted)

/** The alternative path combination with another not-null info. Used to merge
* the nullability info of the two branches of an if.
* the nullability info of the branches of an if or match.
*/
def alt(that: NotNullInfo): NotNullInfo =
NotNullInfo(this.asserted.intersect(that.asserted), this.retracted.union(that.retracted))
val newAsserted =
if this.asserted == null then that.asserted
else if that.asserted == null then this.asserted
else this.asserted.intersect(that.asserted)
val newRetracted = this.retracted.union(that.retracted)
NotNullInfo(newAsserted, newRetracted)
end NotNullInfo

object NotNullInfo:
val empty = new NotNullInfo(Set(), Set())
def apply(asserted: Set[TermRef], retracted: Set[TermRef]): NotNullInfo =
if asserted.isEmpty && retracted.isEmpty then empty
def apply(asserted: Set[TermRef] | Null, retracted: Set[TermRef]): NotNullInfo =
if asserted != null && asserted.isEmpty && retracted.isEmpty then empty
else new NotNullInfo(asserted, retracted)
end NotNullInfo

Expand Down Expand Up @@ -223,7 +235,7 @@ object Nullables:
*/
@tailrec def impliesNotNull(ref: TermRef): Boolean = infos match
case info :: infos1 =>
if info.asserted.contains(ref) then true
if info.asserted == null || info.asserted.contains(ref) then true
else if info.retracted.contains(ref) then false
else infos1.impliesNotNull(ref)
case _ =>
Expand All @@ -233,16 +245,15 @@ object Nullables:
* or retractions in `info` supersede infos in existing entries of `infos`.
*/
def extendWith(info: NotNullInfo) =
if info.isEmpty
|| info.asserted.forall(infos.impliesNotNull(_))
&& !info.retracted.exists(infos.impliesNotNull(_))
then infos
if info.isEmpty then infos
else info :: infos

/** Retract all references to mutable variables */
def retractMutables(using Context) =
val mutables = infos.foldLeft(Set[TermRef]())((ms, info) =>
ms.union(info.asserted.filter(_.symbol.is(Mutable))))
val mutables = infos.foldLeft(Set[TermRef]()):
(ms, info) => ms.union(
if info.asserted == null then Set.empty
else info.asserted.filter(_.symbol.is(Mutable)))
infos.extendWith(NotNullInfo(Set(), mutables))

end extension
Expand Down Expand Up @@ -304,15 +315,35 @@ object Nullables:
extension (tree: Tree)

/* The `tree` with added nullability attachment */
def withNotNullInfo(info: NotNullInfo): tree.type =
if !info.isEmpty then tree.putAttachment(NNInfo, info)
def withNotNullInfo(info: NotNullInfo)(using Context): tree.type =
if ctx.explicitNulls && !info.isEmpty then tree.putAttachment(NNInfo, info)
tree

/* Collect the nullability info from parts of `tree` */
def collectNotNullInfo(using Context): NotNullInfo = tree match
case Typed(expr, _) =>
expr.notNullInfo
case Apply(fn, args) =>
val argsInfo = args.map(_.notNullInfo)
val fnInfo = fn.notNullInfo
argsInfo.foldLeft(fnInfo)(_ seq _)
case TypeApply(fn, _) =>
fn.notNullInfo
case _ =>
// Other cases are handled specially in typer.
NotNullInfo.empty

/* The nullability info of `tree` */
def notNullInfo(using Context): NotNullInfo =
stripInlined(tree).getAttachment(NNInfo) match
case Some(info) if !ctx.erasedTypes => info
case _ => NotNullInfo.empty
if !ctx.explicitNulls then NotNullInfo.empty
else
val tree1 = stripInlined(tree)
tree1.getAttachment(NNInfo) match
case Some(info) if !ctx.erasedTypes => info
case _ =>
val nnInfo = tree1.collectNotNullInfo
tree1.withNotNullInfo(nnInfo)
nnInfo

/* The nullability info of `tree`, assuming it is a condition that evaluates to `c` */
def notNullInfoIf(c: Boolean)(using Context): NotNullInfo =
Expand Down Expand Up @@ -393,21 +424,23 @@ object Nullables:
end extension

extension (tree: Assign)
def computeAssignNullable()(using Context): tree.type = tree.lhs match
case TrackedRef(ref) =>
val rhstp = tree.rhs.typeOpt
if ctx.explicitNulls && ref.isNullableUnion then
if rhstp.isNullType || rhstp.isNullableUnion then
// If the type of rhs is nullable (`T|Null` or `Null`), then the nullability of the
// lhs variable is no longer trackable. We don't need to check whether the type `T`
// is correct here, as typer will check it.
tree.withNotNullInfo(NotNullInfo(Set(), Set(ref)))
else
// If the initial type is nullable and the assigned value is non-null,
// we add it to the NotNull.
tree.withNotNullInfo(NotNullInfo(Set(ref), Set()))
else tree
case _ => tree
def computeAssignNullable()(using Context): tree.type =
var nnInfo = tree.rhs.notNullInfo
tree.lhs match
case TrackedRef(ref) if ctx.explicitNulls && ref.isNullableUnion =>
nnInfo = nnInfo.seq:
val rhstp = tree.rhs.typeOpt
if rhstp.isNullType || rhstp.isNullableUnion then
// If the type of rhs is nullable (`T|Null` or `Null`), then the nullability of the
// lhs variable is no longer trackable. We don't need to check whether the type `T`
// is correct here, as typer will check it.
NotNullInfo(Set(), Set(ref))
else
// If the initial type is nullable and the assigned value is non-null,
// we add it to the NotNull.
NotNullInfo(Set(ref), Set())
case _ =>
tree.withNotNullInfo(nnInfo)
end extension

private val analyzedOps = Set(nme.EQ, nme.NE, nme.eq, nme.ne, nme.ZAND, nme.ZOR, nme.UNARY_!)
Expand Down Expand Up @@ -515,7 +548,10 @@ object Nullables:
&& assignmentSpans.getOrElse(sym.span.start, Nil).exists(whileSpan.contains(_))
&& ctx.notNullInfos.impliesNotNull(ref)

val retractedVars = ctx.notNullInfos.flatMap(_.asserted.filter(isRetracted)).toSet
val retractedVars = ctx.notNullInfos.flatMap(info =>
if info.asserted == null then Set.empty
else info.asserted.filter(isRetracted)
).toSet
ctx.addNotNullInfo(NotNullInfo(Set(), retractedVars))
end whileContext

Expand Down
61 changes: 36 additions & 25 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
untpd.unsplice(tree.expr).putAttachment(AscribedToUnit, ())
typed(tree.expr, underlyingTreeTpe.tpe.widenSkolem)
assignType(cpy.Typed(tree)(expr1, tpt), underlyingTreeTpe)
.withNotNullInfo(expr1.notNullInfo)
}

if (untpd.isWildcardStarArg(tree)) {
Expand Down Expand Up @@ -1551,11 +1550,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

def thenPathInfo = cond1.notNullInfoIf(true).seq(result.thenp.notNullInfo)
def elsePathInfo = cond1.notNullInfoIf(false).seq(result.elsep.notNullInfo)
result.withNotNullInfo(
if result.thenp.tpe.isRef(defn.NothingClass) then elsePathInfo
else if result.elsep.tpe.isRef(defn.NothingClass) then thenPathInfo
else thenPathInfo.alt(elsePathInfo)
)
result.withNotNullInfo(thenPathInfo.alt(elsePathInfo))
end typedIf

/** Decompose function prototype into a list of parameter prototypes and a result
Expand Down Expand Up @@ -2139,20 +2134,25 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
case1
}
.asInstanceOf[List[CaseDef]]
var nni = sel.notNullInfo
if cases1.nonEmpty then nni = nni.seq(cases1.map(_.notNullInfo).reduce(_.alt(_)))
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).cast(pt).withNotNullInfo(nni)
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).cast(pt)
.withNotNullInfo(notNullInfoFromCases(sel.notNullInfo, cases1))
}

// Overridden in InlineTyper for inline matches
def typedMatchFinish(tree: untpd.Match, sel: Tree, wideSelType: Type, cases: List[untpd.CaseDef], pt: Type)(using Context): Tree = {
val cases1 = harmonic(harmonize, pt)(typedCases(cases, sel, wideSelType, pt.dropIfProto))
.asInstanceOf[List[CaseDef]]
var nni = sel.notNullInfo
if cases1.nonEmpty then nni = nni.seq(cases1.map(_.notNullInfo).reduce(_.alt(_)))
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).withNotNullInfo(nni)
assignType(cpy.Match(tree)(sel, cases1), sel, cases1)
.withNotNullInfo(notNullInfoFromCases(sel.notNullInfo, cases1))
}

private def notNullInfoFromCases(initInfo: NotNullInfo, cases: List[CaseDef])(using Context): NotNullInfo =
if cases.isEmpty then
// Empty cases is not allowed for match tree in the source code,
// but it can be generated by inlining: `tests/pos/i19198.scala`.
initInfo
else cases.map(_.notNullInfo).reduce(_.alt(_))

def typedCases(cases: List[untpd.CaseDef], sel: Tree, wideSelType0: Type, pt: Type)(using Context): List[CaseDef] =
var caseCtx = ctx
var wideSelType = wideSelType0
Expand Down Expand Up @@ -2241,7 +2241,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedLabeled(tree: untpd.Labeled)(using Context): Labeled = {
val bind1 = typedBind(tree.bind, WildcardType).asInstanceOf[Bind]
val expr1 = typed(tree.expr, bind1.symbol.info)
assignType(cpy.Labeled(tree)(bind1, expr1))
assignType(cpy.Labeled(tree)(bind1, expr1)).withNotNullInfo(expr1.notNullInfo.retractedInfo)
}

/** Type a case of a type match */
Expand Down Expand Up @@ -2291,7 +2291,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
// Hence no adaptation is possible, and we assume WildcardType as prototype.
(from, proto)
val expr1 = typedExpr(tree.expr orElse untpd.syntheticUnitLiteral.withSpan(tree.span), proto)
assignType(cpy.Return(tree)(expr1, from))
assignType(cpy.Return(tree)(expr1, from)).withNotNullInfo(expr1.notNullInfo.terminatedInfo)
end typedReturn

def typedWhileDo(tree: untpd.WhileDo)(using Context): Tree =
Expand Down Expand Up @@ -2332,7 +2332,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val capabilityProof = caughtExceptions.reduce(OrType(_, _, true))
untpd.Block(makeCanThrow(capabilityProof), expr)

def typedTry(tree: untpd.Try, pt: Type)(using Context): Try = {
def typedTry(tree: untpd.Try, pt: Type)(using Context): Try =
var nnInfo = NotNullInfo.empty
val expr2 :: cases2x = harmonic(harmonize, pt) {
// We want to type check tree.expr first to comput NotNullInfo, but `addCanThrowCapabilities`
// uses the types of patterns in `tree.cases` to determine the capabilities.
Expand All @@ -2344,18 +2345,26 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val casesEmptyBody1 = tree.cases.mapconserve(cpy.CaseDef(_)(body = EmptyTree))
val casesEmptyBody2 = typedCases(casesEmptyBody1, EmptyTree, defn.ThrowableType, WildcardType)
val expr1 = typed(addCanThrowCapabilities(tree.expr, casesEmptyBody2), pt.dropIfProto)
val casesCtx = ctx.addNotNullInfo(expr1.notNullInfo.retractedInfo)

// Since we don't know at which point the the exception is thrown in the body,
// we have to collect any reference that is once retracted.
nnInfo = expr1.notNullInfo.retractedInfo

val casesCtx = ctx.addNotNullInfo(nnInfo)
val cases1 = typedCases(tree.cases, EmptyTree, defn.ThrowableType, pt.dropIfProto)(using casesCtx)
expr1 :: cases1
}: @unchecked
val cases2 = cases2x.asInstanceOf[List[CaseDef]]

var nni = expr2.notNullInfo.retractedInfo
if cases2.nonEmpty then nni = nni.seq(cases2.map(_.notNullInfo.retractedInfo).reduce(_.alt(_)))
val finalizer1 = typed(tree.finalizer, defn.UnitType)(using ctx.addNotNullInfo(nni))
nni = nni.seq(finalizer1.notNullInfo)
assignType(cpy.Try(tree)(expr2, cases2, finalizer1), expr2, cases2).withNotNullInfo(nni)
}
// It is possible to have non-exhaustive cases, and some exceptions are thrown and not caught.
// Therefore, the code in the finalizer and after the try block can only rely on the retracted
// info from the cases' body.
if cases2.nonEmpty then
nnInfo = nnInfo.seq(cases2.map(_.notNullInfo.retractedInfo).reduce(_.alt(_)))

val finalizer1 = typed(tree.finalizer, defn.UnitType)(using ctx.addNotNullInfo(nnInfo))
nnInfo = nnInfo.seq(finalizer1.notNullInfo)
assignType(cpy.Try(tree)(expr2, cases2, finalizer1), expr2, cases2).withNotNullInfo(nnInfo)

def typedTry(tree: untpd.ParsedTry, pt: Type)(using Context): Try =
val cases: List[untpd.CaseDef] = tree.handler match
Expand All @@ -2369,15 +2378,15 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedThrow(tree: untpd.Throw)(using Context): Tree =
val expr1 = typed(tree.expr, defn.ThrowableType)
val cap = checkCanThrow(expr1.tpe.widen, tree.span)
val res = Throw(expr1).withSpan(tree.span)
var res = Throw(expr1).withSpan(tree.span)
if Feature.ccEnabled && !cap.isEmpty && !ctx.isAfterTyper then
// Record access to the CanThrow capabulity recovered in `cap` by wrapping
// the type of the `throw` (i.e. Nothing) in a `@requiresCapability` annotation.
Typed(res,
res = Typed(res,
TypeTree(
AnnotatedType(res.tpe,
Annotation(defn.RequiresCapabilityAnnot, cap, tree.span))))
else res
res.withNotNullInfo(expr1.notNullInfo.terminatedInfo)

def typedSeqLiteral(tree: untpd.SeqLiteral, pt: Type)(using Context): SeqLiteral = {
val elemProto = pt.stripNull().elemType match {
Expand Down Expand Up @@ -2842,6 +2851,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val vdef1 = assignType(cpy.ValDef(vdef)(name, tpt1, rhs1), sym)
postProcessInfo(vdef1, sym)
vdef1.setDefTree
val nnInfo = rhs1.notNullInfo
vdef1.withNotNullInfo(if sym.is(Lazy) then nnInfo.retractedInfo else nnInfo)
}

private def retractDefDef(sym: Symbol)(using Context): Tree =
Expand Down
18 changes: 18 additions & 0 deletions tests/explicit-nulls/neg/i21380b.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,22 @@ def test3(i: Int) =
i match
case 1 if x != null => ()
case _ => x = " "
x.trim() // ok

def test4(i: Int) =
var x: String | Null = null
var y: String | Null = null
i match
case 1 => x = "1"
case _ => y = " "
x.trim() // error

def test5(i: Int): String =
var x: String | Null = null
var y: String | Null = null
i match
case 1 => x = "1"
case _ =>
y = " "
return y
x.trim() // ok
6 changes: 3 additions & 3 deletions tests/explicit-nulls/neg/i21380c.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def test4: Int =
case npe: NullPointerException => x = ""
case _ => x = ""
x.length // error
// Although the catch block here is exhaustive,
// it is possible that the exception is thrown and not caught.
// Therefore, the code after the try block can only rely on the retracted info.
// Although the catch block here is exhaustive, it is possible to have non-exhaustive cases,
// and some exceptions are thrown and not caught. Therefore, the code in the finalizer and
// after the try block can only rely on the retracted info from the cases' body.

def test5: Int =
var x: String | Null = null
Expand Down
Loading

0 comments on commit ee0dd7a

Please sign in to comment.