Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scala 3 fixes and improvements #509

Merged
merged 3 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions shared/src/main/scala-3/org/scalamock/clazz/MockMaker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
package org.scalamock.clazz

import org.scalamock.context.MockContext

import scala.quoted.*
import scala.reflect.Selectable

Expand All @@ -42,16 +41,23 @@ private[clazz] object MockMaker:
def asParent(tree: TypeTree): TypeTree | Term =
val constructorFieldsFilledWithNulls: List[List[Term]] =
tree.tpe.dealias.typeSymbol.primaryConstructor.paramSymss
.filter(_.exists(!_.isType))
.map(_.map(_.typeRef.asType match { case '[t] => '{ null.asInstanceOf[t] }.asTerm }))
.filterNot(_.exists(_.isType))
.map(_.map(_.info.widen match {
case t@AppliedType(inner, applied) =>
Select.unique('{null}.asTerm, "asInstanceOf").appliedToTypes(List(inner.appliedTo(tpe.typeArgs)))
case other =>
Select.unique('{null}.asTerm, "asInstanceOf").appliedToTypes(List(other))
}))

if constructorFieldsFilledWithNulls.forall(_.isEmpty) then
tree
else
Select(
New(TypeIdent(tree.tpe.typeSymbol)),
tree.tpe.typeSymbol.primaryConstructor
).appliedToArgss(constructorFieldsFilledWithNulls)
).appliedToTypes(tree.tpe.typeArgs)
.appliedToArgss(constructorFieldsFilledWithNulls)



val parents =
Expand Down Expand Up @@ -91,15 +97,15 @@ private[clazz] object MockMaker:
Symbol.newVal(
parent = classSymbol,
name = definition.symbol.name,
tpe = definition.tpeWithSubstitutedPathDependentFor(classSymbol),
tpe = definition.tpeWithSubstitutedInnerTypesFor(classSymbol),
flags = Flags.Override,
privateWithin = Symbol.noSymbol
)
else
Symbol.newMethod(
parent = classSymbol,
name = definition.symbol.name,
tpe = definition.tpeWithSubstitutedPathDependentFor(classSymbol),
tpe = definition.tpeWithSubstitutedInnerTypesFor(classSymbol),
flags = Flags.Override,
privateWithin = Symbol.noSymbol
)
Expand Down Expand Up @@ -177,7 +183,7 @@ private[clazz] object MockMaker:
"asInstanceOf"
),
definition.tpe
.resolveParamRefs(definition.resTypeWithPathDependentOverrideFor(classSymbol), args)
.resolveParamRefs(definition.resTypeWithInnerTypesOverrideFor(classSymbol), args)
.asType match { case '[t] => List(TypeTree.of[t]) }
)
)
Expand Down
98 changes: 62 additions & 36 deletions shared/src/main/scala-3/org/scalamock/clazz/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@ package org.scalamock.clazz
import scala.quoted.*
import org.scalamock.context.MockContext

import scala.annotation.tailrec
import scala.annotation.{experimental, tailrec}
private[clazz] class Utils(using val quotes: Quotes):
import quotes.reflect.*

extension (tpe: TypeRepr)
def collectPathDependent(ownerSymbol: Symbol): List[TypeRepr] =
def collectInnerTypes(ownerSymbol: Symbol): List[TypeRepr] =
def loop(currentTpe: TypeRepr, names: List[String]): List[TypeRepr] =
currentTpe match
case AppliedType(inner, appliedTypes) => loop(inner, names) ++ appliedTypes.flatMap(_.collectPathDependent(ownerSymbol))
case AppliedType(inner, appliedTypes) => loop(inner, names) ++ appliedTypes.flatMap(_.collectInnerTypes(ownerSymbol))
case TypeRef(inner, name) if name == ownerSymbol.name && names.nonEmpty => List(tpe)
case TypeRef(inner, name) => loop(inner, name :: names)
case _ => Nil

loop(tpe, Nil)

