diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/Mtags.scala b/mtags/src/main/scala/scala/meta/internal/mtags/Mtags.scala index af9233fbbb5..805e88eced0 100644 --- a/mtags/src/main/scala/scala/meta/internal/mtags/Mtags.scala +++ b/mtags/src/main/scala/scala/meta/internal/mtags/Mtags.scala @@ -34,6 +34,7 @@ final class Mtags(implicit rc: ReportContext) { addLines(language, input.text) mtags .index() + .textDocument .occurrences .iterator .filterNot(_.symbol.isPackage) @@ -55,8 +56,9 @@ final class Mtags(implicit rc: ReportContext) { JavaMtags .index(input, includeMembers = true) .index() + .textDocument } else if (language.isScala) { - ScalaMtags.index(input, dialect).index() + ScalaMtags.index(input, dialect).index().textDocument } else { TextDocument() } @@ -91,11 +93,11 @@ object Mtags { .toList } - def allToplevels( + def allToplevelsEnriched( input: Input.VirtualFile, dialect: Dialect, includeMembers: Boolean = true - )(implicit rc: ReportContext = EmptyReportContext): TextDocument = { + )(implicit rc: ReportContext = EmptyReportContext): EnrichedTextDocument = { input.toLanguage match { case Language.JAVA => new JavaMtags(input, includeMembers = true).index() @@ -104,9 +106,17 @@ object Mtags { new ScalaToplevelMtags(input, true, includeMembers, dialect) mtags.index() case _ => - TextDocument() + JustDocument(TextDocument()) } } + + def allToplevels( + input: Input.VirtualFile, + dialect: Dialect, + includeMembers: Boolean = true + )(implicit rc: ReportContext = EmptyReportContext): TextDocument = + allToplevelsEnriched(input, dialect, includeMembers).textDocument + def toplevels( input: Input.VirtualFile, dialect: Dialect = dialects.Scala213 diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/MtagsIndexer.scala b/mtags/src/main/scala/scala/meta/internal/mtags/MtagsIndexer.scala index 730feaab2e1..09113451a7c 100644 --- a/mtags/src/main/scala/scala/meta/internal/mtags/MtagsIndexer.scala +++ b/mtags/src/main/scala/scala/meta/internal/mtags/MtagsIndexer.scala @@ -11,18 +11,21 @@ import scala.meta.internal.semanticdb.Scala._ import scala.meta.internal.semanticdb.SymbolInformation.Kind import scala.meta.internal.{semanticdb => s} -trait MtagsIndexer { +trait GenericMtagsIndexer[T <: EnrichedTextDocument] { def language: Language def indexRoot(): Unit def input: Input.VirtualFile - def index(): s.TextDocument = { + protected def documentToResult(doc: s.TextDocument): T + def index(): T = { indexRoot() - s.TextDocument( - uri = input.path, - text = input.text, - language = language, - occurrences = names.result(), - symbols = symbols.result() + documentToResult( + s.TextDocument( + uri = input.path, + text = input.text, + language = language, + occurrences = names.result(), + symbols = symbols.result() + ) ) } // This method is intentionally non-final to allow accessing this stream directly without building a s.TextDocument. @@ -64,8 +67,8 @@ trait MtagsIndexer { ) } } - def term(name: String, pos: m.Position, kind: Kind, properties: Int, overriddenSymbols: List[(String, m.Position)] = List.empty): String = - addSignature(Descriptor.Term(name), pos, kind, properties, overriddenSymbols) + def term(name: String, pos: m.Position, kind: Kind, properties: Int): String = + addSignature(Descriptor.Term(name), pos, kind, properties) def term(name: Term.Name, kind: Kind, properties: Int): String = addSignature(Descriptor.Term(name.value), name.pos, kind, properties) def tparam(name: Name, kind: Kind, properties: Int): String = @@ -122,8 +125,8 @@ trait MtagsIndexer { properties ) } - def tpe(name: String, pos: m.Position, kind: Kind, properties: Int, overriddenSymbols: List[(String, m.Position)] = List.empty): String = - addSignature(Descriptor.Type(name), pos, kind, properties, overriddenSymbols) + def tpe(name: String, pos: m.Position, kind: Kind, properties: Int): String = + addSignature(Descriptor.Type(name), pos, kind, properties) def tpe(name: Name, kind: Kind, properties: Int): String = addSignature(Descriptor.Type(name.value), name.pos, kind, properties) def pkg(name: String, pos: m.Position): String = { @@ -141,8 +144,7 @@ trait MtagsIndexer { signature: Descriptor, definition: m.Position, kind: s.SymbolInformation.Kind, - properties: Int, - overriddenSymbols: List[(String, m.Position)] = List.empty + properties: Int ): String = { val previousOwner = currentOwner currentOwner = symbol(signature) @@ -156,16 +158,12 @@ trait MtagsIndexer { syntax, role ) - val encodedOverriddenSymbols = overriddenSymbols.map{ - case (simpleName, pos) => UnresolvedOverriddenSymbol(simpleName, pos.start) - } val info = s.SymbolInformation( symbol = syntax, language = language, kind = kind, properties = properties, - displayName = signature.name.value, - overriddenSymbols = encodedOverriddenSymbols + displayName = signature.name.value ) visitOccurrence(occ, info, previousOwner) syntax @@ -175,3 +173,15 @@ trait MtagsIndexer { Symbols.Global(Symbols.RootPackage, signature) else Symbols.Global(currentOwner, signature) } + +trait EnrichedTextDocument { + def textDocument: s.TextDocument +} + +case class JustDocument(textDocument: s.TextDocument) + extends EnrichedTextDocument + +trait MtagsIndexer extends GenericMtagsIndexer[JustDocument] { + protected def documentToResult(doc: s.TextDocument): JustDocument = + JustDocument(doc) +} diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/OverriddenSymbol.scala b/mtags/src/main/scala/scala/meta/internal/mtags/OverriddenSymbol.scala new file mode 100644 index 00000000000..259f1c961bf --- /dev/null +++ b/mtags/src/main/scala/scala/meta/internal/mtags/OverriddenSymbol.scala @@ -0,0 +1,6 @@ +package scala.meta.internal.mtags + +sealed trait OverriddenSymbol +case class UnresolvedOverriddenSymbol(name: String, pos: Int) + extends OverriddenSymbol +case class ResolvedOverriddenSymbol(symbol: String) extends OverriddenSymbol diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala b/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala index ce4145b6dd9..a83c8156b38 100644 --- a/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala +++ b/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala @@ -14,6 +14,7 @@ import scala.meta.internal.semanticdb.Scala import scala.meta.internal.semanticdb.Scala._ import scala.meta.internal.semanticdb.SymbolInformation import scala.meta.internal.semanticdb.SymbolInformation.Kind +import scala.meta.internal.semanticdb.TextDocument import scala.meta.internal.tokenizers.LegacyScanner import scala.meta.internal.tokenizers.LegacyToken._ import scala.meta.tokenizers.TokenizeException @@ -46,7 +47,17 @@ class ScalaToplevelMtags( includeMembers: Boolean, dialect: Dialect )(implicit rc: ReportContext) - extends MtagsIndexer { + extends GenericMtagsIndexer[TextDocumentWithOverridden] { + + override protected def documentToResult( + doc: TextDocument + ): TextDocumentWithOverridden = + new TextDocumentWithOverridden(doc, overridden.result) + + private val overridden = List.newBuilder[(String, List[OverriddenSymbol])] + + private def addOverridden(symbols: List[OverriddenSymbol]) = + overridden.addOne(currentOwner, symbols) import ScalaToplevelMtags._ @@ -177,13 +188,13 @@ class ScalaToplevelMtags( newExpectExtensionTemplate(nextOwner) ) case CLASS | TRAIT | OBJECT | ENUM if needEmitMember(currRegion) => - val maybeNewIdent = emitMember(false, currRegion.owner) + emitMember(false, currRegion.owner) val template = expectTemplate match { case Some(expect) if expect.isCaseClassConstructor => newExpectCaseClassTemplate case _ => newExpectClassTemplate } - loop(maybeNewIdent.getOrElse(indent), isAfterNewline = maybeNewIdent.isDefined, currRegion, template) + loop(indent, isAfterNewline = false, currRegion, template) // also covers extension methods because of `def` inside case DEF // extension group @@ -402,6 +413,23 @@ class ScalaToplevelMtags( currRegion.changeCaseClassState(true), nextExpectTemplate ) + case EXTENDS => + val (overridden, maybeNewIdent) = findOverridden(List.empty) + expectTemplate.map(tmpl => + withOwner(tmpl.owner) { + addOverridden( + overridden.reverse.map(id => + UnresolvedOverriddenSymbol(id.name, id.pos.start) + ) + ) + } + ) + loop( + maybeNewIdent.getOrElse(indent), + isAfterNewline = maybeNewIdent.isDefined, + currRegion, + expectTemplate + ) case IDENTIFIER if currRegion.emitIdentifier && includeMembers => withOwner(currRegion.owner) { term( @@ -419,17 +447,14 @@ class ScalaToplevelMtags( ) case CASE => val nextIsNewLine = nextIsNL() - val (shouldCreateClassTemplate, isAfterNewline) = + val isAfterNewline = emitEnumCases(region, nextIsNewLine) - val nextExpectTemplate = - if (shouldCreateClassTemplate) newExpectClassTemplate - else expectTemplate.filter(!_.isPackageBody) loop( indent, isAfterNewline, currRegion, if (scanner.curr.token == CLASS) newExpectCaseClassTemplate - else nextExpectTemplate + else newExpectClassTemplate ) case t => val nextExpectTemplate = expectTemplate.filter(!_.isPackageBody) @@ -486,7 +511,8 @@ class ScalaToplevelMtags( } @tailrec - private def acceptAllAfterOverriddenIdentifier(): Unit = { + private def acceptAllAfterOverriddenIdentifier(): Option[Int] = { + val maybeNewIdent = acceptTrivia() scanner.curr.token match { case LPAREN => acceptBalancedDelimeters(LPAREN, RPAREN) @@ -494,50 +520,49 @@ class ScalaToplevelMtags( case LBRACKET => acceptBalancedDelimeters(LBRACKET, RBRACKET) acceptAllAfterOverriddenIdentifier() - case _ => + case _ => maybeNewIdent } } @tailrec - private def findOverridden(acc : List[Identifier]): (List[Identifier], Option[Int]) = { - val maybeNewIdent = acceptTrivia() + private def findOverridden( + acc: List[Identifier] + ): (List[Identifier], Option[Int]) = { + acceptTrivia() + val acc1 = newIdentifier.toList ++ acc + val maybeNewIdent = acceptAllAfterOverriddenIdentifier() scanner.curr.token match { - case EXTENDS | WITH => - acceptTrivia() - val curr = newIdentifier.toList - acceptAllAfterOverriddenIdentifier() - findOverridden(curr ++ acc) - case _ => (acc, maybeNewIdent) + case WITH => findOverridden(acc1) + case _ => (acc1, maybeNewIdent) } + } /** * Enters a toplevel symbol such as class, trait or object */ - def emitMember(isPackageObject: Boolean, owner: String): Option[Int] = { + def emitMember(isPackageObject: Boolean, owner: String): Unit = { val kind = scanner.curr.token acceptTrivia() val maybeName = newIdentifier currentOwner = owner - val (overridden0, maybeNewIdent) = findOverridden(List.empty) - val overridden = overridden0.map(id => (id.name, id.pos)) maybeName.foreach { name => kind match { case CLASS | ENUM => - tpe(name.name, name.pos, Kind.CLASS, 0, overridden) + tpe(name.name, name.pos, Kind.CLASS, 0) case TRAIT => - tpe(name.name, name.pos, Kind.TRAIT, 0, overridden) + tpe(name.name, name.pos, Kind.TRAIT, 0) case OBJECT => if (isPackageObject) { currentOwner = symbol(Scala.Descriptor.Package(name.name)) term("package", name.pos, Kind.OBJECT, 0) } else { - term(name.name, name.pos, Kind.OBJECT, 0, overridden) + term(name.name, name.pos, Kind.OBJECT, 0) } } } - maybeNewIdent + scanner.nextToken() } /** @@ -597,7 +622,7 @@ class ScalaToplevelMtags( private def emitEnumCases( region: Region, nextIsNewLine: Boolean - ): (Boolean, Boolean) = { + ): Boolean = { def ownerCompanionObject = if (currentOwner.endsWith("#")) s"${currentOwner.stripSuffix("#")}." @@ -607,19 +632,22 @@ class ScalaToplevelMtags( val pos = newPosition val name = scanner.curr.name def emitEnumCaseObject() = { - withOwner(ownerCompanionObject) { - term( - name, - pos, - Kind.METHOD, - SymbolInformation.Property.VAL.value - ) - } + currentOwner = ownerCompanionObject + term( + name, + pos, + Kind.METHOD, + SymbolInformation.Property.VAL.value + ) } + def emitOverridden() = addOverridden( + List(ResolvedOverriddenSymbol(region.owner)) + ) val nextIsNewLine0 = nextIsNL() scanner.curr.token match { case COMMA => emitEnumCaseObject() + emitOverridden() resetRegion(region) val nextIsNewLine1 = nextIsNL() emitEnumCases(region, nextIsNewLine1) @@ -631,12 +659,15 @@ class ScalaToplevelMtags( Kind.CLASS, SymbolInformation.Property.VAL.value ) - (true, false) - case _ => + false + case tok => emitEnumCaseObject() - (false, nextIsNewLine0) + if (tok != EXTENDS) { + emitOverridden() + } + nextIsNewLine0 } - case _ => (false, nextIsNewLine) + case _ => nextIsNewLine } } @@ -708,15 +739,15 @@ class ScalaToplevelMtags( case _ => false }) ) { - if(isNewline){ + if (isNewline) { includedNewline = true ident = 0 - } else if(scanner.curr.token == WHITESPACE) { + } else if (scanner.curr.token == WHITESPACE) { ident += 1 } scanner.nextToken() } - if(includedNewline) Some(ident) else None + if (includedNewline) Some(ident) else None } private def nextIsNL(): Boolean = { @@ -991,3 +1022,8 @@ object ScalaToplevelMtags { } } } + +case class TextDocumentWithOverridden( + textDocument: TextDocument, + overridden: List[(String, List[OverriddenSymbol])] +) extends EnrichedTextDocument diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/UnresolvedOverriddenSymbol.scala b/mtags/src/main/scala/scala/meta/internal/mtags/UnresolvedOverriddenSymbol.scala deleted file mode 100644 index 791bad1e680..00000000000 --- a/mtags/src/main/scala/scala/meta/internal/mtags/UnresolvedOverriddenSymbol.scala +++ /dev/null @@ -1,12 +0,0 @@ -package scala.meta.internal.mtags - -object UnresolvedOverriddenSymbol { - def apply(name: String, pos: Int): String = - s"unresolved::$name::$pos" - - def unapply(unresolved: String): Option[(String, Int)] = - unresolved match { - case s"unresolved::$name::$pos" => pos.toIntOption.map((name, _)) - case _ => None - } -} diff --git a/tests/unit/src/test/scala/tests/ScalaToplevelSuite.scala b/tests/unit/src/test/scala/tests/ScalaToplevelSuite.scala index 86a049abc1c..c8ede54b6cf 100644 --- a/tests/unit/src/test/scala/tests/ScalaToplevelSuite.scala +++ b/tests/unit/src/test/scala/tests/ScalaToplevelSuite.scala @@ -4,6 +4,8 @@ import scala.meta.Dialect import scala.meta.dialects import scala.meta.inputs.Input import scala.meta.internal.mtags.Mtags +import scala.meta.internal.mtags.ResolvedOverriddenSymbol +import scala.meta.internal.mtags.TextDocumentWithOverridden import scala.meta.internal.mtags.UnresolvedOverriddenSymbol import munit.TestOptions @@ -52,7 +54,7 @@ class ScalaToplevelSuite extends BaseSuite { List( "_empty_/A.", "_empty_/A.foo().", "_empty_/A.Z#", "_empty_/B#", "_empty_/B#X#", "_empty_/B#foo().", "_empty_/B#v.", "_empty_/C#", - "_empty_/C#i.", "_empty_/D#", "_empty_/D.Da.", "_empty_/D.Db.", + "_empty_/C#i.", "_empty_/D#", "_empty_/D.Da. -> D", "_empty_/D.Db. -> D", "_empty_/D#getI().", "_empty_/D#i.", ), mode = All, @@ -100,7 +102,7 @@ class ScalaToplevelSuite extends BaseSuite { List( "_empty_/A.", "_empty_/A.foo().", "_empty_/A.Z#", "_empty_/B#", "_empty_/B#X#", "_empty_/B#foo().", "_empty_/C#", "_empty_/D#", - "_empty_/D.Da.", "_empty_/D.Db.", + "_empty_/D.Da. -> _empty_/D#", "_empty_/D.Db. -> _empty_/D#", ), mode = All, ) @@ -472,9 +474,10 @@ class ScalaToplevelSuite extends BaseSuite { | |enum NotPlanets{ case Vase } |""".stripMargin, - List("a/", "a/Planets#", "a/Planets.Earth.", "a/Planets.Mercury.", - "a/Planets#num.", "a/Planets.Venus.", "a/NotPlanets#", - "a/NotPlanets.Vase."), + List("a/", "a/Planets#", "a/Planets.Earth. -> Planets", + "a/Planets.Mercury. -> Planets", "a/Planets#num.", + "a/Planets.Venus. -> Planets", "a/NotPlanets#", + "a/NotPlanets.Vase. -> a/NotPlanets#"), dialect = dialects.Scala3, mode = All, ) @@ -495,9 +498,10 @@ class ScalaToplevelSuite extends BaseSuite { |enum NotPlanets: | case Vase |""".stripMargin, - List("a/", "a/Planets#", "a/Planets.Earth.", "a/Planets.Mercury.", - "a/Planets#num.", "a/Planets.Venus.", "a/NotPlanets#", - "a/NotPlanets.Vase."), + List("a/", "a/Planets#", "a/Planets.Earth. -> Planets", + "a/Planets.Mercury. -> Planets", "a/Planets#num.", + "a/Planets.Venus. -> Planets", "a/NotPlanets#", + "a/NotPlanets.Vase. -> a/NotPlanets#"), dialect = dialects.Scala3, mode = All, ) @@ -514,9 +518,10 @@ class ScalaToplevelSuite extends BaseSuite { |enum NotPlanets: | case Vase |""".stripMargin, - List("a/", "a/Planets#", "a/Planets#mmm().", "a/Planets.Earth#", - "a/Planets.Earth#v.", "a/Planets.Mercury#", "a/Planets#num.", - "a/Planets.Venus#", "a/NotPlanets#", "a/NotPlanets.Vase."), + List("a/", "a/Planets#", "a/Planets#mmm().", "a/Planets.Earth# -> Planets", + "a/Planets.Earth#v.", "a/Planets.Mercury# -> Planets", "a/Planets#num.", + "a/Planets.Venus# -> Planets", "a/NotPlanets#", + "a/NotPlanets.Vase. -> a/NotPlanets#"), dialect = dialects.Scala3, mode = All, ) @@ -598,13 +603,24 @@ class ScalaToplevelSuite extends BaseSuite { mode = All, ) + check( + "overridden", + """|package a + |case class A[T](v: Int)(using Context) extends B[Int](2) with C: + | object O extends H + |class M(ctx: Context) extends W(1)(ctx) + |""".stripMargin, + List("a/", "a/A# -> B, C", "a/A#v.", "a/A#O. -> H", "a/M# -> W"), + dialect = dialects.Scala3, + mode = All, + ) + def check( options: TestOptions, code: String, expected: List[String], mode: Mode = Toplevel, dialect: Dialect = dialects.Scala3, - includeOverridden: Boolean = true )(implicit location: munit.Location): Unit = { test(options) { val input = Input.VirtualFile("Test.scala", code) @@ -612,20 +628,29 @@ class ScalaToplevelSuite extends BaseSuite { mode match { case All | ToplevelWithInner => val includeMembers = mode == All - Mtags - .allToplevels(input, dialect, includeMembers) - .symbols - .map{ si => - if(!includeOverridden || si.overriddenSymbols.isEmpty) si.symbol - else { - val overridden = - si.overriddenSymbols.collect{ - case UnresolvedOverriddenSymbol(name, _) => name - }.mkString(", ") - s"${si.symbol} -> $overridden" + val enrichedDoc = + Mtags.allToplevelsEnriched(input, dialect, includeMembers) + val symbols = + enrichedDoc.textDocument.occurrences.map(_.symbol).toList + enrichedDoc match { + case doc: TextDocumentWithOverridden => + val overriddenMap = doc.overridden.toMap + symbols.map { symbol => + overriddenMap.get(symbol) match { + case None => symbol + case Some(symbols) => + val overridden = + symbols + .map { + case ResolvedOverriddenSymbol(symbol) => symbol + case UnresolvedOverriddenSymbol(name, _) => name + } + .mkString(", ") + s"$symbol -> $overridden" + } } - } - .toList + case _ => symbols + } case Toplevel => Mtags.toplevels(input, dialect) } assertNoDiff(