From 61a43bca89cc2d27eddf4f6a32d26833f1b1f27e Mon Sep 17 00:00:00 2001 From: Dale Wijnand Date: Wed, 9 Oct 2024 14:59:42 +0100 Subject: [PATCH] Allow autotupling if fn's param is a type param --- .../src/dotty/tools/dotc/typer/Applications.scala | 14 +++++++++++++- tests/pos/i21682.1.scala | 15 +++++++++++++++ tests/pos/i21682.2.scala | 7 +++++++ tests/pos/i21682.3.scala | 4 ++++ 4 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 tests/pos/i21682.1.scala create mode 100644 tests/pos/i21682.2.scala create mode 100644 tests/pos/i21682.3.scala diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 6bb95e20fcaf..5fb91694b8a6 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -2237,7 +2237,19 @@ trait Applications extends Compatibility { def isCorrectUnaryFunction(alt: TermRef): Boolean = val formals = params(alt) - formals.length == 1 && ptIsCorrectProduct(formals.head, args) + formals.length == 1 && { + formals.head match + case formal: TypeParamRef => + // While `formal` isn't a tuple type of the correct arity, + // it's a type parameter (a method type parameter presumably) + // so check its bounds allow for a tuple type of the correct arity. + // See i21682 for an example. + val tup = defn.tupleType(args.map(v => if v.tpt.isEmpty then WildcardType else typedAheadType(v.tpt).tpe)) + val TypeBounds(lo, hi) = formal.paramInfo + lo <:< tup && tup <:< hi + case formal => + ptIsCorrectProduct(formal, args) + } val numArgs = args.length if numArgs > 1 diff --git a/tests/pos/i21682.1.scala b/tests/pos/i21682.1.scala new file mode 100644 index 000000000000..7340edcaeb4d --- /dev/null +++ b/tests/pos/i21682.1.scala @@ -0,0 +1,15 @@ +sealed abstract class Gen[+T1] +given [T2]: Conversion[T2, Gen[T2]] = ??? + +trait Show[T3] +given Show[Boolean] = ??? +given [A1: Show, B1: Show, C1: Show]: Show[(A1, B1, C1)] = ??? + +object ForAll: + def apply[A2: Show, B2](f: A2 => B2): Unit = ??? + def apply[A3: Show, B3: Show, C3](f: (A3, B3) => C3): Unit = ??? + def apply[A4: Show, B4](gen: Gen[A4])(f: A4 => B4): Unit = ??? + +@main def Test = + ForAll: (b1: Boolean, b2: Boolean, b3: Boolean) => + ??? diff --git a/tests/pos/i21682.2.scala b/tests/pos/i21682.2.scala new file mode 100644 index 000000000000..6717d36c78a6 --- /dev/null +++ b/tests/pos/i21682.2.scala @@ -0,0 +1,7 @@ +object ForAll: + def apply[A1, B](f: A1 => B): Unit = ??? + def apply[A1, A2, B](f: (A1, A2) => B): Unit = ??? + +@main def Test = + ForAll: (b1: Boolean, b2: Boolean, b3: Boolean) => + ??? diff --git a/tests/pos/i21682.3.scala b/tests/pos/i21682.3.scala new file mode 100644 index 000000000000..b44b9a7c91fc --- /dev/null +++ b/tests/pos/i21682.3.scala @@ -0,0 +1,4 @@ +class Test: + def foo[A1 >: (Nothing, Boolean, Nothing) <: (Any, Boolean, Any), B](f: A1 => B): Unit = ??? + def test(): Unit = + val res4 = this.foo((b1: Boolean, b2: Boolean, b3: Boolean) => ???)