From b0f6da983ad114e662c2872062b60b8a0e5a4bac Mon Sep 17 00:00:00 2001 From: Katarzyna Marek Date: Mon, 22 Jan 2024 12:06:14 +0100 Subject: [PATCH] get file content on demand --- .../src/main/scala/bench/Inflated.scala | 9 +- .../meta/internal/metals/Compilers.scala | 57 +++++----- .../internal/metals/IdentifierIndex.scala | 2 +- .../internal/metals/ReferenceProvider.scala | 27 ++++- .../src/main/java/scala/meta/pc/Buffers.java | 9 ++ .../scala/meta/pc/PcAdjustFileParams.java | 8 ++ .../scala/meta/pc/PresentationCompiler.java | 6 +- .../java/scala/meta/pc/ReferencesRequest.java | 10 ++ .../java/scala/meta/pc/ReferencesResult.java | 9 ++ .../pc/JavaPresentationCompiler.scala | 8 +- .../scala/meta/internal/pc/NoopBuffers.scala | 15 +++ .../internal/pc/ReferencesResultImpl.scala | 17 +++ .../internal/pc/PcReferencesProvider.scala | 104 ++++++++++++------ .../pc/ScalaPresentationCompiler.scala | 26 +++-- .../ScalaPresentationCompiler.scala | 68 ++++++++---- .../internal/pc/PcReferencesProvider.scala | 80 +++++++++----- .../pc/ScalaPresentationCompiler.scala | 24 ++-- .../tests/feature/PcReferencesLspSuite.scala | 2 +- 18 files changed, 332 insertions(+), 149 deletions(-) create mode 100644 mtags-interfaces/src/main/java/scala/meta/pc/Buffers.java create mode 100644 mtags-interfaces/src/main/java/scala/meta/pc/PcAdjustFileParams.java create mode 100644 mtags-interfaces/src/main/java/scala/meta/pc/ReferencesRequest.java create mode 100644 mtags-interfaces/src/main/java/scala/meta/pc/ReferencesResult.java create mode 100644 mtags-shared/src/main/scala/scala/meta/internal/pc/NoopBuffers.scala create mode 100644 mtags-shared/src/main/scala/scala/meta/internal/pc/ReferencesResultImpl.scala diff --git a/metals-bench/src/main/scala/bench/Inflated.scala b/metals-bench/src/main/scala/bench/Inflated.scala index 9b255ee2725..a109269e443 100644 --- a/metals-bench/src/main/scala/bench/Inflated.scala +++ b/metals-bench/src/main/scala/bench/Inflated.scala @@ -7,9 +7,12 @@ import scala.meta.internal.io.FileIO import scala.meta.io.AbsolutePath import scala.meta.io.Classpath -case class Inflated(inputs: List[(Input.VirtualFile, AbsolutePath)], linesOfCode: Long) { +case class Inflated( + inputs: List[(Input.VirtualFile, AbsolutePath)], + linesOfCode: Long, +) { def filter(f: Input.VirtualFile => Boolean): Inflated = { - val newInputs = inputs.filter{case (input, _) => f(input)} + val newInputs = inputs.filter { case (input, _) => f(input) } val newLinesOfCode = newInputs.foldLeft(0) { case (accum, input) => accum + input._1.text.linesIterator.length } @@ -19,7 +22,7 @@ case class Inflated(inputs: List[(Input.VirtualFile, AbsolutePath)], linesOfCode Inflated(other.inputs ++ inputs, other.linesOfCode + linesOfCode) def foreach(f: Input.VirtualFile => Unit): Unit = - inputs.foreach{ case (file, _) => f(file)} + inputs.foreach { case (file, _) => f(file) } } object Inflated { diff --git a/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala b/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala index c6fb95289a6..505b7848ac8 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala @@ -1,5 +1,6 @@ package scala.meta.internal.metals +import java.net.URI import java.nio.file.Path import java.nio.file.Paths import java.util.Collections @@ -29,6 +30,7 @@ import scala.meta.internal.pc.ScalaPresentationCompiler import scala.meta.internal.worksheets.WorksheetPcData import scala.meta.internal.worksheets.WorksheetProvider import scala.meta.io.AbsolutePath +import scala.meta.pc import scala.meta.pc.AutoImportsResult import scala.meta.pc.CancelToken import scala.meta.pc.HoverSignature @@ -36,7 +38,6 @@ import scala.meta.pc.OffsetParams import scala.meta.pc.PresentationCompiler import scala.meta.pc.SymbolSearch import scala.meta.pc.SyntheticDecoration -import scala.meta.pc.VirtualFileParams import ch.epfl.scala.bsp4j.BuildTargetIdentifier import ch.epfl.scala.bsp4j.CompileReport @@ -47,6 +48,7 @@ import org.eclipse.lsp4j.CompletionParams import org.eclipse.lsp4j.Diagnostic import org.eclipse.lsp4j.DocumentHighlight import org.eclipse.lsp4j.InitializeParams +import org.eclipse.lsp4j.Location import org.eclipse.lsp4j.ReferenceParams import org.eclipse.lsp4j.RenameParams import org.eclipse.lsp4j.SelectionRange @@ -88,6 +90,18 @@ class Compilers( extends Cancelable { val plugins = new CompilerPlugins() + val pcBuffers: pc.Buffers = new pc.Buffers { + override def getFile( + uri: URI, + scalaVersion: String, + ): ju.Optional[pc.PcAdjustFileParams] = + Try(sourceAdjustments(uri.toString(), scalaVersion)).toOption + .map[pc.PcAdjustFileParams] { case (vFile, _, adjust) => + PcAdjustFileParams(CompilerVirtualFileParams(uri, vFile.text), adjust) + } + .asJava + } + // Not a TrieMap because we want to avoid loading duplicate compilers for the same build target. // Not a `j.u.c.ConcurrentHashMap` because it can deadlock in `computeIfAbsent` when the absent // function is expensive, which is the case here. @@ -687,36 +701,20 @@ class Compilers( def references( params: ReferenceParams, - targetFiles: Iterator[AbsolutePath], + targetFiles: List[AbsolutePath], token: CancelToken, ): Future[List[ReferencesResult]] = { - withPCAndAdjustLsp(params) { (pc, pos, adjust) => - val targets = targetFiles.map { target => - target.toURI.toString -> { - val (vFile, _, adjustLsp) = - sourceAdjustments( - target.toURI.toString(), - pc.scalaVersion(), - ) - val params = - CompilerVirtualFileParams(target.toURI, vFile.text, token) - (params, adjustLsp) - } - }.toMap - val targetFilesParams: List[VirtualFileParams] = - targets.values.map(_._1).toList - pc.references( + withPCAndAdjustLsp(params) { (pc, pos, _) => + val request = PcReferencesRequest( CompilerOffsetParamsUtils.fromPos(pos, token), - targetFilesParams.asJava, params.getContext().isIncludeDeclaration(), - ).asScala + targetFiles.map(_.toURI).asJava, + ) + pc.references(request) + .asScala .map( _.asScala.toList.map { defRes => - val locations = defRes - .locations() - .asScala - .toList - .map(loc => targets(loc.getUri())._2.adjustLocation(loc)) + val locations = defRes.locations().asScala.toList ReferencesResult(defRes.symbol(), locations) } ) @@ -1182,6 +1180,7 @@ class Compilers( .withWorkspace(workspace.toNIO) .withScheduledExecutorService(sh) .withReportsLoggerLevel(MetalsServerConfig.default.loglevel) + .withBuffers(pcBuffers) .withConfiguration { val options = InitializationOptions.from(initializeParams).compilerOptions @@ -1332,3 +1331,11 @@ object Compilers { case object Default extends PresentationCompilerKey } } + +case class PcAdjustFileParams( + params: pc.VirtualFileParams, + adjustLsp: AdjustLspData, +) extends pc.PcAdjustFileParams { + override def adjustLocation(location: Location): Location = + adjustLsp.adjustLocation(location) +} diff --git a/metals/src/main/scala/scala/meta/internal/metals/IdentifierIndex.scala b/metals/src/main/scala/scala/meta/internal/metals/IdentifierIndex.scala index 202a5153bbc..4f8c4c5c5e5 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/IdentifierIndex.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/IdentifierIndex.scala @@ -37,7 +37,7 @@ class IdentifierIndex { def collectIdentifiers( text: String, - dialect: Dialect + dialect: Dialect, ): Iterable[String] = { val identifiers = Set.newBuilder[String] diff --git a/metals/src/main/scala/scala/meta/internal/metals/ReferenceProvider.scala b/metals/src/main/scala/scala/meta/internal/metals/ReferenceProvider.scala index 44175f45648..6e8e5487dbc 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/ReferenceProvider.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/ReferenceProvider.scala @@ -1,7 +1,10 @@ package scala.meta.internal.metals +import java.net.URI import java.nio.charset.StandardCharsets import java.nio.file.Path +import java.util.concurrent.TimeUnit +import java.util.concurrent.TimeoutException import scala.collection.concurrent.TrieMap import scala.concurrent.ExecutionContext @@ -11,6 +14,7 @@ import scala.util.matching.Regex import scala.meta.Importee import scala.meta.inputs.Input +import scala.meta.internal.async.CompletableCancelToken import scala.meta.internal.metals.MetalsEnrichments._ import scala.meta.internal.metals.ResolvedSymbolOccurrence import scala.meta.internal.mtags.DefinitionAlternatives.GlobalSymbol @@ -26,6 +30,8 @@ import scala.meta.internal.semanticdb.TextDocument import scala.meta.internal.semanticdb.TextDocuments import scala.meta.internal.{semanticdb => s} import scala.meta.io.AbsolutePath +import scala.meta.pc.OffsetParams +import scala.meta.pc.ReferencesRequest import scala.meta.tokens.Token.Ident import ch.epfl.scala.bsp4j.BuildTargetIdentifier @@ -50,7 +56,9 @@ final class ReferenceProvider( val identifierIndex: IdentifierIndex = new IdentifierIndex def addIdentifiers(file: AbsolutePath, set: Iterable[String]): Unit = - buildTargets.inverseSources(file).map(id => identifierIndex.addIdentifiers(file, id, set)) + buildTargets + .inverseSources(file) + .map(id => identifierIndex.addIdentifiers(file, id, set)) def indexIdentifiers( path: AbsolutePath, @@ -304,8 +312,15 @@ final class ReferenceProvider( pathsForName(buildTarget, localPath, name).filter(filterTargetFiles) } if targetFiles.nonEmpty + cancelToken = new CompletableCancelToken } yield compilers - .references(params, targetFiles, EmptyCancelToken) + .references(params, targetFiles.toList, cancelToken) + .withTimeout(30, TimeUnit.SECONDS) + .recover { case _: TimeoutException => + cancelToken.cancel() + scribe.warn("pc references search timed out after 30 seconds") + Nil + } result.getOrElse(Future.successful(Nil)) } @@ -445,7 +460,7 @@ final class ReferenceProvider( val isLocal = occ.symbol.isLocal if (isLocal) compilers - .references(params, Iterator(source), EmptyCancelToken) + .references(params, List(source), EmptyCancelToken) .map(_.flatMap(_.locations)) else { /* search local in the following cases: @@ -663,3 +678,9 @@ object SyntheticPackageObject { def unapply(str: String): Option[String] = Option.when(regex.matches(str))(str) } + +case class PcReferencesRequest( + params: OffsetParams, + includeDefinition: Boolean, + targetUris: java.util.List[URI], +) extends ReferencesRequest diff --git a/mtags-interfaces/src/main/java/scala/meta/pc/Buffers.java b/mtags-interfaces/src/main/java/scala/meta/pc/Buffers.java new file mode 100644 index 00000000000..34aa567cfe1 --- /dev/null +++ b/mtags-interfaces/src/main/java/scala/meta/pc/Buffers.java @@ -0,0 +1,9 @@ +package scala.meta.pc; + +import java.util.concurrent.CompletableFuture; +import java.util.Optional; +import java.net.URI; + +public interface Buffers { + Optional getFile(URI uri, String scalaVersion); +} diff --git a/mtags-interfaces/src/main/java/scala/meta/pc/PcAdjustFileParams.java b/mtags-interfaces/src/main/java/scala/meta/pc/PcAdjustFileParams.java new file mode 100644 index 00000000000..25fee4aa039 --- /dev/null +++ b/mtags-interfaces/src/main/java/scala/meta/pc/PcAdjustFileParams.java @@ -0,0 +1,8 @@ +package scala.meta.pc; + +import org.eclipse.lsp4j.Location; + +public interface PcAdjustFileParams { + VirtualFileParams params(); + Location adjustLocation(Location location); +} diff --git a/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java b/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java index 21baac877b9..9530a29aaef 100644 --- a/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java +++ b/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java @@ -111,10 +111,14 @@ public CompletableFuture> semanticTokens(VirtualFileParams params) { */ public abstract CompletableFuture> documentHighlight(OffsetParams params); + public PresentationCompiler withBuffers(Buffers buffers) { + return this; + } + /** * Returns the references of the symbol under the current position in the target files. */ - public CompletableFuture> references(OffsetParams params, java.util.List targetFiles, boolean includeDefinition) { + public CompletableFuture> references(ReferencesRequest params) { return CompletableFuture.completedFuture(Collections.emptyList()); } diff --git a/mtags-interfaces/src/main/java/scala/meta/pc/ReferencesRequest.java b/mtags-interfaces/src/main/java/scala/meta/pc/ReferencesRequest.java new file mode 100644 index 00000000000..627788b252c --- /dev/null +++ b/mtags-interfaces/src/main/java/scala/meta/pc/ReferencesRequest.java @@ -0,0 +1,10 @@ +package scala.meta.pc; + +import java.util.List; +import java.net.URI; + +public interface ReferencesRequest { + OffsetParams params(); + boolean includeDefinition(); + List targetUris(); +} diff --git a/mtags-interfaces/src/main/java/scala/meta/pc/ReferencesResult.java b/mtags-interfaces/src/main/java/scala/meta/pc/ReferencesResult.java new file mode 100644 index 00000000000..cfab7391978 --- /dev/null +++ b/mtags-interfaces/src/main/java/scala/meta/pc/ReferencesResult.java @@ -0,0 +1,9 @@ +package scala.meta.pc; + +import java.util.List; +import org.eclipse.lsp4j.Location; + +public interface ReferencesResult { + String symbol(); + List locations(); +} diff --git a/mtags-java/src/main/scala/scala/meta/internal/pc/JavaPresentationCompiler.scala b/mtags-java/src/main/scala/scala/meta/internal/pc/JavaPresentationCompiler.scala index 4bd1fb6352a..edc04f8c9a1 100644 --- a/mtags-java/src/main/scala/scala/meta/internal/pc/JavaPresentationCompiler.scala +++ b/mtags-java/src/main/scala/scala/meta/internal/pc/JavaPresentationCompiler.scala @@ -20,6 +20,8 @@ import scala.meta.pc.OffsetParams import scala.meta.pc.PresentationCompiler import scala.meta.pc.PresentationCompilerConfig import scala.meta.pc.RangeParams +import scala.meta.pc.ReferencesRequest +import scala.meta.pc.ReferencesResult import scala.meta.pc.SymbolSearch import scala.meta.pc.SyntheticDecoration import scala.meta.pc.SyntheticDecorationsParams @@ -99,10 +101,8 @@ case class JavaPresentationCompiler( CompletableFuture.completedFuture(Nil.asJava) override def references( - params: OffsetParams, - targetFiles: util.List[VirtualFileParams], - includeDefinition: Boolean - ): CompletableFuture[util.List[DefinitionResult]] = + params: ReferencesRequest + ): CompletableFuture[util.List[ReferencesResult]] = CompletableFuture.completedFuture(Nil.asJava) override def getTasty( diff --git a/mtags-shared/src/main/scala/scala/meta/internal/pc/NoopBuffers.scala b/mtags-shared/src/main/scala/scala/meta/internal/pc/NoopBuffers.scala new file mode 100644 index 00000000000..1d6096274d8 --- /dev/null +++ b/mtags-shared/src/main/scala/scala/meta/internal/pc/NoopBuffers.scala @@ -0,0 +1,15 @@ +package scala.meta.internal.pc + +import java.net.URI +import java.util.Optional + +import scala.meta.pc.Buffers +import scala.meta.pc.PcAdjustFileParams + +object NoopBuffers extends Buffers { + override def getFile( + uri: URI, + scalaVersion: String + ): Optional[PcAdjustFileParams] = Optional.empty() + +} diff --git a/mtags-shared/src/main/scala/scala/meta/internal/pc/ReferencesResultImpl.scala b/mtags-shared/src/main/scala/scala/meta/internal/pc/ReferencesResultImpl.scala new file mode 100644 index 00000000000..1335b69bed8 --- /dev/null +++ b/mtags-shared/src/main/scala/scala/meta/internal/pc/ReferencesResultImpl.scala @@ -0,0 +1,17 @@ +package scala.meta.internal.pc + +import java.{util => ju} + +import scala.meta.pc.ReferencesResult + +import org.eclipse.lsp4j.Location + +case class ReferencesResultImpl( + symbol: String, + locations: ju.List[Location] +) extends ReferencesResult + +object ReferencesResultImpl { + def empty: ReferencesResult = + ReferencesResultImpl("", ju.Collections.emptyList()) +} diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/PcReferencesProvider.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/PcReferencesProvider.scala index 98b591246a7..5f27237db83 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/PcReferencesProvider.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/PcReferencesProvider.scala @@ -2,55 +2,87 @@ package scala.meta.internal.pc import scala.meta.internal.jdk.CollectionConverters._ import scala.meta.internal.mtags.MtagsEnrichments._ -import scala.meta.pc.DefinitionResult -import scala.meta.pc.OffsetParams -import scala.meta.pc.VirtualFileParams +import scala.meta.pc.Buffers +import scala.meta.pc.PcAdjustFileParams +import scala.meta.pc.ReferencesRequest +import scala.meta.pc.ReferencesResult import org.eclipse.{lsp4j => l} class PcReferencesProvider( compiler: MetalsGlobal, - params: OffsetParams, - targetFiles: List[VirtualFileParams], - includeDefinition: Boolean -) extends WithCompilationUnit(compiler, params) + request: ReferencesRequest, + buffers: Buffers, + scalaVersion: String +) extends WithCompilationUnit(compiler, request.params()) with PcSymbolSearch { - def result(): List[DefinitionResult] = { - val result = for { - (sought, _) <- soughtSymbols.toList - params <- - if (sought.forall(_.isLocalToBlock)) - targetFiles.find(_.uri() == params.uri()).toList - else targetFiles - collected <- { - val collector = new WithCompilationUnit(compiler, params) - with PcCollector[Option[(String, l.Range)]] { - import compiler._ - override def collect(parent: Option[Tree])( - tree: Tree, - toAdjust: Position, - sym: Option[compiler.Symbol] - ): Option[(String, l.Range)] = { - val (pos, _) = toAdjust.adjust(text) - tree match { - case _: DefTree if !includeDefinition => None - case t => Some(compiler.semanticdbSymbol(t.symbol), pos.toLsp) - } + + @volatile var isCancelled: Boolean = false + def result(): List[ReferencesResult] = { + + val result: List[(String, l.Location)] = + soughtSymbols match { + case Some((sought, _)) if sought.nonEmpty => + def collect(adjustFileParams: PcAdjustFileParams) = { + val collector = + new WithCompilationUnit(compiler, adjustFileParams.params()) + with PcCollector[Option[(String, l.Range)]] { + import compiler._ + override def collect(parent: Option[Tree])( + tree: Tree, + toAdjust: Position, + sym: Option[compiler.Symbol] + ): Option[(String, l.Range)] = { + val (pos, _) = toAdjust.adjust(text) + tree match { + case _: DefTree if !request.includeDefinition() => None + case t => + Some(compiler.semanticdbSymbol(t.symbol), pos.toLsp) + } + } + } + + val result = + collector + .resultWithSought( + sought.asInstanceOf[Set[collector.compiler.Symbol]] + ) + .flatten + .map { case (symbol, range) => + ( + symbol, + adjustFileParams.adjustLocation( + new l.Location( + adjustFileParams.params.uri().toString(), + range + ) + ) + ) + } + compiler.unitOfFile.remove(unit.source.file) + result } - } - collector - .resultWithSought(sought.asInstanceOf[Set[collector.compiler.Symbol]]) - .flatten - .map { case (symbol, range) => - (symbol, new l.Location(params.uri().toString(), range)) + if (sought.forall(_.isLocalToBlock)) { + val file = buffers.getFile(params.uri(), scalaVersion) + if (file.isPresent() && !isCancelled) { + collect(file.get()) + } else Nil + } else { + val fileUris = request.targetUris().asScala.toList + fileUris.flatMap { uri => + val file = buffers.getFile(uri, scalaVersion) + if (file.isPresent()) collect(file.get()) + else Nil + } } + case _ => List.empty } - } yield collected + result .groupBy(_._1) .map { case (symbol, locs) => - DefinitionResultImpl(symbol, locs.map(_._2).asJava) + ReferencesResultImpl(symbol, locs.map(_._2).asJava) } .toList } diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/ScalaPresentationCompiler.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/ScalaPresentationCompiler.scala index f201a1a10bd..11b72de0204 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/ScalaPresentationCompiler.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/ScalaPresentationCompiler.scala @@ -28,6 +28,7 @@ import scala.meta.internal.metals.StdReportContext import scala.meta.internal.mtags.BuildInfo import scala.meta.internal.mtags.MtagsEnrichments._ import scala.meta.pc.AutoImportsResult +import scala.meta.pc.Buffers import scala.meta.pc.DefinitionResult import scala.meta.pc.DisplayableException import scala.meta.pc.HoverSignature @@ -36,6 +37,8 @@ import scala.meta.pc.OffsetParams import scala.meta.pc.PresentationCompiler import scala.meta.pc.PresentationCompilerConfig import scala.meta.pc.RangeParams +import scala.meta.pc.ReferencesRequest +import scala.meta.pc.ReferencesResult import scala.meta.pc.SymbolSearch import scala.meta.pc.SyntheticDecoration import scala.meta.pc.SyntheticDecorationsParams @@ -60,7 +63,8 @@ case class ScalaPresentationCompiler( sh: Option[ScheduledExecutorService] = None, config: PresentationCompilerConfig = PresentationCompilerConfigImpl(), folderPath: Option[Path] = None, - reportsLevel: ReportLevel = ReportLevel.Info + reportsLevel: ReportLevel = ReportLevel.Info, + buffers: Buffers = NoopBuffers ) extends PresentationCompiler { implicit val executionContext: ExecutionContextExecutor = ec @@ -358,22 +362,24 @@ case class ScalaPresentationCompiler( new PcDocumentHighlightProvider(pc.compiler(), params).highlights().asJava } + override def withBuffers(buffers: Buffers): PresentationCompiler = + copy(buffers = buffers) + override def references( - params: OffsetParams, - targetFiles: ju.List[VirtualFileParams], - includeDefinition: Boolean - ): CompletableFuture[ju.List[DefinitionResult]] = - compilerAccess.withInterruptableCompiler(Some(params))( - List.empty[DefinitionResult].asJava, - params.token() + params: ReferencesRequest + ): CompletableFuture[ju.List[ReferencesResult]] = { + compilerAccess.withInterruptableCompiler(Some(params.params()))( + List.empty[ReferencesResult].asJava, + params.params.token() ) { pc => new PcReferencesProvider( pc.compiler(), params, - targetFiles.asScala.toList, - includeDefinition + buffers, + scalaVersion ).result().asJava } + } override def semanticdbTextDocument( fileUri: URI, diff --git a/mtags/src/main/scala-3-wrapper/ScalaPresentationCompiler.scala b/mtags/src/main/scala-3-wrapper/ScalaPresentationCompiler.scala index b2b364617cd..a011e3b011c 100644 --- a/mtags/src/main/scala-3-wrapper/ScalaPresentationCompiler.scala +++ b/mtags/src/main/scala-3-wrapper/ScalaPresentationCompiler.scala @@ -2,6 +2,7 @@ package scala.meta.internal.pc import java.net.URI import java.nio.file.Path +import java.util.Optional import java.util.concurrent.CompletableFuture import java.util.concurrent.ExecutorService import java.util.concurrent.ScheduledExecutorService @@ -15,13 +16,17 @@ import scala.jdk.CollectionConverters.* import scala.meta.internal.metals.ReportLevel import scala.meta.internal.mtags.CommonMtagsEnrichments.* import scala.meta.pc.AutoImportsResult +import scala.meta.pc.Buffers import scala.meta.pc.DefinitionResult import scala.meta.pc.HoverSignature import scala.meta.pc.Node import scala.meta.pc.OffsetParams +import scala.meta.pc.PcAdjustFileParams import scala.meta.pc.PresentationCompiler import scala.meta.pc.PresentationCompilerConfig import scala.meta.pc.RangeParams +import scala.meta.pc.ReferencesRequest +import scala.meta.pc.ReferencesResult import scala.meta.pc.SymbolSearch import scala.meta.pc.SyntheticDecoration import scala.meta.pc.SyntheticDecorationsParams @@ -59,6 +64,7 @@ case class ScalaPresentationCompiler( config: PresentationCompilerConfig = PresentationCompilerConfigImpl(), folderPath: Option[Path] = None, reportsLevel: ReportLevel = ReportLevel.Info, + buffers: Buffers = NoopBuffers, ) extends PresentationCompiler: val underlying: DottyPresentationCompiler = new DottyPresentationCompiler( buildTargetIdentifier = buildTargetIdentifier, @@ -178,32 +184,35 @@ case class ScalaPresentationCompiler( ): CompletableFuture[ju.List[DocumentHighlight]] = underlying.documentHighlight(params) + override def withBuffers(buffers: Buffers): PresentationCompiler = + copy(buffers = buffers) + override def references( - params: OffsetParams, - targetFiles: ju.List[VirtualFileParams], - includeDefinition: Boolean, - ): CompletableFuture[ju.List[DefinitionResult]] = - targetFiles.asScala.toList match - case file :: Nil if file.uri() == params.uri() => + params: ReferencesRequest + ): CompletableFuture[ju.List[ReferencesResult]] = + FutureConverters + .toJava( FutureConverters - .toJava( - FutureConverters - .toScala(documentHighlight(params)) - .map { hightLightResult => - val locations = hightLightResult.asScala.collect { - case highlight - if highlight.getKind() == DocumentHighlightKind.Read || includeDefinition => - new Location( - params.uri().toString(), - highlight.getRange(), - ) - }.asJava - List(new DefinitionResultImpl("", locations)).asJava - } - ) - .toCompletableFuture - case _ => - CompletableFuture.completedFuture(Nil.asJava) + .toScala(documentHighlight(params.params())) + .map { hightLightResult => + val locations = hightLightResult.asScala.collect { + case highlight + if highlight.getKind() == DocumentHighlightKind.Read || params + .includeDefinition() => + val location = + new Location( + params.params().uri().toString(), + highlight.getRange(), + ) + val adjust = + buffers.getFile(params.params().uri(), scalaVersion()) + if adjust.isPresent() then adjust.get().adjustLocation(location) + else location + }.asJava + List(ReferencesResultImpl("", locations)).asJava + } + ) + .toCompletableFuture override def rename( params: OffsetParams, @@ -276,3 +285,14 @@ case class ScalaPresentationCompiler( underlying.inlineValue(params) end ScalaPresentationCompiler + +case class ReferencesResultImpl( + symbol: String, + locations: ju.List[Location], +) extends ReferencesResult + +object NoopBuffers extends Buffers: + override def getFile( + uri: URI, + scalaVersion: String, + ): Optional[PcAdjustFileParams] = Optional.empty() diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala index 935f5a54885..27eb0b75e6b 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala @@ -3,9 +3,10 @@ package scala.meta.internal.pc import scala.collection.JavaConverters.* import scala.meta.internal.mtags.MtagsEnrichments.* -import scala.meta.pc.DefinitionResult -import scala.meta.pc.OffsetParams -import scala.meta.pc.VirtualFileParams +import scala.meta.pc.Buffers +import scala.meta.pc.PcAdjustFileParams +import scala.meta.pc.ReferencesRequest +import scala.meta.pc.ReferencesResult import dotty.tools.dotc.ast.tpd import dotty.tools.dotc.ast.tpd.* @@ -19,34 +20,47 @@ import org.eclipse.lsp4j.Location class PcReferencesProvider( driver: InteractiveDriver, - params: OffsetParams, - targetFiles: List[VirtualFileParams], - includeDefinition: Boolean, + request: ReferencesRequest, + buffers: Buffers, + scalaVersion: String, ): - val symbolSearch = new WithCompilationUnit(driver, params) with PcSymbolSearch - def result(): List[DefinitionResult] = - val result = - for - (sought, _) <- symbolSearch.soughtSymbols.toList - params <- + val symbolSearch = new WithCompilationUnit(driver, request.params()) + with PcSymbolSearch + @volatile var isCancelled: Boolean = false + def result(): List[ReferencesResult] = + val allLocations = + symbolSearch.soughtSymbols match + case Some((sought, _)) if sought.nonEmpty => given Context = symbolSearch.ctx + if sought.forall(_.is(Flags.Local)) then + val file = buffers.getFile(request.params().uri(), scalaVersion) + if file.isPresent() + then collectForFile(sought, file.get()) + else List.empty + else + val fileUris = request.targetUris().asScala.toList + fileUris.flatMap { uri => + val file = buffers.getFile(uri, scalaVersion) + if file.isPresent() && !isCancelled then + collectForFile(sought, file.get()) + else List.empty + } + end if + case _ => List.empty - if sought.forall(_.is(Flags.Local)) - then targetFiles.find(_.uri() == params.uri()).toList - else targetFiles - - collected <- collectForFile(sought, params) - yield collected - result + allLocations .groupBy(_._1) .map { case (symbol, locs) => - DefinitionResultImpl(symbol, locs.map(_._2).asJava) + ReferencesResultImpl(symbol, locs.map(_._2).asJava) } .toList end result - private def collectForFile(sought: Set[Symbol], params: VirtualFileParams) = - new WithCompilationUnit(driver, params) + private def collectForFile( + sought: Set[Symbol], + adjustParams: PcAdjustFileParams, + ) = + val collector = new WithCompilationUnit(driver, adjustParams.params()) with PcCollector[Option[(String, lsp4j.Range)]]: def collect(parent: Option[Tree])( tree: Tree | EndMarker, @@ -55,14 +69,24 @@ class PcReferencesProvider( ): Option[(String, lsp4j.Range)] = val (pos, _) = toAdjust.adjust(text) tree match - case _: DefTree if !includeDefinition => None + case _: DefTree if !request.includeDefinition() => None case t: Tree => val sym = symbol.getOrElse(t.symbol) Some(SemanticdbSymbols.symbolName(sym), pos.toLsp) case _ => None - .resultWithSought(sought) - .flatten - .map { case (symbol, range) => - (symbol, new Location(params.uri().toString(), range)) - } + val results = + collector + .resultWithSought(sought) + .flatten + .map { case (symbol, range) => + ( + symbol, + adjustParams.adjustLocation( + new Location(collector.uri.toString(), range) + ), + ) + } + driver.close(collector.uri) + results + end collectForFile end PcReferencesProvider diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala index 8a5dce5fdac..2a78d9edc45 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala @@ -20,6 +20,7 @@ import scala.meta.internal.metals.ReportContext import scala.meta.internal.metals.ReportLevel import scala.meta.internal.metals.StdReportContext import scala.meta.internal.mtags.BuildInfo +import scala.meta.internal.mtags.BuildInfo.scalaCompilerVersion import scala.meta.internal.pc.completions.CompletionProvider import scala.meta.internal.pc.completions.OverrideCompletions import scala.meta.pc.* @@ -40,6 +41,7 @@ case class ScalaPresentationCompiler( config: PresentationCompilerConfig = PresentationCompilerConfigImpl(), folderPath: Option[Path] = None, reportsLevel: ReportLevel = ReportLevel.Info, + buffers: Buffers = NoopBuffers, ) extends PresentationCompiler: def this() = this("", None, Nil, Nil) @@ -168,22 +170,18 @@ case class ScalaPresentationCompiler( PcDocumentHighlightProvider(driver, params).highlights.asJava } + override def withBuffers(buffers: Buffers): PresentationCompiler = + copy(buffers = buffers) + override def references( - params: OffsetParams, - targetFiles: ju.List[VirtualFileParams], - includeDefinition: Boolean, - ): CompletableFuture[ju.List[DefinitionResult]] = - compilerAccess.withNonInterruptableCompiler(Some(params))( - List.empty[DefinitionResult].asJava, - params.token, + params: ReferencesRequest + ): CompletableFuture[ju.List[ReferencesResult]] = + compilerAccess.withNonInterruptableCompiler(Some(params.params()))( + List.empty[ReferencesResult].asJava, + params.params().token, ) { access => val driver = access.compiler() - PcReferencesProvider( - driver, - params, - targetFiles.asScala.toList, - includeDefinition, - ) + PcReferencesProvider(driver, params, buffers, scalaVersion) .result() .asJava } diff --git a/tests/slow/src/test/scala/tests/feature/PcReferencesLspSuite.scala b/tests/slow/src/test/scala/tests/feature/PcReferencesLspSuite.scala index 4b155ffc1db..6702c0a55b0 100644 --- a/tests/slow/src/test/scala/tests/feature/PcReferencesLspSuite.scala +++ b/tests/slow/src/test/scala/tests/feature/PcReferencesLspSuite.scala @@ -73,7 +73,7 @@ class PcReferencesLspSuite } } - def refFiles = files.keysIterator.map(server.toPath(_)) + def refFiles = files.keysIterator.map(server.toPath(_)).toList val layout = input.replaceAll("<<|>>|@@", "")