Skip to content

Commit

Permalink
Fix scala#21619: Refactor NotNullInfo to record every reference which…
Browse files Browse the repository at this point in the history
… is retracted once. (scala#21624)

This PR improves the flow typing for returning and exceptions.

The `NotNullInfo` is defined as following now:

```scala
case class NotNullInfo(asserted: Set[TermRef] | Null, retracted: Set[TermRef]):
```

* `retracted` contains variable references that are ever assigned to
null;
* if `asserted` is not `null`, it contains `val` or `var` references
that are known to be not null, after the tree finishes executing
normally (non-exceptionally);
* if `asserted` is `null`, the tree is know to terminate, by throwing,
returning, or calling a function with `Nothing` type. Hence, it acts
like a universal set.

`alt` is defined as `<a1,r1>.alt(<a2,r2>) = <a1 intersect a2, r1 union
r2>`.

The difficult part is the `try ... catch ... finally ...`. We don't know
at which point an exception is thrown in the body, and the catch cases
may be not exhaustive, we have to collect any reference that is once
retracted.

Fix scala#21619
  • Loading branch information
noti0na1 authored Dec 10, 2024
2 parents e6b4222 + 200c038 commit ee0dd7a
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 78 deletions.
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -777,13 +777,13 @@ object Contexts {

extension (c: Context)
def addNotNullInfo(info: NotNullInfo) =
c.withNotNullInfos(c.notNullInfos.extendWith(info))
if c.explicitNulls then c.withNotNullInfos(c.notNullInfos.extendWith(info)) else c

def addNotNullRefs(refs: Set[TermRef]) =
c.addNotNullInfo(NotNullInfo(refs, Set()))
if c.explicitNulls then c.addNotNullInfo(NotNullInfo(refs, Set())) else c

def withNotNullInfos(infos: List[NotNullInfo]): Context =
if c.notNullInfos eq infos then c else c.fresh.setNotNullInfos(infos)
if !c.explicitNulls || (c.notNullInfos eq infos) then c else c.fresh.setNotNullInfos(infos)

def relaxedOverrideContext: Context =
c.withModeBits(c.mode &~ Mode.SafeNulls | Mode.RelaxedOverriding)
Expand Down
7 changes: 6 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,7 @@ trait Applications extends Compatibility {
case _ => ()
else ()

fun1.tpe match {
val result = fun1.tpe match {
case err: ErrorType => cpy.Apply(tree)(fun1, proto.typedArgs()).withType(err)
case TryDynamicCallType =>
val isInsertedApply = fun1 match {
Expand Down Expand Up @@ -1208,6 +1208,11 @@ trait Applications extends Compatibility {
else tryWithImplicitOnQualifier(fun1, proto).getOrElse(fail))
}
}

if result.tpe.isNothingType then
val nnInfo = result.notNullInfo
result.withNotNullInfo(nnInfo.terminatedInfo)
else result
}

/** Convert expression like
Expand Down
116 changes: 76 additions & 40 deletions compiler/src/dotty/tools/dotc/typer/Nullables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,34 +52,46 @@ object Nullables:
val hiTree = if(hiTpe eq hi.typeOpt) hi else TypeTree(hiTpe)
TypeBoundsTree(lo, hiTree, alias)

/** A set of val or var references that are known to be not null, plus a set of
* variable references that are not known (anymore) to be not null
/** A set of val or var references that are known to be not null
* after the tree finishes executing normally (non-exceptionally),
* plus a set of variable references that are ever assigned to null,
* and may therefore be null if execution of the tree is interrupted
* by an exception.
*/
case class NotNullInfo(asserted: Set[TermRef], retracted: Set[TermRef]):
assert((asserted & retracted).isEmpty)

case class NotNullInfo(asserted: Set[TermRef] | Null, retracted: Set[TermRef]):
def isEmpty = this eq NotNullInfo.empty

def retractedInfo = NotNullInfo(Set(), retracted)

def terminatedInfo = NotNullInfo(null, retracted)

/** The sequential combination with another not-null info */
def seq(that: NotNullInfo): NotNullInfo =
if this.isEmpty then that
else if that.isEmpty then this
else NotNullInfo(
this.asserted.union(that.asserted).diff(that.retracted),
this.retracted.union(that.retracted).diff(that.asserted))
else
val newAsserted =
if this.asserted == null || that.asserted == null then null
else this.asserted.diff(that.retracted).union(that.asserted)
val newRetracted = this.retracted.union(that.retracted)
NotNullInfo(newAsserted, newRetracted)

/** The alternative path combination with another not-null info. Used to merge
* the nullability info of the two branches of an if.
* the nullability info of the branches of an if or match.
*/
def alt(that: NotNullInfo): NotNullInfo =
NotNullInfo(this.asserted.intersect(that.asserted), this.retracted.union(that.retracted))
val newAsserted =
if this.asserted == null then that.asserted
else if that.asserted == null then this.asserted
else this.asserted.intersect(that.asserted)
val newRetracted = this.retracted.union(that.retracted)
NotNullInfo(newAsserted, newRetracted)
end NotNullInfo

object NotNullInfo:
val empty = new NotNullInfo(Set(), Set())
def apply(asserted: Set[TermRef], retracted: Set[TermRef]): NotNullInfo =
if asserted.isEmpty && retracted.isEmpty then empty
def apply(asserted: Set[TermRef] | Null, retracted: Set[TermRef]): NotNullInfo =
if asserted != null && asserted.isEmpty && retracted.isEmpty then empty
else new NotNullInfo(asserted, retracted)
end NotNullInfo

Expand Down Expand Up @@ -223,7 +235,7 @@ object Nullables:
*/
@tailrec def impliesNotNull(ref: TermRef): Boolean = infos match
case info :: infos1 =>
if info.asserted.contains(ref) then true
if info.asserted == null || info.asserted.contains(ref) then true
else if info.retracted.contains(ref) then false
else infos1.impliesNotNull(ref)
case _ =>
Expand All @@ -233,16 +245,15 @@ object Nullables:
* or retractions in `info` supersede infos in existing entries of `infos`.
*/
def extendWith(info: NotNullInfo) =
if info.isEmpty
|| info.asserted.forall(infos.impliesNotNull(_))
&& !info.retracted.exists(infos.impliesNotNull(_))
then infos
if info.isEmpty then infos
else info :: infos

/** Retract all references to mutable variables */
def retractMutables(using Context) =
val mutables = infos.foldLeft(Set[TermRef]())((ms, info) =>
ms.union(info.asserted.filter(_.symbol.is(Mutable))))
val mutables = infos.foldLeft(Set[TermRef]()):
(ms, info) => ms.union(
if info.asserted == null then Set.empty
else info.asserted.filter(_.symbol.is(Mutable)))
infos.extendWith(NotNullInfo(Set(), mutables))

end extension
Expand Down Expand Up @@ -304,15 +315,35 @@ object Nullables:
extension (tree: Tree)

/* The `tree` with added nullability attachment */
def withNotNullInfo(info: NotNullInfo): tree.type =
if !info.isEmpty then tree.putAttachment(NNInfo, info)
def withNotNullInfo(info: NotNullInfo)(using Context): tree.type =
if ctx.explicitNulls && !info.isEmpty then tree.putAttachment(NNInfo, info)
tree

/* Collect the nullability info from parts of `tree` */
def collectNotNullInfo(using Context): NotNullInfo = tree match
case Typed(expr, _) =>
expr.notNullInfo
case Apply(fn, args) =>
val argsInfo = args.map(_.notNullInfo)
val fnInfo = fn.notNullInfo
argsInfo.foldLeft(fnInfo)(_ seq _)
case TypeApply(fn, _) =>
fn.notNullInfo
case _ =>
// Other cases are handled specially in typer.
NotNullInfo.empty

/* The nullability info of `tree` */
def notNullInfo(using Context): NotNullInfo =
stripInlined(tree).getAttachment(NNInfo) match
case Some(info) if !ctx.erasedTypes => info
case _ => NotNullInfo.empty
if !ctx.explicitNulls then NotNullInfo.empty
else
val tree1 = stripInlined(tree)
tree1.getAttachment(NNInfo) match
case Some(info) if !ctx.erasedTypes => info
case _ =>
val nnInfo = tree1.collectNotNullInfo
tree1.withNotNullInfo(nnInfo)
nnInfo

/* The nullability info of `tree`, assuming it is a condition that evaluates to `c` */
def notNullInfoIf(c: Boolean)(using Context): NotNullInfo =
Expand Down Expand Up @@ -393,21 +424,23 @@ object Nullables:
end extension

extension (tree: Assign)
def computeAssignNullable()(using Context): tree.type = tree.lhs match
case TrackedRef(ref) =>
val rhstp = tree.rhs.typeOpt
if ctx.explicitNulls && ref.isNullableUnion then
if rhstp.isNullType || rhstp.isNullableUnion then
// If the type of rhs is nullable (`T|Null` or `Null`), then the nullability of the
// lhs variable is no longer trackable. We don't need to check whether the type `T`
// is correct here, as typer will check it.
tree.withNotNullInfo(NotNullInfo(Set(), Set(ref)))
else
// If the initial type is nullable and the assigned value is non-null,
// we add it to the NotNull.
tree.withNotNullInfo(NotNullInfo(Set(ref), Set()))
else tree
case _ => tree
def computeAssignNullable()(using Context): tree.type =
var nnInfo = tree.rhs.notNullInfo
tree.lhs match
case TrackedRef(ref) if ctx.explicitNulls && ref.isNullableUnion =>
nnInfo = nnInfo.seq:
val rhstp = tree.rhs.typeOpt
if rhstp.isNullType || rhstp.isNullableUnion then
// If the type of rhs is nullable (`T|Null` or `Null`), then the nullability of the
// lhs variable is no longer trackable. We don't need to check whether the type `T`
// is correct here, as typer will check it.
NotNullInfo(Set(), Set(ref))
else
// If the initial type is nullable and the assigned value is non-null,
// we add it to the NotNull.
NotNullInfo(Set(ref), Set())
case _ =>
tree.withNotNullInfo(nnInfo)
end extension

private val analyzedOps = Set(nme.EQ, nme.NE, nme.eq, nme.ne, nme.ZAND, nme.ZOR, nme.UNARY_!)
Expand Down Expand Up @@ -515,7 +548,10 @@ object Nullables:
&& assignmentSpans.getOrElse(sym.span.start, Nil).exists(whileSpan.contains(_))
&& ctx.notNullInfos.impliesNotNull(ref)

val retractedVars = ctx.notNullInfos.flatMap(_.asserted.filter(isRetracted)).toSet
val retractedVars = ctx.notNullInfos.flatMap(info =>
if info.asserted == null then Set.empty
else info.asserted.filter(isRetracted)
).toSet
ctx.addNotNullInfo(NotNullInfo(Set(), retractedVars))
end whileContext

Expand Down
61 changes: 36 additions & 25 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
untpd.unsplice(tree.expr).putAttachment(AscribedToUnit, ())
typed(tree.expr, underlyingTreeTpe.tpe.widenSkolem)
assignType(cpy.Typed(tree)(expr1, tpt), underlyingTreeTpe)
.withNotNullInfo(expr1.notNullInfo)
}

if (untpd.isWildcardStarArg(tree)) {
Expand Down Expand Up @@ -1551,11 +1550,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

def thenPathInfo = cond1.notNullInfoIf(true).seq(result.thenp.notNullInfo)
def elsePathInfo = cond1.notNullInfoIf(false).seq(result.elsep.notNullInfo)
result.withNotNullInfo(
if result.thenp.tpe.isRef(defn.NothingClass) then elsePathInfo
else if result.elsep.tpe.isRef(defn.NothingClass) then thenPathInfo
else thenPathInfo.alt(elsePathInfo)
)
result.withNotNullInfo(thenPathInfo.alt(elsePathInfo))
end typedIf

/** Decompose function prototype into a list of parameter prototypes and a result
Expand Down Expand Up @@ -2139,20 +2134,25 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
case1
}
.asInstanceOf[List[CaseDef]]
var nni = sel.notNullInfo
if cases1.nonEmpty then nni = nni.seq(cases1.map(_.notNullInfo).reduce(_.alt(_)))
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).cast(pt).withNotNullInfo(nni)
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).cast(pt)
.withNotNullInfo(notNullInfoFromCases(sel.notNullInfo, cases1))
}

// Overridden in InlineTyper for inline matches
def typedMatchFinish(tree: untpd.Match, sel: Tree, wideSelType: Type, cases: List[untpd.CaseDef], pt: Type)(using Context): Tree = {
val cases1 = harmonic(harmonize, pt)(typedCases(cases, sel, wideSelType, pt.dropIfProto))
.asInstanceOf[List[CaseDef]]
var nni = sel.notNullInfo
if cases1.nonEmpty then nni = nni.seq(cases1.map(_.notNullInfo).reduce(_.alt(_)))
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).withNotNullInfo(nni)
assignType(cpy.Match(tree)(sel, cases1), sel, cases1)
.withNotNullInfo(notNullInfoFromCases(sel.notNullInfo, cases1))
}

private def notNullInfoFromCases(initInfo: NotNullInfo, cases: List[CaseDef])(using Context): NotNullInfo =
if cases.isEmpty then
// Empty cases is not allowed for match tree in the source code,
// but it can be generated by inlining: `tests/pos/i19198.scala`.
initInfo
else cases.map(_.notNullInfo).reduce(_.alt(_))

def typedCases(cases: List[untpd.CaseDef], sel: Tree, wideSelType0: Type, pt: Type)(using Context): List[CaseDef] =
var caseCtx = ctx
var wideSelType = wideSelType0
Expand Down Expand Up @@ -2241,7 +2241,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedLabeled(tree: untpd.Labeled)(using Context): Labeled = {
val bind1 = typedBind(tree.bind, WildcardType).asInstanceOf[Bind]
val expr1 = typed(tree.expr, bind1.symbol.info)
assignType(cpy.Labeled(tree)(bind1, expr1))
assignType(cpy.Labeled(tree)(bind1, expr1)).withNotNullInfo(expr1.notNullInfo.retractedInfo)
}

/** Type a case of a type match */
Expand Down Expand Up @@ -2291,7 +2291,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
// Hence no adaptation is possible, and we assume WildcardType as prototype.
(from, proto)
val expr1 = typedExpr(tree.expr orElse untpd.syntheticUnitLiteral.withSpan(tree.span), proto)
assignType(cpy.Return(tree)(expr1, from))
assignType(cpy.Return(tree)(expr1, from)).withNotNullInfo(expr1.notNullInfo.terminatedInfo)
end typedReturn

def typedWhileDo(tree: untpd.WhileDo)(using Context): Tree =
Expand Down Expand Up @@ -2332,7 +2332,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val capabilityProof = caughtExceptions.reduce(OrType(_, _, true))
untpd.Block(makeCanThrow(capabilityProof), expr)

def typedTry(tree: untpd.Try, pt: Type)(using Context): Try = {
def typedTry(tree: untpd.Try, pt: Type)(using Context): Try =
var nnInfo = NotNullInfo.empty
val expr2 :: cases2x = harmonic(harmonize, pt) {
// We want to type check tree.expr first to comput NotNullInfo, but `addCanThrowCapabilities`
// uses the types of patterns in `tree.cases` to determine the capabilities.
Expand All @@ -2344,18 +2345,26 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val casesEmptyBody1 = tree.cases.mapconserve(cpy.CaseDef(_)(body = EmptyTree))
val casesEmptyBody2 = typedCases(casesEmptyBody1, EmptyTree, defn.ThrowableType, WildcardType)
val expr1 = typed(addCanThrowCapabilities(tree.expr, casesEmptyBody2), pt.dropIfProto)
val casesCtx = ctx.addNotNullInfo(expr1.notNullInfo.retractedInfo)

// Since we don't know at which point the the exception is thrown in the body,
// we have to collect any reference that is once retracted.
nnInfo = expr1.notNullInfo.retractedInfo

val casesCtx = ctx.addNotNullInfo(nnInfo)
val cases1 = typedCases(tree.cases, EmptyTree, defn.ThrowableType, pt.dropIfProto)(using casesCtx)
expr1 :: cases1
}: @unchecked
val cases2 = cases2x.asInstanceOf[List[CaseDef]]

var nni = expr2.notNullInfo.retractedInfo
if cases2.nonEmpty then nni = nni.seq(cases2.map(_.notNullInfo.retractedInfo).reduce(_.alt(_)))
val finalizer1 = typed(tree.finalizer, defn.UnitType)(using ctx.addNotNullInfo(nni))
nni = nni.seq(finalizer1.notNullInfo)
assignType(cpy.Try(tree)(expr2, cases2, finalizer1), expr2, cases2).withNotNullInfo(nni)
}
// It is possible to have non-exhaustive cases, and some exceptions are thrown and not caught.
// Therefore, the code in the finalizer and after the try block can only rely on the retracted
// info from the cases' body.
if cases2.nonEmpty then
nnInfo = nnInfo.seq(cases2.map(_.notNullInfo.retractedInfo).reduce(_.alt(_)))

val finalizer1 = typed(tree.finalizer, defn.UnitType)(using ctx.addNotNullInfo(nnInfo))
nnInfo = nnInfo.seq(finalizer1.notNullInfo)
assignType(cpy.Try(tree)(expr2, cases2, finalizer1), expr2, cases2).withNotNullInfo(nnInfo)

def typedTry(tree: untpd.ParsedTry, pt: Type)(using Context): Try =
val cases: List[untpd.CaseDef] = tree.handler match
Expand All @@ -2369,15 +2378,15 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedThrow(tree: untpd.Throw)(using Context): Tree =
val expr1 = typed(tree.expr, defn.ThrowableType)
val cap = checkCanThrow(expr1.tpe.widen, tree.span)
val res = Throw(expr1).withSpan(tree.span)
var res = Throw(expr1).withSpan(tree.span)
if Feature.ccEnabled && !cap.isEmpty && !ctx.isAfterTyper then
// Record access to the CanThrow capabulity recovered in `cap` by wrapping
// the type of the `throw` (i.e. Nothing) in a `@requiresCapability` annotation.
Typed(res,
res = Typed(res,
TypeTree(
AnnotatedType(res.tpe,
Annotation(defn.RequiresCapabilityAnnot, cap, tree.span))))
else res
res.withNotNullInfo(expr1.notNullInfo.terminatedInfo)

def typedSeqLiteral(tree: untpd.SeqLiteral, pt: Type)(using Context): SeqLiteral = {
val elemProto = pt.stripNull().elemType match {
Expand Down Expand Up @@ -2842,6 +2851,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val vdef1 = assignType(cpy.ValDef(vdef)(name, tpt1, rhs1), sym)
postProcessInfo(vdef1, sym)
vdef1.setDefTree
val nnInfo = rhs1.notNullInfo
vdef1.withNotNullInfo(if sym.is(Lazy) then nnInfo.retractedInfo else nnInfo)
}

private def retractDefDef(sym: Symbol)(using Context): Tree =
Expand Down
18 changes: 18 additions & 0 deletions tests/explicit-nulls/neg/i21380b.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,22 @@ def test3(i: Int) =
i match
case 1 if x != null => ()
case _ => x = " "
x.trim() // ok

def test4(i: Int) =
var x: String | Null = null
var y: String | Null = null
i match
case 1 => x = "1"
case _ => y = " "
x.trim() // error

def test5(i: Int): String =
var x: String | Null = null
var y: String | Null = null
i match
case 1 => x = "1"
case _ =>
y = " "
return y
x.trim() // ok
6 changes: 3 additions & 3 deletions tests/explicit-nulls/neg/i21380c.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def test4: Int =
case npe: NullPointerException => x = ""
case _ => x = ""
x.length // error
// Although the catch block here is exhaustive,
// it is possible that the exception is thrown and not caught.
// Therefore, the code after the try block can only rely on the retracted info.
// Although the catch block here is exhaustive, it is possible to have non-exhaustive cases,
// and some exceptions are thrown and not caught. Therefore, the code in the finalizer and
// after the try block can only rely on the retracted info from the cases' body.

def test5: Int =
var x: String | Null = null
Expand Down
Loading

0 comments on commit ee0dd7a

Please sign in to comment.