diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index e66c71731b4f..56c153498f87 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -52,6 +52,10 @@ object desugar { */ val ContextBoundParam: Property.Key[Unit] = Property.StickyKey() + /** Marks a poly fcuntion apply method, so that we can handle adding evidence parameters to them in a special way + */ + val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey() + /** What static check should be applied to a Match? */ enum MatchCheck { case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom @@ -242,7 +246,7 @@ object desugar { * def f$default$2[T](x: Int) = x + "m" */ private def defDef(meth: DefDef, isPrimaryConstructor: Boolean = false)(using Context): Tree = - addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor)) + addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor).asInstanceOf[DefDef]) /** Drop context bounds in given TypeDef, replacing them with evidence ValDefs that * get added to a buffer. @@ -304,10 +308,8 @@ object desugar { tdef1 end desugarContextBounds - private def elimContextBounds(meth: DefDef, isPrimaryConstructor: Boolean)(using Context): DefDef = - val DefDef(_, paramss, tpt, rhs) = meth + def elimContextBounds(meth: Tree, isPrimaryConstructor: Boolean = false)(using Context): Tree = val evidenceParamBuf = mutable.ListBuffer[ValDef]() - var seenContextBounds: Int = 0 def freshName(unused: Tree) = seenContextBounds += 1 // Start at 1 like FreshNameCreator. @@ -317,7 +319,7 @@ object desugar { // parameters of the method since shadowing does not affect // implicit resolution in Scala 3. - val paramssNoContextBounds = + def paramssNoContextBounds(paramss: List[ParamClause]): List[ParamClause] = val iflag = paramss.lastOption.flatMap(_.headOption) match case Some(param) if param.mods.isOneOf(GivenOrImplicit) => param.mods.flags & GivenOrImplicit @@ -329,15 +331,32 @@ object desugar { tparam => desugarContextBounds(tparam, evidenceParamBuf, flags, freshName, paramss) }(identity) - rhs match - case MacroTree(call) => - cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased) - case _ => - addEvidenceParams( - cpy.DefDef(meth)( - name = normalizeName(meth, tpt).asTermName, - paramss = paramssNoContextBounds), - evidenceParamBuf.toList) + meth match + case meth @ DefDef(_, paramss, tpt, rhs) => + val newParamss = paramssNoContextBounds(paramss) + rhs match + case MacroTree(call) => + cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased) + case _ => + addEvidenceParams( + cpy.DefDef(meth)( + name = normalizeName(meth, tpt).asTermName, + paramss = newParamss + ), + evidenceParamBuf.toList + ) + case meth @ PolyFunction(tparams, fun) => + val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = meth: @unchecked + val Function(vparams: List[untpd.ValDef] @unchecked, rhs) = fun: @unchecked + val newParamss = paramssNoContextBounds(tparams :: vparams :: Nil) + val params = evidenceParamBuf.toList + if params.isEmpty then + meth + else + val boundNames = getBoundNames(params, newParamss) + val recur = fitEvidenceParams(params, nme.apply, boundNames) + val (paramsFst, paramsSnd) = recur(newParamss) + functionsOf((paramsFst ++ paramsSnd).filter(_.nonEmpty), rhs) end elimContextBounds def addDefaultGetters(meth: DefDef)(using Context): Tree = @@ -465,6 +484,74 @@ object desugar { case _ => (Nil, tree) + private def referencesName(vdef: ValDef, names: Set[TermName])(using Context): Boolean = + vdef.tpt.existsSubTree: + case Ident(name: TermName) => names.contains(name) + case _ => false + + /** Fit evidence `params` into the `mparamss` parameter lists, making sure + * that all parameters referencing `params` are after them. + * - for methods the final parameter lists are := result._1 ++ result._2 + * - for poly functions, each element of the pair contains at most one term + * parameter list + * + * @param params the evidence parameters list that should fit into `mparamss` + * @param methName the name of the method that `mparamss` belongs to + * @param boundNames the names of the evidence parameters + * @param mparamss the original parameter lists of the method + * @return a pair of parameter lists containing all parameter lists in a + * reference-correct order; make sure that `params` is always at the + * intersection of the pair elements; this is relevant, for poly functions + * where `mparamss` is guaranteed to have exectly one term parameter list, + * then each pair element will have at most one term parameter list + */ + private def fitEvidenceParams( + params: List[ValDef], + methName: Name, + boundNames: Set[TermName] + )(mparamss: List[ParamClause])(using Context): (List[ParamClause], List[ParamClause]) = mparamss match + case ValDefs(mparams) :: _ if mparams.exists(referencesName(_, boundNames)) => + (params :: Nil) -> mparamss + case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) => + val normParams = + if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then + params.map: param => + val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit)) + param.withMods(param.mods.withFlags(normFlags)) + .showing(i"adapted param $result ${result.mods.flags} for ${methName}", Printers.desugar) + else params + ((normParams ++ mparams) :: Nil) -> Nil + case mparams :: mparamss1 => + val (fst, snd) = fitEvidenceParams(params, methName, boundNames)(mparamss1) + (mparams :: fst) -> snd + case Nil => + Nil -> (params :: Nil) + + /** Create a chain of possibly contextual functions from the parameter lists */ + private def functionsOf(paramss: List[ParamClause], rhs: Tree)(using Context): Tree = paramss match + case Nil => rhs + case ValDefs(head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit) => + val paramTpts = head.map(_.tpt) + val paramNames = head.map(_.name) + val paramsErased = head.map(_.mods.flags.is(Erased)) + makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span) + case ValDefs(head) :: rest => + Function(head, functionsOf(rest, rhs)) + case TypeDefs(head) :: rest => + PolyFunction(head, functionsOf(rest, rhs)) + case _ => + assert(false, i"unexpected paramss $paramss") + EmptyTree + + private def getBoundNames(params: List[ValDef], paramss: List[ParamClause])(using Context): Set[TermName] = + var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names + for mparams <- paramss; mparam <- mparams do + mparam match + case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) => + boundNames += tparam.name.toTermName + case _ => + boundNames + /** Add all evidence parameters in `params` as implicit parameters to `meth`. * The position of the added parameters is determined as follows: * @@ -479,36 +566,23 @@ object desugar { private def addEvidenceParams(meth: DefDef, params: List[ValDef])(using Context): DefDef = if params.isEmpty then return meth - var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names - for mparams <- meth.paramss; mparam <- mparams do - mparam match - case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) => - boundNames += tparam.name.toTermName - case _ => + val boundNames = getBoundNames(params, meth.paramss) - def referencesBoundName(vdef: ValDef): Boolean = - vdef.tpt.existsSubTree: - case Ident(name: TermName) => boundNames.contains(name) - case _ => false + val fitParams = fitEvidenceParams(params, meth.name, boundNames) - def recur(mparamss: List[ParamClause]): List[ParamClause] = mparamss match - case ValDefs(mparams) :: _ if mparams.exists(referencesBoundName) => - params :: mparamss - case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) => - val normParams = - if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then - params.map: param => - val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit)) - param.withMods(param.mods.withFlags(normFlags)) - .showing(i"adapted param $result ${result.mods.flags} for ${meth.name}", Printers.desugar) - else params - (normParams ++ mparams) :: Nil - case mparams :: mparamss1 => - mparams :: recur(mparamss1) - case Nil => - params :: Nil - - cpy.DefDef(meth)(paramss = recur(meth.paramss)) + if meth.removeAttachment(PolyFunctionApply).isDefined then + // for PolyFunctions we are limited to a single term param list, so we + // reuse the fitEvidenceParams logic to compute the new parameter lists + // and then we add the other parameter lists as function types to the + // return type + val (paramsFst, paramsSnd) = fitParams(meth.paramss) + if ctx.mode.is(Mode.Type) then + cpy.DefDef(meth)(paramss = paramsFst, tpt = functionsOf(paramsSnd, meth.tpt)) + else + cpy.DefDef(meth)(paramss = paramsFst, rhs = functionsOf(paramsSnd, meth.rhs)) + else + val (paramsFst, paramsSnd) = fitParams(meth.paramss) + cpy.DefDef(meth)(paramss = paramsFst ++ paramsSnd) end addEvidenceParams /** The parameters generated from the contextual bounds of `meth`, as generated by `desugar.defDef` */ @@ -1224,27 +1298,29 @@ object desugar { /** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R * Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R } */ - def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = - val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked - val paramFlags = fun match - case fun: FunctionWithMods => - // TODO: make use of this in the desugaring when pureFuns is enabled. - // val isImpure = funFlags.is(Impure) - - // Function flags to be propagated to each parameter in the desugared method type. - val givenFlag = fun.mods.flags.toTermFlags & Given - fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag) - case _ => - vparamTypes.map(_ => EmptyFlags) - - val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map { - case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags) - case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags) - }.toList - - RefinedTypeTree(ref(defn.PolyFunctionType), List( - DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic) - )).withSpan(tree.span) + def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = (tree: @unchecked) match + case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) => + val paramFlags = fun match + case fun: FunctionWithMods => + // TODO: make use of this in the desugaring when pureFuns is enabled. + // val isImpure = funFlags.is(Impure) + + // Function flags to be propagated to each parameter in the desugared method type. + val givenFlag = fun.mods.flags.toTermFlags & Given + fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag) + case _ => + vparamTypes.map(_ => EmptyFlags) + + val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map { + case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags) + case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags) + }.toList + + RefinedTypeTree(ref(defn.PolyFunctionType), List( + DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree) + .withFlags(Synthetic) + .withAttachment(PolyFunctionApply, ()) + )).withSpan(tree.span) end makePolyFunctionType /** Invent a name for an anonympus given of type or template `impl`. */ diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 2bcc6ae1ce9f..2e441553689c 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -3459,7 +3459,7 @@ object Parsers { * * TypTypeParamClause::= ‘[’ TypTypeParam {‘,’ TypTypeParam} ‘]’ * TypTypeParam ::= {Annotation} - * (id | ‘_’) [HkTypeParamClause] TypeBounds + * (id | ‘_’) [HkTypeParamClause] TypeAndCtxBounds * * HkTypeParamClause ::= ‘[’ HkTypeParam {‘,’ HkTypeParam} ‘]’ * HkTypeParam ::= {Annotation} [‘+’ | ‘-’] @@ -3490,7 +3490,9 @@ object Parsers { else ident().toTypeName val hkparams = typeParamClauseOpt(ParamOwner.Hk) val bounds = - if paramOwner.acceptsCtxBounds then typeAndCtxBounds(name) else typeBounds() + if paramOwner.acceptsCtxBounds then typeAndCtxBounds(name) + else if in.featureEnabled(Feature.modularity) && paramOwner == ParamOwner.Type then typeAndCtxBounds(name) + else typeBounds() TypeDef(name, lambdaAbstract(hkparams, bounds)).withMods(mods) } } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index f8241edc941a..5463386fa771 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1920,7 +1920,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree = val tree1 = desugar.normalizePolyFunction(tree) if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt) - else typedPolyFunctionValue(tree1, pt) + else typedPolyFunctionValue(desugar.elimContextBounds(tree1).asInstanceOf[untpd.PolyFunction], pt) def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree = val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked @@ -2474,7 +2474,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val TypeDef(_, impl: Template) = typed(refineClsDef): @unchecked val refinements1 = impl.body val seen = mutable.Set[Symbol]() - for (refinement <- refinements1) { // TODO: get clarity whether we want to enforce these conditions + for refinement <- refinements1 do // TODO: get clarity whether we want to enforce these conditions typr.println(s"adding refinement $refinement") checkRefinementNonCyclic(refinement, refineCls, seen) val rsym = refinement.symbol @@ -2488,7 +2488,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val member = refineCls.info.member(rsym.name) if (member.isOverloaded) report.error(OverloadInRefinement(rsym), refinement.srcPos) - } assignType(cpy.RefinedTypeTree(tree)(tpt1, refinements1), tpt1, refinements1, refineCls) } diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala new file mode 100644 index 000000000000..13411a3ad769 --- /dev/null +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -0,0 +1,93 @@ +import scala.language.experimental.modularity +import scala.language.future + +trait Ord[X]: + def compare(x: X, y: X): Int + type T + +trait Show[X]: + def show(x: X): String + +val less0: [X: Ord] => (X, X) => Boolean = ??? + +val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + +type PolyTest1 = [X] => X => Ord[X] ?=> Boolean + +val less1_type_test: [X: Ord] => (X, X) => Boolean = + [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + +val less2 = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 + +val less2_type_test: [X: Ord as ord] => (X, X) => Boolean = + [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 + +type CtxFunctionRef = Ord[Int] ?=> Boolean +type ComparerRef = [X] => (x: X, y: X) => Ord[X] ?=> Boolean +type Comparer = [X: Ord] => (x: X, y: X) => Boolean +val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 + +type CmpRest[X] = X => Boolean +type CmpMid[X] = X => CmpRest[X] +type Cmp3 = [X: Ord] => X => CmpMid[X] +val lessCmp3: Cmp3 = [X: Ord] => (x: X) => (y: X) => (z: X) => summon[Ord[X]].compare(x, y) < 0 +val lessCmp3_1: Cmp3 = [X: Ord as ord] => (x: X) => (y: X) => (z: X) => ord.compare(x, y) < 0 + +// type Cmp[X] = (x: X, y: X) => Boolean +// type Comparer2 = [X: Ord] => Cmp[X] +// val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + +type CmpWeak[X] = X => Boolean +type Comparer2Weak = [X: Ord] => X => CmpWeak[X] +val less4_0: [X: Ord] => X => X => Boolean = + [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 +val less4_1: Comparer2Weak = + [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 + +val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + +val less5_type_test: [X: [X] =>> Ord[X]] => (X, X) => Boolean = + [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + +val less6 = [X: {Ord, Show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + +val less6_type_test: [X: {Ord, Show}] => (X, X) => Boolean = + [X: {Ord, Show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + +val less7 = [X: {Ord as ord, Show}] => (x: X, y: X) => ord.compare(x, y) < 0 + +val less7_type_test: [X: {Ord as ord, Show}] => (X, X) => Boolean = + [X: {Ord as ord, Show}] => (x: X, y: X) => ord.compare(x, y) < 0 + +val less8 = [X: {Ord, Show as show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + +val less8_type_test: [X: {Ord, Show as show}] => (X, X) => Boolean = + [X: {Ord, Show as show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + +val less9 = [X: {Ord as ord, Show as show}] => (x: X, y: X) => ord.compare(x, y) < 0 + +val less9_type_test: [X: {Ord as ord, Show as show}] => (X, X) => Boolean = + [X: {Ord as ord, Show as show}] => (x: X, y: X) => ord.compare(x, y) < 0 + +type CmpNested = [X: Ord] => X => [Y: Ord] => Y => Boolean +val less10: CmpNested = [X: Ord] => (x: X) => [Y: Ord] => (y: Y) => true +val less10Explicit: CmpNested = [X] => (x: X) => (ordx: Ord[X]) ?=> [Y] => (y: Y) => (ordy: Ord[Y]) ?=> true + +type CmpAlias[X] = X => Boolean +type CmpNestedAliased = [X: Ord] => X => [Y] => Y => CmpAlias[Y] + +val less11: CmpNestedAliased = [X: Ord] => (x: X) => [Y] => (y: Y) => (y1: Y) => true +val less11Explicit: CmpNestedAliased = [X] => (x: X) => (ordx: Ord[X]) ?=> [Y] => (y: Y) => (y1: Y) => true + +val notationalExample: [X: Ord] => X => [Y: Ord] => Y => Int = + [X] => (x: X) => (ordx: Ord[X]) ?=> [Y] => (y: Y) => (ordy: Ord[Y]) ?=> 1 + +val namedConstraintRef = [X: {Ord as ord}] => (x: ord.T) => x +type DependentCmp = [X: {Ord as ord}] => ord.T => Boolean +type DependentCmp1 = [X: {Ord as ord}] => (ord.T, Int) => ord.T => Boolean +val dependentCmp: DependentCmp = [X: {Ord as ord}] => (x: ord.T) => true +val dependentCmp_1: [X: {Ord as ord}] => ord.T => Boolean = [X: {Ord as ord}] => (x: ord.T) => true + +val dependentCmp1: DependentCmp1 = [X: {Ord as ord}] => (x: ord.T, y: Int) => (z: ord.T) => true +val dependentCmp1_1: [X: {Ord as ord}] => (ord.T, Int) => ord.T => Boolean = + [X: {Ord as ord}] => (x: ord.T, y: Int) => (z: ord.T) => true diff --git a/tests/run/contextbounds-for-poly-functions.check b/tests/run/contextbounds-for-poly-functions.check new file mode 100644 index 000000000000..2e7f62a3914f --- /dev/null +++ b/tests/run/contextbounds-for-poly-functions.check @@ -0,0 +1,6 @@ +42 +a string +Kate is 27 years old +42 and a string +a string and Kate is 27 years old +Kate is 27 years old and 42 diff --git a/tests/run/contextbounds-for-poly-functions.scala b/tests/run/contextbounds-for-poly-functions.scala new file mode 100644 index 000000000000..dcc974fce198 --- /dev/null +++ b/tests/run/contextbounds-for-poly-functions.scala @@ -0,0 +1,30 @@ +import scala.language.experimental.modularity +import scala.language.future + +trait Show[X]: + def show(x: X): String + +given Show[Int] with + def show(x: Int) = x.toString + +given Show[String] with + def show(x: String) = x + +case class Person(name: String, age: Int) + +given Show[Person] with + def show(x: Person) = s"${x.name} is ${x.age} years old" + +type Shower = [X: Show] => X => String +val shower: Shower = [X: {Show as show}] => (x: X) => show.show(x) + +type DoubleShower = [X: Show] => X => [Y: Show] => Y => String +val doubleShower: DoubleShower = [X: {Show as show1}] => (x: X) => [Y: {Show as show2}] => (y: Y) => s"${show1.show(x)} and ${show2.show(y)}" + +object Test extends App: + println(shower(42)) + println(shower("a string")) + println(shower(Person("Kate", 27))) + println(doubleShower(42)("a string")) + println(doubleShower("a string")(Person("Kate", 27))) + println(doubleShower(Person("Kate", 27))(42))