From d12e21e56f6b9559b0275c0ee0776e300de97ba8 Mon Sep 17 00:00:00 2001 From: Katarzyna Marek Date: Wed, 22 May 2024 11:33:50 +0200 Subject: [PATCH] improvement: use pc for finding references of local symbols and when semanticdb is missing (#5940) * improvement: use pc for finding references of local symbols * delete unused `referencedPackages` * use pc for references as fallback when missing semanticdb * move collecting identifiers to mtags * fixes * small review changes * add back `compileAndLookForNewReferences` * benchmarks * get file content on demand * small fix * scalafix * fixes after rebase * add test for rename with un-compiled build target * post rebase fixes * refactor: drop using buffers in pc * review fixes * filter if empty locations * post rebase fixes --- .../src/main/scala/bench/Inflated.scala | 16 +- .../src/main/scala/bench/MetalsBench.scala | 72 ++- .../meta/internal/metals/AdjustLspData.scala | 12 + .../meta/internal/metals/Compilers.scala | 41 ++ .../internal/metals/IdentifierIndex.scala | 62 +++ .../scala/meta/internal/metals/Indexer.scala | 63 ++- .../metals/InteractiveSemanticdbs.scala | 1 - .../scala/meta/internal/metals/Memory.scala | 2 + .../internal/metals/MetalsLspService.scala | 174 +++---- .../internal/metals/ReferenceProvider.scala | 293 ++++++++---- .../meta/internal/rename/RenameProvider.scala | 54 ++- .../scala/meta/pc/PresentationCompiler.java | 7 + .../java/scala/meta/pc/ReferencesRequest.java | 11 + .../java/scala/meta/pc/ReferencesResult.java | 9 + .../pc/JavaPresentationCompiler.scala | 7 + .../meta/internal/pc/ReferencesRequest.scala | 26 + .../scala/meta/internal/pc/PcCollector.scala | 355 +++----------- .../pc/PcDocumentHighlightProvider.scala | 2 +- .../pc/PcInlineValueProviderImpl.scala | 4 +- .../internal/pc/PcReferencesProvider.scala | 93 ++++ .../meta/internal/pc/PcRenameProvider.scala | 2 +- .../pc/PcSemanticTokensProvider.scala | 2 +- .../meta/internal/pc/PcSymbolSearch.scala | 145 ++++++ .../pc/ScalaPresentationCompiler.scala | 15 + .../internal/pc/WithCompilationUnit.scala | 84 ++++ .../internal/pc/WorkspaceSymbolSearch.scala | 72 +-- .../scala/meta/internal/pc/PcCollector.scala | 450 +++--------------- .../pc/PcDocumentHighlightProvider.scala | 4 +- .../pc/PcInlineValueProviderImpl.scala | 4 +- .../internal/pc/PcReferencesProvider.scala | 64 +++ .../meta/internal/pc/PcRenameProvider.scala | 13 +- .../pc/PcSemanticTokensProvider.scala | 2 +- .../meta/internal/pc/PcSymbolSearch.scala | 284 +++++++++++ .../pc/ScalaPresentationCompiler.scala | 16 +- .../pc/SymbolInformationProvider.scala | 110 +++-- .../internal/pc/WithCompilationUnit.scala | 101 ++++ .../metals/SemanticdbDefinition.scala | 25 +- .../meta/internal/mtags/MtagsIndexer.scala | 2 + .../internal/mtags/ScalaToplevelMtags.scala | 98 ++-- .../src/main/scala/tests/RangeReplace.scala | 19 +- .../tests/feature/PcReferencesLspSuite.scala | 232 +++++++++ .../scala/tests/gradle/GradleLspSuite.scala | 3 +- .../scala/tests/maven/MavenLspSuite.scala | 3 +- .../test/scala/tests/mill/MillLspSuite.scala | 3 +- .../scala/tests/sbt/SbtBloopLspSuite.scala | 3 +- .../test/scala/tests/sbt/SbtServerSuite.scala | 46 ++ .../src/main/scala/tests/TestRanges.scala | 2 +- .../src/main/scala/tests/TestingServer.scala | 63 +-- .../scala/tests/FingerprintsLspSuite.scala | 3 +- .../src/test/scala/tests/RenameLspSuite.scala | 28 ++ 50 files changed, 2109 insertions(+), 1093 deletions(-) create mode 100644 metals/src/main/scala/scala/meta/internal/metals/IdentifierIndex.scala create mode 100644 mtags-interfaces/src/main/java/scala/meta/pc/ReferencesRequest.java create mode 100644 mtags-interfaces/src/main/java/scala/meta/pc/ReferencesResult.java create mode 100644 mtags-shared/src/main/scala/scala/meta/internal/pc/ReferencesRequest.scala create mode 100644 mtags/src/main/scala-2/scala/meta/internal/pc/PcReferencesProvider.scala create mode 100644 mtags/src/main/scala-2/scala/meta/internal/pc/PcSymbolSearch.scala create mode 100644 mtags/src/main/scala-2/scala/meta/internal/pc/WithCompilationUnit.scala create mode 100644 mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala create mode 100644 mtags/src/main/scala-3/scala/meta/internal/pc/PcSymbolSearch.scala create mode 100644 mtags/src/main/scala-3/scala/meta/internal/pc/WithCompilationUnit.scala create mode 100644 tests/slow/src/test/scala/tests/feature/PcReferencesLspSuite.scala diff --git a/metals-bench/src/main/scala/bench/Inflated.scala b/metals-bench/src/main/scala/bench/Inflated.scala index 1639ce06111..a109269e443 100644 --- a/metals-bench/src/main/scala/bench/Inflated.scala +++ b/metals-bench/src/main/scala/bench/Inflated.scala @@ -7,16 +7,22 @@ import scala.meta.internal.io.FileIO import scala.meta.io.AbsolutePath import scala.meta.io.Classpath -case class Inflated(inputs: List[Input.VirtualFile], linesOfCode: Long) { +case class Inflated( + inputs: List[(Input.VirtualFile, AbsolutePath)], + linesOfCode: Long, +) { def filter(f: Input.VirtualFile => Boolean): Inflated = { - val newInputs = inputs.filter(input => f(input)) + val newInputs = inputs.filter { case (input, _) => f(input) } val newLinesOfCode = newInputs.foldLeft(0) { case (accum, input) => - accum + input.text.linesIterator.length + accum + input._1.text.linesIterator.length } Inflated(newInputs, newLinesOfCode) } def +(other: Inflated): Inflated = Inflated(other.inputs ++ inputs, other.linesOfCode + linesOfCode) + + def foreach(f: Input.VirtualFile => Unit): Unit = + inputs.foreach { case (file, _) => f(file) } } object Inflated { @@ -33,12 +39,12 @@ object Inflated { close = true, ) { root => var lines = 0L - val buf = List.newBuilder[Input.VirtualFile] + val buf = List.newBuilder[(Input.VirtualFile, AbsolutePath)] FileIO.listAllFilesRecursively(root).foreach { file => val path = file.toURI.toString() val text = FileIO.slurp(file, StandardCharsets.UTF_8) lines += text.linesIterator.length - buf += Input.VirtualFile(path, text) + buf += ((Input.VirtualFile(path, text), file)) } val inputs = buf.result() Inflated(inputs, lines) diff --git a/metals-bench/src/main/scala/bench/MetalsBench.scala b/metals-bench/src/main/scala/bench/MetalsBench.scala index 04c6fa6da80..6442d0c9b07 100644 --- a/metals-bench/src/main/scala/bench/MetalsBench.scala +++ b/metals-bench/src/main/scala/bench/MetalsBench.scala @@ -6,6 +6,8 @@ import scala.tools.nsc.interactive.Global import scala.meta.dialects import scala.meta.interactive.InteractiveSemanticdb +import scala.meta.internal.metals.EmptyReportContext +import scala.meta.internal.metals.IdentifierIndex import scala.meta.internal.metals.JdkSources import scala.meta.internal.metals.LoggerReportContext import scala.meta.internal.metals.ReportContext @@ -24,6 +26,7 @@ import scala.meta.internal.tokenizers.LegacyToken import scala.meta.io.AbsolutePath import scala.meta.io.Classpath +import ch.epfl.scala.bsp4j.BuildTargetIdentifier import org.openjdk.jmh.annotations.Benchmark import org.openjdk.jmh.annotations.BenchmarkMode import org.openjdk.jmh.annotations.Mode @@ -78,10 +81,11 @@ class MetalsBench { .flatMap(_.sources.entries) .filter(_.toNIO.getFileName.toString.endsWith(".jar")) ) + @Benchmark @BenchmarkMode(Array(Mode.SingleShotTime)) def mtagsScalaIndex(): Unit = { - scalaDependencySources.inputs.foreach { input => + scalaDependencySources.foreach { input => ScalaMtags.index(input, dialects.Scala213).index() } } @@ -89,7 +93,7 @@ class MetalsBench { @Benchmark @BenchmarkMode(Array(Mode.SingleShotTime)) def toplevelsScalaIndex(): Unit = { - scalaDependencySources.inputs.foreach { input => + scalaDependencySources.inputs.foreach { case (input, _) => implicit val rc: ReportContext = LoggerReportContext new ScalaToplevelMtags( input, @@ -103,7 +107,7 @@ class MetalsBench { @Benchmark @BenchmarkMode(Array(Mode.SingleShotTime)) def typeHierarchyIndex(): Unit = { - scalaDependencySources.inputs.foreach { input => + scalaDependencySources.inputs.foreach { case (input, _) => implicit val rc: ReportContext = LoggerReportContext new ScalaToplevelMtags( input, @@ -117,7 +121,7 @@ class MetalsBench { @Benchmark @BenchmarkMode(Array(Mode.SingleShotTime)) def scalaTokenize(): Unit = { - scalaDependencySources.inputs.foreach { input => + scalaDependencySources.foreach { input => val scanner = new LegacyScanner(input, Trees.defaultTokenizerDialect) var i = 0 scanner.foreach(_ => i += 1) @@ -128,7 +132,7 @@ class MetalsBench { @BenchmarkMode(Array(Mode.SingleShotTime)) def scalacTokenize(): Unit = { val g = global - scalaDependencySources.inputs.foreach { input => + scalaDependencySources.foreach { input => val unit = new g.CompilationUnit( new BatchSourceFile(new VirtualFile(input.path), input.chars) ) @@ -143,7 +147,7 @@ class MetalsBench { @Benchmark @BenchmarkMode(Array(Mode.SingleShotTime)) def scalametaParse(): Unit = { - scalaDependencySources.inputs.foreach { input => + scalaDependencySources.foreach { input => import scala.meta._ Trees.defaultTokenizerDialect(input).parse[Source].get } @@ -155,7 +159,7 @@ class MetalsBench { @BenchmarkMode(Array(Mode.SingleShotTime)) def scalacParse(): Unit = { val g = global - scalaDependencySources.inputs.foreach { input => + scalaDependencySources.foreach { input => val unit = new g.CompilationUnit( new BatchSourceFile(new VirtualFile(input.path), input.chars) ) @@ -173,7 +177,7 @@ class MetalsBench { @Benchmark @BenchmarkMode(Array(Mode.SingleShotTime)) def mtagsJavaParse(): Unit = { - javaDependencySources.inputs.foreach { input => + javaDependencySources.foreach { input => JavaMtags .index(input, includeMembers = true) .index() @@ -183,7 +187,7 @@ class MetalsBench { @Benchmark @BenchmarkMode(Array(Mode.SingleShotTime)) def toplevelJavaMtags(): Unit = { - javaDependencySources.inputs.foreach { input => + javaDependencySources.inputs.foreach { case (input, _) => new JavaToplevelMtags(input, includeInnerClasses = true)( LoggerReportContext ).index() @@ -202,8 +206,54 @@ class MetalsBench { @Benchmark @BenchmarkMode(Array(Mode.SingleShotTime)) def alltoplevelsScalaIndex(): Unit = { - scalaDependencySources.inputs.foreach { input => - Mtags.allToplevels(input, dialects.Scala3) + scalaDependencySources.foreach { input => + Mtags.allToplevels(input, dialects.Scala213) + } + } + + @Benchmark + @BenchmarkMode(Array(Mode.SingleShotTime)) + def alltoplevelsScalaIndexWithCollectIdents(): Unit = { + scalaDependencySources.foreach { input => + new ScalaToplevelMtags( + input, + includeInnerClasses = true, + includeMembers = true, + dialects.Scala213, + collectIdentifiers = true, + )(EmptyReportContext).indexRoot() + } + } + + @Benchmark + @BenchmarkMode(Array(Mode.SingleShotTime)) + def alltoplevelsScalaIndexWithBuildIdentifierIndex(): Unit = { + val buildTargetIdent = List( + new BuildTargetIdentifier("id1"), + new BuildTargetIdentifier("id2"), + new BuildTargetIdentifier("id3"), + ) + var btIndex = 0 + val index = new IdentifierIndex + scalaDependencySources.inputs.foreach { case (input, path) => + val mtags = new ScalaToplevelMtags( + input, + includeInnerClasses = true, + includeMembers = true, + dialects.Scala213, + collectIdentifiers = true, + )(EmptyReportContext) + + mtags.indexRoot() + + val identifiers = mtags.allIdentifiers + if (identifiers.nonEmpty) + index.addIdentifiers( + path, + buildTargetIdent(btIndex), + identifiers, + ) + btIndex = (btIndex + 1) % 3 } } diff --git a/metals/src/main/scala/scala/meta/internal/metals/AdjustLspData.scala b/metals/src/main/scala/scala/meta/internal/metals/AdjustLspData.scala index 5c2494e1aae..ab50702d156 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/AdjustLspData.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/AdjustLspData.scala @@ -3,6 +3,7 @@ package scala.meta.internal.metals import java.{util => ju} import scala.meta.internal.metals.MetalsEnrichments._ +import scala.meta.pc import scala.meta.pc.AutoImportsResult import scala.meta.pc.HoverSignature @@ -50,6 +51,17 @@ trait AdjustLspData { diag } + def adjustLocation(location: Location): Location = + new Location(location.getUri(), adjustRange(location.getRange())) + + def adjustReferencesResult( + referencesResult: pc.ReferencesResult + ): ReferencesResult = + new ReferencesResult( + referencesResult.symbol, + referencesResult.locations().asScala.map(adjustLocation).toList, + ) + def adjustLocations( locations: java.util.List[Location] ): ju.List[Location] diff --git a/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala b/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala index 4018e32c346..071b2170aae 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala @@ -15,6 +15,7 @@ import scala.util.control.NonFatal import scala.meta.inputs.Input import scala.meta.inputs.Position +import scala.meta.internal import scala.meta.internal.builds.SbtBuildTool import scala.meta.internal.metals.CompilerOffsetParamsUtils import scala.meta.internal.metals.CompilerRangeParamsUtils @@ -48,6 +49,7 @@ import org.eclipse.lsp4j.InitializeParams import org.eclipse.lsp4j.InlayHint import org.eclipse.lsp4j.InlayHintKind import org.eclipse.lsp4j.InlayHintParams +import org.eclipse.lsp4j.ReferenceParams import org.eclipse.lsp4j.RenameParams import org.eclipse.lsp4j.SelectionRange import org.eclipse.lsp4j.SelectionRangeParams @@ -57,6 +59,7 @@ import org.eclipse.lsp4j.SignatureHelp import org.eclipse.lsp4j.TextDocumentIdentifier import org.eclipse.lsp4j.TextDocumentPositionParams import org.eclipse.lsp4j.TextEdit +import org.eclipse.lsp4j.jsonrpc.messages.{Either => JEither} import org.eclipse.lsp4j.{Position => LspPosition} import org.eclipse.lsp4j.{Range => LspRange} import org.eclipse.lsp4j.{debug => d} @@ -727,6 +730,44 @@ class Compilers( } }.getOrElse(Future.successful(Nil.asJava)) + def references( + params: ReferenceParams, + token: CancelToken, + ): Future[List[ReferencesResult]] = { + withPCAndAdjustLsp(params) { case (pc, pos, adjust) => + val requestParams = new internal.pc.PcReferencesRequest( + CompilerOffsetParamsUtils.fromPos(pos, token), + params.getContext().isIncludeDeclaration(), + JEither.forLeft(pos.start), + ) + pc.references(requestParams) + .asScala + .map(_.asScala.map(adjust.adjustReferencesResult).toList) + } + }.getOrElse(Future.successful(Nil)) + + def references( + searchFile: AbsolutePath, + includeDefinition: Boolean, + symbol: String, + ): Future[List[ReferencesResult]] = + loadCompiler(searchFile) + .map { compiler => + val uri = searchFile.toURI + val (input, _, adjust) = + sourceAdjustments(uri.toString(), compiler.scalaVersion()) + val requestParams = new internal.pc.PcReferencesRequest( + CompilerVirtualFileParams(uri, input.text), + includeDefinition, + JEither.forRight(symbol), + ) + compiler + .references(requestParams) + .asScala + .map(_.asScala.map(adjust.adjustReferencesResult).toList) + } + .getOrElse(Future.successful(Nil)) + def extractMethod( doc: TextDocumentIdentifier, range: LspRange, diff --git a/metals/src/main/scala/scala/meta/internal/metals/IdentifierIndex.scala b/metals/src/main/scala/scala/meta/internal/metals/IdentifierIndex.scala new file mode 100644 index 00000000000..4f8c4c5c5e5 --- /dev/null +++ b/metals/src/main/scala/scala/meta/internal/metals/IdentifierIndex.scala @@ -0,0 +1,62 @@ +package scala.meta.internal.metals + +import java.nio.charset.StandardCharsets +import java.nio.file.Path + +import scala.collection.concurrent.TrieMap +import scala.util.control.NonFatal + +import scala.meta.Dialect +import scala.meta.inputs.Input +import scala.meta.internal.tokenizers.LegacyScanner +import scala.meta.internal.tokenizers.LegacyToken._ +import scala.meta.io.AbsolutePath + +import ch.epfl.scala.bsp4j.BuildTargetIdentifier +import com.google.common.hash.BloomFilter +import com.google.common.hash.Funnels + +class IdentifierIndex { + val index: TrieMap[Path, IdentifierIndex.IndexEntry] = TrieMap.empty + + def addIdentifiers( + file: AbsolutePath, + id: BuildTargetIdentifier, + set: Iterable[String], + ): Unit = { + val bloom = BloomFilter.create( + Funnels.stringFunnel(StandardCharsets.UTF_8), + Integer.valueOf(set.size * 2), + 0.01, + ) + + val entry = IdentifierIndex.IndexEntry(id, bloom) + index(file.toNIO) = entry + set.foreach(bloom.put) + } + + def collectIdentifiers( + text: String, + dialect: Dialect, + ): Iterable[String] = { + val identifiers = Set.newBuilder[String] + + try { + new LegacyScanner(Input.String(text), dialect).foreach { + case ident if ident.token == IDENTIFIER => identifiers += ident.name + case _ => + } + } catch { + case NonFatal(_) => + } + + identifiers.result() + } +} + +object IdentifierIndex { + case class IndexEntry( + id: BuildTargetIdentifier, + bloom: BloomFilter[CharSequence], + ) +} diff --git a/metals/src/main/scala/scala/meta/internal/metals/Indexer.scala b/metals/src/main/scala/scala/meta/internal/metals/Indexer.scala index 2ec86b11aff..fda41a0dc1c 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/Indexer.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/Indexer.scala @@ -174,6 +174,7 @@ final case class Indexer( List( ("definition index", definitionIndex), ("references index", referencesProvider().index), + ("identifier index", referencesProvider().identifierIndex.index), ("workspace symbol index", workspaceSymbols().inWorkspace), ("build targets", buildTargets), ( @@ -544,40 +545,50 @@ final case class Indexer( val input = sourceToIndex0.toInput val symbols = ArrayBuffer.empty[WorkspaceSymbolInformation] val methodSymbols = ArrayBuffer.empty[WorkspaceSymbolInformation] - SemanticdbDefinition.foreach(input, dialect, includeMembers = true) { - case SemanticdbDefinition(info, occ, owner) => - if (info.isExtension) { + val optMtags = SemanticdbDefinition.foreachWithReturnMtags( + input, + dialect, + includeMembers = true, + collectIdentifiers = true, + ) { case SemanticdbDefinition(info, occ, owner) => + if (info.isExtension) { + occ.range.foreach { range => + methodSymbols += WorkspaceSymbolInformation( + info.symbol, + info.kind, + range.toLsp, + ) + } + } else { + if (info.kind.isRelevantKind) { occ.range.foreach { range => - methodSymbols += WorkspaceSymbolInformation( + symbols += WorkspaceSymbolInformation( info.symbol, info.kind, range.toLsp, ) } - } else { - if (info.kind.isRelevantKind) { - occ.range.foreach { range => - symbols += WorkspaceSymbolInformation( - info.symbol, - info.kind, - range.toLsp, - ) - } - } - } - if ( - sourceItem.isDefined && - !info.symbol.isPackage && - (owner.isPackage || source.isAmmoniteScript) - ) { - definitionIndex.addToplevelSymbol( - reluri, - source, - info.symbol, - dialect, - ) } + } + if ( + sourceItem.isDefined && + !info.symbol.isPackage && + (owner.isPackage || source.isAmmoniteScript) + ) { + definitionIndex.addToplevelSymbol( + reluri, + source, + info.symbol, + dialect, + ) + } } + optMtags + .map(_.allIdentifiers) + .filter(_.nonEmpty) + .foreach(identifiers => + referencesProvider().addIdentifiers(source, identifiers) + ) workspaceSymbols().didChange(source, symbols.toSeq, methodSymbols.toSeq) // Since the `symbols` here are toplevel symbols, diff --git a/metals/src/main/scala/scala/meta/internal/metals/InteractiveSemanticdbs.scala b/metals/src/main/scala/scala/meta/internal/metals/InteractiveSemanticdbs.scala index 2395178d7e9..ff9c77e11c8 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/InteractiveSemanticdbs.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/InteractiveSemanticdbs.scala @@ -30,7 +30,6 @@ final class InteractiveSemanticdbs( charset: Charset, tables: Tables, compilers: () => Compilers, - clientConfig: ClientConfiguration, semanticdbIndexer: () => SemanticdbIndexer, javaInteractiveSemanticdb: Option[JavaInteractiveSemanticdb], buffers: Buffers, diff --git a/metals/src/main/scala/scala/meta/internal/metals/Memory.scala b/metals/src/main/scala/scala/meta/internal/metals/Memory.scala index 24492b814eb..54ccfe43002 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/Memory.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/Memory.scala @@ -38,6 +38,8 @@ object Memory { val elements = i.values.foldLeft(0L) { case (n, b: BloomFilter[_]) => n + b.approximateElementCount() + case (n, i: IdentifierIndex.IndexEntry) => + n + i.bloom.approximateElementCount() case (n, c: CompressedPackageIndex) => n + c.bloom.approximateElementCount() case (n, _) => diff --git a/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala b/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala index 9e13818e2cc..f617fc6d534 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala @@ -170,9 +170,6 @@ abstract class MetalsLspService( ReportLevel.fromString(MetalsServerConfig.default.loglevel), ) - val folderReportsZippper: FolderReportsZippper = - FolderReportsZippper(doctor, reports) - def javaHome = userConfig.javaHome protected val optJavaHome: Option[AbsolutePath] = JdkSources.defaultJavaHome(javaHome).headOption @@ -360,15 +357,6 @@ abstract class MetalsLspService( ) } - protected val referencesProvider: ReferenceProvider = new ReferenceProvider( - folder, - semanticdbs, - buffers, - definitionProvider, - trees, - buildTargets, - ) - protected val formattingProvider: FormattingProvider = new FormattingProvider( folder, buffers, @@ -388,26 +376,6 @@ abstract class MetalsLspService( semanticdbs, ) - protected val packageProvider: PackageProvider = - new PackageProvider( - buildTargets, - trees, - referencesProvider, - buffers, - definitionProvider, - ) - - protected val newFileProvider: NewFileProvider = new NewFileProvider( - languageClient, - packageProvider, - scalaVersionSelector, - clientConfig.icons, - onCreate = path => { - onCreate(path) - onChange(List(path)) - }, - ) - protected def onCreate(path: AbsolutePath): Unit = { buildTargets.onCreate(path) compilers.didChange(path) @@ -424,7 +392,6 @@ abstract class MetalsLspService( charset, tables, () => compilers, - clientConfig, () => semanticDBIndexer, javaInteractiveSemanticdb, buffers, @@ -487,6 +454,37 @@ abstract class MetalsLspService( ) ) + protected val referencesProvider: ReferenceProvider = new ReferenceProvider( + folder, + semanticdbs, + buffers, + definitionProvider, + trees, + buildTargets, + compilers, + scalaVersionSelector, + ) + + protected val packageProvider: PackageProvider = + new PackageProvider( + buildTargets, + trees, + referencesProvider, + buffers, + definitionProvider, + ) + + protected val newFileProvider: NewFileProvider = new NewFileProvider( + languageClient, + packageProvider, + scalaVersionSelector, + clientConfig.icons, + onCreate = path => { + onCreate(path) + onChange(List(path)) + }, + ) + protected val javaFormattingProvider: JavaFormattingProvider = new JavaFormattingProvider( buffers, @@ -891,10 +889,12 @@ abstract class MetalsLspService( val path = params.getTextDocument.getUri.toAbsolutePath savedFiles.add(path) // read file from disk, we only remove files from buffers on didClose. - buffers.put(path, path.toInput.text) + val text = path.toInput.text + buffers.put(path, text) Future .sequence( List( + referencesProvider.indexIdentifiers(path, text), renameProvider.runSave(), parseTrees(path), onChange(List(path)), @@ -1111,12 +1111,14 @@ abstract class MetalsLspService( override def references( params: ReferenceParams ): CompletableFuture[util.List[Location]] = - CancelTokens { _ => referencesResult(params).flatMap(_.locations).asJava } + CancelTokens.future { _ => + referencesResult(params).map(_.flatMap(_.locations).asJava) + } // Triggers a cascade compilation and tries to find new references to a given symbol. // It's not possible to stream reference results so if we find new symbols we notify the // user to run references again to see updated results. - protected def compileAndLookForNewReferences( + private def compileAndLookForNewReferences( params: ReferenceParams, result: List[ReferencesResult], ): Unit = { @@ -1147,44 +1149,50 @@ abstract class MetalsLspService( newParams match { case None => case Some(p) => - val newResult = referencesProvider.references(p) - val diff = newResult - .flatMap(_.locations) - .length - result.flatMap(_.locations).length - val diffSyms: Set[String] = - newResult.map(_.symbol).toSet -- result.map(_.symbol).toSet - if (diffSyms.nonEmpty && diff > 0) { - import scala.meta.internal.semanticdb.Scala._ - val names = - diffSyms.map(sym => s"'${sym.desc.name.value}'").mkString(" and ") - val message = - s"Found new symbol references for $names, try running again." - scribe.info(message) - statusBar - .addMessage(clientConfig.icons.info + message) + referencesProvider.references(p).foreach { newResult => + val diff = newResult + .flatMap(_.locations) + .length - result.flatMap(_.locations).length + val diffSyms: Set[String] = + newResult.map(_.symbol).toSet -- result.map(_.symbol).toSet + if (diffSyms.nonEmpty && diff > 0) { + import scala.meta.internal.semanticdb.Scala._ + val names = + diffSyms + .map(sym => s"'${sym.desc.name.value}'") + .mkString(" and ") + val message = + s"Found new symbol references for $names, try running again." + scribe.info(message) + statusBar + .addMessage(clientConfig.icons.info + message) + } } } } } - def referencesResult(params: ReferenceParams): List[ReferencesResult] = { + def referencesResult( + params: ReferenceParams + ): Future[List[ReferencesResult]] = { val timer = new Timer(time) - val results: List[ReferencesResult] = referencesProvider.references(params) - if (clientConfig.initialConfig.statistics.isReferences) { - if (results.forall(_.symbol.isEmpty)) { - scribe.info(s"time: found 0 references in $timer") - } else { - scribe.info( - s"time: found ${results.flatMap(_.locations).length} references to symbol '${results - .map(_.symbol) - .mkString("and")}' in $timer" - ) + referencesProvider.references(params).map { results => + if (clientConfig.initialConfig.statistics.isReferences) { + if (results.forall(_.symbol.isEmpty)) { + scribe.info(s"time: found 0 references in $timer") + } else { + scribe.info( + s"time: found ${results.flatMap(_.locations).length} references to symbol '${results + .map(_.symbol) + .mkString("and")}' in $timer" + ) + } } + if (results.nonEmpty) { + compileAndLookForNewReferences(params, results) + } + results } - if (results.nonEmpty) { - compileAndLookForNewReferences(params, results) - } - results } override def semanticTokensFull( @@ -1652,6 +1660,9 @@ abstract class MetalsLspService( () => projectInfo, ) + val folderReportsZippper: FolderReportsZippper = + FolderReportsZippper(doctor, reports) + protected def check(): Unit = { doctor.check(headDoctor) } @@ -1708,22 +1719,23 @@ abstract class MetalsLspService( positionParams.getPosition(), new ReferenceContext(false), ) - val results = referencesResult(refParams) - if (results.flatMap(_.locations).isEmpty) { - // Fallback again to the original behavior that returns - // the definition location itself if no reference locations found, - // for avoiding the confusing messages like "No definition found ..." - definitionResult(positionParams, token) - } else { - Future.successful( - DefinitionResult( - locations = results.flatMap(_.locations).asJava, - symbol = results.head.symbol, - definition = None, - semanticdb = None, - querySymbol = results.head.symbol, + referencesResult(refParams).flatMap { results => + if (results.flatMap(_.locations).isEmpty) { + // Fallback again to the original behavior that returns + // the definition location itself if no reference locations found, + // for avoiding the confusing messages like "No definition found ..." + definitionResult(positionParams, token) + } else { + Future.successful( + DefinitionResult( + locations = results.flatMap(_.locations).asJava, + symbol = results.head.symbol, + definition = None, + semanticdb = None, + querySymbol = results.head.symbol, + ) ) - ) + } } } else { definitionResult(positionParams, token) diff --git a/metals/src/main/scala/scala/meta/internal/metals/ReferenceProvider.scala b/metals/src/main/scala/scala/meta/internal/metals/ReferenceProvider.scala index d081f289c38..7ec855afa2b 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/ReferenceProvider.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/ReferenceProvider.scala @@ -4,6 +4,7 @@ import java.nio.charset.StandardCharsets import java.nio.file.Path import scala.collection.concurrent.TrieMap +import scala.collection.mutable import scala.concurrent.ExecutionContext import scala.concurrent.Future import scala.util.control.NonFatal @@ -40,13 +41,28 @@ final class ReferenceProvider( definition: DefinitionProvider, trees: Trees, buildTargets: BuildTargets, -) extends SemanticdbFeatureProvider { - - case class IndexEntry( - id: BuildTargetIdentifier, - bloom: BloomFilter[CharSequence], - ) - val index: TrieMap[Path, IndexEntry] = TrieMap.empty + compilers: Compilers, + scalaVersionSelector: ScalaVersionSelector, +)(implicit ec: ExecutionContext) + extends SemanticdbFeatureProvider { + val index: TrieMap[Path, IdentifierIndex.IndexEntry] = TrieMap.empty + val identifierIndex: IdentifierIndex = new IdentifierIndex + + def addIdentifiers(file: AbsolutePath, set: Iterable[String]): Unit = + buildTargets + .inverseSources(file) + .map(id => identifierIndex.addIdentifiers(file, id, set)) + + def indexIdentifiers( + path: AbsolutePath, + text: String, + ): Future[Unit] = Future { + buildTargets.inverseSources(path).map { id => + val dialect = scalaVersionSelector.getDialect(path) + val set = identifierIndex.collectIdentifiers(text, dialect) + identifierIndex.addIdentifiers(path, id, set) + } + } override def reset(): Unit = { index.clear() @@ -66,7 +82,7 @@ final class ReferenceProvider( 0.01, ) - val entry = IndexEntry(id, bloom) + val entry = IdentifierIndex.IndexEntry(id, bloom) index(file.toNIO) = entry docs.documents.foreach { d => d.occurrences.foreach { o => @@ -126,7 +142,7 @@ final class ReferenceProvider( params: ReferenceParams, findRealRange: AdjustRange = noAdjustRange, includeSynthetics: Synthetic => Boolean = _ => true, - )(implicit report: ReportContext): List[ReferencesResult] = { + )(implicit report: ReportContext): Future[List[ReferencesResult]] = { val source = params.getTextDocument.getUri.toAbsolutePath semanticdbs().textDocument(source).documentIncludingStale match { case Some(doc) => @@ -135,7 +151,7 @@ final class ReferenceProvider( definition.positionOccurrences(source, params.getPosition, doc) if (posOccurrences.isEmpty) // handling case `import a.{A as @@B}` - occerencesForRenamedImport(source, params, doc) + occurrencesForRenamedImport(source, params, doc) else posOccurrences } if (results.isEmpty) { @@ -143,55 +159,87 @@ final class ReferenceProvider( s"No symbol found at ${params.getPosition()} for $source" ) } - results.map { result => - val occurrence = result.occurrence.get - val distance = result.distance - val alternatives = - referenceAlternatives(occurrence.symbol, source, doc) - val locations = references( - source, - params, - doc, - distance, - occurrence, - alternatives, - params.getContext.isIncludeDeclaration, - findRealRange, - includeSynthetics, - ) - // It's possible to return nothing is we exclude declaration - if (locations.isEmpty && params.getContext().isIncludeDeclaration()) { - val fileInIndex = - if (index.contains(source.toNIO)) - s"Current file ${source} is present" - else s"Missing current file ${source}" - scribe.debug( - s"No references found, index size ${index.size}\n" + fileInIndex - ) - report.unsanitized.create( - Report( - "empty-references", - index - .map { case (path, entry) => - s"$path -> ${entry.bloom.approximateElementCount()}" - } - .mkString("\n"), - s"Could not find any locations for ${result.occurrence}, printing index state", - Some(source.toString()), - Some(source.toString() + ":" + result.occurrence.getOrElse("")), - ) - ) + val semanticdbResult = Future.sequence { + results.map { result => + val occurrence = result.occurrence.get + val distance = result.distance + val alternatives = + referenceAlternatives(occurrence.symbol, source, doc) + references( + source, + params, + doc, + distance, + occurrence, + alternatives, + params.getContext.isIncludeDeclaration, + findRealRange, + includeSynthetics, + ).map { locations => + // It's possible to return nothing is we exclude declaration + if ( + locations.isEmpty && params.getContext().isIncludeDeclaration() + ) { + val fileInIndex = + if (index.contains(source.toNIO)) + s"Current file ${source} is present" + else s"Missing current file ${source}" + scribe.debug( + s"No references found, index size ${index.size}\n" + fileInIndex + ) + report.unsanitized.create( + Report( + "empty-references", + index + .map { case (path, entry) => + s"$path -> ${entry.bloom.approximateElementCount()}" + } + .mkString("\n"), + s"Could not find any locations for ${result.occurrence}, printing index state", + Some(source.toString()), + Some( + source.toString() + ":" + result.occurrence.getOrElse("") + ), + ) + ) + } + ReferencesResult(occurrence.symbol, locations) + } } - ReferencesResult(occurrence.symbol, locations) } + val pcResult = + pcReferences( + source, + results.flatMap(_.occurrence).map(_.symbol), + params.getContext().isIncludeDeclaration(), + path => !index.contains(path.toNIO), + ) + + Future + .sequence(List(semanticdbResult, pcResult)) + .map( + _.flatten + .groupBy(_.symbol) + .collect { case (symbol, refs) => + ReferencesResult(symbol, refs.flatMap(_.locations)) + } + .toList + ) case None => - Nil + scribe.debug(s"No semanticdb for $source") + pcReferences(source, params).map( + _.groupBy(_.symbol) + .collect { case (symbol, refs) => + ReferencesResult(symbol, refs.flatMap(_.locations)) + } + .toList + ) } } - // for `import package.{AA as B@@B}` we look for occurences at `import package.{@@AA as BB}`, - // since rename is not a position occurence in semanticDB - private def occerencesForRenamedImport( + // for `import package.{AA as B@@B}` we look for occurrences at `import package.{@@AA as BB}`, + // since rename is not a position occurrence in semanticDB + private def occurrencesForRenamedImport( source: AbsolutePath, params: ReferenceParams, document: TextDocument, @@ -279,6 +327,77 @@ final class ReferenceProvider( } } + private def pcReferences( + path: AbsolutePath, + params: ReferenceParams, + ): Future[List[ReferencesResult]] = { + compilers.references(params, EmptyCancelToken).flatMap { foundRefs => + pcReferences( + path, + foundRefs.map(_.symbol), + includeDeclaration = params.getContext().isIncludeDeclaration(), + filterTargetFiles = _ != path, + ).map(_ ++ foundRefs) + } + } + + private def pcReferences( + path: AbsolutePath, + symbols: List[String], + includeDeclaration: Boolean, + filterTargetFiles: AbsolutePath => Boolean, + ): Future[List[ReferencesResult]] = { + val visited = mutable.Set[AbsolutePath]() + val results = for { + buildTarget <- buildTargets.inverseSources(path).toList + _ = visited.clear() + symbol <- symbols + name = nameFromSymbol(symbol) + searchFile <- pathsForName(buildTarget, name) + if (filterTargetFiles(searchFile) && !visited(searchFile)) + } yield { + visited += searchFile + compilers.references( + searchFile, + includeDeclaration, + symbol, + ) + } + Future.sequence(results).map(_.flatten) + } + + private def nameFromSymbol( + semanticDBSymbol: String + ): String = { + val desc = semanticDBSymbol.desc + val actualSym = + if ( + desc.isMethod && (desc.name.value == "apply" || desc.name.value == "unapply") + ) { + val owner = semanticDBSymbol.owner + if (owner != s.Scala.Symbols.None) owner.desc + else desc + } else desc + actualSym.name.value + } + + private def pathsForName( + buildTarget: BuildTargetIdentifier, + name: String, + ): Iterator[AbsolutePath] = { + val allowedBuildTargets = buildTargets.allInverseDependencies(buildTarget) + val visited = scala.collection.mutable.Set.empty[AbsolutePath] + for { + (path, entry) <- identifierIndex.index.iterator + if allowedBuildTargets.contains(entry.id) && + entry.bloom.mightContain(name) + sourcePath = AbsolutePath(path) + if !visited(sourcePath) + _ = visited.add(sourcePath) + if sourcePath.exists + } yield sourcePath + } + /** * Return all paths to files which contain at least one symbol from isSymbol set. */ @@ -390,40 +509,42 @@ final class ReferenceProvider( isIncludeDeclaration: Boolean, findRealRange: AdjustRange, includeSynthetics: Synthetic => Boolean, - ): Seq[Location] = { + ): Future[Seq[Location]] = { val isSymbol = alternatives + occ.symbol val isLocal = occ.symbol.isLocal + if (isLocal) + compilers + .references(params, EmptyCancelToken) + .map(_.flatMap(_.locations)) + else { + /* search local in the following cases: + * - it's a dependency source. + * We can't search references inside dependencies so at least show them in a source file. + * - it's a standalone file that doesn't belong to any build target + */ + val searchLocal = + source.isDependencySource(workspace) || + buildTargets.inverseSources(source).isEmpty - /* search local in the following cases: - * - it's local symbol - * - it's a dependency source. - * We can't search references inside dependencies so at least show them in a source file. - * - it's a standalone file that doesn't belong to any build target - */ - val searchLocal = - isLocal || source.isDependencySource(workspace) || - buildTargets.inverseSources(source).isEmpty - val local = - if (searchLocal) - referenceLocations( - snapshot, - isSymbol, - distance, - params.getTextDocument.getUri, - isIncludeDeclaration, - findRealRange, - includeSynthetics, - source.isJava, - ) - else Seq.empty - - val workspaceRefs = - if (!isLocal) { - val sourceContainsDefinition = - occ.role.isDefinition || snapshot.symbols.exists( - _.symbol == occ.symbol + val local = + if (searchLocal) + referenceLocations( + snapshot, + isSymbol, + distance, + params.getTextDocument.getUri, + isIncludeDeclaration, + findRealRange, + includeSynthetics, + source.isJava, ) + else Seq.empty + val sourceContainsDefinition = + occ.role.isDefinition || snapshot.symbols.exists( + _.symbol == occ.symbol + ) + val workspaceRefs = workspaceReferences( source, isSymbol, @@ -432,11 +553,8 @@ final class ReferenceProvider( includeSynthetics, sourceContainsDefinition, ) - - } else - Seq.empty - - workspaceRefs ++ local + Future.successful(local ++ workspaceRefs) + } } private def referenceLocations( @@ -501,6 +619,7 @@ final class ReferenceProvider( private val noAdjustRange: AdjustRange = (range: s.Range, _: String, _: String) => Some(range) type AdjustRange = (s.Range, String, String) => Option[s.Range] + } class SymbolAlternatives(symbol: String, name: String) { diff --git a/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala b/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala index 6e7e120869e..be7557106ac 100644 --- a/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala +++ b/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala @@ -129,9 +129,9 @@ final class RenameProvider( ) .recoverWith { case _ => compilations.compilationFinished(source).flatMap { _ => - val defininionFuture = definitionProvider + val definitionFuture = definitionProvider .definition(source, params, token) - defininionFuture + definitionFuture .flatMap { definition => val textParams = new TextDocumentPositionParams( params.getTextDocument(), @@ -218,7 +218,7 @@ final class RenameProvider( findRealRange = findRealRange(newName), includeSynthetic, ) - .flatMap(_.locations) + .map(_.flatMap(_.locations)) definitionLocation = { if (parentSymbols.isEmpty) definition.locations.asScala @@ -239,9 +239,13 @@ final class RenameProvider( ), newName, ) - } yield implReferences.map(implLocs => - currentReferences ++ implLocs ++ companionRefs ++ definitionLocation - ) + } yield Future + .sequence( + List(implReferences, currentReferences, companionRefs) + ) + .map( + _.flatten ++ definitionLocation + ) Future .sequence(allReferences) .map(locs => @@ -399,7 +403,7 @@ final class RenameProvider( sym: String, source: AbsolutePath, newName: String, - ): Seq[Location] = { + ): Future[Seq[Location]] = { val results = for { companionSymbol <- companion(sym).toIterable loc <- @@ -408,15 +412,15 @@ final class RenameProvider( .asScala // no companion objects in Java files if loc.getUri().isScalaFilename - companionLocs <- - referenceProvider - .references( - toReferenceParams(loc, includeDeclaration = false), - findRealRange = findRealRange(newName), - ) - .flatMap(_.locations) :+ loc - } yield companionLocs - results.toList + } yield { + referenceProvider + .references( + toReferenceParams(loc, includeDeclaration = false), + findRealRange = findRealRange(newName), + ) + .map(_.flatMap(_.locations :+ loc)) + } + Future.sequence(results).map(_.flatten.toSeq) } private def companion(sym: String) = { @@ -459,20 +463,22 @@ final class RenameProvider( if (shouldCheckImplementation) { for { implLocs <- implementationProvider.implementations(textParams) - } yield { - for { - implLoc <- implLocs - locParams = toReferenceParams(implLoc, includeDeclaration = true) - loc <- + result <- { + val result = for { + implLoc <- implLocs + locParams = toReferenceParams(implLoc, includeDeclaration = true) + } yield { referenceProvider .references( locParams, findRealRange = findRealRange(newName), includeSynthetic, ) - .flatMap(_.locations) - } yield loc - } + .map(_.flatMap(_.locations)) + } + Future.sequence(result) + } + } yield result.flatten } else { Future.successful(Nil) } diff --git a/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java b/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java index 4b7eebdbf03..2dfbb93668a 100644 --- a/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java +++ b/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java @@ -112,6 +112,13 @@ public CompletableFuture> semanticTokens(VirtualFileParams params) { */ public abstract CompletableFuture> documentHighlight(OffsetParams params); + /** + * Returns the references of the symbol under the current position in the target files. + */ + public CompletableFuture> references(ReferencesRequest params) { + return CompletableFuture.completedFuture(Collections.emptyList()); + } + /** * Return decoded and pretty printed TASTy content for .scala or .tasty file. * diff --git a/mtags-interfaces/src/main/java/scala/meta/pc/ReferencesRequest.java b/mtags-interfaces/src/main/java/scala/meta/pc/ReferencesRequest.java new file mode 100644 index 00000000000..f173b1eecb7 --- /dev/null +++ b/mtags-interfaces/src/main/java/scala/meta/pc/ReferencesRequest.java @@ -0,0 +1,11 @@ +package scala.meta.pc; + +import java.util.List; +import java.net.URI; +import org.eclipse.lsp4j.jsonrpc.messages.Either; + +public interface ReferencesRequest { + VirtualFileParams file(); + boolean includeDefinition(); + Either offsetOrSymbol(); +} diff --git a/mtags-interfaces/src/main/java/scala/meta/pc/ReferencesResult.java b/mtags-interfaces/src/main/java/scala/meta/pc/ReferencesResult.java new file mode 100644 index 00000000000..cfab7391978 --- /dev/null +++ b/mtags-interfaces/src/main/java/scala/meta/pc/ReferencesResult.java @@ -0,0 +1,9 @@ +package scala.meta.pc; + +import java.util.List; +import org.eclipse.lsp4j.Location; + +public interface ReferencesResult { + String symbol(); + List locations(); +} diff --git a/mtags-java/src/main/scala/scala/meta/internal/pc/JavaPresentationCompiler.scala b/mtags-java/src/main/scala/scala/meta/internal/pc/JavaPresentationCompiler.scala index 17239c2a7cc..6b8c67d6665 100644 --- a/mtags-java/src/main/scala/scala/meta/internal/pc/JavaPresentationCompiler.scala +++ b/mtags-java/src/main/scala/scala/meta/internal/pc/JavaPresentationCompiler.scala @@ -21,6 +21,8 @@ import scala.meta.pc.OffsetParams import scala.meta.pc.PresentationCompiler import scala.meta.pc.PresentationCompilerConfig import scala.meta.pc.RangeParams +import scala.meta.pc.ReferencesRequest +import scala.meta.pc.ReferencesResult import scala.meta.pc.SymbolSearch import scala.meta.pc.VirtualFileParams @@ -99,6 +101,11 @@ case class JavaPresentationCompiler( ): CompletableFuture[util.List[DocumentHighlight]] = CompletableFuture.completedFuture(Nil.asJava) + override def references( + params: ReferencesRequest + ): CompletableFuture[util.List[ReferencesResult]] = + CompletableFuture.completedFuture(Nil.asJava) + override def getTasty( targetUri: URI, isHttpEnabled: Boolean diff --git a/mtags-shared/src/main/scala/scala/meta/internal/pc/ReferencesRequest.scala b/mtags-shared/src/main/scala/scala/meta/internal/pc/ReferencesRequest.scala new file mode 100644 index 00000000000..d721dd21a21 --- /dev/null +++ b/mtags-shared/src/main/scala/scala/meta/internal/pc/ReferencesRequest.scala @@ -0,0 +1,26 @@ +package scala.meta.internal.pc + +import java.{util => ju} + +import scala.meta.pc.ReferencesRequest +import scala.meta.pc.ReferencesResult +import scala.meta.pc.VirtualFileParams + +import org.eclipse.lsp4j.Location +import org.eclipse.lsp4j.jsonrpc.messages.{Either => JEither} + +case class PcReferencesRequest( + file: VirtualFileParams, + includeDefinition: Boolean, + offsetOrSymbol: JEither[Integer, String] +) extends ReferencesRequest + +case class PcReferencesResult( + symbol: String, + locations: ju.List[Location] +) extends ReferencesResult + +object PcReferencesResult { + def empty: ReferencesResult = + PcReferencesResult("", ju.Collections.emptyList()) +} diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/PcCollector.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/PcCollector.scala index 0a8bcff3deb..b7984216d97 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/PcCollector.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/PcCollector.scala @@ -5,291 +5,78 @@ import scala.reflect.internal.util.RangePosition import scala.meta.pc.OffsetParams import scala.meta.pc.VirtualFileParams -abstract class PcCollector[T]( - val compiler: MetalsGlobal, - params: VirtualFileParams -) { +trait PcCollector[T] { self: WithCompilationUnit => import compiler._ def collect( parent: Option[Tree] )(tree: Tree, pos: Position, sym: Option[Symbol]): T - val unit: RichCompilationUnit = addCompilationUnit( - code = params.text(), - filename = params.uri().toString(), - cursor = None - ) - val offset: Int = params match { - case p: OffsetParams => p.offset() - case _: VirtualFileParams => 0 - } - val pos: Position = unit.position(offset) - lazy val text = unit.source.content - private val caseClassSynthetics: Set[Name] = Set(nme.apply, nme.copy) - typeCheck(unit) - lazy val typedTree: Tree = locateTree(pos) match { - // Check actual object if apply is synthetic - case sel @ Select(qual, name) if name == nme.apply && qual.pos == sel.pos => - qual - case Import(expr, _) if expr.pos.includes(pos) => - // imports seem to be marked as transparent - locateTree(pos, expr, acceptTransparent = true) - case t => t - } - - /** - * Find all symbols that should be shown together. - * For example class we want to also show companion object. - * - * @param sym symbol to find the alternative candidates for - * @return set of possible symbols - */ - def symbolAlternatives(sym: Symbol): Set[Symbol] = { - val all = - if (sym.isClass) { - if (sym.owner.isMethod) Set(sym) ++ sym.localCompanion(pos) - else Set(sym, sym.companionModule, sym.companion.moduleClass) - } else if (sym.isModuleOrModuleClass) { - if (sym.owner.isMethod) Set(sym) ++ sym.localCompanion(pos) - else Set(sym, sym.companionClass, sym.moduleClass) - } else if (sym.isTerm && (sym.owner.isClass || sym.owner.isConstructor)) { - val info = - if (sym.owner.isClass) sym.owner.info - else sym.owner.owner.info - Set( - sym, - info.member(sym.getterName), - info.member(sym.setterName), - info.member(sym.localName) - ) ++ constructorParam(sym) ++ sym.allOverriddenSymbols.toSet - } else Set(sym) - all.filter(s => s != NoSymbol && !s.isError) - } - - private def constructorParam( - symbol: Symbol - ): Set[Symbol] = { - if (symbol.owner.isClass) { - val info = symbol.owner.info.member(nme.CONSTRUCTOR).info - info.paramss.flatten.find(_.name == symbol.name).toSet - } else Set.empty - } - private lazy val namedArgCache = { - val parsedTree = parseTree(unit.source) - parsedTree.collect { case arg @ AssignOrNamedArg(_, rhs) => - rhs.pos -> arg - }.toMap - } - def fallbackSymbol(name: Name, pos: Position): Option[Symbol] = { - val context = doLocateImportContext(pos) - context.lookupSymbol(name, sym => sym.isType) match { - case LookupSucceeded(_, symbol) => - Some(symbol) - case _ => None - } - } - - // First identify the symbol we are at, comments identify @@ as current cursor position - lazy val soughtSymbols: Option[(Set[Symbol], Position)] = typedTree match { - /* simple identifier: - * val a = val@@ue + value + def resultWithSought(sought: Set[Symbol]): List[T] = { + val owners = sought + .map(_.owner) + .flatMap(o => symbolAlternatives(o)) + .filter(_ != NoSymbol) + val soughtNames: Set[Name] = sought.map(_.name) + /* + * For comprehensions have two owners, one for the enumerators and one for + * yield. This is a heuristic to find that out. */ - case (id: Ident) => - // might happen in type trees - // also this doesn't seem to be picked up by semanticdb - if (id.symbol == NoSymbol) - fallbackSymbol(id.name, pos).map(sym => - (symbolAlternatives(sym), id.pos) + def isForComprehensionOwner(named: NameTree) = { + if (named.symbol.pos.isDefined) { + def alternativeSymbol = sought.exists(symbol => + symbol.name == named.name && + symbol.pos.isDefined && + symbol.pos.start == named.symbol.pos.start ) - else { - Some(symbolAlternatives(id.symbol), id.pos) - } - /* Anonynous function parameters such as: - * List(1).map{ <>: Int => abc} - * In this case, parameter has incorrect namePosition, so we need to handle it separately - */ - case (vd: ValDef) if isAnonFunctionParam(vd) => - val namePos = vd.pos.withEnd(vd.pos.start + vd.name.length) - if (namePos.includes(pos)) Some(symbolAlternatives(vd.symbol), namePos) - else None - - /* all definitions: - * def fo@@o = ??? - * class Fo@@o = ??? - * etc. - */ - case (df: DefTree) if df.namePosition.includes(pos) => - Some(symbolAlternatives(df.symbol), df.namePosition) - /* Import selectors: - * import scala.util.Tr@@y - */ - case (imp: Import) if imp.pos.includes(pos) => - imp.selector(pos).map(sym => (symbolAlternatives(sym), sym.pos)) - /* simple selector: - * object.val@@ue - */ - case (sel: NameTree) if sel.namePosition.includes(pos) => - Some(symbolAlternatives(sel.symbol), sel.namePosition) - - // needed for classOf[AB@@C]` - case lit @ Literal(Constant(TypeRef(_, sym, _))) if lit.pos.includes(pos) => - val posStart = text.indexOfSlice(sym.decodedName, lit.pos.start) - if (posStart == -1) None - else - Some( - ( - symbolAlternatives(sym), - new RangePosition( - lit.pos.source, - posStart, - lit.pos.point, - posStart + sym.decodedName.length - ) + def sameOwner = { + val owner = named.symbol.owner + owner.isAnonymousFunction && owners.exists(o => + pos.isDefined && o.pos.isDefined && o.pos.point == owner.pos.point ) - ) - /* named argument, which is a bit complex: - * foo(nam@@e = "123") - */ - case _ => - val apply = typedTree match { - case (apply: Apply) => Some(apply) - /** - * For methods with multiple parameter lists and default args, the tree looks like this: - * Block(List(val x&1, val x&2, ...), Apply(<>, List(x&1, x&2, ...))) - */ - case _ => - typedTree.children.collectFirst { - case Block(_, apply: Apply) if apply.pos.includes(pos) => - apply - } - } - apply - .collect { - case apply if apply.symbol != null => - collectArguments(apply) - .flatMap(arg => namedArgCache.find(_._1.includes(arg.pos))) - .collectFirst { - case (_, AssignOrNamedArg(id: Ident, _)) - if id.pos.includes(pos) => - apply.symbol.paramss.flatten.find(_.name == id.name).map { - s => - // if it's a case class we need to look for parameters also - if ( - caseClassSynthetics(s.owner.name) && s.owner.isSynthetic - ) { - val applyOwner = s.owner.owner - val constructorOwner = - if (applyOwner.isCaseClass) applyOwner - else - applyOwner.companion match { - case NoSymbol => - applyOwner - .localCompanion(pos) - .getOrElse(NoSymbol) - case comp => comp - } - val info = constructorOwner.info - val constructorParams = info.members - .filter(_.isConstructor) - .flatMap(_.paramss) - .flatten - .filter(_.name == id.name) - .toSet - ( - (constructorParams ++ Set( - s, - info.member(s.getterName), - info.member(s.setterName), - info.member(s.localName) - )).filter(_ != NoSymbol), - id.pos - ) - } else (Set(s), id.pos) - } - } - .flatten } - .getOrElse(None) - } - - def result(): List[T] = { - params match { - case _: OffsetParams => resultWithSought() - case _ => resultAllOccurences().toList + alternativeSymbol && sameOwner + } else false } - } - - def resultWithSought(): List[T] = { - soughtSymbols match { - case Some((sought, _)) => - val owners = sought - .map(_.owner) - .flatMap(o => symbolAlternatives(o)) - .filter(_ != NoSymbol) - val soughtNames: Set[Name] = sought.map(_.name) - /* - * For comprehensions have two owners, one for the enumerators and one for - * yield. This is a heuristic to find that out. - */ - def isForComprehensionOwner(named: NameTree) = { - if (named.symbol.pos.isDefined) { - def alternativeSymbol = sought.exists(symbol => - symbol.name == named.name && - symbol.pos.isDefined && - symbol.pos.start == named.symbol.pos.start - ) - def sameOwner = { - val owner = named.symbol.owner - owner.isAnonymousFunction && owners.exists(o => - pos.isDefined && o.pos.isDefined && o.pos.point == owner.pos.point - ) - } - alternativeSymbol && sameOwner - } else false - } - - def soughtOrOverride(sym: Symbol) = - sought(sym) || sym.allOverriddenSymbols.exists(sought(_)) - def soughtTreeFilter(tree: Tree): Boolean = - tree match { - case ident: Ident - if (soughtOrOverride(ident.symbol) || - isForComprehensionOwner(ident)) => - true - case tpe: TypeTree => - sought(tpe.original.symbol) - case sel: Select => - soughtOrOverride(sel.symbol) - case df: MemberDef => - (soughtOrOverride(df.symbol) || - isForComprehensionOwner(df)) - case appl: Apply => - appl.symbol != null && - (owners(appl.symbol) || - symbolAlternatives(appl.symbol.owner).exists(owners(_))) - case imp: Import => - owners(imp.expr.symbol) && imp.selectors - .exists(sel => soughtNames(sel.name)) - case bind: Bind => - (soughtOrOverride(bind.symbol)) || - isForComprehensionOwner(bind) - case _ => false - } - - // 1. In most cases, we try to compare by Symbol#==. - // 2. If there is NoSymbol at a node, we check if the identifier there has the same decoded name. - // It it does, we look up the symbol at this position using `fallbackSymbol` or `members`, - // and again check if this time we got symbol equal by Symbol#==. - def soughtFilter(f: Symbol => Boolean): Boolean = { - sought.exists(f) - } + def soughtOrOverride(sym: Symbol) = + sought(sym) || sym.allOverriddenSymbols.exists(sought(_)) - traverseSought(soughtTreeFilter, soughtFilter).toList + def soughtTreeFilter(tree: Tree): Boolean = + tree match { + case ident: Ident + if (soughtOrOverride(ident.symbol) || + isForComprehensionOwner(ident)) => + true + case tpe: TypeTree => + sought(tpe.original.symbol) + case sel: Select => + soughtOrOverride(sel.symbol) + case df: MemberDef => + (soughtOrOverride(df.symbol) || + isForComprehensionOwner(df)) + case appl: Apply => + appl.symbol != null && + (owners(appl.symbol) || + symbolAlternatives(appl.symbol.owner).exists(owners(_))) + case imp: Import => + owners(imp.expr.symbol) && imp.selectors + .exists(sel => soughtNames(sel.name)) + case bind: Bind => + (soughtOrOverride(bind.symbol)) || + isForComprehensionOwner(bind) + case _ => false + } - case None => Nil + // 1. In most cases, we try to compare by Symbol#==. + // 2. If there is NoSymbol at a node, we check if the identifier there has the same decoded name. + // It it does, we look up the symbol at this position using `fallbackSymbol` or `members`, + // and again check if this time we got symbol equal by Symbol#==. + def soughtFilter(f: Symbol => Boolean): Boolean = { + sought.exists(f) } + + traverseSought(soughtTreeFilter, soughtFilter).toList } def resultAllOccurences(): Set[T] = { @@ -518,8 +305,7 @@ abstract class PcCollector[T]( tree.children.foldLeft(acc)(traverse(_, _)) } } - val all = traverseWithParent(None)(Set.empty[T], unit.lastBody) - all + traverseWithParent(None)(Set.empty[T], unit.lastBody) } private def annotationChildren(mdef: MemberDef): List[Tree] = { @@ -540,14 +326,23 @@ abstract class PcCollector[T]( } } - private def collectArguments(apply: Apply): List[Tree] = { - apply.fun match { - case appl: Apply => collectArguments(appl) ++ apply.args - case _ => apply.args - } - } +} - private def isAnonFunctionParam(vd: ValDef): Boolean = - vd.symbol != null && vd.symbol.owner.isAnonymousFunction && vd.rhs.isEmpty +abstract class SimpleCollector[T]( + compiler: MetalsGlobal, + params: VirtualFileParams +) extends WithCompilationUnit(compiler, params) + with PcCollector[T] { + def result(): List[T] = resultAllOccurences().toList +} +abstract class WithSymbolSearchCollector[T]( + compiler: MetalsGlobal, + params: OffsetParams +) extends WithCompilationUnit(compiler, params) + with PcCollector[T] + with PcSymbolSearch { + def result(): List[T] = soughtSymbols + .map { case (sought, _) => resultWithSought(sought) } + .getOrElse(Nil) } diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/PcDocumentHighlightProvider.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/PcDocumentHighlightProvider.scala index ad23b9c1dc7..9575fa2eb19 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/PcDocumentHighlightProvider.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/PcDocumentHighlightProvider.scala @@ -9,7 +9,7 @@ import org.eclipse.lsp4j.DocumentHighlightKind final class PcDocumentHighlightProvider( override val compiler: MetalsGlobal, params: OffsetParams -) extends PcCollector[DocumentHighlight](compiler, params) { +) extends WithSymbolSearchCollector[DocumentHighlight](compiler, params) { import compiler._ def collect( diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/PcInlineValueProviderImpl.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/PcInlineValueProviderImpl.scala index f29a8397ae6..cc3fce542c0 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/PcInlineValueProviderImpl.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/PcInlineValueProviderImpl.scala @@ -15,8 +15,8 @@ final class PcInlineValueProviderImpl( ) extends InlineValueProvider { import cp._ - val pcCollector: PcCollector[Occurence] = - new PcCollector[Occurence](cp, params) { + val pcCollector: WithSymbolSearchCollector[Occurence] = + new WithSymbolSearchCollector[Occurence](cp, params) { def collect( parent: Option[compiler.Tree] )( diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/PcReferencesProvider.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/PcReferencesProvider.scala new file mode 100644 index 00000000000..8c9e53aa49d --- /dev/null +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/PcReferencesProvider.scala @@ -0,0 +1,93 @@ +package scala.meta.internal.pc + +import scala.meta.internal.jdk.CollectionConverters._ +import scala.meta.internal.metals.CompilerOffsetParams +import scala.meta.internal.mtags.MtagsEnrichments._ +import scala.meta.pc.OffsetParams +import scala.meta.pc.ReferencesRequest +import scala.meta.pc.VirtualFileParams + +import org.eclipse.{lsp4j => l} + +trait PcReferencesProvider { + _: WithCompilationUnit with PcCollector[(String, Option[l.Range])] => + import compiler._ + protected def includeDefinition: Boolean + protected def result(): List[(String, Option[l.Range])] + + def collect( + parent: Option[Tree] + )( + tree: Tree, + toAdjust: Position, + sym: Option[Symbol] + ): (String, Option[l.Range]) = { + val (pos, _) = toAdjust.adjust(text) + tree match { + case t: DefTree if !includeDefinition => + (compiler.semanticdbSymbol(t.symbol), None) + case t => + (compiler.semanticdbSymbol(t.symbol), Some(pos.toLsp)) + } + } + + def references(): List[PcReferencesResult] = + result() + .groupBy(_._1) + .map { case (symbol, locs) => + PcReferencesResult( + symbol, + locs.flatMap { case (_, optRange) => + optRange.map(new l.Location(params.uri().toString(), _)) + }.asJava + ) + } + .toList +} + +class LocalPcReferencesProvider( + override val compiler: MetalsGlobal, + params: OffsetParams, + override val includeDefinition: Boolean +) extends WithSymbolSearchCollector[(String, Option[l.Range])](compiler, params) + with PcReferencesProvider + +class BySymbolPCReferencesProvider( + override val compiler: MetalsGlobal, + params: VirtualFileParams, + override val includeDefinition: Boolean, + semanticDbSymbol: String +) extends WithCompilationUnit(compiler, params) + with PcCollector[(String, Option[l.Range])] + with PcReferencesProvider { + def result(): List[(String, Option[l.Range])] = + compiler + .compilerSymbol(semanticDbSymbol) + .map(sought => resultWithSought(symbolAlternatives(sought))) + .getOrElse(Nil) +} + +object PcReferencesProvider { + def apply( + compiler: MetalsGlobal, + params: ReferencesRequest + ): PcReferencesProvider = + if (params.offsetOrSymbol().isLeft()) { + val offsetParams = CompilerOffsetParams( + params.file().uri(), + params.file().text(), + params.offsetOrSymbol().getLeft() + ) + new LocalPcReferencesProvider( + compiler, + offsetParams, + params.includeDefinition() + ) + } else + new BySymbolPCReferencesProvider( + compiler, + params.file(), + params.includeDefinition(), + params.offsetOrSymbol().getRight() + ) +} diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/PcRenameProvider.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/PcRenameProvider.scala index 0cef6a61ce1..0b88de0fedb 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/PcRenameProvider.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/PcRenameProvider.scala @@ -9,7 +9,7 @@ class PcRenameProvider( override val compiler: MetalsGlobal, params: OffsetParams, name: Option[String] -) extends PcCollector[l.TextEdit](compiler, params) { +) extends WithSymbolSearchCollector[l.TextEdit](compiler, params) { import compiler._ private val forbiddenMethods = Set("equals", "hashCode", "unapply", "unary_!", "!") diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/PcSemanticTokensProvider.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/PcSemanticTokensProvider.scala index 28b2cdfd1cf..3016c5d2180 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/PcSemanticTokensProvider.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/PcSemanticTokensProvider.scala @@ -16,7 +16,7 @@ final class PcSemanticTokensProvider( val params: VirtualFileParams ) { // Initialize Tree - object Collector extends PcCollector[Option[Node]](cp, params) { + object Collector extends SimpleCollector[Option[Node]](cp, params) { /** * Declaration is set for: diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/PcSymbolSearch.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/PcSymbolSearch.scala new file mode 100644 index 00000000000..62b6655e955 --- /dev/null +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/PcSymbolSearch.scala @@ -0,0 +1,145 @@ +package scala.meta.internal.pc + +import scala.reflect.internal.util.RangePosition + +trait PcSymbolSearch { self: WithCompilationUnit => + import compiler._ + + private val caseClassSynthetics: Set[Name] = Set(nme.apply, nme.copy) + + lazy val typedTree: Tree = locateTree(pos) match { + // Check actual object if apply is synthetic + case sel @ Select(qual, name) if name == nme.apply && qual.pos == sel.pos => + qual + case Import(expr, _) if expr.pos.includes(pos) => + // imports seem to be marked as transparent + locateTree(pos, expr, acceptTransparent = true) + case t => t + } + + // First identify the symbol we are at, comments identify @@ as current cursor position + lazy val soughtSymbols: Option[(Set[Symbol], Position)] = typedTree match { + /* simple identifier: + * val a = val@@ue + value + */ + case (id: Ident) => + // might happen in type trees + // also this doesn't seem to be picked up by semanticdb + if (id.symbol == NoSymbol) + fallbackSymbol(id.name, pos).map(sym => + (symbolAlternatives(sym), id.pos) + ) + else { + Some(symbolAlternatives(id.symbol), id.pos) + } + /* Anonynous function parameters such as: + * List(1).map{ <>: Int => abc} + * In this case, parameter has incorrect namePosition, so we need to handle it separately + */ + case (vd: ValDef) if isAnonFunctionParam(vd) => + val namePos = vd.pos.withEnd(vd.pos.start + vd.name.length) + if (namePos.includes(pos)) Some(symbolAlternatives(vd.symbol), namePos) + else None + + /* all definitions: + * def fo@@o = ??? + * class Fo@@o = ??? + * etc. + */ + case (df: DefTree) if df.namePosition.includes(pos) => + Some(symbolAlternatives(df.symbol), df.namePosition) + /* Import selectors: + * import scala.util.Tr@@y + */ + case (imp: Import) if imp.pos.includes(pos) => + imp.selector(pos).map(sym => (symbolAlternatives(sym), sym.pos)) + /* simple selector: + * object.val@@ue + */ + case (sel: NameTree) if sel.namePosition.includes(pos) => + Some(symbolAlternatives(sel.symbol), sel.namePosition) + + // needed for classOf[AB@@C]` + case lit @ Literal(Constant(TypeRef(_, sym, _))) if lit.pos.includes(pos) => + val posStart = text.indexOfSlice(sym.decodedName, lit.pos.start) + if (posStart == -1) None + else + Some( + ( + symbolAlternatives(sym), + new RangePosition( + lit.pos.source, + posStart, + lit.pos.point, + posStart + sym.decodedName.length + ) + ) + ) + /* named argument, which is a bit complex: + * foo(nam@@e = "123") + */ + case _ => + val apply = typedTree match { + case (apply: Apply) => Some(apply) + /** + * For methods with multiple parameter lists and default args, the tree looks like this: + * Block(List(val x&1, val x&2, ...), Apply(<>, List(x&1, x&2, ...))) + */ + case _ => + typedTree.children.collectFirst { + case Block(_, apply: Apply) if apply.pos.includes(pos) => + apply + } + } + apply + .collect { + case apply if apply.symbol != null => + collectArguments(apply) + .flatMap(arg => namedArgCache.find(_._1.includes(arg.pos))) + .collectFirst { + case (_, AssignOrNamedArg(id: Ident, _)) + if id.pos.includes(pos) => + apply.symbol.paramss.flatten.find(_.name == id.name).map { + s => + // if it's a case class we need to look for parameters also + if ( + caseClassSynthetics(s.owner.name) && s.owner.isSynthetic + ) { + val applyOwner = s.owner.owner + val constructorOwner = + if (applyOwner.isCaseClass) applyOwner + else + applyOwner.companion match { + case NoSymbol => + applyOwner + .localCompanion(pos) + .getOrElse(NoSymbol) + case comp => comp + } + val info = constructorOwner.info + val constructorParams = info.members + .filter(_.isConstructor) + .flatMap(_.paramss) + .flatten + .filter(_.name == id.name) + .toSet + ( + (constructorParams ++ Set( + s, + info.member(s.getterName), + info.member(s.setterName), + info.member(s.localName) + )).filter(_ != NoSymbol), + id.pos + ) + } else (Set(s), id.pos) + } + } + .flatten + } + .getOrElse(None) + } + + private def isAnonFunctionParam(vd: ValDef): Boolean = + vd.symbol != null && vd.symbol.owner.isAnonymousFunction && vd.rhs.isEmpty +} diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/ScalaPresentationCompiler.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/ScalaPresentationCompiler.scala index 8659ca01df8..510659f60e6 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/ScalaPresentationCompiler.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/ScalaPresentationCompiler.scala @@ -37,6 +37,8 @@ import scala.meta.pc.OffsetParams import scala.meta.pc.PresentationCompiler import scala.meta.pc.PresentationCompilerConfig import scala.meta.pc.RangeParams +import scala.meta.pc.ReferencesRequest +import scala.meta.pc.ReferencesResult import scala.meta.pc.SymbolSearch import scala.meta.pc.VirtualFileParams import scala.meta.pc.{PcSymbolInformation => IPcSymbolInformation} @@ -425,6 +427,19 @@ case class ScalaPresentationCompiler( .asJava } + override def references( + params: ReferencesRequest + ): CompletableFuture[ju.List[ReferencesResult]] = { + compilerAccess.withInterruptableCompiler(Some(params.file()))( + List.empty[ReferencesResult].asJava, + params.file.token() + ) { pc => + val res: List[ReferencesResult] = + PcReferencesProvider(pc.compiler(), params).references() + res.asJava + } + } + override def semanticdbTextDocument( fileUri: URI, code: String diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/WithCompilationUnit.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/WithCompilationUnit.scala new file mode 100644 index 00000000000..a3b49d0a924 --- /dev/null +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/WithCompilationUnit.scala @@ -0,0 +1,84 @@ +package scala.meta.internal.pc + +import scala.meta.pc.OffsetParams +import scala.meta.pc.VirtualFileParams + +class WithCompilationUnit( + val compiler: MetalsGlobal, + val params: VirtualFileParams +) { + import compiler._ + val unit: RichCompilationUnit = addCompilationUnit( + code = params.text(), + filename = params.uri().toString(), + cursor = None + ) + val offset: Int = params match { + case p: OffsetParams => p.offset() + case _: VirtualFileParams => 0 + } + val pos: Position = unit.position(offset) + lazy val text = unit.source.content + typeCheck(unit) + + protected lazy val namedArgCache: Map[Position, Tree] = { + val parsedTree = parseTree(unit.source) + parsedTree.collect { case arg @ AssignOrNamedArg(_, rhs) => + rhs.pos -> arg + }.toMap + } + + /** + * Find all symbols that should be shown together. + * For example class we want to also show companion object. + * + * @param sym symbol to find the alternative candidates for + * @return set of possible symbols + */ + def symbolAlternatives(sym: Symbol): Set[Symbol] = { + val all = + if (sym.isClass) { + if (sym.owner.isMethod) Set(sym) ++ sym.localCompanion(pos) + else Set(sym, sym.companionModule, sym.companion.moduleClass) + } else if (sym.isModuleOrModuleClass) { + if (sym.owner.isMethod) Set(sym) ++ sym.localCompanion(pos) + else Set(sym, sym.companionClass, sym.moduleClass) + } else if (sym.isTerm && (sym.owner.isClass || sym.owner.isConstructor)) { + val info = + if (sym.owner.isClass) sym.owner.info + else sym.owner.owner.info + Set( + sym, + info.member(sym.getterName), + info.member(sym.setterName), + info.member(sym.localName) + ) ++ constructorParam(sym) ++ sym.allOverriddenSymbols.toSet + } else Set(sym) + all.filter(s => s != NoSymbol && !s.isError) + } + + private def constructorParam( + symbol: Symbol + ): Set[Symbol] = { + if (symbol.owner.isClass) { + val info = symbol.owner.info.member(nme.CONSTRUCTOR).info + info.paramss.flatten.find(_.name == symbol.name).toSet + } else Set.empty + } + + protected def collectArguments(apply: Apply): List[Tree] = { + apply.fun match { + case appl: Apply => collectArguments(appl) ++ apply.args + case _ => apply.args + } + } + + def fallbackSymbol(name: Name, pos: Position): Option[Symbol] = { + val context = doLocateImportContext(pos) + context.lookupSymbol(name, sym => sym.isType) match { + case LookupSucceeded(_, symbol) => + Some(symbol) + case _ => None + } + } +} diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/WorkspaceSymbolSearch.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/WorkspaceSymbolSearch.scala index 8e42984fe5b..f2ec4eded1f 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/WorkspaceSymbolSearch.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/WorkspaceSymbolSearch.scala @@ -47,41 +47,8 @@ trait WorkspaceSymbolSearch { compiler: MetalsGlobal => } def info(symbol: String): Option[PcSymbolInformation] = { - val index = symbol.lastIndexOf("/") - val pkgString = symbol.take(index + 1) - val pkg = packageSymbolFromString(pkgString) - - def loop( - symbol: String, - acc: List[(String, Boolean)] - ): List[(String, Boolean)] = - if (symbol.isEmpty()) acc.reverse - else { - val newSymbol = symbol.takeWhile(c => c != '.' && c != '#') - val rest = symbol.drop(newSymbol.size) - loop(rest.drop(1), (newSymbol, rest.headOption.exists(_ == '#')) :: acc) - } - - val names = - loop(symbol.drop(index + 1).takeWhile(_ != '('), List.empty) - - val compilerSymbols = names.foldLeft(pkg.toList) { - case (owners, (name, isClass)) => - owners.flatMap { owner => - val foundChild = - if (isClass) owner.info.member(TypeName(name)) - else owner.info.member(TermName(name)) - if (foundChild.exists) { - foundChild.info match { - case OverloadedType(_, alts) => alts - case _ => List(foundChild) - } - } else Nil - } - } - val (searchedSymbol, alternativeSymbols) = - compilerSymbols.partition(compilerSymbol => + compilerSymbols(symbol).partition(compilerSymbol => semanticdbSymbol(compilerSymbol) == symbol ) @@ -110,6 +77,43 @@ trait WorkspaceSymbolSearch { compiler: MetalsGlobal => } } + def compilerSymbol(symbol: String): Option[Symbol] = + compilerSymbols(symbol).find(sym => semanticdbSymbol(sym) == symbol) + + private def compilerSymbols(symbol: String) = { + val index = symbol.lastIndexOf("/") + val pkgString = symbol.take(index + 1) + val pkg = packageSymbolFromString(pkgString) + + def loop( + symbol: String, + acc: List[(String, Boolean)] + ): List[(String, Boolean)] = + if (symbol.isEmpty()) acc.reverse + else { + val newSymbol = symbol.takeWhile(c => c != '.' && c != '#') + val rest = symbol.drop(newSymbol.size) + loop(rest.drop(1), (newSymbol, rest.headOption.exists(_ == '#')) :: acc) + } + + val names = + loop(symbol.drop(index + 1).takeWhile(_ != '('), List.empty) + + names.foldLeft(pkg.toList) { case (owners, (name, isClass)) => + owners.flatMap { owner => + val foundChild = + if (isClass) owner.info.member(TypeName(name)) + else owner.info.member(TermName(name)) + if (foundChild.exists) { + foundChild.info match { + case OverloadedType(_, alts) => alts + case _ => List(foundChild) + } + } else Nil + } + } + } + private def getSymbolKind(sym: Symbol): PcSymbolKind = if (sym.isJavaInterface) PcSymbolKind.INTERFACE else if (sym.isTrait) PcSymbolKind.TRAIT diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcCollector.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcCollector.scala index 98af2da1306..c9b077cd32a 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/PcCollector.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcCollector.scala @@ -1,403 +1,81 @@ package scala.meta.internal.pc -import java.nio.file.Paths - import scala.meta as m -import scala.meta.internal.metals.CompilerOffsetParams import scala.meta.internal.mtags.MtagsEnrichments.* import scala.meta.internal.pc.MetalsInteractive.ExtensionMethodCall +import scala.meta.internal.pc.PcSymbolSearch.* import scala.meta.pc.OffsetParams import scala.meta.pc.VirtualFileParams -import dotty.tools.dotc.ast.NavigateAST -import dotty.tools.dotc.ast.Positioned import dotty.tools.dotc.ast.tpd import dotty.tools.dotc.ast.tpd.* import dotty.tools.dotc.ast.untpd -import dotty.tools.dotc.ast.untpd.ExtMethods import dotty.tools.dotc.ast.untpd.ImportSelector import dotty.tools.dotc.core.Contexts.* import dotty.tools.dotc.core.Flags import dotty.tools.dotc.core.NameOps.* import dotty.tools.dotc.core.Names.* -import dotty.tools.dotc.core.StdNames.* import dotty.tools.dotc.core.Symbols.* import dotty.tools.dotc.core.Types.* -import dotty.tools.dotc.interactive.Interactive import dotty.tools.dotc.interactive.InteractiveDriver -import dotty.tools.dotc.util.SourceFile import dotty.tools.dotc.util.SourcePosition import dotty.tools.dotc.util.Spans.Span -abstract class PcCollector[T]( - driver: InteractiveDriver, - params: VirtualFileParams, -): - private val caseClassSynthetics: Set[Name] = Set(nme.apply, nme.copy) - val uri = params.uri() - val filePath = Paths.get(uri) - val sourceText = params.text - val text = sourceText.toCharArray() - val source = - SourceFile.virtual(filePath.toString, sourceText) - driver.run(uri, source) - given ctx: Context = driver.currentCtx +trait PcCollector[T]: + self: WithCompilationUnit => - val unit = driver.latestRun - val compilatonUnitContext = ctx.fresh.setCompilationUnit(unit) - val offset = params match - case op: OffsetParams => op.offset() - case _ => 0 - val offsetParams = - params match - case op: OffsetParams => op - case _ => - CompilerOffsetParams(params.uri(), params.text(), 0, params.token()) - val pos = driver.sourcePosition(offsetParams) - val rawPath = - Interactive - .pathTo(driver.openedTrees(uri), pos)(using driver.currentCtx) - .dropWhile(t => // NamedArg anyway doesn't have symbol - t.symbol == NoSymbol && !t.isInstanceOf[NamedArg] || - // same issue https://github.com/lampepfl/dotty/issues/15937 as below - t.isInstanceOf[TypeTree] - ) - - val path = rawPath match - // For type it will sometimes go into the wrong tree since TypeTree also contains the same span - // https://github.com/lampepfl/dotty/issues/15937 - case TypeApply(sel: Select, _) :: tail if sel.span.contains(pos.span) => - Interactive.pathTo(sel, pos.span) ::: rawPath - case _ => rawPath def collect( parent: Option[Tree] )(tree: Tree | EndMarker, pos: SourcePosition, symbol: Option[Symbol]): T - def symbolAlternatives(sym: Symbol) = - def member(parent: Symbol) = parent.info.member(sym.name).symbol - def primaryConstructorTypeParam(owner: Symbol) = - for - typeParams <- owner.primaryConstructor.paramSymss.headOption - param <- typeParams.find(_.name == sym.name) - if (param.isType) - yield param - def additionalForEnumTypeParam(enumClass: Symbol) = - if enumClass.is(Flags.Enum) then - val enumOwner = - if enumClass.is(Flags.Case) - then - // we check that the type parameter is the one from enum class - // and not an enum case type parameter with the same name - Option.when(member(enumClass).is(Flags.Synthetic))( - enumClass.maybeOwner.companionClass - ) - else Some(enumClass) - enumOwner.toSet.flatMap { enumOwner => - val symsInEnumCases = enumOwner.children.toSet.flatMap(enumCase => - if member(enumCase).is(Flags.Synthetic) - then primaryConstructorTypeParam(enumCase) - else None - ) - val symsInEnumOwner = - primaryConstructorTypeParam(enumOwner).toSet + member(enumOwner) - symsInEnumCases ++ symsInEnumOwner - } - else Set.empty - val all = - if sym.is(Flags.ModuleClass) then - Set(sym, sym.companionModule, sym.companionModule.companion) - else if sym.isClass then - Set(sym, sym.companionModule, sym.companion.moduleClass) - else if sym.is(Flags.Module) then - Set(sym, sym.companionClass, sym.moduleClass) - else if sym.isTerm && (sym.owner.isClass || sym.owner.isConstructor) - then - val info = - if sym.owner.isClass then sym.owner.info else sym.owner.owner.info - Set( - sym, - info.member(sym.asTerm.name.setterName).symbol, - info.member(sym.asTerm.name.getterName).symbol, - ) ++ sym.allOverriddenSymbols.toSet - // type used in primary constructor will not match the one used in the class - else if sym.isTypeParam && sym.owner.isPrimaryConstructor then - Set(sym, member(sym.maybeOwner.maybeOwner)) - ++ additionalForEnumTypeParam(sym.maybeOwner.maybeOwner) - else if sym.isTypeParam then - primaryConstructorTypeParam(sym.maybeOwner).toSet - ++ additionalForEnumTypeParam(sym.maybeOwner) + sym - else Set(sym) - all.filter(s => s != NoSymbol && !s.isError) - end symbolAlternatives - - private def isGeneratedGiven(df: NamedDefTree)(using Context) = - val nameSpan = df.nameSpan - df.symbol.is(Flags.Given) && sourceText.substring( - nameSpan.start, - nameSpan.end, - ) != df.name.toString() - - // First identify the symbol we are at, comments identify @@ as current cursor position - def soughtSymbols(path: List[Tree]): Option[(Set[Symbol], SourcePosition)] = - val sought = path match - /* reference of an extension paramter - * extension [EF](<>: List[EF]) - * def double(ys: List[EF]) = <> ++ ys - */ - case (id: Ident) :: _ - if id.symbol - .is(Flags.Param) && id.symbol.owner.is(Flags.ExtensionMethod) => - Some(findAllExtensionParamSymbols(id.sourcePos, id.name, id.symbol)) - /** - * Workaround for missing symbol in: - * class A[T](a: T) - * val x = new <>(1) - */ - case t :: (n: New) :: (sel: Select) :: _ - if t.symbol == NoSymbol && sel.symbol.isConstructor => - Some(symbolAlternatives(sel.symbol.owner), namePos(t)) - /** - * Workaround for missing symbol in: - * class A[T](a: T) - * val x = <>[Int](1) - */ - case (sel @ Select(New(t), _)) :: (_: TypeApply) :: _ - if sel.symbol.isConstructor => - Some(symbolAlternatives(sel.symbol.owner), namePos(t)) - - /* simple identifier: - * val a = val@@ue + value - */ - case (id: Ident) :: _ => - Some(symbolAlternatives(id.symbol), id.sourcePos) - /* simple selector: - * object.val@@ue - */ - case (sel: Select) :: _ if selectNameSpan(sel).contains(pos.span) => - Some(symbolAlternatives(sel.symbol), pos.withSpan(sel.nameSpan)) - /* named argument: - * foo(nam@@e = "123") - */ - case (arg: NamedArg) :: (appl: Apply) :: _ => - val realName = arg.name.stripModuleClassSuffix.lastPart - if pos.span.start > arg.span.start && pos.span.end < arg.span.point + realName.length - then - val length = realName.toString.backticked.length() - val pos = arg.sourcePos.withSpan( - arg.span - .withEnd(arg.span.start + length) - .withPoint(arg.span.start) - ) - appl.symbol.paramSymss.flatten.find(_.name == arg.name).map { s => - // if it's a case class we need to look for parameters also - if caseClassSynthetics(s.owner.name) && s.owner.is(Flags.Synthetic) - then - ( - Set( - s, - s.owner.owner.companion.info.member(s.name).symbol, - s.owner.owner.info.member(s.name).symbol, - ) - .filter(_ != NoSymbol), - pos, - ) - else (Set(s), pos) - } - else None - end if - /* all definitions: - * def fo@@o = ??? - * class Fo@@o = ??? - * etc. - */ - case (df: NamedDefTree) :: _ - if df.nameSpan.contains(pos.span) && !isGeneratedGiven(df) => - Some(symbolAlternatives(df.symbol), pos.withSpan(df.nameSpan)) - /* enum cases with params - * enum Foo: - * case B@@ar[A](i: A) - */ - case (df: NamedDefTree) :: Template(_, _, self, _) :: _ - if (df.name == nme.apply || df.name == nme.unapply) && df.nameSpan.isZeroExtent => - Some(symbolAlternatives(self.tpt.symbol), self.sourcePos) - /** - * For traversing annotations: - * @JsonNo@@tification("") - * def params() = ??? - */ - case (df: MemberDef) :: _ if df.span.contains(pos.span) => - val annotTree = df.mods.annotations.find { t => - t.span.contains(pos.span) - } - collectTrees(annotTree).flatMap { t => - soughtSymbols( - Interactive.pathTo(t, pos.span) - ) - }.headOption - - /* Import selectors: - * import scala.util.Tr@@y - */ - case (imp: Import) :: _ if imp.span.contains(pos.span) => - imp - .selector(pos.span) - .map(sym => (symbolAlternatives(sym), sym.sourcePos)) - - /* Workaround for missing span in: - * class MyIntOut(val value: Int) - * object MyIntOut: - * extension (i: MyIntOut) def <> = i.value % 2 == 1 - * - * val a = MyIntOut(1).<> - */ - case ExtensionMethodCall(id) :: _ if id.span.contains(pos.span) => - Some(symbolAlternatives(id.symbol), id.sourcePos) - case _ => None - - sought match - case None => seekInExtensionParameters() - case _ => sought - - end soughtSymbols - - lazy val extensionMethods = - NavigateAST - .untypedPath(pos.span)(using compilatonUnitContext) - .collectFirst { case em @ ExtMethods(_, _) => em } - - private def findAllExtensionParamSymbols( - pos: SourcePosition, - name: Name, - sym: Symbol, - ) = - val symbols = - for - methods <- extensionMethods.map(_.methods) - symbols <- collectAllExtensionParamSymbols( - unit.tpdTree, - ExtensionParamOccurence(name, pos, sym, methods), - ) - yield symbols - symbols.getOrElse((symbolAlternatives(sym), pos)) - end findAllExtensionParamSymbols - - private def seekInExtensionParameters() = - def collectParams( - extMethods: ExtMethods - ): Option[ExtensionParamOccurence] = - MetalsNavigateAST - .pathToExtensionParam(pos.span, extMethods)(using compilatonUnitContext) - .collectFirst { - case v: untpd.ValOrTypeDef => - ExtensionParamOccurence( - v.name, - v.namePos, - v.symbol, - extMethods.methods, - ) - case i: untpd.Ident => - ExtensionParamOccurence( - i.name, - i.sourcePos, - i.symbol, - extMethods.methods, - ) - } - - for - extensionMethodScope <- extensionMethods - occurence <- collectParams(extensionMethodScope) - symbols <- collectAllExtensionParamSymbols( - path.headOption.getOrElse(unit.tpdTree), - occurence, - ) - yield symbols - end seekInExtensionParameters - - private def collectAllExtensionParamSymbols( - tree: tpd.Tree, - occurrence: ExtensionParamOccurence, - ): Option[(Set[Symbol], SourcePosition)] = - occurrence match - case ExtensionParamOccurence(_, namePos, symbol, _) - if symbol != NoSymbol && !symbol.isError && !symbol.owner.is( - Flags.ExtensionMethod - ) => - Some((symbolAlternatives(symbol), namePos)) - case ExtensionParamOccurence(name, namePos, _, methods) => - val symbols = - for - method <- methods.toSet - symbol <- - Interactive.pathTo(tree, method.span) match - case (d: DefDef) :: _ => - d.paramss.flatten.collect { - case param if param.name.decoded == name.decoded => - param.symbol - } - case _ => Set.empty[Symbol] - if (symbol != NoSymbol && !symbol.isError) - withAlt <- symbolAlternatives(symbol) - yield withAlt - if symbols.nonEmpty then Some((symbols, namePos)) else None - end collectAllExtensionParamSymbols - - def result(): List[T] = - params match - case _: OffsetParams => resultWithSought() - case _ => resultAllOccurences().toList - def resultAllOccurences(): Set[T] = def noTreeFilter = (_: Tree) => true def noSoughtFilter = (_: Symbol => Boolean) => true traverseSought(noTreeFilter, noSoughtFilter) - def resultWithSought(): List[T] = - soughtSymbols(path) match - case Some((sought, _)) => - lazy val owners = sought - .flatMap { s => Set(s.owner, s.owner.companionModule) } - .filter(_ != NoSymbol) - lazy val soughtNames: Set[Name] = sought.map(_.name) - - /* - * For comprehensions have two owners, one for the enumerators and one for - * yield. This is a heuristic to find that out. - */ - def isForComprehensionOwner(named: NameTree) = - soughtNames(named.name) && - scala.util - .Try(named.symbol.owner) - .toOption - .exists(_.isAnonymousFunction) && - owners.exists(o => - o.span.exists && o.span.point == named.symbol.owner.span.point - ) - - def soughtOrOverride(sym: Symbol) = - sought(sym) || sym.allOverriddenSymbols.exists(sought(_)) + def resultWithSought(sought: Set[Symbol]): List[T] = + lazy val owners = sought + .flatMap { s => Set(s.owner, s.owner.companionModule) } + .filter(_ != NoSymbol) + lazy val soughtNames: Set[Name] = sought.map(_.name) + + /* + * For comprehensions have two owners, one for the enumerators and one for + * yield. This is a heuristic to find that out. + */ + def isForComprehensionOwner(named: NameTree) = + soughtNames(named.name) && + scala.util + .Try(named.symbol.owner) + .toOption + .exists(_.isAnonymousFunction) && + owners.exists(o => + o.span.exists && o.span.point == named.symbol.owner.span.point + ) - def soughtTreeFilter(tree: Tree): Boolean = - tree match - case ident: Ident - if soughtOrOverride(ident.symbol) || - isForComprehensionOwner(ident) => - true - case sel: Select if soughtOrOverride(sel.symbol) => true - case df: NamedDefTree - if soughtOrOverride(df.symbol) && !df.symbol.isSetter => - true - case imp: Import if owners(imp.expr.symbol) => true - case _ => false + def soughtOrOverride(sym: Symbol) = + sought(sym) || sym.allOverriddenSymbols.exists(sought(_)) - def soughtFilter(f: Symbol => Boolean): Boolean = - sought.exists(f) + def soughtTreeFilter(tree: Tree): Boolean = + tree match + case ident: Ident + if soughtOrOverride(ident.symbol) || + isForComprehensionOwner(ident) => + true + case sel: Select if soughtOrOverride(sel.symbol) => true + case df: NamedDefTree + if soughtOrOverride(df.symbol) && !df.symbol.isSetter => + true + case imp: Import if owners(imp.expr.symbol) => true + case _ => false - traverseSought(soughtTreeFilter, soughtFilter).toList + def soughtFilter(f: Symbol => Boolean): Boolean = + sought.exists(f) - case None => Nil + traverseSought(soughtTreeFilter, soughtFilter).toList + end resultWithSought extension (span: Span) def isCorrect = @@ -483,7 +161,7 @@ abstract class PcCollector[T]( */ case df: NamedDefTree if df.span.isCorrect && df.nameSpan.isCorrect && - filter(df) && !isGeneratedGiven(df) => + filter(df) && !isGeneratedGiven(df, sourceText) => def collectEndMarker = EndMarker.getPosition(df, pos, sourceText).map { collect(EndMarker(df.symbol), _) @@ -605,34 +283,9 @@ abstract class PcCollector[T]( val traverser = new PcCollector.DeepFolderWithParent[Set[T]](collectNamesWithParent) - val all = traverser(Set.empty[T], unit.tpdTree) - all + traverser(Set.empty[T], unit.tpdTree) end traverseSought - private def collectTrees(trees: Iterable[Positioned]): Iterable[Tree] = - trees.collect { case t: Tree => - t - } - - // NOTE: Connected to https://github.com/lampepfl/dotty/issues/16771 - // `sel.nameSpan` is calculated incorrectly in (1 + 2).toString - // See test DocumentHighlightSuite.select-parentheses - private def selectNameSpan(sel: Select): Span = - val span = sel.span - if span.exists then - val point = span.point - if sel.name.toTermName == nme.ERROR then Span(point) - else if sel.qualifier.span.start > span.point then // right associative - val realName = sel.name.stripModuleClassSuffix.lastPart - Span(span.start, span.start + realName.length, point) - else Span(point, span.end, point) - else span - - private def namePos(tree: Tree): SourcePosition = - tree match - case sel: Select => sel.sourcePos.withSpan(selectNameSpan(sel)) - case _ => tree.sourcePos - /** * Those have wrong spans and we special case for them. */ @@ -641,7 +294,6 @@ abstract class PcCollector[T]( case _: TypeApply | _: Apply => true case _ => false } - end PcCollector object PcCollector: @@ -697,3 +349,21 @@ object EndMarker: ) end getPosition end EndMarker + +abstract class WithSymbolSearchCollector[T]( + driver: InteractiveDriver, + params: OffsetParams, +) extends WithCompilationUnit(driver, params) + with PcSymbolSearch + with PcCollector[T]: + def result(): List[T] = + soughtSymbols.toList.flatMap { case (sought, _) => + resultWithSought(sought) + } + +abstract class SimpleCollector[T]( + driver: InteractiveDriver, + params: VirtualFileParams, +) extends WithCompilationUnit(driver, params) + with PcCollector[T]: + def result(): List[T] = resultAllOccurences().toList diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcDocumentHighlightProvider.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcDocumentHighlightProvider.scala index 32954fd344e..fbb4d44a0eb 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/PcDocumentHighlightProvider.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcDocumentHighlightProvider.scala @@ -13,7 +13,8 @@ import org.eclipse.lsp4j.DocumentHighlightKind final class PcDocumentHighlightProvider( driver: InteractiveDriver, params: OffsetParams, -) extends PcCollector[DocumentHighlight](driver, params): +) extends WithSymbolSearchCollector[DocumentHighlight](driver, params): + def collect( parent: Option[Tree] )( @@ -29,4 +30,5 @@ final class PcDocumentHighlightProvider( def highlights: List[DocumentHighlight] = result().distinctBy(_.getRange()) + end PcDocumentHighlightProvider diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcInlineValueProviderImpl.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcInlineValueProviderImpl.scala index 90ec8f921cd..fd8510d7bcb 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/PcInlineValueProviderImpl.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcInlineValueProviderImpl.scala @@ -17,9 +17,9 @@ import dotty.tools.dotc.util.SourcePosition import org.eclipse.lsp4j as l final class PcInlineValueProviderImpl( - val driver: InteractiveDriver, + driver: InteractiveDriver, val params: OffsetParams, -) extends PcCollector[Option[Occurence]](driver, params) +) extends WithSymbolSearchCollector[Option[Occurence]](driver, params) with InlineValueProvider: val position: l.Position = pos.toLsp.getStart() diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala new file mode 100644 index 00000000000..96551aed769 --- /dev/null +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala @@ -0,0 +1,64 @@ +package scala.meta.internal.pc + +import scala.collection.JavaConverters.* + +import scala.meta.internal.metals.CompilerOffsetParams +import scala.meta.internal.mtags.MtagsEnrichments.* +import scala.meta.pc.ReferencesRequest +import scala.meta.pc.ReferencesResult + +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourcePosition +import org.eclipse.lsp4j +import org.eclipse.lsp4j.Location + +class PcReferencesProvider( + driver: InteractiveDriver, + request: ReferencesRequest, +) extends WithCompilationUnit(driver, request.file()) with PcCollector[Option[(String, Option[lsp4j.Range])]]: + + private def soughtSymbols = + if(request.offsetOrSymbol().isLeft()) { + val offsetParams = CompilerOffsetParams( + request.file().uri(), + request.file().text(), + request.offsetOrSymbol().getLeft() + ) + val symbolSearch = new WithCompilationUnit(driver, offsetParams) with PcSymbolSearch + symbolSearch.soughtSymbols.map(_._1) + } else { + SymbolProvider.compilerSymbol(request.offsetOrSymbol().getRight()).map(symbolAlternatives(_)) + } + + def collect(parent: Option[Tree])( + tree: Tree | EndMarker, + toAdjust: SourcePosition, + symbol: Option[Symbol], + ): Option[(String, Option[lsp4j.Range])] = + val (pos, _) = toAdjust.adjust(text) + tree match + case t: DefTree if !request.includeDefinition() => + val sym = symbol.getOrElse(t.symbol) + Some(SemanticdbSymbols.symbolName(sym), None) + case t: Tree => + val sym = symbol.getOrElse(t.symbol) + Some(SemanticdbSymbols.symbolName(sym), Some(pos.toLsp)) + case _ => None + + def references(): List[ReferencesResult] = + soughtSymbols match + case Some(sought) if sought.nonEmpty => + resultWithSought(sought) + .flatten + .groupMap(_._1) { case (_, optRange) => + optRange.map(new Location(request.file().uri().toString(), _)) + } + .map { case (symbol, locs) => + PcReferencesResult(symbol, locs.flatten.asJava) + } + .toList + case _ => Nil +end PcReferencesProvider diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcRenameProvider.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcRenameProvider.scala index e918a0aa9a5..d4c006b68a0 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/PcRenameProvider.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcRenameProvider.scala @@ -15,7 +15,7 @@ final class PcRenameProvider( driver: InteractiveDriver, params: OffsetParams, name: Option[String], -) extends PcCollector[l.TextEdit](driver, params): +) extends WithSymbolSearchCollector[l.TextEdit](driver, params): private val forbiddenMethods = Set("equals", "hashCode", "unapply", "unary_!", "!") def canRenameSymbol(sym: Symbol)(using Context): Boolean = @@ -24,7 +24,7 @@ final class PcRenameProvider( || sym.source.path.isWorksheet) def prepareRename(): Option[l.Range] = - soughtSymbols(path).flatMap((symbols, pos) => + soughtSymbols.flatMap((symbols, pos) => if symbols.forall(canRenameSymbol) then Some(pos.toLsp) else None ) @@ -45,13 +45,10 @@ final class PcRenameProvider( ) end collect - def rename( - ): List[l.TextEdit] = - val (symbols, _) = soughtSymbols(path).getOrElse(Set.empty, pos) + def rename(): List[l.TextEdit] = + val (symbols, _) = soughtSymbols.getOrElse(Set.empty, pos) if symbols.nonEmpty && symbols.forall(canRenameSymbol(_)) - then - val res = result() - res + then result() else Nil end rename end PcRenameProvider diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcSemanticTokensProvider.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcSemanticTokensProvider.scala index f9762a8a729..b2d1bd4723d 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/PcSemanticTokensProvider.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcSemanticTokensProvider.scala @@ -58,7 +58,7 @@ final class PcSemanticTokensProvider( case _ => !df.rhs.isEmpty case _ => false - object Collector extends PcCollector[Option[Node]](driver, params): + object Collector extends SimpleCollector[Option[Node]](driver, params): override def collect( parent: Option[Tree] )( diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcSymbolSearch.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcSymbolSearch.scala new file mode 100644 index 00000000000..0ab8bdce569 --- /dev/null +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcSymbolSearch.scala @@ -0,0 +1,284 @@ +package scala.meta.internal.pc + +import scala.meta.internal.mtags.MtagsEnrichments.* +import scala.meta.internal.pc.MetalsInteractive.ExtensionMethodCall +import scala.meta.internal.pc.PcSymbolSearch.* + +import dotty.tools.dotc.ast.NavigateAST +import dotty.tools.dotc.ast.Positioned +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.ast.untpd +import dotty.tools.dotc.ast.untpd.ExtMethods +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.NameOps.* +import dotty.tools.dotc.core.Names.* +import dotty.tools.dotc.core.StdNames.* +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.core.Types.* +import dotty.tools.dotc.interactive.Interactive +import dotty.tools.dotc.util.SourcePosition +import dotty.tools.dotc.util.Spans.Span + +trait PcSymbolSearch: + self: WithCompilationUnit => + + private val caseClassSynthetics: Set[Name] = Set(nme.apply, nme.copy) + + lazy val rawPath = + Interactive + .pathTo(driver.openedTrees(uri), pos)(using driver.currentCtx) + .dropWhile(t => // NamedArg anyway doesn't have symbol + t.symbol == NoSymbol && !t.isInstanceOf[NamedArg] || + // same issue https://github.com/lampepfl/dotty/issues/15937 as below + t.isInstanceOf[TypeTree] + ) + + lazy val extensionMethods = + NavigateAST + .untypedPath(pos.span)(using compilatonUnitContext) + .collectFirst { case em @ ExtMethods(_, _) => em } + + lazy val path = rawPath match + // For type it will sometimes go into the wrong tree since TypeTree also contains the same span + // https://github.com/lampepfl/dotty/issues/15937 + case TypeApply(sel: Select, _) :: tail if sel.span.contains(pos.span) => + Interactive.pathTo(sel, pos.span) ::: rawPath + case _ => rawPath + + lazy val soughtSymbols: Option[(Set[Symbol], SourcePosition)] = + soughtSymbols(path) + + def soughtSymbols(path: List[Tree]): Option[(Set[Symbol], SourcePosition)] = + val sought = path match + /* reference of an extension paramter + * extension [EF](<>: List[EF]) + * def double(ys: List[EF]) = <> ++ ys + */ + case (id: Ident) :: _ + if id.symbol + .is(Flags.Param) && id.symbol.owner.is(Flags.ExtensionMethod) => + Some(findAllExtensionParamSymbols(id.sourcePos, id.name, id.symbol)) + /** + * Workaround for missing symbol in: + * class A[T](a: T) + * val x = new <>(1) + */ + case t :: (n: New) :: (sel: Select) :: _ + if t.symbol == NoSymbol && sel.symbol.isConstructor => + Some(symbolAlternatives(sel.symbol.owner), namePos(t)) + /** + * Workaround for missing symbol in: + * class A[T](a: T) + * val x = <>[Int](1) + */ + case (sel @ Select(New(t), _)) :: (_: TypeApply) :: _ + if sel.symbol.isConstructor => + Some(symbolAlternatives(sel.symbol.owner), namePos(t)) + + /* simple identifier: + * val a = val@@ue + value + */ + case (id: Ident) :: _ => + Some(symbolAlternatives(id.symbol), id.sourcePos) + /* simple selector: + * object.val@@ue + */ + case (sel: Select) :: _ if selectNameSpan(sel).contains(pos.span) => + Some(symbolAlternatives(sel.symbol), pos.withSpan(sel.nameSpan)) + /* named argument: + * foo(nam@@e = "123") + */ + case (arg: NamedArg) :: (appl: Apply) :: _ => + val realName = arg.name.stripModuleClassSuffix.lastPart + if pos.span.start > arg.span.start && pos.span.end < arg.span.point + realName.length + then + val length = realName.toString.backticked.length() + val pos = arg.sourcePos.withSpan( + arg.span + .withEnd(arg.span.start + length) + .withPoint(arg.span.start) + ) + appl.symbol.paramSymss.flatten.find(_.name == arg.name).map { s => + // if it's a case class we need to look for parameters also + if caseClassSynthetics(s.owner.name) && s.owner.is(Flags.Synthetic) + then + ( + Set( + s, + s.owner.owner.companion.info.member(s.name).symbol, + s.owner.owner.info.member(s.name).symbol, + ) + .filter(_ != NoSymbol), + pos, + ) + else (Set(s), pos) + } + else None + end if + /* all definitions: + * def fo@@o = ??? + * class Fo@@o = ??? + * etc. + */ + case (df: NamedDefTree) :: _ + if df.nameSpan + .contains(pos.span) && !isGeneratedGiven(df, sourceText) => + Some(symbolAlternatives(df.symbol), pos.withSpan(df.nameSpan)) + /* enum cases with params + * enum Foo: + * case B@@ar[A](i: A) + */ + case (df: NamedDefTree) :: Template(_, _, self, _) :: _ + if (df.name == nme.apply || df.name == nme.unapply) && df.nameSpan.isZeroExtent => + Some(symbolAlternatives(self.tpt.symbol), self.sourcePos) + /** + * For traversing annotations: + * @JsonNo@@tification("") + * def params() = ??? + */ + case (df: MemberDef) :: _ if df.span.contains(pos.span) => + val annotTree = df.mods.annotations.find { t => + t.span.contains(pos.span) + } + collectTrees(annotTree).flatMap { t => + soughtSymbols( + Interactive.pathTo(t, pos.span) + ) + }.headOption + + /* Import selectors: + * import scala.util.Tr@@y + */ + case (imp: Import) :: _ if imp.span.contains(pos.span) => + imp + .selector(pos.span) + .map(sym => (symbolAlternatives(sym), sym.sourcePos)) + + /* Workaround for missing span in: + * class MyIntOut(val value: Int) + * object MyIntOut: + * extension (i: MyIntOut) def <> = i.value % 2 == 1 + * + * val a = MyIntOut(1).<> + */ + case ExtensionMethodCall(id) :: _ if id.span.contains(pos.span) => + Some(symbolAlternatives(id.symbol), id.sourcePos) + case _ => None + + sought match + case None => seekInExtensionParameters() + case _ => sought + + end soughtSymbols + + private def seekInExtensionParameters() = + def collectParams( + extMethods: ExtMethods + ): Option[ExtensionParamOccurence] = + MetalsNavigateAST + .pathToExtensionParam(pos.span, extMethods)(using compilatonUnitContext) + .collectFirst { + case v: untpd.ValOrTypeDef => + ExtensionParamOccurence( + v.name, + v.namePos, + v.symbol, + extMethods.methods, + ) + case i: untpd.Ident => + ExtensionParamOccurence( + i.name, + i.sourcePos, + i.symbol, + extMethods.methods, + ) + } + + for + extensionMethodScope <- extensionMethods + occurence <- collectParams(extensionMethodScope) + symbols <- collectAllExtensionParamSymbols( + path.headOption.getOrElse(unit.tpdTree), + occurence, + ) + yield symbols + end seekInExtensionParameters + + private def collectAllExtensionParamSymbols( + tree: tpd.Tree, + occurrence: ExtensionParamOccurence, + ): Option[(Set[Symbol], SourcePosition)] = + occurrence match + case ExtensionParamOccurence(_, namePos, symbol, _) + if symbol != NoSymbol && !symbol.isError && !symbol.owner.is( + Flags.ExtensionMethod + ) => + Some((symbolAlternatives(symbol), namePos)) + case ExtensionParamOccurence(name, namePos, _, methods) => + val symbols = + for + method <- methods.toSet + symbol <- + Interactive.pathTo(tree, method.span) match + case (d: DefDef) :: _ => + d.paramss.flatten.collect { + case param if param.name.decoded == name.decoded => + param.symbol + } + case _ => Set.empty[Symbol] + if (symbol != NoSymbol && !symbol.isError) + withAlt <- symbolAlternatives(symbol) + yield withAlt + if symbols.nonEmpty then Some((symbols, namePos)) else None + end collectAllExtensionParamSymbols + + private def findAllExtensionParamSymbols( + pos: SourcePosition, + name: Name, + sym: Symbol, + ) = + val symbols = + for + methods <- extensionMethods.map(_.methods) + symbols <- collectAllExtensionParamSymbols( + unit.tpdTree, + ExtensionParamOccurence(name, pos, sym, methods), + ) + yield symbols + symbols.getOrElse((symbolAlternatives(sym), pos)) + end findAllExtensionParamSymbols +end PcSymbolSearch + +object PcSymbolSearch: + // NOTE: Connected to https://github.com/lampepfl/dotty/issues/16771 + // `sel.nameSpan` is calculated incorrectly in (1 + 2).toString + // See test DocumentHighlightSuite.select-parentheses + def selectNameSpan(sel: Select): Span = + val span = sel.span + if span.exists then + val point = span.point + if sel.name.toTermName == nme.ERROR then Span(point) + else if sel.qualifier.span.start > span.point then // right associative + val realName = sel.name.stripModuleClassSuffix.lastPart + Span(span.start, span.start + realName.length, point) + else Span(point, span.end, point) + else span + + def collectTrees(trees: Iterable[Positioned]): Iterable[Tree] = + trees.collect { case t: Tree => t } + + def namePos(tree: Tree)(using Context): SourcePosition = + tree match + case sel: Select => sel.sourcePos.withSpan(selectNameSpan(sel)) + case _ => tree.sourcePos + + def isGeneratedGiven(df: NamedDefTree, sourceText: String)(using Context) = + val nameSpan = df.nameSpan + df.symbol.is(Flags.Given) && sourceText.substring( + nameSpan.start, + nameSpan.end, + ) != df.name.toString() + +end PcSymbolSearch diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala index 9119270d371..6433570af84 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala @@ -20,6 +20,7 @@ import scala.meta.internal.metals.ReportContext import scala.meta.internal.metals.ReportLevel import scala.meta.internal.metals.StdReportContext import scala.meta.internal.mtags.BuildInfo +import scala.meta.internal.mtags.BuildInfo.scalaCompilerVersion import scala.meta.internal.mtags.MtagsEnrichments.given import scala.meta.internal.pc.completions.CompletionProvider import scala.meta.internal.pc.completions.OverrideCompletions @@ -41,7 +42,7 @@ case class ScalaPresentationCompiler( sh: Option[ScheduledExecutorService] = None, config: PresentationCompilerConfig = PresentationCompilerConfigImpl(), folderPath: Option[Path] = None, - reportsLevel: ReportLevel = ReportLevel.Info, + reportsLevel: ReportLevel = ReportLevel.Info ) extends PresentationCompiler: def this() = this("", None, Nil, Nil) @@ -179,6 +180,19 @@ case class ScalaPresentationCompiler( PcDocumentHighlightProvider(driver, params).highlights.asJava } + override def references( + params: ReferencesRequest + ): CompletableFuture[ju.List[ReferencesResult]] = + compilerAccess.withNonInterruptableCompiler(Some(params.file()))( + List.empty[ReferencesResult].asJava, + params.file().token, + ) { access => + val driver = access.compiler() + PcReferencesProvider(driver, params) + .references() + .asJava + } + def shutdown(): Unit = compilerAccess.shutdown() diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/SymbolInformationProvider.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/SymbolInformationProvider.scala index f45eb12c98f..6cb761d1a9c 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/SymbolInformationProvider.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/SymbolInformationProvider.scala @@ -15,58 +15,9 @@ import dotty.tools.dotc.core.StdNames.nme import dotty.tools.dotc.core.Symbols.* class SymbolInformationProvider(using Context): - private def toSymbols( - pkg: String, - parts: List[(String, Boolean)], - ): List[Symbol] = - def collectSymbols(denotation: Denotation): List[Symbol] = - denotation match - case MultiDenotation(denot1, denot2) => - collectSymbols(denot1) ++ collectSymbols(denot2) - case denot => List(denot.symbol) - - def loop( - owners: List[Symbol], - parts: List[(String, Boolean)], - ): List[Symbol] = - parts match - case (head, isClass) :: tl => - val foundSymbols = - owners.flatMap { owner => - val next = - if isClass then owner.info.member(typeName(head)) - else owner.info.member(termName(head)) - collectSymbols(next).filter(_.exists) - } - if foundSymbols.nonEmpty then loop(foundSymbols, tl) - else Nil - case Nil => owners - - val pkgSym = - if pkg == "_empty_" then requiredPackage(nme.EMPTY_PACKAGE) - else requiredPackage(pkg) - loop(List(pkgSym), parts) - end toSymbols def info(symbol: String): Option[PcSymbolInformation] = - val index = symbol.lastIndexOf("/") - val pkg = normalizePackage(symbol.take(index + 1)) - - def loop( - symbol: String, - acc: List[(String, Boolean)], - ): List[(String, Boolean)] = - if symbol.isEmpty() then acc.reverse - else - val newSymbol = symbol.takeWhile(c => c != '.' && c != '#') - val rest = symbol.drop(newSymbol.size) - loop(rest.drop(1), (newSymbol, rest.headOption.exists(_ == '#')) :: acc) - val names = - loop(symbol.drop(index + 1).takeWhile(_ != '('), List.empty) - - val foundSymbols = - try toSymbols(pkg, names) - catch case NonFatal(e) => Nil + val foundSymbols = SymbolProvider.compilerSymbols(symbol) val (searchedSymbol, alternativeSymbols) = foundSymbols.partition(compilerSymbol => @@ -120,7 +71,64 @@ class SymbolInformationProvider(using Context): else if sym.is(Flags.TypeParam) then PcSymbolKind.TYPE_PARAMETER else if sym.isType then PcSymbolKind.TYPE else PcSymbolKind.UNKNOWN_KIND +end SymbolInformationProvider + +object SymbolProvider: + + def compilerSymbol(symbol: String)(using Context): Option[Symbol] = + compilerSymbols(symbol).find(sym => SemanticdbSymbols.symbolName(sym) == symbol) + + def compilerSymbols(symbol: String)(using Context): List[Symbol] = + val index = symbol.lastIndexOf("/") + val pkg = normalizePackage(symbol.take(index + 1)) + + def loop( + symbol: String, + acc: List[(String, Boolean)], + ): List[(String, Boolean)] = + if symbol.isEmpty() then acc.reverse + else + val newSymbol = symbol.takeWhile(c => c != '.' && c != '#') + val rest = symbol.drop(newSymbol.size) + loop(rest.drop(1), (newSymbol, rest.headOption.exists(_ == '#')) :: acc) + val names = + loop(symbol.drop(index + 1).takeWhile(_ != '('), List.empty) + + try toSymbols(pkg, names) + catch case NonFatal(e) => Nil private def normalizePackage(pkg: String): String = pkg.replace("/", ".").stripSuffix(".") -end SymbolInformationProvider + + private def toSymbols( + pkg: String, + parts: List[(String, Boolean)], + )(using Context): List[Symbol] = + def collectSymbols(denotation: Denotation): List[Symbol] = + denotation match + case MultiDenotation(denot1, denot2) => + collectSymbols(denot1) ++ collectSymbols(denot2) + case denot => List(denot.symbol) + + def loop( + owners: List[Symbol], + parts: List[(String, Boolean)], + ): List[Symbol] = + parts match + case (head, isClass) :: tl => + val foundSymbols = + owners.flatMap { owner => + val next = + if isClass then owner.info.member(typeName(head)) + else owner.info.member(termName(head)) + collectSymbols(next).filter(_.exists) + } + if foundSymbols.nonEmpty then loop(foundSymbols, tl) + else Nil + case Nil => owners + + val pkgSym = + if pkg == "_empty_" then requiredPackage(nme.EMPTY_PACKAGE) + else requiredPackage(pkg) + loop(List(pkgSym), parts) + end toSymbols diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/WithCompilationUnit.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/WithCompilationUnit.scala new file mode 100644 index 00000000000..8ce0fd390be --- /dev/null +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/WithCompilationUnit.scala @@ -0,0 +1,101 @@ +package scala.meta.internal.pc +import java.nio.file.Paths + +import scala.meta as m + +import scala.meta.internal.metals.CompilerOffsetParams +import scala.meta.internal.mtags.MtagsEnrichments.* +import scala.meta.pc.OffsetParams +import scala.meta.pc.VirtualFileParams + +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.NameOps.* +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourceFile + +class WithCompilationUnit( + val driver: InteractiveDriver, + params: VirtualFileParams, +): + val uri = params.uri() + val filePath = Paths.get(uri) + val sourceText = params.text + val text = sourceText.toCharArray() + val source = + SourceFile.virtual(filePath.toString, sourceText) + driver.run(uri, source) + given ctx: Context = driver.currentCtx + + val unit = driver.latestRun + val compilatonUnitContext = ctx.fresh.setCompilationUnit(unit) + val offset = params match + case op: OffsetParams => op.offset() + case _ => 0 + val offsetParams = + params match + case op: OffsetParams => op + case _ => + CompilerOffsetParams(params.uri(), params.text(), 0, params.token()) + val pos = driver.sourcePosition(offsetParams) + + // First identify the symbol we are at, comments identify @@ as current cursor position + def symbolAlternatives(sym: Symbol)(using Context) = + def member(parent: Symbol) = parent.info.member(sym.name).symbol + def primaryConstructorTypeParam(owner: Symbol) = + for + typeParams <- owner.primaryConstructor.paramSymss.headOption + param <- typeParams.find(_.name == sym.name) + if (param.isType) + yield param + def additionalForEnumTypeParam(enumClass: Symbol) = + if enumClass.is(Flags.Enum) then + val enumOwner = + if enumClass.is(Flags.Case) + then + // we check that the type parameter is the one from enum class + // and not an enum case type parameter with the same name + Option.when(member(enumClass).is(Flags.Synthetic))( + enumClass.maybeOwner.companionClass + ) + else Some(enumClass) + enumOwner.toSet.flatMap { enumOwner => + val symsInEnumCases = enumOwner.children.toSet.flatMap(enumCase => + if member(enumCase).is(Flags.Synthetic) + then primaryConstructorTypeParam(enumCase) + else None + ) + val symsInEnumOwner = + primaryConstructorTypeParam(enumOwner).toSet + member(enumOwner) + symsInEnumCases ++ symsInEnumOwner + } + else Set.empty + val all = + if sym.is(Flags.ModuleClass) then + Set(sym, sym.companionModule, sym.companionModule.companion) + else if sym.isClass then + Set(sym, sym.companionModule, sym.companion.moduleClass) + else if sym.is(Flags.Module) then + Set(sym, sym.companionClass, sym.moduleClass) + else if sym.isTerm && (sym.owner.isClass || sym.owner.isConstructor) + then + val info = + if sym.owner.isClass then sym.owner.info else sym.owner.owner.info + Set( + sym, + info.member(sym.asTerm.name.setterName).symbol, + info.member(sym.asTerm.name.getterName).symbol, + ) ++ sym.allOverriddenSymbols.toSet + // type used in primary constructor will not match the one used in the class + else if sym.isTypeParam && sym.owner.isPrimaryConstructor then + Set(sym, member(sym.maybeOwner.maybeOwner)) + ++ additionalForEnumTypeParam(sym.maybeOwner.maybeOwner) + else if sym.isTypeParam then + primaryConstructorTypeParam(sym.maybeOwner).toSet + ++ additionalForEnumTypeParam(sym.maybeOwner) + sym + else Set(sym) + all.filter(s => s != NoSymbol && !s.isError) + end symbolAlternatives + +end WithCompilationUnit diff --git a/mtags/src/main/scala/scala/meta/internal/metals/SemanticdbDefinition.scala b/mtags/src/main/scala/scala/meta/internal/metals/SemanticdbDefinition.scala index 21029813fb6..640509fb559 100644 --- a/mtags/src/main/scala/scala/meta/internal/metals/SemanticdbDefinition.scala +++ b/mtags/src/main/scala/scala/meta/internal/metals/SemanticdbDefinition.scala @@ -5,6 +5,7 @@ import scala.util.control.NonFatal import scala.meta.Dialect import scala.meta.inputs.Input import scala.meta.internal.mtags.JavaMtags +import scala.meta.internal.mtags.MtagsIndexer import scala.meta.internal.mtags.ScalaToplevelMtags import scala.meta.internal.mtags.ScalametaCommonEnrichments._ import scala.meta.internal.semanticdb.Language @@ -44,14 +45,30 @@ object SemanticdbDefinition { includeMembers: Boolean )( fn: SemanticdbDefinition => Unit - )(implicit rc: ReportContext): Unit = { + )(implicit rc: ReportContext): Unit = + foreachWithReturnMtags( + input, + dialect, + includeMembers, + collectIdentifiers = false + )(fn) + + def foreachWithReturnMtags( + input: Input.VirtualFile, + dialect: Dialect, + includeMembers: Boolean, + collectIdentifiers: Boolean + )( + fn: SemanticdbDefinition => Unit + )(implicit rc: ReportContext): Option[MtagsIndexer] = { input.toLanguage match { case Language.SCALA => val mtags = new ScalaToplevelMtags( input, includeInnerClasses = true, includeMembers = includeMembers, - dialect + dialect, + collectIdentifiers = collectIdentifiers ) { override def visitOccurrence( occ: SymbolOccurrence, @@ -66,6 +83,7 @@ object SemanticdbDefinition { case _: TokenizeException => () // ignore because we don't need to index untokenizable files. } + Some(mtags) case Language.JAVA => val mtags = new JavaMtags(input, includeMembers) { override def visitOccurrence( @@ -80,7 +98,8 @@ object SemanticdbDefinition { catch { case NonFatal(_) => } - case _ => + Some(mtags) + case _ => None } } } diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/MtagsIndexer.scala b/mtags/src/main/scala/scala/meta/internal/mtags/MtagsIndexer.scala index 92694a7e703..930c54560c7 100644 --- a/mtags/src/main/scala/scala/meta/internal/mtags/MtagsIndexer.scala +++ b/mtags/src/main/scala/scala/meta/internal/mtags/MtagsIndexer.scala @@ -43,6 +43,8 @@ trait MtagsIndexer { private var myLastCurrentOwner: String = currentOwner def lastCurrentOwner: String = myLastCurrentOwner + def allIdentifiers: Set[String] = Set.empty + def owner: String = currentOwner def withOwner[A](owner: String = currentOwner)(thunk: => A): A = { val old = currentOwner diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala b/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala index dea3c9a974c..257a8ca0f91 100644 --- a/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala +++ b/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala @@ -44,7 +44,8 @@ class ScalaToplevelMtags( val input: Input.VirtualFile, includeInnerClasses: Boolean, includeMembers: Boolean, - dialect: Dialect + dialect: Dialect, + collectIdentifiers: Boolean = false )(implicit rc: ReportContext) extends MtagsIndexer { @@ -58,6 +59,20 @@ class ScalaToplevelMtags( import ScalaToplevelMtags._ + private val identifiers = Set.newBuilder[String] + override def allIdentifiers: Set[String] = identifiers.result() + + implicit class XtensionScanner(scanner: LegacyScanner) { + def mtagsNextToken(): Unit = { + scanner.nextToken() + if (collectIdentifiers) + scanner.curr.token match { + case IDENTIFIER => identifiers += scanner.curr.name + case _ => + } + } + } + override def language: Language = Language.SCALA override def indexRoot(): Unit = @@ -201,7 +216,7 @@ class ScalaToplevelMtags( } owner } else currRegion.termOwner - scanner.nextToken() + scanner.mtagsNextToken() loop( indent, isAfterNewline = false, @@ -307,7 +322,7 @@ class ScalaToplevelMtags( withOwner(currRegion.termOwner) { emitTerm(currRegion, isImplicit) } - } else scanner.nextToken() + } else scanner.mtagsNextToken() loop( indent, isAfterNewline = false, @@ -319,7 +334,7 @@ class ScalaToplevelMtags( withOwner(currRegion.termOwner) { emitType(needEmitTermMember()) } - } else scanner.nextToken() + } else scanner.mtagsNextToken() loop(indent, isAfterNewline = false, currRegion, newExpectIgnoreBody) case IMPORT | EXPORT => // skip imports because they might have `given` kw @@ -327,7 +342,7 @@ class ScalaToplevelMtags( loop(indent, isAfterNewline = false, currRegion, expectTemplate) case COMMENT => // skip comment because they might break indentation - scanner.nextToken() + scanner.mtagsNextToken() loop(indent, isAfterNewline = false, currRegion, expectTemplate) case WHITESPACE if dialect.allowSignificantIndentation => if (isNewline) { @@ -338,7 +353,7 @@ class ScalaToplevelMtags( val next = expect.startIndentedRegion(currRegion, expect.isExtension) resetRegion(next) - scanner.nextToken() + scanner.mtagsNextToken() loop(0, isAfterNewline = true, next, None) // basically for braceless def case Some(expect) if expect.ignoreBody => @@ -350,7 +365,7 @@ class ScalaToplevelMtags( None ) case _ => - scanner.nextToken() + scanner.mtagsNextToken() loop( 0, isAfterNewline = true, @@ -361,7 +376,7 @@ class ScalaToplevelMtags( } else { val nextIndentLevel = if (isAfterNewline) indent + 1 else indent - scanner.nextToken() + scanner.mtagsNextToken() loop( nextIndentLevel, isAfterNewline, @@ -395,7 +410,7 @@ class ScalaToplevelMtags( isImplicitClass = expect.isImplicit ) resetRegion(next) - scanner.nextToken() + scanner.mtagsNextToken() loop(0, isAfterNewline = true, next, None) case (Some(expect), true) => val nextIndent = acceptWhileIndented(expect.indent) @@ -414,7 +429,7 @@ class ScalaToplevelMtags( expectTemplate ) case _ => - scanner.nextToken() + scanner.mtagsNextToken() loop( indent, isAfterNewline = false, @@ -431,7 +446,7 @@ class ScalaToplevelMtags( // e.g. class A(val foo: Foo { type T = Int }) // ^ acceptBalancedDelimeters(LBRACE, RBRACE) - scanner.nextToken() + scanner.mtagsNextToken() loop(indent, isAfterNewline = false, currRegion, expectTemplate) } else { val next = @@ -441,12 +456,12 @@ class ScalaToplevelMtags( expect.isImplicit ) resetRegion(next) - scanner.nextToken() + scanner.mtagsNextToken() loop(indent, isAfterNewline = false, next, None) } case _ => acceptBalancedDelimeters(LBRACE, RBRACE) - scanner.nextToken() + scanner.mtagsNextToken() loop(indent, isAfterNewline = false, currRegion, None) } case RBRACE => @@ -454,17 +469,17 @@ class ScalaToplevelMtags( case Region.InBrace(_, prev, _, _, _) => resetRegion(prev) case r => r } - scanner.nextToken() + scanner.mtagsNextToken() loop(indent, isAfterNewline, nextRegion, None) case LBRACKET => acceptBalancedDelimeters(LBRACKET, RBRACKET) - scanner.nextToken() + scanner.mtagsNextToken() loop(indent, isAfterNewline = false, currRegion, expectTemplate) case LPAREN => expectTemplate match { case Some(expect) if expect.isClassConstructor && includeInnerClasses => { - scanner.nextToken() + scanner.mtagsNextToken() loop( indent, isAfterNewline = false, @@ -477,7 +492,7 @@ class ScalaToplevelMtags( } case _ => { acceptBalancedDelimeters(LPAREN, RPAREN) - scanner.nextToken() + scanner.mtagsNextToken() loop(indent, isAfterNewline = false, currRegion, expectTemplate) } } @@ -486,7 +501,7 @@ class ScalaToplevelMtags( case _: Region.InParenClass => true case _ => false }) => - scanner.nextToken() + scanner.mtagsNextToken() loop( indent, isAfterNewline = false, @@ -498,7 +513,7 @@ class ScalaToplevelMtags( ) case COMMA => val nextExpectTemplate = expectTemplate.filter(!_.isPackageBody) - scanner.nextToken() + scanner.mtagsNextToken() loop( indent, isAfterNewline = false, @@ -550,7 +565,7 @@ class ScalaToplevelMtags( else newExpectClassTemplate() ) case IMPLICIT => - scanner.nextToken() + scanner.mtagsNextToken() loop( indent, isAfterNewline, @@ -560,7 +575,7 @@ class ScalaToplevelMtags( ) case t => val nextExpectTemplate = expectTemplate.filter(!_.isPackageBody) - scanner.nextToken() + scanner.mtagsNextToken() loop( indent, isAfterNewline = false, @@ -640,7 +655,7 @@ class ScalaToplevelMtags( val maybeNewIndent = acceptAllAfterOverriddenIdentifier() scanner.curr.token match { case DOT => - scanner.nextToken() + scanner.mtagsNextToken() getIdentifier() case _ => (currentIdentifier, maybeNewIndent) } @@ -686,7 +701,7 @@ class ScalaToplevelMtags( } } } - scanner.nextToken() + scanner.mtagsNextToken() } def emitType(emitTermMember: Boolean): Option[Unit] = { @@ -706,25 +721,25 @@ class ScalaToplevelMtags( case SEMI => name case _ if isNewline | isDone => name case EQUALS => - scanner.nextToken() + scanner.mtagsNextToken() loop(name, isAfterEq = true) case TYPELAMBDAARROW | WHITESPACE => - scanner.nextToken() + scanner.mtagsNextToken() loop(name, isAfterEq) case LBRACKET => acceptBalancedDelimeters(LBRACKET, RBRACKET) - scanner.nextToken() + scanner.mtagsNextToken() loop(name, isAfterEq) case LBRACE => acceptBalancedDelimeters(LBRACE, RBRACE) - scanner.nextToken() + scanner.mtagsNextToken() loop(name, isAfterEq) case IDENTIFIER if isAfterEq && scanner.curr.name != "|" && scanner.curr.name != "&" => loop(identOrSelectName(), isAfterEq) case _ if isAfterEq => None case _ => - scanner.nextToken() + scanner.mtagsNextToken() loop(name) } } @@ -848,7 +863,7 @@ class ScalaToplevelMtags( require(scanner.curr.token == Open, "open delimeter { or (") var count = 1 while (!isDone && count > 0) { - scanner.nextToken() + scanner.mtagsNextToken() scanner.curr.token match { case Open => count += 1 @@ -868,15 +883,16 @@ class ScalaToplevelMtags( if (!isDone) { scanner.curr.token match { case WHITESPACE => - if (isNewline) { scanner.nextToken; loop(0, true) } - else if (isAfterNL) { scanner.nextToken; loop(indent + 1, true) } - else { scanner.nextToken(); loop(indent, false) } + if (isNewline) { scanner.mtagsNextToken(); loop(0, true) } + else if (isAfterNL) { + scanner.mtagsNextToken(); loop(indent + 1, true) + } else { scanner.mtagsNextToken(); loop(indent, false) } case COMMENT => - scanner.nextToken() + scanner.mtagsNextToken() loop(indent, false) case _ if indent <= exitIndent => indent case _ => - scanner.nextToken() + scanner.mtagsNextToken() loop(indent, false) } } else indent @@ -885,7 +901,7 @@ class ScalaToplevelMtags( } def acceptToStatSep(): Unit = { - scanner.nextToken() + scanner.mtagsNextToken() while ( !isDone && (scanner.curr.token match { @@ -894,14 +910,14 @@ class ScalaToplevelMtags( case _ => true }) ) { - scanner.nextToken() + scanner.mtagsNextToken() } } private def acceptTrivia(): Option[Int] = { var includedNewline = false var indent = 0 - scanner.nextToken() + scanner.mtagsNextToken() while ( !isDone && (scanner.curr.token match { @@ -915,13 +931,13 @@ class ScalaToplevelMtags( } else if (scanner.curr.token == WHITESPACE) { indent += 1 } - scanner.nextToken() + scanner.mtagsNextToken() } if (includedNewline) Some(indent) else None } private def nextIsNL(): Boolean = { - scanner.nextToken() + scanner.mtagsNextToken() scanner.curr.token match { case WHITESPACE if isNewline => true case WHITESPACE => @@ -963,7 +979,7 @@ class ScalaToplevelMtags( nextIsNL() val newIdent = getName if (newIdent.exists(_ == "type")) { - scanner.nextToken() + scanner.mtagsNextToken() current } else identOrSelectName(newIdent) case _ => current @@ -1027,7 +1043,7 @@ class ScalaToplevelMtags( case WHITESPACE | COMMA => {} case _ => { isUnapply = true } } - scanner.nextToken() + scanner.mtagsNextToken() } if (isUnapply) resultList.filterNot(_.name.charAt(0).isUpper) else resultList diff --git a/tests/mtest/src/main/scala/tests/RangeReplace.scala b/tests/mtest/src/main/scala/tests/RangeReplace.scala index 866d6f0d89e..a1df5bcdd3e 100644 --- a/tests/mtest/src/main/scala/tests/RangeReplace.scala +++ b/tests/mtest/src/main/scala/tests/RangeReplace.scala @@ -11,14 +11,21 @@ trait RangeReplace { def renderHighlightsAsString( code: String, highlights: List[DocumentHighlight] + ): String = renderRangesAsString(code, highlights.map(_.getRange())) + + def renderRangesAsString( + code: String, + highlights: List[Range], + alreadyAddedMarkings: List[(Int, Int)] = Nil, + currentBase: Option[String] = None ): String = { highlights - .foldLeft((code, List.empty[(Int, Int)])) { - case ((base, alreadyAddedMarkings), location) => - replaceInRangeWithAdjustmens( + .foldLeft((currentBase.getOrElse(code), alreadyAddedMarkings)) { + case ((base, alreadyAddedMarkings), range) => + replaceInRangeWithAdjustments( code, base, - location.getRange, + range, alreadyAddedMarkings ) } @@ -31,9 +38,9 @@ trait RangeReplace { prefix: String = "<<", suffix: String = ">>" ): String = - replaceInRangeWithAdjustmens(base, base, range, List(), prefix, suffix)._1 + replaceInRangeWithAdjustments(base, base, range, List(), prefix, suffix)._1 - protected def replaceInRangeWithAdjustmens( + protected def replaceInRangeWithAdjustments( code: String, currentBase: String, range: Range, diff --git a/tests/slow/src/test/scala/tests/feature/PcReferencesLspSuite.scala b/tests/slow/src/test/scala/tests/feature/PcReferencesLspSuite.scala new file mode 100644 index 00000000000..6b0c4a682ba --- /dev/null +++ b/tests/slow/src/test/scala/tests/feature/PcReferencesLspSuite.scala @@ -0,0 +1,232 @@ +package tests.feature + +import scala.meta.internal.metals +import scala.meta.internal.metals.MetalsEnrichments._ + +import munit.TestOptions +import org.eclipse.lsp4j.Location +import org.eclipse.lsp4j.ReferenceContext +import org.eclipse.lsp4j.ReferenceParams +import tests.BaseLspSuite +import tests.BuildInfo +import tests.FileLayout +import tests.RangeReplace + +class PcReferencesLspSuite + extends BaseLspSuite("pc-references") + with RangeReplace { + + for { + scalaVersion <- List(metals.BuildInfo.scala213, metals.BuildInfo.scala3) + } { + check( + s"basic1_$scalaVersion", + """|/a/src/main/scala/Defn.scala + |package a + |object O { + | val <> = 1 + | val k = <> + |} + |/a/src/main/scala/Main.scala + |package a + |object Main { + | val g = O.<> + |} + |""".stripMargin, + scalaVersion, + ) + + check( + s"basic2_$scalaVersion", + """|/a/src/main/scala/Defn.scala + |package a + |object O { + | val <> = 1 + | val k = <> + |} + |/a/src/main/scala/Main.scala + |package a + |object Main { + | val g = O.<> + |} + |""".stripMargin, + scalaVersion, + ) + + check( + s"basic3_$scalaVersion", + """|/a/src/main/scala/Defn.scala + |package a + |object O { + | val <> = 1 + |} + |/a/src/main/scala/Main.scala + |package a + |object Main { + | val g = O.<> + |} + |""".stripMargin, + scalaVersion, + ) + + check( + s"basic4_$scalaVersion", + """|/a/src/main/scala/Defn.scala + |package a + |object Foo { + | val bar@@ = 1 + |} + |/a/src/main/scala/Main.scala + |package a + |object Main { + | val h = Foo.<> + |} + |""".stripMargin, + scalaVersion, + includeDefinition = false, + ) + + check( + s"local_$scalaVersion", + """|/a/src/main/scala/Defn.scala + |package a + |object O { + | val bar = 5 + | def f = bar + | def foo = { + | val <> = 3 + | <> + 4 + | } + |} + |""".stripMargin, + scalaVersion, + ) + } + + check( + s"apply-test_${metals.BuildInfo.scala3}", + """|/a/src/main/scala/Defn.scala + |package a + |class O(v: Int) { } + |object O { + | def <>() = new O(1) + |} + |/a/src/main/scala/Main.scala + |package a + |object Main { + | val g = <>() + |} + |""".stripMargin, + metals.BuildInfo.scala3, + ) + + check( + s"apply-test_${metals.BuildInfo.scala213}", + """|/a/src/main/scala/Defn.scala + |package a + |class O(v: Int) { } + |object O { + | def <>() = new O(1) + |} + |/a/src/main/scala/Main.scala + |package a + |object Main { + | val g = <> + |""".stripMargin, + metals.BuildInfo.scala213, + ) + + def check( + name: TestOptions, + input: String, + scalaVersion: String = BuildInfo.scalaVersion, + includeDefinition: Boolean = true, + defnFileName: String = "a/src/main/scala/Defn.scala", + ): Unit = + test(name) { + cleanWorkspace() + val fullInput = + s"""|$input + |/a/src/main/scala/Error.scala + |package a + |object Error { + | val foo: Int = "" + |} + |""".stripMargin + val files = FileLayout.mapFromString(fullInput) + val focusFile = + files.collectFirst { + case (pathStr, content) if content.contains("@@") => pathStr + }.get + + def paramsF = { + val content = files.get(focusFile).get + val actualContent = content.replaceAll("<<|>>", "") + val context = new ReferenceContext(includeDefinition) + server.offsetParams(focusFile, actualContent, workspace).map { + case (_, params) => + new ReferenceParams( + params.getTextDocument(), + params.getPosition(), + context, + ) + } + } + + val layout = fullInput.replaceAll("<<|>>|@@", "") + + def renderObtained(refs: List[Location]): String = { + val refsMap = refs + .groupMap(ref => + ref.getUri().toAbsolutePath.toRelative(workspace).toString + )(_.getRange()) + files + .map { case (pathStr, content) => + val actualContent = content.replaceAll("<<|>>", "") + val withMarkings = + if (pathStr == focusFile) { + val index = actualContent.indexOf("@@") + val code = actualContent.replaceAll("@@", "") + renderRangesAsString( + code, + refsMap.getOrElse(pathStr, Nil), + List((index, 2)), + Some(actualContent), + ) + } else { + renderRangesAsString( + actualContent, + refsMap.getOrElse(pathStr, Nil), + ) + } + s"""|/$pathStr + |$withMarkings""".stripMargin + } + .mkString("\n") + } + + for { + _ <- initialize( + s"""|/metals.json + |{ + | "a" : { + | "scalaVersion": "$scalaVersion" + | } + |} + |$layout""".stripMargin + ) + _ <- server.didOpen(defnFileName) + defnFileContent = files + .get(defnFileName) + .map(_.replaceAll("<<|>>|@@", "")) + .getOrElse("") + // load the file with definition to pc + _ <- server.hover(defnFileName, s"@@$defnFileContent", workspace) + _ <- server.didOpen(focusFile) + params <- paramsF + refs <- server.server.references(params).asScala + _ = assertNoDiff(renderObtained(refs.asScala.toList), fullInput) + } yield () + } +} diff --git a/tests/slow/src/test/scala/tests/gradle/GradleLspSuite.scala b/tests/slow/src/test/scala/tests/gradle/GradleLspSuite.scala index 5b620d0ffcb..e96e4a6288f 100644 --- a/tests/slow/src/test/scala/tests/gradle/GradleLspSuite.scala +++ b/tests/slow/src/test/scala/tests/gradle/GradleLspSuite.scala @@ -486,8 +486,9 @@ class GradleLspSuite extends BaseImportSuite("gradle-import") { """.stripMargin, ) // we should still have references despite fatal warning + refs <- server.workspaceReferences() _ = assertNoDiff( - server.workspaceReferences().references.map(_.symbol).mkString("\n"), + refs.references.map(_.symbol).sorted.mkString("\n"), """|_empty_/A. |_empty_/A.B. |_empty_/Warning. diff --git a/tests/slow/src/test/scala/tests/maven/MavenLspSuite.scala b/tests/slow/src/test/scala/tests/maven/MavenLspSuite.scala index 5a24ad3c21e..3909a918286 100644 --- a/tests/slow/src/test/scala/tests/maven/MavenLspSuite.scala +++ b/tests/slow/src/test/scala/tests/maven/MavenLspSuite.scala @@ -237,8 +237,9 @@ class MavenLspSuite extends BaseImportSuite("maven-import") { """.stripMargin, ) // we should still have references despite fatal warning + refs <- server.workspaceReferences() _ = assertNoDiff( - server.workspaceReferences().references.map(_.symbol).mkString("\n"), + refs.references.map(_.symbol).sorted.mkString("\n"), """|_empty_/A. |_empty_/A.B. |_empty_/Warning. diff --git a/tests/slow/src/test/scala/tests/mill/MillLspSuite.scala b/tests/slow/src/test/scala/tests/mill/MillLspSuite.scala index 7a0f49a67b7..6d2a14b9dcd 100644 --- a/tests/slow/src/test/scala/tests/mill/MillLspSuite.scala +++ b/tests/slow/src/test/scala/tests/mill/MillLspSuite.scala @@ -160,8 +160,9 @@ class MillLspSuite extends BaseImportSuite("mill-import") { """.stripMargin, ) // we should still have references despite fatal warning + refs <- server.workspaceReferences() _ = assertNoDiff( - server.workspaceReferences().references.map(_.symbol).mkString("\n"), + refs.references.map(_.symbol).sorted.mkString("\n"), """|_empty_/A. |_empty_/A.B. |_empty_/Warning. diff --git a/tests/slow/src/test/scala/tests/sbt/SbtBloopLspSuite.scala b/tests/slow/src/test/scala/tests/sbt/SbtBloopLspSuite.scala index 19c157bb870..e924f16f4d1 100644 --- a/tests/slow/src/test/scala/tests/sbt/SbtBloopLspSuite.scala +++ b/tests/slow/src/test/scala/tests/sbt/SbtBloopLspSuite.scala @@ -476,8 +476,9 @@ class SbtBloopLspSuite """.stripMargin, ) // we should still have references despite fatal warning + refs <- server.workspaceReferences() _ = assertNoDiff( - server.workspaceReferences().references.map(_.symbol).mkString("\n"), + refs.references.map(_.symbol).sorted.mkString("\n"), """|_empty_/A. |_empty_/A.B. |_empty_/Warning. diff --git a/tests/slow/src/test/scala/tests/sbt/SbtServerSuite.scala b/tests/slow/src/test/scala/tests/sbt/SbtServerSuite.scala index 9003241f8d9..9a5a5b4f98a 100644 --- a/tests/slow/src/test/scala/tests/sbt/SbtServerSuite.scala +++ b/tests/slow/src/test/scala/tests/sbt/SbtServerSuite.scala @@ -533,4 +533,50 @@ class SbtServerSuite fileContent => SbtBuildLayout(fileContent, V.scala213), ) + test("meta-build-references") { + cleanWorkspace() + + val buildSbt = + s"""|${SbtBuildLayout.commonSbtSettings} + |ThisBuild / scalaVersion := "${V.scala213}" + |val a = project.in(file(V.<>)) + |""".stripMargin + val buildSbtBase = buildSbt.replaceAll("<<|>>", "") + + val v = + s"""|object V { + | val <> = "a" + |} + |""".stripMargin + val vBase = v.replaceAll("<<|>>|@@", "") + + for { + _ <- initialize( + s"""|/project/build.properties + |sbt.version=${V.sbtVersion} + |/project/V.scala + |$vBase + |/build.sbt + |$buildSbtBase + |""".stripMargin + ) + _ <- server.server.indexingPromise.future + _ <- server.didOpen("project/V.scala") + _ <- + server.assertReferences( + "project/V.scala", + v.replaceAll("<<|>>", ""), + Map( + "project/V.scala" -> v.replaceAll("@@", ""), + "build.sbt" -> buildSbt, + ), + Map( + "project/V.scala" -> vBase, + "build.sbt" -> buildSbtBase, + ), + ) + + } yield () + } + } diff --git a/tests/unit/src/main/scala/tests/TestRanges.scala b/tests/unit/src/main/scala/tests/TestRanges.scala index 6bc359c5500..e07c96a6eed 100644 --- a/tests/unit/src/main/scala/tests/TestRanges.scala +++ b/tests/unit/src/main/scala/tests/TestRanges.scala @@ -20,7 +20,7 @@ object TestRanges extends RangeReplace { } yield file -> validLocations .foldLeft((code, List.empty[(Int, Int)])) { case ((base, alreadyAddedMarkings), location) => - replaceInRangeWithAdjustmens( + replaceInRangeWithAdjustments( code, base, location.getRange, diff --git a/tests/unit/src/main/scala/tests/TestingServer.scala b/tests/unit/src/main/scala/tests/TestingServer.scala index 20c7fd67ec5..b80373e3b23 100644 --- a/tests/unit/src/main/scala/tests/TestingServer.scala +++ b/tests/unit/src/main/scala/tests/TestingServer.scala @@ -394,8 +394,7 @@ final case class TestingServer( def assertReferenceDefinitionBijection()(implicit loc: munit.Location - ): Unit = { - val compare = workspaceReferences() + ): Future[Unit] = workspaceReferences().map { compare => assert(compare.definition.nonEmpty, "Definitions should not be empty") assert(compare.references.nonEmpty, "References should not be empty") Assertions.assertNoDiff( @@ -406,13 +405,11 @@ final case class TestingServer( def assertReferenceDefinitionDiff( expectedDiff: String - )(implicit loc: munit.Location): Unit = { - Assertions.assertNoDiff( - workspaceReferences().diff, - expectedDiff, + )(implicit loc: munit.Location): Future[Unit] = + workspaceReferences().map(refs => + Assertions.assertNoDiff(refs.diff, expectedDiff) ) - } - def workspaceReferences(): WorkspaceSymbolReferences = { + def workspaceReferences(): Future[WorkspaceSymbolReferences] = { val inverse = mutable.Map.empty[SymbolReference, mutable.ListBuffer[Location]] val inputsCache = mutable.Map.empty[String, Input] @@ -459,27 +456,35 @@ final case class TestingServer( } val definition = Seq.newBuilder[SymbolReference] val references = Seq.newBuilder[SymbolReference] - for { - (ref, expectedLocations) <- inverse.toSeq.sortBy(_._1.symbol) - } { - val params = new ReferenceParams( - new TextDocumentIdentifier( - ref.location.getUri - ), - ref.location.getRange.getStart, - new ReferenceContext(true), + val resultFuture: Future[Unit] = + Future + .sequence( + for { + (ref, expectedLocations) <- inverse.toSeq.sortBy(_._1.symbol) + } yield { + val params = new ReferenceParams( + new TextDocumentIdentifier( + ref.location.getUri + ), + ref.location.getRange.getStart, + new ReferenceContext(true), + ) + server.referencesResult(params).map { obtainedLocations => + references ++= obtainedLocations.flatMap { result => + result.locations.map { l => + newRef(result.symbol, l) + } + } + definition ++= expectedLocations.map(l => newRef(ref.symbol, l)) + } + } + ) + .ignoreValue + resultFuture.map(_ => + WorkspaceSymbolReferences( + references.result().distinct, + definition.result().distinct, ) - val obtainedLocations = server.referencesResult(params) - references ++= obtainedLocations.flatMap { result => - result.locations.map { l => - newRef(result.symbol, l) - } - } - definition ++= expectedLocations.map(l => newRef(ref.symbol, l)) - } - WorkspaceSymbolReferences( - references.result().distinct, - definition.result().distinct, ) } @@ -1216,7 +1221,7 @@ final case class TestingServer( } } - private def offsetParams( + def offsetParams( filename: String, original: String, root: AbsolutePath, diff --git a/tests/unit/src/test/scala/tests/FingerprintsLspSuite.scala b/tests/unit/src/test/scala/tests/FingerprintsLspSuite.scala index adb8354bce0..c9fc508b272 100644 --- a/tests/unit/src/test/scala/tests/FingerprintsLspSuite.scala +++ b/tests/unit/src/test/scala/tests/FingerprintsLspSuite.scala @@ -67,8 +67,9 @@ class FingerprintsLspSuite extends BaseLspSuite("fingerprints") { | ^^^^^^^^ |""".stripMargin, ) + workspaceRefs <- server.workspaceReferences() _ = assertNoDiff( - newServer.workspaceReferences().referencesFormat, + workspaceRefs.referencesFormat, """|============= |= a/Adresses# |============= diff --git a/tests/unit/src/test/scala/tests/RenameLspSuite.scala b/tests/unit/src/test/scala/tests/RenameLspSuite.scala index dd2b7a43b11..a2d0f1af639 100644 --- a/tests/unit/src/test/scala/tests/RenameLspSuite.scala +++ b/tests/unit/src/test/scala/tests/RenameLspSuite.scala @@ -894,6 +894,34 @@ class RenameLspSuite extends BaseRenameLspSuite(s"rename") { expectedError = true, ) + renamed( + "renames-in-related-build-targets", + """|/a/src/main/scala/a/Main.scala + |package a + |trait <<@@A>> + |/a/src/main/scala/a/B.scala + |package a + |trait B extends <> + |/b/src/main/scala/b/B.scala + |package b + |import a.A + |trait B extends <> + |""".stripMargin, + metalsJson = Some( + s"""|{ + | "a" : { + | "scalaVersion": "${BuildInfo.scalaVersion}" + | }, + | "b" : { + | "scalaVersion": "${BuildInfo.scalaVersion}", + | "dependsOn": ["a"] + | } + |}""".stripMargin + ), + newName = "C", + nonOpened = Set("b/src/main/scala/b/B.scala"), + ) + override protected def libraryDependencies: List[String] = List("org.scalatest::scalatest:3.2.12", "io.circe::circe-generic:0.14.1")