Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #21619: Refactor NotNullInfo to record every reference which is retracted once. #21624

Merged
merged 8 commits into from
Dec 10, 2024
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -777,13 +777,13 @@ object Contexts {

extension (c: Context)
def addNotNullInfo(info: NotNullInfo) =
c.withNotNullInfos(c.notNullInfos.extendWith(info))
if c.explicitNulls then c.withNotNullInfos(c.notNullInfos.extendWith(info)) else c

def addNotNullRefs(refs: Set[TermRef]) =
c.addNotNullInfo(NotNullInfo(refs, Set()))
if c.explicitNulls then c.addNotNullInfo(NotNullInfo(refs, Set())) else c

def withNotNullInfos(infos: List[NotNullInfo]): Context =
if c.notNullInfos eq infos then c else c.fresh.setNotNullInfos(infos)
if !c.explicitNulls || (c.notNullInfos eq infos) then c else c.fresh.setNotNullInfos(infos)

def relaxedOverrideContext: Context =
c.withModeBits(c.mode &~ Mode.SafeNulls | Mode.RelaxedOverriding)
Expand Down
7 changes: 6 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,7 @@ trait Applications extends Compatibility {
case _ => ()
else ()

fun1.tpe match {
val result = fun1.tpe match {
case err: ErrorType => cpy.Apply(tree)(fun1, proto.typedArgs()).withType(err)
case TryDynamicCallType =>
val isInsertedApply = fun1 match {
Expand Down Expand Up @@ -1208,6 +1208,11 @@ trait Applications extends Compatibility {
else tryWithImplicitOnQualifier(fun1, proto).getOrElse(fail))
}
}

if result.tpe.isNothingType then
val nnInfo = result.notNullInfo
result.withNotNullInfo(nnInfo.terminatedInfo)
else result
}

/** Convert expression like
Expand Down
116 changes: 76 additions & 40 deletions compiler/src/dotty/tools/dotc/typer/Nullables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,34 +52,46 @@ object Nullables:
val hiTree = if(hiTpe eq hi.typeOpt) hi else TypeTree(hiTpe)
TypeBoundsTree(lo, hiTree, alias)

/** A set of val or var references that are known to be not null, plus a set of
* variable references that are not known (anymore) to be not null
/** A set of val or var references that are known to be not null
* after the tree finishes executing normally (non-exceptionally),
* plus a set of variable references that are ever assigned to null,
* and may therefore be null if execution of the tree is interrupted
* by an exception.
*/
case class NotNullInfo(asserted: Set[TermRef], retracted: Set[TermRef]):
assert((asserted & retracted).isEmpty)

case class NotNullInfo(asserted: Set[TermRef] | Null, retracted: Set[TermRef]):
def isEmpty = this eq NotNullInfo.empty

def retractedInfo = NotNullInfo(Set(), retracted)

def terminatedInfo = NotNullInfo(null, retracted)

/** The sequential combination with another not-null info */
def seq(that: NotNullInfo): NotNullInfo =
if this.isEmpty then that
else if that.isEmpty then this
else NotNullInfo(
this.asserted.union(that.asserted).diff(that.retracted),
this.retracted.union(that.retracted).diff(that.asserted))
else
val newAsserted =
if this.asserted == null || that.asserted == null then null
else this.asserted.diff(that.retracted).union(that.asserted)
val newRetracted = this.retracted.union(that.retracted)
NotNullInfo(newAsserted, newRetracted)

/** The alternative path combination with another not-null info. Used to merge
* the nullability info of the two branches of an if.
* the nullability info of the branches of an if or match.
*/
def alt(that: NotNullInfo): NotNullInfo =
NotNullInfo(this.asserted.intersect(that.asserted), this.retracted.union(that.retracted))
val newAsserted =
if this.asserted == null then that.asserted
else if that.asserted == null then this.asserted
else this.asserted.intersect(that.asserted)
val newRetracted = this.retracted.union(that.retracted)
NotNullInfo(newAsserted, newRetracted)
end NotNullInfo

object NotNullInfo:
val empty = new NotNullInfo(Set(), Set())
def apply(asserted: Set[TermRef], retracted: Set[TermRef]): NotNullInfo =
if asserted.isEmpty && retracted.isEmpty then empty
def apply(asserted: Set[TermRef] | Null, retracted: Set[TermRef]): NotNullInfo =
if asserted != null && asserted.isEmpty && retracted.isEmpty then empty
else new NotNullInfo(asserted, retracted)
end NotNullInfo

Expand Down Expand Up @@ -223,7 +235,7 @@ object Nullables:
*/
@tailrec def impliesNotNull(ref: TermRef): Boolean = infos match
case info :: infos1 =>
if info.asserted.contains(ref) then true
if info.asserted == null || info.asserted.contains(ref) then true
else if info.retracted.contains(ref) then false
else infos1.impliesNotNull(ref)
case _ =>
Expand All @@ -233,16 +245,15 @@ object Nullables:
* or retractions in `info` supersede infos in existing entries of `infos`.
*/
def extendWith(info: NotNullInfo) =
if info.isEmpty
|| info.asserted.forall(infos.impliesNotNull(_))
&& !info.retracted.exists(infos.impliesNotNull(_))
then infos
if info.isEmpty then infos
else info :: infos

/** Retract all references to mutable variables */
def retractMutables(using Context) =
val mutables = infos.foldLeft(Set[TermRef]())((ms, info) =>
ms.union(info.asserted.filter(_.symbol.is(Mutable))))
val mutables = infos.foldLeft(Set[TermRef]()):
(ms, info) => ms.union(
if info.asserted == null then Set.empty
else info.asserted.filter(_.symbol.is(Mutable)))
infos.extendWith(NotNullInfo(Set(), mutables))

end extension
Expand Down Expand Up @@ -304,15 +315,35 @@ object Nullables:
extension (tree: Tree)

/* The `tree` with added nullability attachment */
def withNotNullInfo(info: NotNullInfo): tree.type =
if !info.isEmpty then tree.putAttachment(NNInfo, info)
def withNotNullInfo(info: NotNullInfo)(using Context): tree.type =
if ctx.explicitNulls && !info.isEmpty then tree.putAttachment(NNInfo, info)
tree

/* Collect the nullability info from parts of `tree` */
def collectNotNullInfo(using Context): NotNullInfo = tree match
case Typed(expr, _) =>
expr.notNullInfo
case Apply(fn, args) =>
val argsInfo = args.map(_.notNullInfo)
val fnInfo = fn.notNullInfo
argsInfo.foldLeft(fnInfo)(_ seq _)
case TypeApply(fn, _) =>
fn.notNullInfo
case _ =>
// Other cases are handled specially in typer.
NotNullInfo.empty

/* The nullability info of `tree` */
def notNullInfo(using Context): NotNullInfo =
stripInlined(tree).getAttachment(NNInfo) match
case Some(info) if !ctx.erasedTypes => info
case _ => NotNullInfo.empty
if !ctx.explicitNulls then NotNullInfo.empty
else
val tree1 = stripInlined(tree)
tree1.getAttachment(NNInfo) match
case Some(info) if !ctx.erasedTypes => info
case _ =>
val nnInfo = tree1.collectNotNullInfo
tree1.withNotNullInfo(nnInfo)
nnInfo

/* The nullability info of `tree`, assuming it is a condition that evaluates to `c` */
def notNullInfoIf(c: Boolean)(using Context): NotNullInfo =
Expand Down Expand Up @@ -393,21 +424,23 @@ object Nullables:
end extension

extension (tree: Assign)
def computeAssignNullable()(using Context): tree.type = tree.lhs match
case TrackedRef(ref) =>
val rhstp = tree.rhs.typeOpt
if ctx.explicitNulls && ref.isNullableUnion then
if rhstp.isNullType || rhstp.isNullableUnion then
// If the type of rhs is nullable (`T|Null` or `Null`), then the nullability of the
// lhs variable is no longer trackable. We don't need to check whether the type `T`
// is correct here, as typer will check it.
tree.withNotNullInfo(NotNullInfo(Set(), Set(ref)))
else
// If the initial type is nullable and the assigned value is non-null,
// we add it to the NotNull.
tree.withNotNullInfo(NotNullInfo(Set(ref), Set()))
else tree
case _ => tree
def computeAssignNullable()(using Context): tree.type =
var nnInfo = tree.rhs.notNullInfo
tree.lhs match
case TrackedRef(ref) if ctx.explicitNulls && ref.isNullableUnion =>
nnInfo = nnInfo.seq:
val rhstp = tree.rhs.typeOpt
if rhstp.isNullType || rhstp.isNullableUnion then
// If the type of rhs is nullable (`T|Null` or `Null`), then the nullability of the
// lhs variable is no longer trackable. We don't need to check whether the type `T`
// is correct here, as typer will check it.
NotNullInfo(Set(), Set(ref))
else
// If the initial type is nullable and the assigned value is non-null,
// we add it to the NotNull.
NotNullInfo(Set(ref), Set())
case _ =>
tree.withNotNullInfo(nnInfo)
end extension

private val analyzedOps = Set(nme.EQ, nme.NE, nme.eq, nme.ne, nme.ZAND, nme.ZOR, nme.UNARY_!)
Expand Down Expand Up @@ -515,7 +548,10 @@ object Nullables:
&& assignmentSpans.getOrElse(sym.span.start, Nil).exists(whileSpan.contains(_))
&& ctx.notNullInfos.impliesNotNull(ref)

val retractedVars = ctx.notNullInfos.flatMap(_.asserted.filter(isRetracted)).toSet
val retractedVars = ctx.notNullInfos.flatMap(info =>
if info.asserted == null then Set.empty
else info.asserted.filter(isRetracted)
).toSet
ctx.addNotNullInfo(NotNullInfo(Set(), retractedVars))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we lose knowledge about the loop being unreachable. The result always has asserted == Set(), even if it was null before.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine for the NotNullInfo in the ctx. Suppose there is a terminated info in the ctx, adding a new non-terminated info will not change its behaviour: still treating all symbols as non-nullable.

end whileContext

Expand Down
59 changes: 34 additions & 25 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
untpd.unsplice(tree.expr).putAttachment(AscribedToUnit, ())
typed(tree.expr, underlyingTreeTpe.tpe.widenSkolem)
assignType(cpy.Typed(tree)(expr1, tpt), underlyingTreeTpe)
.withNotNullInfo(expr1.notNullInfo)
}

if (untpd.isWildcardStarArg(tree)) {
Expand Down Expand Up @@ -1551,11 +1550,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

def thenPathInfo = cond1.notNullInfoIf(true).seq(result.thenp.notNullInfo)
def elsePathInfo = cond1.notNullInfoIf(false).seq(result.elsep.notNullInfo)
result.withNotNullInfo(
if result.thenp.tpe.isRef(defn.NothingClass) then elsePathInfo
else if result.elsep.tpe.isRef(defn.NothingClass) then thenPathInfo
else thenPathInfo.alt(elsePathInfo)
)
result.withNotNullInfo(thenPathInfo.alt(elsePathInfo))
end typedIf

/** Decompose function prototype into a list of parameter prototypes and a result
Expand Down Expand Up @@ -2139,20 +2134,23 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
case1
}
.asInstanceOf[List[CaseDef]]
var nni = sel.notNullInfo
if cases1.nonEmpty then nni = nni.seq(cases1.map(_.notNullInfo).reduce(_.alt(_)))
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).cast(pt).withNotNullInfo(nni)
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).cast(pt)
.withNotNullInfo(notNullInfoFromCases(sel.notNullInfo, cases1))
}

// Overridden in InlineTyper for inline matches
def typedMatchFinish(tree: untpd.Match, sel: Tree, wideSelType: Type, cases: List[untpd.CaseDef], pt: Type)(using Context): Tree = {
val cases1 = harmonic(harmonize, pt)(typedCases(cases, sel, wideSelType, pt.dropIfProto))
.asInstanceOf[List[CaseDef]]
var nni = sel.notNullInfo
if cases1.nonEmpty then nni = nni.seq(cases1.map(_.notNullInfo).reduce(_.alt(_)))
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).withNotNullInfo(nni)
assignType(cpy.Match(tree)(sel, cases1), sel, cases1)
.withNotNullInfo(notNullInfoFromCases(sel.notNullInfo, cases1))
}

private def notNullInfoFromCases(initInfo: NotNullInfo, cases: List[CaseDef])(using Context): NotNullInfo =
if cases.nonEmpty then
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what instances would the cases be empty? Depending on the answer, perhaps a comment would be useful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the cases can be empty. Just want to avoid the exception when calling reduce on empty list.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it seems the empty cases can happen after inlining.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. A question is, if empty cases does happen, what does it mean in terms of runtime semantics? If it definitely fails to match, it would be more precise to return terminatedInfo, but initInfo is sound and safer in case empty cases means something different in some special situations.

I think it's fine to leave it as it is, and you can merge this PR now.

initInfo.seq(cases.map(_.notNullInfo).reduce(_.alt(_)))
else initInfo

def typedCases(cases: List[untpd.CaseDef], sel: Tree, wideSelType0: Type, pt: Type)(using Context): List[CaseDef] =
var caseCtx = ctx
var wideSelType = wideSelType0
Expand Down Expand Up @@ -2241,7 +2239,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedLabeled(tree: untpd.Labeled)(using Context): Labeled = {
val bind1 = typedBind(tree.bind, WildcardType).asInstanceOf[Bind]
val expr1 = typed(tree.expr, bind1.symbol.info)
assignType(cpy.Labeled(tree)(bind1, expr1))
assignType(cpy.Labeled(tree)(bind1, expr1)).withNotNullInfo(expr1.notNullInfo.retractedInfo)
}

/** Type a case of a type match */
Expand Down Expand Up @@ -2291,7 +2289,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
// Hence no adaptation is possible, and we assume WildcardType as prototype.
(from, proto)
val expr1 = typedExpr(tree.expr orElse untpd.syntheticUnitLiteral.withSpan(tree.span), proto)
assignType(cpy.Return(tree)(expr1, from))
assignType(cpy.Return(tree)(expr1, from)).withNotNullInfo(expr1.notNullInfo.terminatedInfo)
end typedReturn

def typedWhileDo(tree: untpd.WhileDo)(using Context): Tree =
Expand Down Expand Up @@ -2332,7 +2330,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val capabilityProof = caughtExceptions.reduce(OrType(_, _, true))
untpd.Block(makeCanThrow(capabilityProof), expr)

def typedTry(tree: untpd.Try, pt: Type)(using Context): Try = {
def typedTry(tree: untpd.Try, pt: Type)(using Context): Try =
var nnInfo = NotNullInfo.empty
val expr2 :: cases2x = harmonic(harmonize, pt) {
// We want to type check tree.expr first to comput NotNullInfo, but `addCanThrowCapabilities`
// uses the types of patterns in `tree.cases` to determine the capabilities.
Expand All @@ -2344,18 +2343,26 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val casesEmptyBody1 = tree.cases.mapconserve(cpy.CaseDef(_)(body = EmptyTree))
val casesEmptyBody2 = typedCases(casesEmptyBody1, EmptyTree, defn.ThrowableType, WildcardType)
val expr1 = typed(addCanThrowCapabilities(tree.expr, casesEmptyBody2), pt.dropIfProto)
val casesCtx = ctx.addNotNullInfo(expr1.notNullInfo.retractedInfo)

// Since we don't know at which point the the exception is thrown in the body,
// we have to collect any reference that is once retracted.
nnInfo = expr1.notNullInfo.retractedInfo

val casesCtx = ctx.addNotNullInfo(nnInfo)
val cases1 = typedCases(tree.cases, EmptyTree, defn.ThrowableType, pt.dropIfProto)(using casesCtx)
expr1 :: cases1
}: @unchecked
val cases2 = cases2x.asInstanceOf[List[CaseDef]]

var nni = expr2.notNullInfo.retractedInfo
if cases2.nonEmpty then nni = nni.seq(cases2.map(_.notNullInfo.retractedInfo).reduce(_.alt(_)))
val finalizer1 = typed(tree.finalizer, defn.UnitType)(using ctx.addNotNullInfo(nni))
nni = nni.seq(finalizer1.notNullInfo)
assignType(cpy.Try(tree)(expr2, cases2, finalizer1), expr2, cases2).withNotNullInfo(nni)
}
// It is possible to have non-exhaustive cases, and some exceptions are thrown and not caught.
// Therefore, the code in the finalizer and after the try block can only rely on the retracted
// info from the cases' body.
if cases2.nonEmpty then
nnInfo = nnInfo.seq(cases2.map(_.notNullInfo.retractedInfo).reduce(_.alt(_)))

val finalizer1 = typed(tree.finalizer, defn.UnitType)(using ctx.addNotNullInfo(nnInfo))
nnInfo = nnInfo.seq(finalizer1.notNullInfo)
assignType(cpy.Try(tree)(expr2, cases2, finalizer1), expr2, cases2).withNotNullInfo(nnInfo)

def typedTry(tree: untpd.ParsedTry, pt: Type)(using Context): Try =
val cases: List[untpd.CaseDef] = tree.handler match
Expand All @@ -2369,15 +2376,15 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedThrow(tree: untpd.Throw)(using Context): Tree =
val expr1 = typed(tree.expr, defn.ThrowableType)
val cap = checkCanThrow(expr1.tpe.widen, tree.span)
val res = Throw(expr1).withSpan(tree.span)
var res = Throw(expr1).withSpan(tree.span)
if Feature.ccEnabled && !cap.isEmpty && !ctx.isAfterTyper then
// Record access to the CanThrow capabulity recovered in `cap` by wrapping
// the type of the `throw` (i.e. Nothing) in a `@requiresCapability` annotation.
Typed(res,
res = Typed(res,
TypeTree(
AnnotatedType(res.tpe,
Annotation(defn.RequiresCapabilityAnnot, cap, tree.span))))
else res
res.withNotNullInfo(expr1.notNullInfo.terminatedInfo)

def typedSeqLiteral(tree: untpd.SeqLiteral, pt: Type)(using Context): SeqLiteral = {
val elemProto = pt.stripNull().elemType match {
Expand Down Expand Up @@ -2842,6 +2849,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val vdef1 = assignType(cpy.ValDef(vdef)(name, tpt1, rhs1), sym)
postProcessInfo(vdef1, sym)
vdef1.setDefTree
val nnInfo = rhs1.notNullInfo
vdef1.withNotNullInfo(if sym.is(Lazy) then nnInfo.retractedInfo else nnInfo)
}

private def retractDefDef(sym: Symbol)(using Context): Tree =
Expand Down
18 changes: 18 additions & 0 deletions tests/explicit-nulls/neg/i21380b.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,22 @@ def test3(i: Int) =
i match
case 1 if x != null => ()
case _ => x = " "
x.trim() // ok

def test4(i: Int) =
var x: String | Null = null
var y: String | Null = null
i match
case 1 => x = "1"
case _ => y = " "
x.trim() // error

def test5(i: Int): String =
var x: String | Null = null
var y: String | Null = null
i match
case 1 => x = "1"
case _ =>
y = " "
return y
x.trim() // ok
6 changes: 3 additions & 3 deletions tests/explicit-nulls/neg/i21380c.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def test4: Int =
case npe: NullPointerException => x = ""
case _ => x = ""
x.length // error
// Although the catch block here is exhaustive,
// it is possible that the exception is thrown and not caught.
// Therefore, the code after the try block can only rely on the retracted info.
// Although the catch block here is exhaustive, it is possible to have non-exhaustive cases,
// and some exceptions are thrown and not caught. Therefore, the code in the finalizer and
// after the try block can only rely on the retracted info from the cases' body.

def test5: Int =
var x: String | Null = null
Expand Down
Loading
Loading