def pathDependentOverride(ownerSymbol: Symbol, newOwnerSymbol: Symbol, applyTypes: Boolean): TypeRepr =
def innerTypeOverride(ownerSymbol: Symbol, newOwnerSymbol: Symbol, applyTypes: Boolean): TypeRepr =
@tailrec
def loop(currentTpe: TypeRepr, names: List[(String, List[TypeRepr])], appliedTypes: List[TypeRepr]): TypeRepr =
currentTpe match
Expand Down Expand Up @@ -53,55 +53,80 @@ private[clazz] class Utils(using val quotes: Quotes):
case _ =>
tpe

@experimental
def resolveParamRefs(resType: TypeRepr, methodArgs: List[List[Tree]]) =
def loop(baseBindings: TypeRepr, typeRepr: TypeRepr): TypeRepr =
typeRepr match
case pr@ParamRef(bindings, idx) if bindings == baseBindings =>
methodArgs.head(idx).asInstanceOf[TypeTree].tpe
tpe match
case baseBindings: PolyType =>
def loop(typeRepr: TypeRepr): TypeRepr =
typeRepr match
case pr@ParamRef(bindings, idx) if bindings == baseBindings =>
methodArgs.head(idx).asInstanceOf[TypeTree].tpe

case AppliedType(tycon, args) =>
AppliedType(tycon, args.map(arg => loop(baseBindings, arg)))
case AppliedType(tycon, args) =>
AppliedType(loop(tycon), args.map(arg => loop(arg)))

case other => other
case ff @ TypeRef(ref @ ParamRef(bindings, idx), name) =>
def getIndex(bindings: TypeRepr): Int =
@tailrec
def loop(bindings: TypeRepr, idx: Int): Int =
bindings match
case MethodType(_, _, method: MethodType) => loop(method, idx + 1)
case _ => idx

tpe match
case pt: PolyType => loop(pt, resType)
case _ => resType
loop(bindings, 1)

val maxIndex = methodArgs.length
val parameterListIdx = maxIndex - getIndex(bindings)

TypeSelect(methodArgs(parameterListIdx)(idx).asInstanceOf[Term], name).tpe

case other => other

loop(resType)
case _ =>
resType


def collectTypes: List[TypeRepr] =
def loop(currentTpe: TypeRepr, params: List[TypeRepr]): List[TypeRepr] =
def collectTypes: (List[TypeRepr], TypeRepr) =
@tailrec
def loop(currentTpe: TypeRepr, argTypesAcc: List[List[TypeRepr]], resType: TypeRepr): (List[TypeRepr], TypeRepr) =
currentTpe match
case PolyType(_, _, res) => loop(res, Nil)
case MethodType(_, argTypes, res) => argTypes ++ loop(res, params)
case other => List(other)
loop(tpe, Nil)
case PolyType(_, _, res) => loop(res, List.empty[TypeRepr] :: argTypesAcc, resType)
case MethodType(_, argTypes, res) => loop(res, argTypes :: argTypesAcc, resType)
case other => (argTypesAcc.reverse.flatten, other)
loop(tpe, Nil, TypeRepr.of[Nothing])

case class MockableDefinition(idx: Int, symbol: Symbol, ownerTpe: TypeRepr):
val mockValName = s"mock$$${symbol.name}$$$idx"
val tpe = ownerTpe.memberType(symbol)
private val rawTypes = tpe.widen.collectTypes
private val (rawTypes, rawResType) = tpe.widen.collectTypes
val parameterTypes = prepareTypesFor(ownerTpe.typeSymbol).map(_.tpe).init

def resTypeWithPathDependentOverrideFor(classSymbol: Symbol): TypeRepr =
val pd = rawTypes.last.collectPathDependent(ownerTpe.typeSymbol)
val pdUpdated = pd.map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false))
rawTypes.last.substituteTypes(pd.map(_.typeSymbol), pdUpdated)
def resTypeWithInnerTypesOverrideFor(classSymbol: Symbol): TypeRepr =
updatePathDependent(rawResType, List(rawResType), classSymbol)

