diff --git a/shared/src/main/scala-3/org/scalamock/clazz/MockMaker.scala b/shared/src/main/scala-3/org/scalamock/clazz/MockMaker.scala index e7514c28..b5f2da90 100644 --- a/shared/src/main/scala-3/org/scalamock/clazz/MockMaker.scala +++ b/shared/src/main/scala-3/org/scalamock/clazz/MockMaker.scala @@ -91,7 +91,7 @@ private[clazz] object MockMaker: Symbol.newVal( parent = classSymbol, name = definition.symbol.name, - tpe = definition.tpeWithSubstitutedPathDependentFor(classSymbol), + tpe = definition.tpeWithSubstitutedInnerTypesFor(classSymbol), flags = Flags.Override, privateWithin = Symbol.noSymbol ) @@ -99,7 +99,7 @@ private[clazz] object MockMaker: Symbol.newMethod( parent = classSymbol, name = definition.symbol.name, - tpe = definition.tpeWithSubstitutedPathDependentFor(classSymbol), + tpe = definition.tpeWithSubstitutedInnerTypesFor(classSymbol), flags = Flags.Override, privateWithin = Symbol.noSymbol ) @@ -177,7 +177,7 @@ private[clazz] object MockMaker: "asInstanceOf" ), definition.tpe - .resolveParamRefs(definition.resTypeWithPathDependentOverrideFor(classSymbol), args) + .resolveParamRefs(definition.resTypeWithInnerTypesOverrideFor(classSymbol), args) .asType match { case '[t] => List(TypeTree.of[t]) } ) ) diff --git a/shared/src/main/scala-3/org/scalamock/clazz/Utils.scala b/shared/src/main/scala-3/org/scalamock/clazz/Utils.scala index 52a5cf21..60bfb4dd 100644 --- a/shared/src/main/scala-3/org/scalamock/clazz/Utils.scala +++ b/shared/src/main/scala-3/org/scalamock/clazz/Utils.scala @@ -3,22 +3,22 @@ package org.scalamock.clazz import scala.quoted.* import org.scalamock.context.MockContext -import scala.annotation.tailrec +import scala.annotation.{experimental, tailrec} private[clazz] class Utils(using val quotes: Quotes): import quotes.reflect.* extension (tpe: TypeRepr) - def collectPathDependent(ownerSymbol: Symbol): List[TypeRepr] = + def collectInnerTypes(ownerSymbol: Symbol): List[TypeRepr] = def loop(currentTpe: TypeRepr, names: List[String]): List[TypeRepr] = currentTpe match - case AppliedType(inner, appliedTypes) => loop(inner, names) ++ appliedTypes.flatMap(_.collectPathDependent(ownerSymbol)) + case AppliedType(inner, appliedTypes) => loop(inner, names) ++ appliedTypes.flatMap(_.collectInnerTypes(ownerSymbol)) case TypeRef(inner, name) if name == ownerSymbol.name && names.nonEmpty => List(tpe) case TypeRef(inner, name) => loop(inner, name :: names) case _ => Nil loop(tpe, Nil) - def pathDependentOverride(ownerSymbol: Symbol, newOwnerSymbol: Symbol, applyTypes: Boolean): TypeRepr = + def innerTypeOverride(ownerSymbol: Symbol, newOwnerSymbol: Symbol, applyTypes: Boolean): TypeRepr = @tailrec def loop(currentTpe: TypeRepr, names: List[(String, List[TypeRepr])], appliedTypes: List[TypeRepr]): TypeRepr = currentTpe match @@ -53,55 +53,80 @@ private[clazz] class Utils(using val quotes: Quotes): case _ => tpe + @experimental def resolveParamRefs(resType: TypeRepr, methodArgs: List[List[Tree]]) = - def loop(baseBindings: TypeRepr, typeRepr: TypeRepr): TypeRepr = - typeRepr match - case pr@ParamRef(bindings, idx) if bindings == baseBindings => - methodArgs.head(idx).asInstanceOf[TypeTree].tpe + tpe match + case baseBindings: PolyType => + def loop(typeRepr: TypeRepr): TypeRepr = + typeRepr match + case pr@ParamRef(bindings, idx) if bindings == baseBindings => + methodArgs.head(idx).asInstanceOf[TypeTree].tpe - case AppliedType(tycon, args) => - AppliedType(tycon, args.map(arg => loop(baseBindings, arg))) + case AppliedType(tycon, args) => + AppliedType(loop(tycon), args.map(arg => loop(arg))) - case other => other + case ff @ TypeRef(ref @ ParamRef(bindings, idx), name) => + def getIndex(bindings: TypeRepr): Int = + @tailrec + def loop(bindings: TypeRepr, idx: Int): Int = + bindings match + case MethodType(_, _, method: MethodType) => loop(method, idx + 1) + case _ => idx - tpe match - case pt: PolyType => loop(pt, resType) - case _ => resType + loop(bindings, 1) + + val maxIndex = methodArgs.length + val parameterListIdx = maxIndex - getIndex(bindings) + + TypeSelect(methodArgs(parameterListIdx)(idx).asInstanceOf[Term], name).tpe + + case other => other + + loop(resType) + case _ => + resType - def collectTypes: List[TypeRepr] = - def loop(currentTpe: TypeRepr, params: List[TypeRepr]): List[TypeRepr] = + def collectTypes: (List[TypeRepr], TypeRepr) = + @tailrec + def loop(currentTpe: TypeRepr, argTypesAcc: List[List[TypeRepr]], resType: TypeRepr): (List[TypeRepr], TypeRepr) = currentTpe match - case PolyType(_, _, res) => loop(res, Nil) - case MethodType(_, argTypes, res) => argTypes ++ loop(res, params) - case other => List(other) - loop(tpe, Nil) + case PolyType(_, _, res) => loop(res, List.empty[TypeRepr] :: argTypesAcc, resType) + case MethodType(_, argTypes, res) => loop(res, argTypes :: argTypesAcc, resType) + case other => (argTypesAcc.reverse.flatten, other) + loop(tpe, Nil, TypeRepr.of[Nothing]) case class MockableDefinition(idx: Int, symbol: Symbol, ownerTpe: TypeRepr): val mockValName = s"mock$$${symbol.name}$$$idx" val tpe = ownerTpe.memberType(symbol) - private val rawTypes = tpe.widen.collectTypes + private val (rawTypes, rawResType) = tpe.widen.collectTypes val parameterTypes = prepareTypesFor(ownerTpe.typeSymbol).map(_.tpe).init - def resTypeWithPathDependentOverrideFor(classSymbol: Symbol): TypeRepr = - val pd = rawTypes.last.collectPathDependent(ownerTpe.typeSymbol) - val pdUpdated = pd.map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false)) - rawTypes.last.substituteTypes(pd.map(_.typeSymbol), pdUpdated) + def resTypeWithInnerTypesOverrideFor(classSymbol: Symbol): TypeRepr = + updatePathDependent(rawResType, List(rawResType), classSymbol) + + def tpeWithSubstitutedInnerTypesFor(classSymbol: Symbol): TypeRepr = + updatePathDependent(tpe, rawResType :: rawTypes, classSymbol) - def tpeWithSubstitutedPathDependentFor(classSymbol: Symbol): TypeRepr = - val pathDependentTypes = rawTypes.flatMap(_.collectPathDependent(ownerTpe.typeSymbol)) - val pdUpdated = pathDependentTypes.map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false)) - tpe.substituteTypes(pathDependentTypes.map(_.typeSymbol), pdUpdated) + private def updatePathDependent(where: TypeRepr, types: List[TypeRepr], classSymbol: Symbol): TypeRepr = + val pathDependentTypes = types.flatMap(_.collectInnerTypes(ownerTpe.typeSymbol)) + val pdUpdated = pathDependentTypes.map(_.innerTypeOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false)) + where.substituteTypes(pathDependentTypes.map(_.typeSymbol), pdUpdated) - def prepareTypesFor(classSymbol: Symbol) = rawTypes - .map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = true)) + def prepareTypesFor(classSymbol: Symbol) = (rawTypes :+ rawResType) + .map(_.innerTypeOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = true)) .map { typeRepr => val adjusted = typeRepr.widen.mapParamRefWithWildcard match case TypeBounds(lower, upper) => upper case AppliedType(TypeRef(_, ""), elemTyps) => TypeRepr.typeConstructorOf(classOf[Seq[_]]).appliedTo(elemTyps) - case other => other + case TypeRef(_: ParamRef, _) => + TypeRepr.of[Any] + case AppliedType(TypeRef(_: ParamRef, _), _) => + TypeRepr.of[Any] + case other => + other adjusted.asType match case '[t] => TypeTree.of[t] } @@ -128,9 +153,9 @@ private[clazz] class Utils(using val quotes: Quotes): def apply(tpe: TypeRepr): List[MockableDefinition] = val methods = (tpe.typeSymbol.methodMembers.toSet -- TypeRepr.of[Object].typeSymbol.methodMembers).toList - .filter(sym => + .filter(sym => !sym.flags.is(Flags.Private) && - !sym.flags.is(Flags.Final) && + !sym.flags.is(Flags.Final) && !sym.flags.is(Flags.Mutable) && !sym.name.contains("$default$") ) diff --git a/shared/src/test/scala-3/com/paulbutcher/test/PathDependentParamSpec.scala b/shared/src/test/scala-3/com/paulbutcher/test/PathDependentParamSpec.scala new file mode 100644 index 00000000..3524ef43 --- /dev/null +++ b/shared/src/test/scala-3/com/paulbutcher/test/PathDependentParamSpec.scala @@ -0,0 +1,107 @@ +package com.paulbutcher.test + +import org.scalamock.matchers.Matchers +import org.scalamock.scalatest.MockFactory +import org.scalatest.funspec.AnyFunSpec + +class PathDependentParamSpec extends AnyFunSpec with Matchers with MockFactory { + + trait Command { + type Answer + type AnswerConstructor[A] + } + + case class IntCommand() extends Command { + override type Answer = Int + override type AnswerConstructor[A] = Option[A] + } + + val cmd = IntCommand() + + trait PathDependent { + + def call0[T <: Command](cmd: T): cmd.Answer + + def call1[T <: Command](x: Int)(cmd: T): cmd.Answer + + def call2[T <: Command](y: String)(cmd: T)(x: Int): cmd.Answer + + def call3[T <: Command](cmd: T)(y: String)(x: Int): cmd.Answer + + def call4[T <: Command](cmd: T): Option[cmd.Answer] + + def call5[T <: Command](cmd: T)(x: cmd.Answer): Unit + + def call6[T <: Command](cmd: T): cmd.AnswerConstructor[Int] + + def call7[T <: Command](cmd: T)(x: cmd.AnswerConstructor[String])(y: cmd.Answer): Unit + } + + + it("path dependent in return type") { + val pathDependent = mock[PathDependent] + + (pathDependent.call0[IntCommand] _).expects(cmd).returns(5) + + assert(pathDependent.call0(cmd) == 5) + } + + it("path dependent in return type and parameter in last parameter list") { + val pathDependent = mock[PathDependent] + + (pathDependent.call1(_: Int)(_: IntCommand)).expects(5, cmd).returns(5) + + assert(pathDependent.call1(5)(cmd) == 5) + } + + it("path dependent in return type and parameter in middle parameter list ") { + val pathDependent = mock[PathDependent] + + (pathDependent.call2(_: String)(_: IntCommand)(_: Int)).expects("5", cmd, 5).returns(5) + + assert(pathDependent.call2("5")(cmd)(5) == 5) + } + + it("path dependent in return type and parameter in first parameter list ") { + val pathDependent = mock[PathDependent] + + (pathDependent.call3(_: IntCommand)(_: String)(_: Int)).expects(cmd, "5", 5).returns(5) + + assert(pathDependent.call3(cmd)("5")(5) == 5) + } + + it("path dependent in tycon return type") { + val pathDependent = mock[PathDependent] + + (pathDependent.call4[IntCommand] _).expects(cmd).returns(Some(5)) + + assert(pathDependent.call4(cmd) == Some(5)) + } + + it("path dependent in parameter list") { + val pathDependent = mock[PathDependent] + + (pathDependent.call5(_: IntCommand)(_: Int)).expects(cmd, 5).returns(()) + + assert(pathDependent.call5(cmd)(5) == ()) + } + + it("path dependent tycon in return type") { + val pathDependent = mock[PathDependent] + + (pathDependent.call6[IntCommand] _).expects(cmd).returns(Some(5)) + + assert(pathDependent.call6(cmd) == Some(5)) + } + + it("path dependent tycon in parameter list") { + val pathDependent = mock[PathDependent] + + (pathDependent.call7[IntCommand](_: IntCommand)(_: Option[String])(_: Int)) + .expects(cmd, Some("5"), 6) + .returns(()) + + assert(pathDependent.call7(cmd)(Some("5"))(6) == ()) + } + +}