diff --git a/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala b/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala index 60c1bc7c61bb..16219055b8c0 100644 --- a/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala +++ b/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala @@ -76,10 +76,10 @@ object BetaReduce: val bindingsBuf = new ListBuffer[DefTree] def recur(fn: Tree, argss: List[List[Tree]]): Option[Tree] = fn match case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol => - Some(reduceApplication(ddef, argss, bindingsBuf)) + reduceApplication(ddef, argss, bindingsBuf) case Block((TypeDef(_, template: Template)) :: Nil, Typed(Apply(Select(New(_), _), _), _)) if template.constr.rhs.isEmpty => template.body match - case (ddef: DefDef) :: Nil => Some(reduceApplication(ddef, argss, bindingsBuf)) + case (ddef: DefDef) :: Nil => reduceApplication(ddef, argss, bindingsBuf) case _ => None case Block(stats, expr) if stats.forall(isPureBinding) => recur(expr, argss).map(cpy.Block(fn)(stats, _)) @@ -106,12 +106,22 @@ object BetaReduce: case _ => tree - /** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings */ - def reduceApplication(ddef: DefDef, argss: List[List[Tree]], bindings: ListBuffer[DefTree])(using Context): Tree = + /** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings. + * @return optionally, the expanded call, or none if the actual argument + * lists do not match in shape the formal parameters + */ + def reduceApplication(ddef: DefDef, argss: List[List[Tree]], bindings: ListBuffer[DefTree]) + (using Context): Option[Tree] = val (targs, args) = argss.flatten.partition(_.isType) val tparams = ddef.leadingTypeParams val vparams = ddef.termParamss.flatten + def shapeMatch(paramss: List[ParamClause], argss: List[List[Tree]]): Boolean = (paramss, argss) match + case (params :: paramss1, args :: argss1) if params.length == args.length => + shapeMatch(paramss1, argss1) + case (Nil, Nil) => true + case _ => false + val targSyms = for (targ, tparam) <- targs.zip(tparams) yield targ.tpe.dealias match @@ -143,19 +153,26 @@ object BetaReduce: bindings += binding.withSpan(arg.span) bindingSymbol - val expansion = TreeTypeMap( - oldOwners = ddef.symbol :: Nil, - newOwners = ctx.owner :: Nil, - substFrom = (tparams ::: vparams).map(_.symbol), - substTo = targSyms ::: argSyms - ).transform(ddef.rhs) - - val expansion1 = new TreeMap { - override def transform(tree: Tree)(using Context) = tree.tpe.widenTermRefExpr match - case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const) - case tpe: TypeRef if tree.isTerm && tpe.derivesFrom(defn.UnitClass) && isPureExpr(tree) => - cpy.Literal(tree)(Constant(())) - case _ => super.transform(tree) - }.transform(expansion) - - expansion1 + if shapeMatch(ddef.paramss, argss) then + // We can't assume arguments always match. It's possible to construct a + // function with wrong apply method by hand which causes `shapeMatch` to fail. + // See neg/i21952.scala + val expansion = TreeTypeMap( + oldOwners = ddef.symbol :: Nil, + newOwners = ctx.owner :: Nil, + substFrom = (tparams ::: vparams).map(_.symbol), + substTo = targSyms ::: argSyms + ).transform(ddef.rhs) + + val expansion1 = new TreeMap { + override def transform(tree: Tree)(using Context) = tree.tpe.widenTermRefExpr match + case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const) + case tpe: TypeRef if tree.isTerm && tpe.derivesFrom(defn.UnitClass) && isPureExpr(tree) => + cpy.Literal(tree)(Constant(())) + case _ => super.transform(tree) + }.transform(expansion) + + Some(expansion1) + else None + end reduceApplication +end BetaReduce \ No newline at end of file diff --git a/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala b/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala index 18333ae506fd..d2a72e10fcfc 100644 --- a/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala +++ b/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala @@ -60,9 +60,11 @@ class InlinePatterns extends MiniPhase: template.body match case List(ddef @ DefDef(`name`, _, _, _)) => val bindings = new ListBuffer[DefTree]() - val expansion1 = BetaReduce.reduceApplication(ddef, argss, bindings) - val bindings1 = bindings.result() - seq(bindings1, expansion1) + BetaReduce.reduceApplication(ddef, argss, bindings) match + case Some(expansion1) => + val bindings1 = bindings.result() + seq(bindings1, expansion1) + case None => tree case _ => tree case _ => tree diff --git a/tests/neg/i21952.scala b/tests/neg/i21952.scala new file mode 100644 index 000000000000..0365d82463c0 --- /dev/null +++ b/tests/neg/i21952.scala @@ -0,0 +1 @@ +val _ = (new Function[(Int, Int), Int] {def apply(a: Int, b: Int): Int = a * b})(2, 3) // error