def tpeWithSubstitutedInnerTypesFor(classSymbol: Symbol): TypeRepr =
updatePathDependent(tpe, rawResType :: rawTypes, classSymbol)

def tpeWithSubstitutedPathDependentFor(classSymbol: Symbol): TypeRepr =
val pathDependentTypes = rawTypes.flatMap(_.collectPathDependent(ownerTpe.typeSymbol))
val pdUpdated = pathDependentTypes.map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false))
tpe.substituteTypes(pathDependentTypes.map(_.typeSymbol), pdUpdated)
private def updatePathDependent(where: TypeRepr, types: List[TypeRepr], classSymbol: Symbol): TypeRepr =
val pathDependentTypes = types.flatMap(_.collectInnerTypes(ownerTpe.typeSymbol))
val pdUpdated = pathDependentTypes.map(_.innerTypeOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false))
where.substituteTypes(pathDependentTypes.map(_.typeSymbol), pdUpdated)

def prepareTypesFor(classSymbol: Symbol) = rawTypes
.map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = true))
def prepareTypesFor(classSymbol: Symbol) = (rawTypes :+ rawResType)
.map(_.innerTypeOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = true))
.map { typeRepr =>
val adjusted =
typeRepr.widen.mapParamRefWithWildcard match
case TypeBounds(lower, upper) => upper
case AppliedType(TypeRef(_, "<repeated>"), elemTyps) =>
TypeRepr.typeConstructorOf(classOf[Seq[_]]).appliedTo(elemTyps)
case other => other
case TypeRef(_: ParamRef, _) =>
TypeRepr.of[Any]
case AppliedType(TypeRef(_: ParamRef, _), _) =>
TypeRepr.of[Any]
case other =>
other
adjusted.asType match
case '[t] => TypeTree.of[t]
}
Expand All @@ -128,10 +153,11 @@ private[clazz] class Utils(using val quotes: Quotes):

