Skip to content

Commit

Permalink
Refactor pattern matching, skipping cases when safe to do so
Browse files Browse the repository at this point in the history
  • Loading branch information
EnzeXing committed Dec 10, 2024
1 parent 1775d0b commit 966a174
Showing 1 changed file with 38 additions and 18 deletions.
56 changes: 38 additions & 18 deletions compiler/src/dotty/tools/dotc/transform/init/Objects.scala
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,12 @@ class Objects(using Context @constructorOnly):
case (ValueSet(values), b : ValueElement) => ValueSet(values + b)
case (a : ValueElement, b : ValueElement) => ValueSet(ListSet(a, b))

def remove(b: Value): Value = (a, b) match
case (ValueSet(values1), b: ValueElement) => ValueSet(values1 - b)
case (ValueSet(values1), ValueSet(values2)) => ValueSet(values1.removedAll(values2))
case (a: Ref, b: Ref) if a.equals(b) => Bottom
case _ => a

def widen(height: Int)(using Context): Value =
if height == 0 then Cold
else
Expand Down Expand Up @@ -1341,29 +1347,25 @@ class Objects(using Context @constructorOnly):
def getMemberMethod(receiver: Type, name: TermName, tp: Type): Denotation =
receiver.member(name).suchThat(receiver.memberInfo(_) <:< tp)

def evalCase(caseDef: CaseDef): Value =
evalPattern(scrutinee, caseDef.pat)
eval(caseDef.guard, thisV, klass)
eval(caseDef.body, thisV, klass)

/** Abstract evaluation of patterns.
*
* It augments the local environment for bound pattern variables. As symbols are globally
* unique, we can put them in a single environment.
*
* Currently, we assume all cases are reachable, thus all patterns are assumed to match.
*/
def evalPattern(scrutinee: Value, pat: Tree): Value = log("match " + scrutinee.show + " against " + pat.show, printer, (_: Value).show):
def evalPattern(scrutinee: Value, pat: Tree): (Type, Value) = log("match " + scrutinee.show + " against " + pat.show, printer, (_: (Type, Value))._2.show):
val trace2 = Trace.trace.add(pat)
pat match
case Alternative(pats) =>
for pat <- pats do evalPattern(scrutinee, pat)
scrutinee
val (types, values) = pats.map(evalPattern(scrutinee, _)).unzip()
val orType = types.fold(defn.NothingType)(OrType(_, _, false))
(orType, values.join)

case bind @ Bind(_, pat) =>
val value = evalPattern(scrutinee, pat)
val (tpe, value) = evalPattern(scrutinee, pat)
initLocal(bind.symbol, value)
scrutinee
(tpe, value)

case UnApply(fun, implicits, pats) =>
given Trace = trace2
Expand All @@ -1372,6 +1374,10 @@ class Objects(using Context @constructorOnly):
val funRef = fun1.tpe.asInstanceOf[TermRef]
val unapplyResTp = funRef.widen.finalResultType

val receiverType = fun1 match
case ident: Ident => funRef.prefix
case select: Select => select.qualifier.tpe

val receiver = fun1 match
case ident: Ident =>
evalType(funRef.prefix, thisV, klass)
Expand Down Expand Up @@ -1460,17 +1466,18 @@ class Objects(using Context @constructorOnly):
end if
end if
end if
scrutinee
(receiverType, scrutinee.filterType(receiverType))

case Ident(nme.WILDCARD) | Ident(nme.WILDCARD_STAR) =>
scrutinee
(defn.ThrowableType, scrutinee)

case Typed(pat, _) =>
evalPattern(scrutinee, pat)
case Typed(pat, typeTree) =>
val (_, value) = evalPattern(scrutinee.filterType(typeTree.tpe), pat)
(typeTree.tpe, value)

case tree =>
// For all other trees, the semantics is normal.
eval(tree, thisV, klass)
(defn.ThrowableType, eval(tree, thisV, klass))

end evalPattern

Expand All @@ -1494,12 +1501,12 @@ class Objects(using Context @constructorOnly):
if isWildcardStarArgList(pats) then
if pats.size == 1 then
// call .toSeq
val toSeqDenot = getMemberMethod(scrutineeType, nme.toSeq, toSeqType(elemType))
val toSeqDenot = scrutineeType.member(nme.toSeq).suchThat(_.info.isParameterless)
val toSeqRes = call(scrutinee, toSeqDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
evalPattern(toSeqRes, pats.head)
else
// call .drop
val dropDenot = getMemberMethod(scrutineeType, nme.drop, dropType(elemType))
val dropDenot = getMemberMethod(scrutineeType, nme.drop, applyType(elemType))
val dropRes = call(scrutinee, dropDenot.symbol, ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
for pat <- pats.init do evalPattern(applyRes, pat)
evalPattern(dropRes, pats.last)
Expand All @@ -1510,8 +1517,21 @@ class Objects(using Context @constructorOnly):
end if
end evalSeqPatterns

def canSkipCase(remainingScrutinee: Value, catchValue: Value) =
(remainingScrutinee == Bottom && scrutinee != Bottom) ||
(catchValue == Bottom && remainingScrutinee != Bottom)

cases.map(evalCase).join
var remainingScrutinee = scrutinee
val caseResults: mutable.ArrayBuffer[Value] = mutable.ArrayBuffer()
for caseDef <- cases do
val (tpe, value) = evalPattern(remainingScrutinee, caseDef.pat)
eval(caseDef.guard, thisV, klass)
if !canSkipCase(remainingScrutinee, value) then
caseResults.addOne(eval(caseDef.body, thisV, klass))
if catchesAllOf(caseDef, tpe) then
remainingScrutinee = remainingScrutinee.remove(value)

caseResults.join
end patternMatch

/** Handle semantics of leaf nodes
Expand Down

0 comments on commit 966a174

Please sign in to comment.