diff --git a/plugin/src/main/scala-2/BetterToStringPlugin.scala b/plugin/src/main/scala-2/BetterToStringPlugin.scala index 9844c8c..699bd83 100644 --- a/plugin/src/main/scala-2/BetterToStringPlugin.scala +++ b/plugin/src/main/scala-2/BetterToStringPlugin.scala @@ -22,27 +22,33 @@ final class BetterToStringPluginComponent(val global: Global) extends PluginComp override val phaseName: String = "better-tostring-phase" override val runsAfter: List[String] = List("parser") - private val impl: BetterToStringImpl[Scala2CompilerApi[global.type]] = - BetterToStringImpl.instance(Scala2CompilerApi.instance(global)) + private val api: Scala2CompilerApi[global.type] = Scala2CompilerApi.instance(global) + private val impl = BetterToStringImpl.instance(api) private def modifyClasses(tree: Tree, enclosingObject: Option[ModuleDef]): Tree = tree match { - case p: PackageDef => p.copy(stats = p.stats.map(modifyClasses(_, None))) - // https://github.com/polyvariant/better-tostring/issues/59 - // start here - ModuleDef which is a case object should be transformed. - // We might need to change the type of CompilerApi#Clazz to allow objects. - case m: ModuleDef => + case p: PackageDef => p.copy(stats = p.stats.map(modifyClasses(_, None))) + + case m: ModuleDef if m.mods.isCase => + // isNested=false for the same reason as in the ClassDef case + impl.transformClass(api.Classable.Obj(m), isNested = false, enclosingObject).merge + + case m: ModuleDef => m.copy(impl = m.impl.copy(body = m.impl.body.map(modifyClasses(_, Some(m))))) + case clazz: ClassDef => - impl.transformClass( - clazz, - // If it was nested, we wouldn't be in this branch. - // Scala 2.x compiler API limitation (classes can't tell what the owner is). - // This should be more optimal as we don't traverse every template, but it hasn't been benchmarked. - isNested = false, - enclosingObject - ) - case other => other + impl + .transformClass( + api.Classable.Clazz(clazz), + // If it was nested, we wouldn't be in this branch. + // Scala 2.x compiler API limitation (classes can't tell what the owner is). + // This should be more optimal as we don't traverse every template, but it hasn't been benchmarked. + isNested = false, + enclosingObject + ) + .merge + + case other => other } override def newPhase(prev: Phase): Phase = new StdPhase(prev) { diff --git a/plugin/src/main/scala-2/Scala2CompilerApi.scala b/plugin/src/main/scala-2/Scala2CompilerApi.scala index ea3c006..93d92a8 100644 --- a/plugin/src/main/scala-2/Scala2CompilerApi.scala +++ b/plugin/src/main/scala-2/Scala2CompilerApi.scala @@ -6,8 +6,36 @@ import scala.tools.nsc.Global trait Scala2CompilerApi[G <: Global] extends CompilerApi { val theGlobal: G import theGlobal._ + + sealed trait Classable extends Product with Serializable { + + def bimap( + clazz: ClassDef => ClassDef, + obj: ModuleDef => ModuleDef + ): Classable = this match { + case Classable.Clazz(c) => Classable.Clazz(clazz(c)) + case Classable.Obj(o) => Classable.Obj(obj(o)) + } + + def fold[A]( + clazz: ClassDef => A, + obj: ModuleDef => A + ): A = this match { + case Classable.Clazz(c) => clazz(c) + case Classable.Obj(o) => obj(o) + } + + def merge: ImplDef = fold(identity, identity) + + } + + object Classable { + case class Clazz(c: ClassDef) extends Classable + case class Obj(o: ModuleDef) extends Classable + } + type Tree = theGlobal.Tree - type Clazz = ClassDef + type Clazz = Classable type Param = ValDef type ParamName = TermName type Method = DefDef @@ -21,11 +49,14 @@ object Scala2CompilerApi { val theGlobal: global.type = global import global._ - def params(clazz: Clazz): List[Param] = clazz.impl.body.collect { - case v: ValDef if v.mods.hasFlag(Flags.CASEACCESSOR) => v - } + def params(clazz: Clazz): List[Param] = clazz.fold( + clazz = _.impl.body.collect { + case v: ValDef if v.mods.isCaseAccessor => v + }, + obj = _ => Nil + ) - def className(clazz: Clazz): String = clazz.name.toString + def className(clazz: Clazz): String = clazz.merge.name.toString def isPackageOrPackageObject(enclosingObject: EnclosingObject): Boolean = // couldn't find any nice api for this. `m.symbol.isPackageObject` does not work after the parser compiler phase (needs to run later). @@ -46,16 +77,27 @@ object Scala2CompilerApi { body ) - def addMethod(clazz: Clazz, method: Method): Clazz = - clazz.copy(impl = clazz.impl.copy(body = clazz.impl.body :+ method)) + def addMethod(clazz: Clazz, method: Method): Clazz = { + val newBody = clazz.merge.impl.copy(body = clazz.merge.impl.body :+ method) + clazz.bimap( + clazz = _.copy(impl = newBody), + obj = _.copy(impl = newBody) + ) + } - def methodNames(clazz: Clazz): List[String] = clazz.impl.body.collect { + def methodNames(clazz: Clazz): List[String] = clazz.merge.impl.body.collect { case d: DefDef => d.name.toString case d: ValDef => d.name.toString } - def isCaseClass(clazz: Clazz): Boolean = clazz.mods.hasFlag(Flags.CASE) - def isObject(clazz: Clazz): Boolean = clazz.mods.hasFlag(Flags.MODULE) + def isCaseClass(clazz: Clazz): Boolean = clazz.merge.mods.isCase + + // Always return true for ModuleDef - apparently ModuleDef doesn't have the module flag... + def isObject(clazz: Clazz): Boolean = clazz.fold( + clazz = _.mods.hasModuleFlag, + obj = _ => true + ) + } } diff --git a/tests/src/test/scala/Tests.scala b/tests/src/test/scala/Tests.scala index c6d57fb..b60b19a 100644 --- a/tests/src/test/scala/Tests.scala +++ b/tests/src/test/scala/Tests.scala @@ -1,6 +1,4 @@ import munit.FunSuite -import munit.TestOptions -import b2s.buildinfo.BuildInfo class Tests extends FunSuite { @@ -52,8 +50,7 @@ class Tests extends FunSuite { ) } - // https://github.com/polyvariant/better-tostring/issues/59 - test(onlyScala3("Case object nested in an object should include enclosing object's name")) { + test("Case object nested in an object should include enclosing object's name") { assertEquals( ObjectNestedParent.ObjectNestedObject.toString, "ObjectNestedParent.ObjectNestedObject" @@ -109,11 +106,6 @@ class Tests extends FunSuite { ) } - def onlyScala3(name: String) = { - val isScala3 = BuildInfo.scalaVersion.startsWith("3") - if (isScala3) name: TestOptions else name.fail - } - } case object CaseObject