Skip to content

Commit

Permalink
path dependent types support added to scala 3
Browse files Browse the repository at this point in the history
  • Loading branch information
goshacodes committed Feb 26, 2024
1 parent 10f89c6 commit 29dd436
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 37 deletions.
6 changes: 3 additions & 3 deletions shared/src/main/scala-3/org/scalamock/clazz/MockMaker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@ private[clazz] object MockMaker:
Symbol.newVal(
parent = classSymbol,
name = definition.symbol.name,
tpe = definition.tpeWithSubstitutedPathDependentFor(classSymbol),
tpe = definition.tpeWithSubstitutedInnerTypesFor(classSymbol),
flags = Flags.Override,
privateWithin = Symbol.noSymbol
)
else
Symbol.newMethod(
parent = classSymbol,
name = definition.symbol.name,
tpe = definition.tpeWithSubstitutedPathDependentFor(classSymbol),
tpe = definition.tpeWithSubstitutedInnerTypesFor(classSymbol),
flags = Flags.Override,
privateWithin = Symbol.noSymbol
)
Expand Down Expand Up @@ -177,7 +177,7 @@ private[clazz] object MockMaker:
"asInstanceOf"
),
definition.tpe
.resolveParamRefs(definition.resTypeWithPathDependentOverrideFor(classSymbol), args)
.resolveParamRefs(definition.resTypeWithInnerTypesOverrideFor(classSymbol), args)
.asType match { case '[t] => List(TypeTree.of[t]) }
)
)
Expand Down
93 changes: 59 additions & 34 deletions shared/src/main/scala-3/org/scalamock/clazz/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@ package org.scalamock.clazz
import scala.quoted.*
import org.scalamock.context.MockContext

import scala.annotation.tailrec
import scala.annotation.{experimental, tailrec}
private[clazz] class Utils(using val quotes: Quotes):
import quotes.reflect.*

extension (tpe: TypeRepr)
def collectPathDependent(ownerSymbol: Symbol): List[TypeRepr] =
def collectInnerTypes(ownerSymbol: Symbol): List[TypeRepr] =
def loop(currentTpe: TypeRepr, names: List[String]): List[TypeRepr] =
currentTpe match
case AppliedType(inner, appliedTypes) => loop(inner, names) ++ appliedTypes.flatMap(_.collectPathDependent(ownerSymbol))
case AppliedType(inner, appliedTypes) => loop(inner, names) ++ appliedTypes.flatMap(_.collectInnerTypes(ownerSymbol))
case TypeRef(inner, name) if name == ownerSymbol.name && names.nonEmpty => List(tpe)
case TypeRef(inner, name) => loop(inner, name :: names)
case _ => Nil

loop(tpe, Nil)

def pathDependentOverride(ownerSymbol: Symbol, newOwnerSymbol: Symbol, applyTypes: Boolean): TypeRepr =
def innerTypeOverride(ownerSymbol: Symbol, newOwnerSymbol: Symbol, applyTypes: Boolean): TypeRepr =
@tailrec
def loop(currentTpe: TypeRepr, names: List[(String, List[TypeRepr])], appliedTypes: List[TypeRepr]): TypeRepr =
currentTpe match
Expand Down Expand Up @@ -53,55 +53,80 @@ private[clazz] class Utils(using val quotes: Quotes):
case _ =>
tpe

@experimental
def resolveParamRefs(resType: TypeRepr, methodArgs: List[List[Tree]]) =
def loop(baseBindings: TypeRepr, typeRepr: TypeRepr): TypeRepr =
typeRepr match
case pr@ParamRef(bindings, idx) if bindings == baseBindings =>
methodArgs.head(idx).asInstanceOf[TypeTree].tpe
tpe match
case baseBindings: PolyType =>
def loop(typeRepr: TypeRepr): TypeRepr =
typeRepr match
case pr@ParamRef(bindings, idx) if bindings == baseBindings =>
methodArgs.head(idx).asInstanceOf[TypeTree].tpe

case AppliedType(tycon, args) =>
AppliedType(tycon, args.map(arg => loop(baseBindings, arg)))
case AppliedType(tycon, args) =>
AppliedType(loop(tycon), args.map(arg => loop(arg)))