def apply(tpe: TypeRepr): List[MockableDefinition] =
val methods = (tpe.typeSymbol.methodMembers.toSet -- TypeRepr.of[Object].typeSymbol.methodMembers).toList
.filter(sym => !sym.flags.is(Flags.Private) && !sym.flags.is(Flags.Final) && !sym.flags.is(Flags.Mutable))
.filterNot(sym => tpe.memberType(sym) match
case defaultParam @ ByNameType(AnnotatedType(_, Apply(Select(New(Inferred()), "<init>"), Nil))) => true
case _ => false
.filter(sym =>
!sym.flags.is(Flags.Private) &&
!sym.flags.is(Flags.Final) &&
!sym.flags.is(Flags.Mutable) &&
!sym.name.contains("$default$")
)
.zipWithIndex
.map((sym, idx) => MockableDefinition(idx, sym, tpe))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package com.paulbutcher.test

import org.scalamock.scalatest.MockFactory
import org.scalatest.funspec.AnyFunSpec

import scala.reflect.ClassTag

class ClassWithContextBoundSpec extends AnyFunSpec with MockFactory {

it("compile without args") {
class ContextBounded[T: ClassTag] {
def method(x: Int): Unit = ()
}

val m = mock[ContextBounded[String]]

}

it("compile with args") {
class ContextBounded[T: ClassTag](x: Int) {
def method(x: Int): Unit = ()
}

val m = mock[ContextBounded[String]]

}

it("compile with provided explicitly type class") {
class ContextBounded[T](x: ClassTag[T]) {
def method(x: Int): Unit = ()
}

val m = mock[ContextBounded[String]]

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package com.paulbutcher.test

import org.scalamock.matchers.Matchers
import org.scalamock.scalatest.MockFactory
import org.scalatest.funspec.AnyFunSpec

class PathDependentParamSpec extends AnyFunSpec with Matchers with MockFactory {

trait Command {
type Answer
type AnswerConstructor[A]
}

case class IntCommand() extends Command {
override type Answer = Int
override type AnswerConstructor[A] = Option[A]
}

val cmd = IntCommand()

trait PathDependent {

def call0[T <: Command](cmd: T): cmd.Answer

def call1[T <: Command](x: Int)(cmd: T): cmd.Answer

def call2[T <: Command](y: String)(cmd: T)(x: Int): cmd.Answer

def call3[T <: Command](cmd: T)(y: String)(x: Int): cmd.Answer

def call4[T <: Command](cmd: T): Option[cmd.Answer]

def call5[T <: Command](cmd: T)(x: cmd.Answer): Unit

def call6[T <: Command](cmd: T): cmd.AnswerConstructor[Int]

def call7[T <: Command](cmd: T)(x: cmd.AnswerConstructor[String])(y: cmd.Answer): Unit
}


it("path dependent in return type") {
val pathDependent = mock[PathDependent]

(pathDependent.call0[IntCommand] _).expects(cmd).returns(5)

assert(pathDependent.call0(cmd) == 5)
}

it("path dependent in return type and parameter in last parameter list") {
val pathDependent = mock[PathDependent]

(pathDependent.call1(_: Int)(_: IntCommand)).expects(5, cmd).returns(5)

assert(pathDependent.call1(5)(cmd) == 5)
}

it("path dependent in return type and parameter in middle parameter list ") {
val pathDependent = mock[PathDependent]

(pathDependent.call2(_: String)(_: IntCommand)(_: Int)).expects("5", cmd, 5).returns(5)

assert(pathDependent.call2("5")(cmd)(5) == 5)
}

it("path dependent in return type and parameter in first parameter list ") {
val pathDependent = mock[PathDependent]

(pathDependent.call3(_: IntCommand)(_: String)(_: Int)).expects(cmd, "5", 5).returns(5)

assert(pathDependent.call3(cmd)("5")(5) == 5)
}

it("path dependent in tycon return type") {
val pathDependent = mock[PathDependent]

(pathDependent.call4[IntCommand] _).expects(cmd).returns(Some(5))

assert(pathDependent.call4(cmd) == Some(5))
}

it("path dependent in parameter list") {
val pathDependent = mock[PathDependent]

(pathDependent.call5(_: IntCommand)(_: Int)).expects(cmd, 5).returns(())

assert(pathDependent.call5(cmd)(5) == ())
}

it("path dependent tycon in return type") {
val pathDependent = mock[PathDependent]

(pathDependent.call6[IntCommand] _).expects(cmd).returns(Some(5))

assert(pathDependent.call6(cmd) == Some(5))
}

it("path dependent tycon in parameter list") {
val pathDependent = mock[PathDependent]

(pathDependent.call7[IntCommand](_: IntCommand)(_: Option[String])(_: Int))
.expects(cmd, Some("5"), 6)
.returns(())

assert(pathDependent.call7(cmd)(Some("5"))(6) == ())
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class MethodsWithDefaultParamsTest extends IsolatedSpec {

trait TraitHavingMethodsWithDefaultParams {
def withAllDefaultParams(a: String = "default", b: CaseClass = CaseClass(42)): String

def withDefaultParamAndTypeParam[T](a: String = "default", b: Int = 5): T
}

behavior of "Mocks"
Expand Down Expand Up @@ -84,5 +86,13 @@ class MethodsWithDefaultParamsTest extends IsolatedSpec {
m.withAllDefaultParams("other", CaseClass(99))
}

they should "mock trait methods with type param and default parameters" in {
val m = mock[TraitHavingMethodsWithDefaultParams]

(m.withDefaultParamAndTypeParam[Int] _).expects("default", 5).returns(5)

m.withDefaultParamAndTypeParam[Int]("default", 5) shouldBe 5
}

override def newInstance = new MethodsWithDefaultParamsTest
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@

package org.scalamock.test.scalatest

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest._
import org.scalatest.flatspec.{AnyFlatSpec, AsyncFlatSpec}
import org.scalamock.scalatest.{MockFactory, AsyncMockFactory}

/**
* Tests for issue #371
*/
@Ignore
class AsyncSyncMixinTest extends AnyFlatSpec {

"MockFactory" should "be mixed only with Any*Spec and not Async*Spec traits" in {
Expand Down
Loading