Skip to content

Commit

Permalink
Attempt to beta reduce only if parameter and argument lists have same…
Browse files Browse the repository at this point in the history
… shape

It's possible to define Functions with wrong apply methods by hand which will
give an error but pass on a function that does fails beta reduction.

Fixes #21952
  • Loading branch information
odersky committed Nov 18, 2024
1 parent 58f88a6 commit 6896ea2
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 23 deletions.
57 changes: 37 additions & 20 deletions compiler/src/dotty/tools/dotc/transform/BetaReduce.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ object BetaReduce:
val bindingsBuf = new ListBuffer[DefTree]
def recur(fn: Tree, argss: List[List[Tree]]): Option[Tree] = fn match
case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol =>
Some(reduceApplication(ddef, argss, bindingsBuf))
reduceApplication(ddef, argss, bindingsBuf)
case Block((TypeDef(_, template: Template)) :: Nil, Typed(Apply(Select(New(_), _), _), _)) if template.constr.rhs.isEmpty =>
template.body match
case (ddef: DefDef) :: Nil => Some(reduceApplication(ddef, argss, bindingsBuf))
case (ddef: DefDef) :: Nil => reduceApplication(ddef, argss, bindingsBuf)
case _ => None
case Block(stats, expr) if stats.forall(isPureBinding) =>
recur(expr, argss).map(cpy.Block(fn)(stats, _))
Expand All @@ -106,12 +106,22 @@ object BetaReduce:
case _ =>
tree

/** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings */
def reduceApplication(ddef: DefDef, argss: List[List[Tree]], bindings: ListBuffer[DefTree])(using Context): Tree =
/** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings.
* @return optionally, the expanded call, or none if the actual argument
* lists do not match in shape the formal parameters
*/
def reduceApplication(ddef: DefDef, argss: List[List[Tree]], bindings: ListBuffer[DefTree])
(using Context): Option[Tree] =
val (targs, args) = argss.flatten.partition(_.isType)
val tparams = ddef.leadingTypeParams
val vparams = ddef.termParamss.flatten

def shapeMatch(paramss: List[ParamClause], argss: List[List[Tree]]): Boolean = (paramss, argss) match
case (params :: paramss1, args :: argss1) if params.length == args.length =>
shapeMatch(paramss1, argss1)
case (Nil, Nil) => true
case _ => false

val targSyms =
for (targ, tparam) <- targs.zip(tparams) yield
targ.tpe.dealias match
Expand Down Expand Up @@ -143,19 +153,26 @@ object BetaReduce:
bindings += binding.withSpan(arg.span)
bindingSymbol

val expansion = TreeTypeMap(
oldOwners = ddef.symbol :: Nil,
newOwners = ctx.owner :: Nil,
substFrom = (tparams ::: vparams).map(_.symbol),
substTo = targSyms ::: argSyms
).transform(ddef.rhs)

val expansion1 = new TreeMap {
override def transform(tree: Tree)(using Context) = tree.tpe.widenTermRefExpr match
case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const)
case tpe: TypeRef if tree.isTerm && tpe.derivesFrom(defn.UnitClass) && isPureExpr(tree) =>
cpy.Literal(tree)(Constant(()))
case _ => super.transform(tree)
}.transform(expansion)

expansion1
if shapeMatch(ddef.paramss, argss) then
// We can't assume arguments always match. It's possible to construct a
// function with wrong apply method by hand which causes `shapeMatch` to fail.
// See neg/i21952.scala
val expansion = TreeTypeMap(
oldOwners = ddef.symbol :: Nil,
newOwners = ctx.owner :: Nil,
substFrom = (tparams ::: vparams).map(_.symbol),
substTo = targSyms ::: argSyms
).transform(ddef.rhs)

val expansion1 = new TreeMap {
override def transform(tree: Tree)(using Context) = tree.tpe.widenTermRefExpr match
case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const)
case tpe: TypeRef if tree.isTerm && tpe.derivesFrom(defn.UnitClass) && isPureExpr(tree) =>
cpy.Literal(tree)(Constant(()))
case _ => super.transform(tree)
}.transform(expansion)

Some(expansion1)
else None
end reduceApplication
end BetaReduce
8 changes: 5 additions & 3 deletions compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ class InlinePatterns extends MiniPhase:
template.body match
case List(ddef @ DefDef(`name`, _, _, _)) =>
val bindings = new ListBuffer[DefTree]()
val expansion1 = BetaReduce.reduceApplication(ddef, argss, bindings)
val bindings1 = bindings.result()
seq(bindings1, expansion1)
BetaReduce.reduceApplication(ddef, argss, bindings) match
case Some(expansion1) =>
val bindings1 = bindings.result()
seq(bindings1, expansion1)
case None => tree
case _ => tree
case _ => tree

Expand Down
1 change: 1 addition & 0 deletions tests/neg/i21952.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
val _ = (new Function[(Int, Int), Int] {def apply(a: Int, b: Int): Int = a * b})(2, 3) // error

0 comments on commit 6896ea2

Please sign in to comment.