case other => other
case ff @ TypeRef(ref @ ParamRef(bindings, idx), name) =>
def getIndex(bindings: TypeRepr): Int =
@tailrec
def loop(bindings: TypeRepr, idx: Int): Int =
bindings match
case MethodType(_, _, method: MethodType) => loop(method, idx + 1)
case _ => idx

tpe match
case pt: PolyType => loop(pt, resType)
case _ => resType
loop(bindings, 1)

val maxIndex = methodArgs.length
val parameterListIdx = maxIndex - getIndex(bindings)

TypeSelect(methodArgs(parameterListIdx)(idx).asInstanceOf[Term], name).tpe

case other => other

loop(resType)
case _ =>
resType


def collectTypes: List[TypeRepr] =
def loop(currentTpe: TypeRepr, params: List[TypeRepr]): List[TypeRepr] =
def collectTypes: (List[TypeRepr], TypeRepr) =
@tailrec
def loop(currentTpe: TypeRepr, argTypesAcc: List[List[TypeRepr]], resType: TypeRepr): (List[TypeRepr], TypeRepr) =
currentTpe match
case PolyType(_, _, res) => loop(res, Nil)
case MethodType(_, argTypes, res) => argTypes ++ loop(res, params)
case other => List(other)
loop(tpe, Nil)
case PolyType(_, _, res) => loop(res, List.empty[TypeRepr] :: argTypesAcc, resType)
case MethodType(_, argTypes, res) => loop(res, argTypes :: argTypesAcc, resType)
case other => (argTypesAcc.reverse.flatten, other)
loop(tpe, Nil, TypeRepr.of[Nothing])

case class MockableDefinition(idx: Int, symbol: Symbol, ownerTpe: TypeRepr):
val mockValName = s"mock$$${symbol.name}$$$idx"
val tpe = ownerTpe.memberType(symbol)
private val rawTypes = tpe.widen.collectTypes
private val (rawTypes, rawResType) = tpe.widen.collectTypes
val parameterTypes = prepareTypesFor(ownerTpe.typeSymbol).map(_.tpe).init

def resTypeWithPathDependentOverrideFor(classSymbol: Symbol): TypeRepr =
val pd = rawTypes.last.collectPathDependent(ownerTpe.typeSymbol)
val pdUpdated = pd.map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false))
rawTypes.last.substituteTypes(pd.map(_.typeSymbol), pdUpdated)
def resTypeWithInnerTypesOverrideFor(classSymbol: Symbol): TypeRepr =
updatePathDependent(rawResType, List(rawResType), classSymbol)

def tpeWithSubstitutedInnerTypesFor(classSymbol: Symbol): TypeRepr =
updatePathDependent(tpe, rawResType :: rawTypes, classSymbol)

def tpeWithSubstitutedPathDependentFor(classSymbol: Symbol): TypeRepr =
val pathDependentTypes = rawTypes.flatMap(_.collectPathDependent(ownerTpe.typeSymbol))
val pdUpdated = pathDependentTypes.map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false))
tpe.substituteTypes(pathDependentTypes.map(_.typeSymbol), pdUpdated)
private def updatePathDependent(where: TypeRepr, types: List[TypeRepr], classSymbol: Symbol): TypeRepr =
val pathDependentTypes = types.flatMap(_.collectInnerTypes(ownerTpe.typeSymbol))
val pdUpdated = pathDependentTypes.map(_.innerTypeOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false))
where.substituteTypes(pathDependentTypes.map(_.typeSymbol), pdUpdated)

def prepareTypesFor(classSymbol: Symbol) = rawTypes
.map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = true))
def prepareTypesFor(classSymbol: Symbol) = (rawTypes :+ rawResType)
.map(_.innerTypeOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = true))
.map { typeRepr =>
val adjusted =
typeRepr.widen.mapParamRefWithWildcard match
case TypeBounds(lower, upper) => upper
case AppliedType(TypeRef(_, "<repeated>"), elemTyps) =>
TypeRepr.typeConstructorOf(classOf[Seq[_]]).appliedTo(elemTyps)
case other => other
case TypeRef(_: ParamRef, _) =>
TypeRepr.of[Any]
case AppliedType(TypeRef(_: ParamRef, _), _) =>
TypeRepr.of[Any]
case other =>
other
adjusted.asType match
case '[t] => TypeTree.of[t]
}
Expand All @@ -128,9 +153,9 @@ private[clazz] class Utils(using val quotes: Quotes):

def apply(tpe: TypeRepr): List[MockableDefinition] =
val methods = (tpe.typeSymbol.methodMembers.toSet -- TypeRepr.of[Object].typeSymbol.methodMembers).toList
.filter(sym =>
.filter(sym =>
!sym.flags.is(Flags.Private) &&
!sym.flags.is(Flags.Final) &&
!sym.flags.is(Flags.Final) &&
!sym.flags.is(Flags.Mutable) &&
!sym.name.contains("$default$")
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package com.paulbutcher.test

import org.scalamock.matchers.Matchers
import org.scalamock.scalatest.MockFactory
import org.scalatest.funspec.AnyFunSpec

class PathDependentParamSpec extends AnyFunSpec with Matchers with MockFactory {

trait Command {
type Answer
type AnswerConstructor[A]
}

case class IntCommand() extends Command {
override type Answer = Int
override type AnswerConstructor[A] = Option[A]
}

val cmd = IntCommand()

trait PathDependent {

def call0[T <: Command](cmd: T): cmd.Answer

def call1[T <: Command](x: Int)(cmd: T): cmd.Answer

def call2[T <: Command](y: String)(cmd: T)(x: Int): cmd.Answer

def call3[T <: Command](cmd: T)(y: String)(x: Int): cmd.Answer

def call4[T <: Command](cmd: T): Option[cmd.Answer]

def call5[T <: Command](cmd: T)(x: cmd.Answer): Unit

def call6[T <: Command](cmd: T): cmd.AnswerConstructor[Int]

def call7[T <: Command](cmd: T)(x: cmd.AnswerConstructor[String])(y: cmd.Answer): Unit
}


it("path dependent in return type") {
val pathDependent = mock[PathDependent]

(pathDependent.call0[IntCommand] _).expects(cmd).returns(5)

assert(pathDependent.call0(cmd) == 5)
}

it("path dependent in return type and parameter in last parameter list") {
val pathDependent = mock[PathDependent]

(pathDependent.call1(_: Int)(_: IntCommand)).expects(5, cmd).returns(5)

assert(pathDependent.call1(5)(cmd) == 5)
}

it("path dependent in return type and parameter in middle parameter list ") {
val pathDependent = mock[PathDependent]

(pathDependent.call2(_: String)(_: IntCommand)(_: Int)).expects("5", cmd, 5).returns(5)

assert(pathDependent.call2("5")(cmd)(5) == 5)
}

it("path dependent in return type and parameter in first parameter list ") {
val pathDependent = mock[PathDependent]

(pathDependent.call3(_: IntCommand)(_: String)(_: Int)).expects(cmd, "5", 5).returns(5)

assert(pathDependent.call3(cmd)("5")(5) == 5)
}

it("path dependent in tycon return type") {
val pathDependent = mock[PathDependent]

(pathDependent.call4[IntCommand] _).expects(cmd).returns(Some(5))

assert(pathDependent.call4(cmd) == Some(5))
}

it("path dependent in parameter list") {
val pathDependent = mock[PathDependent]

(pathDependent.call5(_: IntCommand)(_: Int)).expects(cmd, 5).returns(())

assert(pathDependent.call5(cmd)(5) == ())
}

it("path dependent tycon in return type") {
val pathDependent = mock[PathDependent]

(pathDependent.call6[IntCommand] _).expects(cmd).returns(Some(5))

assert(pathDependent.call6(cmd) == Some(5))
}

it("path dependent tycon in parameter list") {
val pathDependent = mock[PathDependent]

(pathDependent.call7[IntCommand](_: IntCommand)(_: Option[String])(_: Int))
.expects(cmd, Some("5"), 6)
.returns(())

assert(pathDependent.call7(cmd)(Some("5"))(6) == ())
}

}

0 comments on commit 29dd436

Please sign in to comment.