diff --git a/metals/src/main/scala/scala/meta/internal/metals/AdjustLspData.scala b/metals/src/main/scala/scala/meta/internal/metals/AdjustLspData.scala index 5c2494e1aae..ce1b2f1dcaf 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/AdjustLspData.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/AdjustLspData.scala @@ -50,6 +50,9 @@ trait AdjustLspData { diag } + def adjustLocation(location: Location): Location = + new Location(location.getUri(), adjustRange(location.getRange())) + def adjustLocations( locations: java.util.List[Location] ): ju.List[Location] diff --git a/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala b/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala index 45d1600efa4..637ec186ab1 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala @@ -36,6 +36,7 @@ import scala.meta.pc.OffsetParams import scala.meta.pc.PresentationCompiler import scala.meta.pc.SymbolSearch import scala.meta.pc.SyntheticDecoration +import scala.meta.pc.VirtualFileParams import ch.epfl.scala.bsp4j.BuildTargetIdentifier import ch.epfl.scala.bsp4j.CompileReport @@ -46,6 +47,8 @@ import org.eclipse.lsp4j.CompletionParams import org.eclipse.lsp4j.Diagnostic import org.eclipse.lsp4j.DocumentHighlight import org.eclipse.lsp4j.InitializeParams +import org.eclipse.lsp4j.Location +import org.eclipse.lsp4j.ReferenceParams import org.eclipse.lsp4j.RenameParams import org.eclipse.lsp4j.SelectionRange import org.eclipse.lsp4j.SelectionRangeParams @@ -683,6 +686,39 @@ class Compilers( } }.getOrElse(Future.successful(Nil.asJava)) + def references( + params: ReferenceParams, + targetFiles: List[AbsolutePath], + token: CancelToken, + ): Future[List[Location]] = { + withPCAndAdjustLsp(params) { (pc, pos, adjust) => + val targets = targetFiles.map { target => + target.toURI.toString -> { + val (vFile, _, adjustLsp) = + sourceAdjustments( + target.toURI.toString(), + pc.scalaVersion(), + ) + val params = + CompilerVirtualFileParams(target.toURI, vFile.text, token) + (params, adjustLsp) + } + }.toMap + val targetFilesParams: List[VirtualFileParams] = + targets.values.map(_._1).toList + pc.references( + CompilerOffsetParamsUtils.fromPos(pos, token), + targetFilesParams.asJava, + params.getContext().isIncludeDeclaration(), + ).asScala + .map( + _.asScala.toList.map(loc => + targets(loc.getUri())._2.adjustLocation(loc) + ) + ) + } + }.getOrElse(Future.successful(Nil)) + def extractMethod( doc: TextDocumentIdentifier, range: LspRange, diff --git a/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala b/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala index 1625d5ae708..436057ab69d 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala @@ -549,6 +549,59 @@ class MetalsLspService( implementationProvider, ) + val worksheetProvider: WorksheetProvider = { + val worksheetPublisher = + if (clientConfig.isDecorationProvider) + new DecorationWorksheetPublisher( + clientConfig.isInlineDecorationProvider() + ) + else + new WorkspaceEditWorksheetPublisher(buffers, trees) + + register( + new WorksheetProvider( + folder, + buffers, + buildTargets, + languageClient, + () => userConfig, + statusBar, + diagnostics, + embedded, + worksheetPublisher, + compilations, + scalaVersionSelector, + ) + ) + } + + private val symbolSearch: MetalsSymbolSearch = new MetalsSymbolSearch( + symbolDocs, + workspaceSymbols, + definitionProvider, + ) + + val compilers: Compilers = register( + new Compilers( + folder, + clientConfig, + () => userConfig, + buildTargets, + buffers, + symbolSearch, + embedded, + statusBar, + sh, + initializeParams, + () => excludedPackageHandler, + scalaVersionSelector, + trees, + mtagsResolver, + sourceMapper, + worksheetProvider, + ) + ) + private val referencesProvider: ReferenceProvider = new ReferenceProvider( folder, semanticdbs, @@ -557,6 +610,7 @@ class MetalsLspService( remote, trees, buildTargets, + compilers, ) private val syntheticHoverProvider: SyntheticHoverProvider = @@ -640,59 +694,6 @@ class MetalsLspService( clientConfig.icons, ) - private val symbolSearch: MetalsSymbolSearch = new MetalsSymbolSearch( - symbolDocs, - workspaceSymbols, - definitionProvider, - ) - - val worksheetProvider: WorksheetProvider = { - val worksheetPublisher = - if (clientConfig.isDecorationProvider) - new DecorationWorksheetPublisher( - clientConfig.isInlineDecorationProvider() - ) - else - new WorkspaceEditWorksheetPublisher(buffers, trees) - - register( - new WorksheetProvider( - folder, - buffers, - buildTargets, - languageClient, - () => userConfig, - statusBar, - diagnostics, - embedded, - worksheetPublisher, - compilations, - scalaVersionSelector, - ) - ) - } - - private val compilers: Compilers = register( - new Compilers( - folder, - clientConfig, - () => userConfig, - buildTargets, - buffers, - symbolSearch, - embedded, - statusBar, - sh, - initializeParams, - () => excludedPackageHandler, - scalaVersionSelector, - trees, - mtagsResolver, - sourceMapper, - worksheetProvider, - ) - ) - private val renameProvider: RenameProvider = new RenameProvider( referencesProvider, implementationProvider, @@ -1519,7 +1520,9 @@ class MetalsLspService( override def references( params: ReferenceParams ): CompletableFuture[util.List[Location]] = - CancelTokens { _ => referencesResult(params).flatMap(_.locations).asJava } + CancelTokens.future { _ => + referencesResult(params).map(_.flatMap(_.locations).asJava) + } // Triggers a cascade compilation and tries to find new references to a given symbol. // It's not possible to stream reference results so if we find new symbols we notify the @@ -1527,10 +1530,10 @@ class MetalsLspService( private def compileAndLookForNewReferences( params: ReferenceParams, result: List[ReferencesResult], - ): Unit = { + ): Future[Unit] = { val path = params.getTextDocument.getUri.toAbsolutePath val old = path.toInputFromBuffers(buffers) - compilations.cascadeCompileFiles(Seq(path)).foreach { _ => + compilations.cascadeCompileFiles(Seq(path)).flatMap { _ => val newBuffer = path.toInputFromBuffers(buffers) val newParams: Option[ReferenceParams] = if (newBuffer.text == old.text) Some(params) @@ -1553,46 +1556,54 @@ class MetalsLspService( ) } newParams match { - case None => + case None => Future.unit case Some(p) => - val newResult = referencesProvider.references(p) - val diff = newResult - .flatMap(_.locations) - .length - result.flatMap(_.locations).length - val diffSyms: Set[String] = - newResult.map(_.symbol).toSet -- result.map(_.symbol).toSet - if (diffSyms.nonEmpty && diff > 0) { - import scala.meta.internal.semanticdb.Scala._ - val names = - diffSyms.map(sym => s"'${sym.desc.name.value}'").mkString(" and ") - val message = - s"Found new symbol references for $names, try running again." - scribe.info(message) - statusBar - .addMessage(clientConfig.icons.info + message) + for { + newResult <- referencesProvider.references(p) + } yield { + val diff = newResult + .flatMap(_.locations) + .length - result.flatMap(_.locations).length + val diffSyms: Set[String] = + newResult.map(_.symbol).toSet -- result.map(_.symbol).toSet + if (diffSyms.nonEmpty && diff > 0) { + import scala.meta.internal.semanticdb.Scala._ + val names = + diffSyms + .map(sym => s"'${sym.desc.name.value}'") + .mkString(" and ") + val message = + s"Found new symbol references for $names, try running again." + scribe.info(message) + statusBar + .addMessage(clientConfig.icons.info + message) + } } } } } - def referencesResult(params: ReferenceParams): List[ReferencesResult] = { + def referencesResult( + params: ReferenceParams + ): Future[List[ReferencesResult]] = { val timer = new Timer(time) - val results: List[ReferencesResult] = referencesProvider.references(params) - if (clientConfig.initialConfig.statistics.isReferences) { - if (results.forall(_.symbol.isEmpty)) { - scribe.info(s"time: found 0 references in $timer") - } else { - scribe.info( - s"time: found ${results.flatMap(_.locations).length} references to symbol '${results - .map(_.symbol) - .mkString("and")}' in $timer" - ) + referencesProvider.references(params).map { results => + if (clientConfig.initialConfig.statistics.isReferences) { + if (results.forall(_.symbol.isEmpty)) { + scribe.info(s"time: found 0 references in $timer") + } else { + scribe.info( + s"time: found ${results.flatMap(_.locations).length} references to symbol '${results + .map(_.symbol) + .mkString("and")}' in $timer" + ) + } } + if (results.nonEmpty) { + compileAndLookForNewReferences(params, results) + } + results } - if (results.nonEmpty) { - compileAndLookForNewReferences(params, results) - } - results } override def semanticTokensFull( @@ -2535,21 +2546,22 @@ class MetalsLspService( positionParams.getPosition(), new ReferenceContext(false), ) - val results = referencesResult(refParams) - if (results.flatMap(_.locations).isEmpty) { - // Fallback again to the original behavior that returns - // the definition location itself if no reference locations found, - // for avoiding the confusing messages like "No definition found ..." - definitionResult(positionParams, token) - } else { - Future.successful( - DefinitionResult( - locations = results.flatMap(_.locations).asJava, - symbol = results.head.symbol, - definition = None, - semanticdb = None, + referencesResult(refParams).flatMap { results => + if (results.flatMap(_.locations).isEmpty) { + // Fallback again to the original behavior that returns + // the definition location itself if no reference locations found, + // for avoiding the confusing messages like "No definition found ..." + definitionResult(positionParams, token) + } else { + Future.successful( + DefinitionResult( + locations = results.flatMap(_.locations).asJava, + symbol = results.head.symbol, + definition = None, + semanticdb = None, + ) ) - ) + } } } else { definitionResult(positionParams, token) diff --git a/metals/src/main/scala/scala/meta/internal/metals/ReferenceProvider.scala b/metals/src/main/scala/scala/meta/internal/metals/ReferenceProvider.scala index f90d7b78b2e..0a780727a05 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/ReferenceProvider.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/ReferenceProvider.scala @@ -41,7 +41,9 @@ final class ReferenceProvider( remote: RemoteLanguageServer, trees: Trees, buildTargets: BuildTargets, -) extends SemanticdbFeatureProvider { + compilers: Compilers, +)(implicit ec: ExecutionContext) + extends SemanticdbFeatureProvider { private var referencedPackages: BloomFilter[CharSequence] = BloomFilters.create(10000) @@ -133,7 +135,7 @@ final class ReferenceProvider( params: ReferenceParams, findRealRange: AdjustRange = noAdjustRange, includeSynthetics: Synthetic => Boolean = _ => true, - ): List[ReferencesResult] = { + ): Future[List[ReferencesResult]] = { val source = params.getTextDocument.getUri.toAbsolutePath semanticdbs.textDocument(source).documentIncludingStale match { case Some(doc) => @@ -142,7 +144,7 @@ final class ReferenceProvider( definition.positionOccurrences(source, params.getPosition, doc) if (posOccurrences.isEmpty) // handling case `import a.{A as @@B}` - occerencesForRenamedImport(source, params, doc) + occurrencesForRenamedImport(source, params, doc) else posOccurrences } if (results.isEmpty) { @@ -150,23 +152,25 @@ final class ReferenceProvider( s"No symbol found at ${params.getPosition()} for $source" ) } - results.map { result => - val occurrence = result.occurrence.get - val distance = result.distance - val alternatives = - referenceAlternatives(occurrence.symbol, source, doc) - val locations = references( - source, - params, - doc, - distance, - occurrence, - alternatives, - params.getContext.isIncludeDeclaration, - findRealRange, - includeSynthetics, - ) - ReferencesResult(occurrence.symbol, locations) + Future.sequence { + results.map { result => + val occurrence = result.occurrence.get + val distance = result.distance + val alternatives = + referenceAlternatives(occurrence.symbol, source, doc) + val locations = references( + source, + params, + doc, + distance, + occurrence, + alternatives, + params.getContext.isIncludeDeclaration, + findRealRange, + includeSynthetics, + ) + locations.map(ReferencesResult(occurrence.symbol, _)) + } } case None => scribe.debug(s"No semanticdb for $source") @@ -175,15 +179,17 @@ final class ReferenceProvider( // its dependencies (including rename provider) asynchronous. The remote // language server returns `Future.successful(None)` when it's disabled // so this isn't even blocking for normal usage of Metals. - List( - remote.referencesBlocking(params).getOrElse(ReferencesResult.empty) + Future.successful( + List( + remote.referencesBlocking(params).getOrElse(ReferencesResult.empty) + ) ) } } - // for `import package.{AA as B@@B}` we look for occurences at `import package.{@@AA as BB}`, - // since rename is not a position occurence in semanticDB - private def occerencesForRenamedImport( + // for `import package.{AA as B@@B}` we look for occurrences at `import package.{@@AA as BB}`, + // since rename is not a position occurrence in semanticDB + private def occurrencesForRenamedImport( source: AbsolutePath, params: ReferenceParams, document: TextDocument, @@ -363,12 +369,11 @@ final class ReferenceProvider( isIncludeDeclaration: Boolean, findRealRange: AdjustRange, includeSynthetics: Synthetic => Boolean, - ): Seq[Location] = { + ): Future[Seq[Location]] = { val isSymbol = alternatives + occ.symbol val isLocal = occ.symbol.isLocal /* search local in the following cases: - * - it's local symbol * - it's a dependency source. * We can't search references inside dependencies so at least show them in a source file. * - it's a standalone file that doesn't belong to any build target @@ -377,18 +382,22 @@ final class ReferenceProvider( isLocal || source.isDependencySource(workspace) || buildTargets.inverseSources(source).isEmpty val local = - if (searchLocal) - referenceLocations( - snapshot, - isSymbol, - distance, - params.getTextDocument.getUri, - isIncludeDeclaration, - findRealRange, - includeSynthetics, - source.isJava, + if (isLocal && !source.isJava) + compilers.references(params, List(source), EmptyCancelToken) + else if (searchLocal) + Future.successful( + referenceLocations( + snapshot, + isSymbol, + distance, + params.getTextDocument.getUri, + isIncludeDeclaration, + findRealRange, + includeSynthetics, + source.isJava, + ) ) - else Seq.empty + else Future.successful(Seq.empty) val workspaceRefs = if (!isLocal) @@ -402,7 +411,7 @@ final class ReferenceProvider( else Seq.empty - workspaceRefs ++ local + local.map(workspaceRefs ++ _) } private def referenceLocations( diff --git a/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala b/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala index 50ca3f5e634..87c2e481900 100644 --- a/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala +++ b/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala @@ -125,9 +125,9 @@ final class RenameProvider( ) .recoverWith { case _ => compilations.compilationFinished(source).flatMap { _ => - val defininionFuture = definitionProvider + val definitionFuture = definitionProvider .definition(source, params, token) - defininionFuture + definitionFuture .flatMap { definition => val textParams = new TextDocumentPositionParams( params.getTextDocument(), @@ -214,7 +214,7 @@ final class RenameProvider( findRealRange = findRealRange(newName), includeSynthetic, ) - .flatMap(_.locations) + .map(_.flatMap(_.locations)) definitionLocation = { if (parentSymbols.isEmpty) definition.locations.asScala @@ -235,9 +235,13 @@ final class RenameProvider( ), newName, ) - } yield implReferences.map(implLocs => - currentReferences ++ implLocs ++ companionRefs ++ definitionLocation - ) + } yield Future + .sequence( + List(implReferences, currentReferences, companionRefs) + ) + .map( + _.reduce(_ ++ _) ++ definitionLocation + ) Future .sequence(allReferences) .map(locs => @@ -390,7 +394,7 @@ final class RenameProvider( sym: String, source: AbsolutePath, newName: String, - ): Seq[Location] = { + ): Future[Seq[Location]] = { val results = for { companionSymbol <- companion(sym).toIterable loc <- @@ -399,15 +403,15 @@ final class RenameProvider( .asScala // no companion objects in Java files if loc.getUri().isScalaFilename - companionLocs <- - referenceProvider - .references( - toReferenceParams(loc, includeDeclaration = false), - findRealRange = findRealRange(newName), - ) - .flatMap(_.locations) :+ loc - } yield companionLocs - results.toList + } yield { + referenceProvider + .references( + toReferenceParams(loc, includeDeclaration = false), + findRealRange = findRealRange(newName), + ) + .map(_.flatMap(_.locations :+ loc)) + } + Future.sequence(results).map(_.flatten.toSeq) } private def companion(sym: String) = { @@ -450,20 +454,22 @@ final class RenameProvider( if (shouldCheckImplementation) { for { implLocs <- implementationProvider.implementations(textParams) - } yield { - for { - implLoc <- implLocs - locParams = toReferenceParams(implLoc, includeDeclaration = true) - loc <- + result <- { + val result = for { + implLoc <- implLocs + locParams = toReferenceParams(implLoc, includeDeclaration = true) + } yield { referenceProvider .references( locParams, findRealRange = findRealRange(newName), includeSynthetic, ) - .flatMap(_.locations) - } yield loc - } + .map(_.flatMap(_.locations)) + } + Future.sequence(result) + } + } yield result.flatten } else { Future.successful(Nil) } diff --git a/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java b/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java index af82134fa32..7a927eab34f 100644 --- a/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java +++ b/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java @@ -111,6 +111,13 @@ public CompletableFuture> semanticTokens(VirtualFileParams params) { */ public abstract CompletableFuture> documentHighlight(OffsetParams params); + /** + * Returns the references of the symbol under the current position in the target files. + */ + public CompletableFuture> references(OffsetParams params, java.util.List targetFiles, boolean includeDefinition) { + return CompletableFuture.completedFuture(Collections.emptyList()); + } + /** * Return decoded and pretty printed TASTy content for .scala or .tasty file. * diff --git a/mtags-java/src/main/scala/scala/meta/internal/pc/JavaPresentationCompiler.scala b/mtags-java/src/main/scala/scala/meta/internal/pc/JavaPresentationCompiler.scala index 6a4b7cb42c4..f94ef52d0d7 100644 --- a/mtags-java/src/main/scala/scala/meta/internal/pc/JavaPresentationCompiler.scala +++ b/mtags-java/src/main/scala/scala/meta/internal/pc/JavaPresentationCompiler.scala @@ -99,6 +99,13 @@ case class JavaPresentationCompiler( ): CompletableFuture[util.List[DocumentHighlight]] = CompletableFuture.completedFuture(Nil.asJava) + override def references( + params: OffsetParams, + targetFiles: util.List[VirtualFileParams], + includeDefinition: Boolean + ): CompletableFuture[util.List[lsp4j.Location]] = + CompletableFuture.completedFuture(Nil.asJava) + override def getTasty( targetUri: URI, isHttpEnabled: Boolean diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/PcCollector.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/PcCollector.scala index 0fb2d6bc09c..b9a3dd5b941 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/PcCollector.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/PcCollector.scala @@ -5,291 +5,78 @@ import scala.reflect.internal.util.RangePosition import scala.meta.pc.OffsetParams import scala.meta.pc.VirtualFileParams -abstract class PcCollector[T]( - val compiler: MetalsGlobal, - params: VirtualFileParams -) { +trait PcCollector[T] { self: WithCompilationUnit => import compiler._ def collect( parent: Option[Tree] )(tree: Tree, pos: Position, sym: Option[Symbol]): T - val unit: RichCompilationUnit = addCompilationUnit( - code = params.text(), - filename = params.uri().toString(), - cursor = None - ) - val offset: Int = params match { - case p: OffsetParams => p.offset() - case _: VirtualFileParams => 0 - } - val pos: Position = unit.position(offset) - lazy val text = unit.source.content - private val caseClassSynthetics: Set[Name] = Set(nme.apply, nme.copy) - typeCheck(unit) - lazy val typedTree: Tree = locateTree(pos) match { - // Check actual object if apply is synthetic - case sel @ Select(qual, name) if name == nme.apply && qual.pos == sel.pos => - qual - case Import(expr, _) if expr.pos.includes(pos) => - // imports seem to be marked as transparent - locateTree(pos, expr, acceptTransparent = true) - case t => t - } - - /** - * Find all symbols that should be shown together. - * For example class we want to also show companion object. - * - * @param sym symbol to find the alternative candidates for - * @return set of possible symbols - */ - def symbolAlternatives(sym: Symbol): Set[Symbol] = { - val all = - if (sym.isClass) { - if (sym.owner.isMethod) Set(sym) ++ sym.localCompanion(pos) - else Set(sym, sym.companionModule, sym.companion.moduleClass) - } else if (sym.isModuleOrModuleClass) { - if (sym.owner.isMethod) Set(sym) ++ sym.localCompanion(pos) - else Set(sym, sym.companionClass, sym.moduleClass) - } else if (sym.isTerm && (sym.owner.isClass || sym.owner.isConstructor)) { - val info = - if (sym.owner.isClass) sym.owner.info - else sym.owner.owner.info - Set( - sym, - info.member(sym.getterName), - info.member(sym.setterName), - info.member(sym.localName) - ) ++ constructorParam(sym) ++ sym.allOverriddenSymbols.toSet - } else Set(sym) - all.filter(s => s != NoSymbol && !s.isError) - } - - private def constructorParam( - symbol: Symbol - ): Set[Symbol] = { - if (symbol.owner.isClass) { - val info = symbol.owner.info.member(nme.CONSTRUCTOR).info - info.paramss.flatten.find(_.name == symbol.name).toSet - } else Set.empty - } - private lazy val namedArgCache = { - val parsedTree = parseTree(unit.source) - parsedTree.collect { case arg @ AssignOrNamedArg(_, rhs) => - rhs.pos -> arg - }.toMap - } - def fallbackSymbol(name: Name, pos: Position): Option[Symbol] = { - val context = doLocateImportContext(pos) - context.lookupSymbol(name, sym => sym.isType) match { - case LookupSucceeded(_, symbol) => - Some(symbol) - case _ => None - } - } - - // First identify the symbol we are at, comments identify @@ as current cursor position - lazy val soughtSymbols: Option[(Set[Symbol], Position)] = typedTree match { - /* simple identifier: - * val a = val@@ue + value + def resultWithSought(sought: Set[Symbol]): List[T] = { + val owners = sought + .map(_.owner) + .flatMap(o => symbolAlternatives(o)) + .filter(_ != NoSymbol) + val soughtNames: Set[Name] = sought.map(_.name) + /* + * For comprehensions have two owners, one for the enumerators and one for + * yield. This is a heuristic to find that out. */ - case (id: Ident) => - // might happen in type trees - // also this doesn't seem to be picked up by semanticdb - if (id.symbol == NoSymbol) - fallbackSymbol(id.name, pos).map(sym => - (symbolAlternatives(sym), id.pos) + def isForComprehensionOwner(named: NameTree) = { + if (named.symbol.pos.isDefined) { + def alternativeSymbol = sought.exists(symbol => + symbol.name == named.name && + symbol.pos.isDefined && + symbol.pos.start == named.symbol.pos.start ) - else { - Some(symbolAlternatives(id.symbol), id.pos) - } - /* Anonynous function parameters such as: - * List(1).map{ <>: Int => abc} - * In this case, parameter has incorrect namePosition, so we need to handle it separately - */ - case (vd: ValDef) if isAnonFunctionParam(vd) => - val namePos = vd.pos.withEnd(vd.pos.start + vd.name.length) - if (namePos.includes(pos)) Some(symbolAlternatives(vd.symbol), namePos) - else None - - /* all definitions: - * def fo@@o = ??? - * class Fo@@o = ??? - * etc. - */ - case (df: DefTree) if df.namePosition.includes(pos) => - Some(symbolAlternatives(df.symbol), df.namePosition) - /* Import selectors: - * import scala.util.Tr@@y - */ - case (imp: Import) if imp.pos.includes(pos) => - imp.selector(pos).map(sym => (symbolAlternatives(sym), sym.pos)) - /* simple selector: - * object.val@@ue - */ - case (sel: NameTree) if sel.namePosition.includes(pos) => - Some(symbolAlternatives(sel.symbol), sel.namePosition) - - // needed for classOf[AB@@C]` - case lit @ Literal(Constant(TypeRef(_, sym, _))) if lit.pos.includes(pos) => - val posStart = text.indexOfSlice(sym.decodedName, lit.pos.start) - if (posStart == -1) None - else - Some( - ( - symbolAlternatives(sym), - new RangePosition( - lit.pos.source, - posStart, - lit.pos.point, - posStart + sym.decodedName.length - ) + def sameOwner = { + val owner = named.symbol.owner + owner.isAnonymousFunction && owners.exists(o => + pos.isDefined && o.pos.isDefined && o.pos.point == owner.pos.point ) - ) - /* named argument, which is a bit complex: - * foo(nam@@e = "123") - */ - case _ => - val apply = typedTree match { - case (apply: Apply) => Some(apply) - /** - * For methods with multiple parameter lists and default args, the tree looks like this: - * Block(List(val x&1, val x&2, ...), Apply(<>, List(x&1, x&2, ...))) - */ - case _ => - typedTree.children.collectFirst { - case Block(_, apply: Apply) if apply.pos.includes(pos) => - apply - } - } - apply - .collect { - case apply if apply.symbol != null => - collectArguments(apply) - .flatMap(arg => namedArgCache.find(_._1.includes(arg.pos))) - .collectFirst { - case (_, AssignOrNamedArg(id: Ident, _)) - if id.pos.includes(pos) => - apply.symbol.paramss.flatten.find(_.name == id.name).map { - s => - // if it's a case class we need to look for parameters also - if ( - caseClassSynthetics(s.owner.name) && s.owner.isSynthetic - ) { - val applyOwner = s.owner.owner - val constructorOwner = - if (applyOwner.isCaseClass) applyOwner - else - applyOwner.companion match { - case NoSymbol => - applyOwner - .localCompanion(pos) - .getOrElse(NoSymbol) - case comp => comp - } - val info = constructorOwner.info - val constructorParams = info.members - .filter(_.isConstructor) - .flatMap(_.paramss) - .flatten - .filter(_.name == id.name) - .toSet - ( - (constructorParams ++ Set( - s, - info.member(s.getterName), - info.member(s.setterName), - info.member(s.localName) - )).filter(_ != NoSymbol), - id.pos - ) - } else (Set(s), id.pos) - } - } - .flatten } - .getOrElse(None) - } - - def result(): List[T] = { - params match { - case _: OffsetParams => resultWithSought() - case _ => resultAllOccurences().toList + alternativeSymbol && sameOwner + } else false } - } - - def resultWithSought(): List[T] = { - soughtSymbols match { - case Some((sought, _)) => - val owners = sought - .map(_.owner) - .flatMap(o => symbolAlternatives(o)) - .filter(_ != NoSymbol) - val soughtNames: Set[Name] = sought.map(_.name) - /* - * For comprehensions have two owners, one for the enumerators and one for - * yield. This is a heuristic to find that out. - */ - def isForComprehensionOwner(named: NameTree) = { - if (named.symbol.pos.isDefined) { - def alternativeSymbol = sought.exists(symbol => - symbol.name == named.name && - symbol.pos.isDefined && - symbol.pos.start == named.symbol.pos.start - ) - def sameOwner = { - val owner = named.symbol.owner - owner.isAnonymousFunction && owners.exists(o => - pos.isDefined && o.pos.isDefined && o.pos.point == owner.pos.point - ) - } - alternativeSymbol && sameOwner - } else false - } - - def soughtOrOverride(sym: Symbol) = - sought(sym) || sym.allOverriddenSymbols.exists(sought(_)) - def soughtTreeFilter(tree: Tree): Boolean = - tree match { - case ident: Ident - if (soughtOrOverride(ident.symbol) || - isForComprehensionOwner(ident)) => - true - case tpe: TypeTree => - sought(tpe.original.symbol) - case sel: Select => - soughtOrOverride(sel.symbol) - case df: MemberDef => - (soughtOrOverride(df.symbol) || - isForComprehensionOwner(df)) - case appl: Apply => - appl.symbol != null && - (owners(appl.symbol) || - symbolAlternatives(appl.symbol.owner).exists(owners(_))) - case imp: Import => - owners(imp.expr.symbol) && imp.selectors - .exists(sel => soughtNames(sel.name)) - case bind: Bind => - (soughtOrOverride(bind.symbol)) || - isForComprehensionOwner(bind) - case _ => false - } - - // 1. In most cases, we try to compare by Symbol#==. - // 2. If there is NoSymbol at a node, we check if the identifier there has the same decoded name. - // It it does, we look up the symbol at this position using `fallbackSymbol` or `members`, - // and again check if this time we got symbol equal by Symbol#==. - def soughtFilter(f: Symbol => Boolean): Boolean = { - sought.exists(f) - } + def soughtOrOverride(sym: Symbol) = + sought(sym) || sym.allOverriddenSymbols.exists(sought(_)) - traverseSought(soughtTreeFilter, soughtFilter).toList + def soughtTreeFilter(tree: Tree): Boolean = + tree match { + case ident: Ident + if (soughtOrOverride(ident.symbol) || + isForComprehensionOwner(ident)) => + true + case tpe: TypeTree => + sought(tpe.original.symbol) + case sel: Select => + soughtOrOverride(sel.symbol) + case df: MemberDef => + (soughtOrOverride(df.symbol) || + isForComprehensionOwner(df)) + case appl: Apply => + appl.symbol != null && + (owners(appl.symbol) || + symbolAlternatives(appl.symbol.owner).exists(owners(_))) + case imp: Import => + owners(imp.expr.symbol) && imp.selectors + .exists(sel => soughtNames(sel.name)) + case bind: Bind => + (soughtOrOverride(bind.symbol)) || + isForComprehensionOwner(bind) + case _ => false + } - case None => Nil + // 1. In most cases, we try to compare by Symbol#==. + // 2. If there is NoSymbol at a node, we check if the identifier there has the same decoded name. + // It it does, we look up the symbol at this position using `fallbackSymbol` or `members`, + // and again check if this time we got symbol equal by Symbol#==. + def soughtFilter(f: Symbol => Boolean): Boolean = { + sought.exists(f) } + + traverseSought(soughtTreeFilter, soughtFilter).toList } def resultAllOccurences(): Set[T] = { @@ -518,8 +305,7 @@ abstract class PcCollector[T]( tree.children.foldLeft(acc)(traverse(_, _)) } } - val all = traverseWithParent(None)(Set.empty[T], unit.lastBody) - all + traverseWithParent(None)(Set.empty[T], unit.lastBody) } private def annotationChildren(mdef: MemberDef): List[Tree] = { @@ -540,14 +326,23 @@ abstract class PcCollector[T]( } } - private def collectArguments(apply: Apply): List[Tree] = { - apply.fun match { - case appl: Apply => collectArguments(appl) ++ apply.args - case _ => apply.args - } - } +} - private def isAnonFunctionParam(vd: ValDef): Boolean = - vd.symbol != null && vd.symbol.owner.isAnonymousFunction && vd.rhs.isEmpty +abstract class SimpleCollector[T]( + compiler: MetalsGlobal, + params: VirtualFileParams +) extends WithCompilationUnit(compiler, params) + with PcCollector[T] { + def result(): List[T] = resultAllOccurences().toList +} +abstract class WithSymbolSearchCollector[T]( + compiler: MetalsGlobal, + params: OffsetParams +) extends WithCompilationUnit(compiler, params) + with PcCollector[T] + with PcSymbolSearch { + def result(): List[T] = soughtSymbols + .map { case (sought, _) => resultWithSought(sought) } + .getOrElse(Nil) } diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/PcDocumentHighlightProvider.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/PcDocumentHighlightProvider.scala index ad23b9c1dc7..9575fa2eb19 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/PcDocumentHighlightProvider.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/PcDocumentHighlightProvider.scala @@ -9,7 +9,7 @@ import org.eclipse.lsp4j.DocumentHighlightKind final class PcDocumentHighlightProvider( override val compiler: MetalsGlobal, params: OffsetParams -) extends PcCollector[DocumentHighlight](compiler, params) { +) extends WithSymbolSearchCollector[DocumentHighlight](compiler, params) { import compiler._ def collect( diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/PcInlineValueProviderImpl.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/PcInlineValueProviderImpl.scala index f29a8397ae6..cc3fce542c0 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/PcInlineValueProviderImpl.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/PcInlineValueProviderImpl.scala @@ -15,8 +15,8 @@ final class PcInlineValueProviderImpl( ) extends InlineValueProvider { import cp._ - val pcCollector: PcCollector[Occurence] = - new PcCollector[Occurence](cp, params) { + val pcCollector: WithSymbolSearchCollector[Occurence] = + new WithSymbolSearchCollector[Occurence](cp, params) { def collect( parent: Option[compiler.Tree] )( diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/PcReferencesProvider.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/PcReferencesProvider.scala new file mode 100644 index 00000000000..f4b55502f29 --- /dev/null +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/PcReferencesProvider.scala @@ -0,0 +1,43 @@ +package scala.meta.internal.pc + +import scala.meta.internal.mtags.MtagsEnrichments._ +import scala.meta.pc.OffsetParams +import scala.meta.pc.VirtualFileParams + +import org.eclipse.{lsp4j => l} + +class PcReferencesProvider( + compiler: MetalsGlobal, + params: OffsetParams, + targetFiles: List[VirtualFileParams], + includeDefinition: Boolean +) extends WithCompilationUnit(compiler, params) + with PcSymbolSearch { + def result(): List[l.Location] = + for { + (sought, _) <- soughtSymbols.toList + params <- targetFiles + collected <- { + val collector = new WithCompilationUnit(compiler, params) + with PcCollector[Option[l.Range]] { + import compiler._ + override def collect(parent: Option[Tree])( + tree: Tree, + toAdjust: Position, + sym: Option[compiler.Symbol] + ): Option[l.Range] = { + val (pos, _) = toAdjust.adjust(text) + tree match { + case _: DefTree if !includeDefinition => None + case _ => Some(pos.toLsp) + } + } + } + + collector + .resultWithSought(sought.asInstanceOf[Set[collector.compiler.Symbol]]) + .flatten + .map(new l.Location(params.uri().toString(), _)) + } + } yield collected +} diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/PcRenameProvider.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/PcRenameProvider.scala index 0cef6a61ce1..0b88de0fedb 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/PcRenameProvider.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/PcRenameProvider.scala @@ -9,7 +9,7 @@ class PcRenameProvider( override val compiler: MetalsGlobal, params: OffsetParams, name: Option[String] -) extends PcCollector[l.TextEdit](compiler, params) { +) extends WithSymbolSearchCollector[l.TextEdit](compiler, params) { import compiler._ private val forbiddenMethods = Set("equals", "hashCode", "unapply", "unary_!", "!") diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/PcSemanticTokensProvider.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/PcSemanticTokensProvider.scala index 75179a44975..82c91c24b0a 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/PcSemanticTokensProvider.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/PcSemanticTokensProvider.scala @@ -16,7 +16,7 @@ final class PcSemanticTokensProvider( val params: VirtualFileParams ) { // Initialize Tree - object Collector extends PcCollector[Option[Node]](cp, params) { + object Collector extends SimpleCollector[Option[Node]](cp, params) { /** * Declaration is set for: diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/PcSymbolSearch.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/PcSymbolSearch.scala new file mode 100644 index 00000000000..62b6655e955 --- /dev/null +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/PcSymbolSearch.scala @@ -0,0 +1,145 @@ +package scala.meta.internal.pc + +import scala.reflect.internal.util.RangePosition + +trait PcSymbolSearch { self: WithCompilationUnit => + import compiler._ + + private val caseClassSynthetics: Set[Name] = Set(nme.apply, nme.copy) + + lazy val typedTree: Tree = locateTree(pos) match { + // Check actual object if apply is synthetic + case sel @ Select(qual, name) if name == nme.apply && qual.pos == sel.pos => + qual + case Import(expr, _) if expr.pos.includes(pos) => + // imports seem to be marked as transparent + locateTree(pos, expr, acceptTransparent = true) + case t => t + } + + // First identify the symbol we are at, comments identify @@ as current cursor position + lazy val soughtSymbols: Option[(Set[Symbol], Position)] = typedTree match { + /* simple identifier: + * val a = val@@ue + value + */ + case (id: Ident) => + // might happen in type trees + // also this doesn't seem to be picked up by semanticdb + if (id.symbol == NoSymbol) + fallbackSymbol(id.name, pos).map(sym => + (symbolAlternatives(sym), id.pos) + ) + else { + Some(symbolAlternatives(id.symbol), id.pos) + } + /* Anonynous function parameters such as: + * List(1).map{ <>: Int => abc} + * In this case, parameter has incorrect namePosition, so we need to handle it separately + */ + case (vd: ValDef) if isAnonFunctionParam(vd) => + val namePos = vd.pos.withEnd(vd.pos.start + vd.name.length) + if (namePos.includes(pos)) Some(symbolAlternatives(vd.symbol), namePos) + else None + + /* all definitions: + * def fo@@o = ??? + * class Fo@@o = ??? + * etc. + */ + case (df: DefTree) if df.namePosition.includes(pos) => + Some(symbolAlternatives(df.symbol), df.namePosition) + /* Import selectors: + * import scala.util.Tr@@y + */ + case (imp: Import) if imp.pos.includes(pos) => + imp.selector(pos).map(sym => (symbolAlternatives(sym), sym.pos)) + /* simple selector: + * object.val@@ue + */ + case (sel: NameTree) if sel.namePosition.includes(pos) => + Some(symbolAlternatives(sel.symbol), sel.namePosition) + + // needed for classOf[AB@@C]` + case lit @ Literal(Constant(TypeRef(_, sym, _))) if lit.pos.includes(pos) => + val posStart = text.indexOfSlice(sym.decodedName, lit.pos.start) + if (posStart == -1) None + else + Some( + ( + symbolAlternatives(sym), + new RangePosition( + lit.pos.source, + posStart, + lit.pos.point, + posStart + sym.decodedName.length + ) + ) + ) + /* named argument, which is a bit complex: + * foo(nam@@e = "123") + */ + case _ => + val apply = typedTree match { + case (apply: Apply) => Some(apply) + /** + * For methods with multiple parameter lists and default args, the tree looks like this: + * Block(List(val x&1, val x&2, ...), Apply(<>, List(x&1, x&2, ...))) + */ + case _ => + typedTree.children.collectFirst { + case Block(_, apply: Apply) if apply.pos.includes(pos) => + apply + } + } + apply + .collect { + case apply if apply.symbol != null => + collectArguments(apply) + .flatMap(arg => namedArgCache.find(_._1.includes(arg.pos))) + .collectFirst { + case (_, AssignOrNamedArg(id: Ident, _)) + if id.pos.includes(pos) => + apply.symbol.paramss.flatten.find(_.name == id.name).map { + s => + // if it's a case class we need to look for parameters also + if ( + caseClassSynthetics(s.owner.name) && s.owner.isSynthetic + ) { + val applyOwner = s.owner.owner + val constructorOwner = + if (applyOwner.isCaseClass) applyOwner + else + applyOwner.companion match { + case NoSymbol => + applyOwner + .localCompanion(pos) + .getOrElse(NoSymbol) + case comp => comp + } + val info = constructorOwner.info + val constructorParams = info.members + .filter(_.isConstructor) + .flatMap(_.paramss) + .flatten + .filter(_.name == id.name) + .toSet + ( + (constructorParams ++ Set( + s, + info.member(s.getterName), + info.member(s.setterName), + info.member(s.localName) + )).filter(_ != NoSymbol), + id.pos + ) + } else (Set(s), id.pos) + } + } + .flatten + } + .getOrElse(None) + } + + private def isAnonFunctionParam(vd: ValDef): Boolean = + vd.symbol != null && vd.symbol.owner.isAnonymousFunction && vd.rhs.isEmpty +} diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/ScalaPresentationCompiler.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/ScalaPresentationCompiler.scala index 88e09c157d2..ef7e1f47933 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/ScalaPresentationCompiler.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/ScalaPresentationCompiler.scala @@ -45,6 +45,7 @@ import org.eclipse.lsp4j.CompletionItem import org.eclipse.lsp4j.CompletionList import org.eclipse.lsp4j.Diagnostic import org.eclipse.lsp4j.DocumentHighlight +import org.eclipse.lsp4j.Location import org.eclipse.lsp4j.Range import org.eclipse.lsp4j.SelectionRange import org.eclipse.lsp4j.SignatureHelp @@ -358,6 +359,23 @@ case class ScalaPresentationCompiler( new PcDocumentHighlightProvider(pc.compiler(), params).highlights().asJava } + override def references( + params: OffsetParams, + targetFiles: ju.List[VirtualFileParams], + includeDefinition: Boolean + ): CompletableFuture[ju.List[Location]] = + compilerAccess.withInterruptableCompiler(Some(params))( + List.empty[Location].asJava, + params.token() + ) { pc => + new PcReferencesProvider( + pc.compiler(), + params, + targetFiles.asScala.toList, + includeDefinition + ).result().asJava + } + override def semanticdbTextDocument( fileUri: URI, code: String diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/WithCompilationUnit.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/WithCompilationUnit.scala new file mode 100644 index 00000000000..a3b49d0a924 --- /dev/null +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/WithCompilationUnit.scala @@ -0,0 +1,84 @@ +package scala.meta.internal.pc + +import scala.meta.pc.OffsetParams +import scala.meta.pc.VirtualFileParams + +class WithCompilationUnit( + val compiler: MetalsGlobal, + val params: VirtualFileParams +) { + import compiler._ + val unit: RichCompilationUnit = addCompilationUnit( + code = params.text(), + filename = params.uri().toString(), + cursor = None + ) + val offset: Int = params match { + case p: OffsetParams => p.offset() + case _: VirtualFileParams => 0 + } + val pos: Position = unit.position(offset) + lazy val text = unit.source.content + typeCheck(unit) + + protected lazy val namedArgCache: Map[Position, Tree] = { + val parsedTree = parseTree(unit.source) + parsedTree.collect { case arg @ AssignOrNamedArg(_, rhs) => + rhs.pos -> arg + }.toMap + } + + /** + * Find all symbols that should be shown together. + * For example class we want to also show companion object. + * + * @param sym symbol to find the alternative candidates for + * @return set of possible symbols + */ + def symbolAlternatives(sym: Symbol): Set[Symbol] = { + val all = + if (sym.isClass) { + if (sym.owner.isMethod) Set(sym) ++ sym.localCompanion(pos) + else Set(sym, sym.companionModule, sym.companion.moduleClass) + } else if (sym.isModuleOrModuleClass) { + if (sym.owner.isMethod) Set(sym) ++ sym.localCompanion(pos) + else Set(sym, sym.companionClass, sym.moduleClass) + } else if (sym.isTerm && (sym.owner.isClass || sym.owner.isConstructor)) { + val info = + if (sym.owner.isClass) sym.owner.info + else sym.owner.owner.info + Set( + sym, + info.member(sym.getterName), + info.member(sym.setterName), + info.member(sym.localName) + ) ++ constructorParam(sym) ++ sym.allOverriddenSymbols.toSet + } else Set(sym) + all.filter(s => s != NoSymbol && !s.isError) + } + + private def constructorParam( + symbol: Symbol + ): Set[Symbol] = { + if (symbol.owner.isClass) { + val info = symbol.owner.info.member(nme.CONSTRUCTOR).info + info.paramss.flatten.find(_.name == symbol.name).toSet + } else Set.empty + } + + protected def collectArguments(apply: Apply): List[Tree] = { + apply.fun match { + case appl: Apply => collectArguments(appl) ++ apply.args + case _ => apply.args + } + } + + def fallbackSymbol(name: Name, pos: Position): Option[Symbol] = { + val context = doLocateImportContext(pos) + context.lookupSymbol(name, sym => sym.isType) match { + case LookupSucceeded(_, symbol) => + Some(symbol) + case _ => None + } + } +} diff --git a/mtags/src/main/scala-3-wrapper/ScalaPresentationCompiler.scala b/mtags/src/main/scala-3-wrapper/ScalaPresentationCompiler.scala index b24fc22255b..babddb017a9 100644 --- a/mtags/src/main/scala-3-wrapper/ScalaPresentationCompiler.scala +++ b/mtags/src/main/scala-3-wrapper/ScalaPresentationCompiler.scala @@ -7,11 +7,13 @@ import java.util.concurrent.ExecutorService import java.util.concurrent.ScheduledExecutorService import java.{util as ju} +import scala.compat.java8.FutureConverters import scala.concurrent.ExecutionContext import scala.concurrent.ExecutionContextExecutor import scala.jdk.CollectionConverters.* import scala.meta.internal.metals.ReportLevel +import scala.meta.internal.mtags.CommonMtagsEnrichments.* import scala.meta.pc.AutoImportsResult import scala.meta.pc.DefinitionResult import scala.meta.pc.HoverSignature @@ -30,6 +32,8 @@ import org.eclipse.lsp4j.CompletionItem import org.eclipse.lsp4j.CompletionList import org.eclipse.lsp4j.Diagnostic import org.eclipse.lsp4j.DocumentHighlight +import org.eclipse.lsp4j.DocumentHighlightKind +import org.eclipse.lsp4j.Location import org.eclipse.lsp4j.SelectionRange import org.eclipse.lsp4j.SignatureHelp import org.eclipse.lsp4j.TextEdit @@ -69,6 +73,8 @@ case class ScalaPresentationCompiler( reportsLevel = reportsLevel, ) + given ExecutionContext = ec + def this() = this("", None, Nil, Nil) override def syntheticDecorations( @@ -172,6 +178,34 @@ case class ScalaPresentationCompiler( ): CompletableFuture[ju.List[DocumentHighlight]] = underlying.documentHighlight(params) + override def references( + params: OffsetParams, + targetFiles: ju.List[VirtualFileParams], + includeDefinition: Boolean, + ): CompletableFuture[ju.List[Location]] = + targetFiles.asScala.toList match + case file :: Nil if file.uri() == params.uri() => + FutureConverters + .toJava( + FutureConverters + .toScala(documentHighlight(params)) + .map( + _.asScala + .collect { + case highlight + if highlight.getKind() == DocumentHighlightKind.Read || includeDefinition => + new Location( + params.uri().toString(), + highlight.getRange(), + ) + } + .asJava + ) + ) + .toCompletableFuture + case _ => + CompletableFuture.completedFuture(Nil.asJava) + override def rename( params: OffsetParams, name: String, diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcCollector.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcCollector.scala index 032bf860b08..15ee0fe6c4e 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/PcCollector.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcCollector.scala @@ -1,406 +1,80 @@ package scala.meta.internal.pc -import java.nio.file.Paths - import scala.meta as m -import scala.meta.internal.metals.CompilerOffsetParams import scala.meta.internal.mtags.MtagsEnrichments.* -import scala.meta.internal.pc.MetalsInteractive.ExtensionMethodCall import scala.meta.internal.pc.MetalsInteractive.ExtensionMethodCallSymbol +import scala.meta.internal.pc.PcSymbolSearch.* import scala.meta.pc.OffsetParams import scala.meta.pc.VirtualFileParams -import dotty.tools.dotc.ast.NavigateAST -import dotty.tools.dotc.ast.Positioned import dotty.tools.dotc.ast.tpd import dotty.tools.dotc.ast.tpd.* import dotty.tools.dotc.ast.untpd -import dotty.tools.dotc.ast.untpd.ExtMethods import dotty.tools.dotc.ast.untpd.ImportSelector import dotty.tools.dotc.core.Contexts.* import dotty.tools.dotc.core.Flags import dotty.tools.dotc.core.NameOps.* import dotty.tools.dotc.core.Names.* -import dotty.tools.dotc.core.StdNames.* import dotty.tools.dotc.core.Symbols.* import dotty.tools.dotc.core.Types.* -import dotty.tools.dotc.interactive.Interactive import dotty.tools.dotc.interactive.InteractiveDriver -import dotty.tools.dotc.util.SourceFile import dotty.tools.dotc.util.SourcePosition import dotty.tools.dotc.util.Spans.Span -abstract class PcCollector[T]( - driver: InteractiveDriver, - params: VirtualFileParams, -): - private val caseClassSynthetics: Set[Name] = Set(nme.apply, nme.copy) - val uri = params.uri() - val filePath = Paths.get(uri) - val sourceText = params.text - val text = sourceText.toCharArray() - val source = - SourceFile.virtual(filePath.toString, sourceText) - driver.run(uri, source) - given ctx: Context = driver.currentCtx - - val unit = driver.currentCtx.run.units.head - val compilatonUnitContext = ctx.fresh.setCompilationUnit(unit) - val offset = params match - case op: OffsetParams => op.offset() - case _ => 0 - val offsetParams = - params match - case op: OffsetParams => op - case _ => - CompilerOffsetParams(params.uri(), params.text(), 0, params.token()) - val pos = driver.sourcePosition(offsetParams) - val rawPath = - Interactive - .pathTo(driver.openedTrees(uri), pos)(using driver.currentCtx) - .dropWhile(t => // NamedArg anyway doesn't have symbol - t.symbol == NoSymbol && !t.isInstanceOf[NamedArg] || - // same issue https://github.com/lampepfl/dotty/issues/15937 as below - t.isInstanceOf[TypeTree] - ) - - val path = rawPath match - // For type it will sometimes go into the wrong tree since TypeTree also contains the same span - // https://github.com/lampepfl/dotty/issues/15937 - case TypeApply(sel: Select, _) :: tail if sel.span.contains(pos.span) => - Interactive.pathTo(sel, pos.span) ::: rawPath - case _ => rawPath +trait PcCollector[T]: + self: WithCompilationUnit => def collect( parent: Option[Tree] )(tree: Tree | EndMarker, pos: SourcePosition, symbol: Option[Symbol]): T - def symbolAlternatives(sym: Symbol) = - def member(parent: Symbol) = parent.info.member(sym.name).symbol - def primaryConstructorTypeParam(owner: Symbol) = - for - typeParams <- owner.primaryConstructor.paramSymss.headOption - param <- typeParams.find(_.name == sym.name) - if (param.isType) - yield param - def additionalForEnumTypeParam(enumClass: Symbol) = - if enumClass.is(Flags.Enum) then - val enumOwner = - if enumClass.is(Flags.Case) - then - // we check that the type parameter is the one from enum class - // and not an enum case type parameter with the same name - Option.when(member(enumClass).is(Flags.Synthetic))( - enumClass.maybeOwner.companionClass - ) - else Some(enumClass) - enumOwner.toSet.flatMap { enumOwner => - val symsInEnumCases = enumOwner.children.toSet.flatMap(enumCase => - if member(enumCase).is(Flags.Synthetic) - then primaryConstructorTypeParam(enumCase) - else None - ) - val symsInEnumOwner = - primaryConstructorTypeParam(enumOwner).toSet + member(enumOwner) - symsInEnumCases ++ symsInEnumOwner - } - else Set.empty - val all = - if sym.is(Flags.ModuleClass) then - Set(sym, sym.companionModule, sym.companionModule.companion) - else if sym.isClass then - Set(sym, sym.companionModule, sym.companion.moduleClass) - else if sym.is(Flags.Module) then - Set(sym, sym.companionClass, sym.moduleClass) - else if sym.isTerm && (sym.owner.isClass || sym.owner.isConstructor) - then - val info = - if sym.owner.isClass then sym.owner.info else sym.owner.owner.info - Set( - sym, - info.member(sym.asTerm.name.setterName).symbol, - info.member(sym.asTerm.name.getterName).symbol, - ) ++ sym.allOverriddenSymbols.toSet - // type used in primary constructor will not match the one used in the class - else if sym.isTypeParam && sym.owner.isPrimaryConstructor then - Set(sym, member(sym.maybeOwner.maybeOwner)) - ++ additionalForEnumTypeParam(sym.maybeOwner.maybeOwner) - else if sym.isTypeParam then - primaryConstructorTypeParam(sym.maybeOwner).toSet - ++ additionalForEnumTypeParam(sym.maybeOwner) + sym - else Set(sym) - all.filter(s => s != NoSymbol && !s.isError) - end symbolAlternatives - - private def isGeneratedGiven(df: NamedDefTree)(using Context) = - val nameSpan = df.nameSpan - df.symbol.is(Flags.Given) && sourceText.substring( - nameSpan.start, - nameSpan.end, - ) != df.name.toString() - - // First identify the symbol we are at, comments identify @@ as current cursor position - def soughtSymbols(path: List[Tree]): Option[(Set[Symbol], SourcePosition)] = - val sought = path match - /* reference of an extension paramter - * extension [EF](<>: List[EF]) - * def double(ys: List[EF]) = <> ++ ys - */ - case (id: Ident) :: _ - if id.symbol - .is(Flags.Param) && id.symbol.owner.is(Flags.ExtensionMethod) => - Some(findAllExtensionParamSymbols(id.sourcePos, id.name, id.symbol)) - /** - * Workaround for missing symbol in: - * class A[T](a: T) - * val x = new <>(1) - */ - case t :: (n: New) :: (sel: Select) :: _ - if t.symbol == NoSymbol && sel.symbol.isConstructor => - Some(symbolAlternatives(sel.symbol.owner), namePos(t)) - /** - * Workaround for missing symbol in: - * class A[T](a: T) - * val x = <>[Int](1) - */ - case (sel @ Select(New(t), _)) :: (_: TypeApply) :: _ - if sel.symbol.isConstructor => - Some(symbolAlternatives(sel.symbol.owner), namePos(t)) - - /* simple identifier: - * val a = val@@ue + value - */ - case (id: Ident) :: _ => - Some(symbolAlternatives(id.symbol), id.sourcePos) - /* simple selector: - * object.val@@ue - */ - case (sel: Select) :: _ if selectNameSpan(sel).contains(pos.span) => - Some(symbolAlternatives(sel.symbol), pos.withSpan(sel.nameSpan)) - /* named argument: - * foo(nam@@e = "123") - */ - case (arg: NamedArg) :: (appl: Apply) :: _ => - val realName = arg.name.stripModuleClassSuffix.lastPart - if pos.span.start > arg.span.start && pos.span.end < arg.span.point + realName.length - then - val length = realName.toString.backticked.length() - val pos = arg.sourcePos.withSpan( - arg.span - .withEnd(arg.span.start + length) - .withPoint(arg.span.start) - ) - appl.symbol.paramSymss.flatten.find(_.name == arg.name).map { s => - // if it's a case class we need to look for parameters also - if caseClassSynthetics(s.owner.name) && s.owner.is(Flags.Synthetic) - then - ( - Set( - s, - s.owner.owner.companion.info.member(s.name).symbol, - s.owner.owner.info.member(s.name).symbol, - ) - .filter(_ != NoSymbol), - pos, - ) - else (Set(s), pos) - } - else None - end if - /* all definitions: - * def fo@@o = ??? - * class Fo@@o = ??? - * etc. - */ - case (df: NamedDefTree) :: _ - if df.nameSpan.contains(pos.span) && !isGeneratedGiven(df) => - Some(symbolAlternatives(df.symbol), pos.withSpan(df.nameSpan)) - /* enum cases with params - * enum Foo: - * case B@@ar[A](i: A) - */ - case (df: NamedDefTree) :: Template(_, _, self, _) :: _ - if (df.name == nme.apply || df.name == nme.unapply) && df.nameSpan.isZeroExtent => - Some(symbolAlternatives(self.tpt.symbol), self.sourcePos) - /** - * For traversing annotations: - * @JsonNo@@tification("") - * def params() = ??? - */ - case (df: MemberDef) :: _ if df.span.contains(pos.span) => - val annotTree = df.mods.annotations.find { t => - t.span.contains(pos.span) - } - collectTrees(annotTree).flatMap { t => - soughtSymbols( - Interactive.pathTo(t, pos.span) - ) - }.headOption - - /* Import selectors: - * import scala.util.Tr@@y - */ - case (imp: Import) :: _ if imp.span.contains(pos.span) => - imp - .selector(pos.span) - .map(sym => (symbolAlternatives(sym), sym.sourcePos)) - - /* Workaround for missing span in: - * class MyIntOut(val value: Int) - * object MyIntOut: - * extension (i: MyIntOut) def <> = i.value % 2 == 1 - * - * val a = MyIntOut(1).<> - */ - case ExtensionMethodCall(sym, app) :: _ - if app.span.withStart(app.span.point).contains(pos.span) => - val span = app.span.withStart(app.span.point) - Some(symbolAlternatives(sym), pos.withSpan(span)) - case _ => None - - sought match - case None => seekInExtensionParameters() - case _ => sought - - end soughtSymbols - - lazy val extensionMethods = - NavigateAST - .untypedPath(pos.span)(using compilatonUnitContext) - .collectFirst { case em @ ExtMethods(_, _) => em } - - private def findAllExtensionParamSymbols( - pos: SourcePosition, - name: Name, - sym: Symbol, - ) = - val symbols = - for - methods <- extensionMethods.map(_.methods) - symbols <- collectAllExtensionParamSymbols( - unit.tpdTree, - ExtensionParamOccurence(name, pos, sym, methods), - ) - yield symbols - symbols.getOrElse((symbolAlternatives(sym), pos)) - end findAllExtensionParamSymbols - - private def seekInExtensionParameters() = - def collectParams( - extMethods: ExtMethods - ): Option[ExtensionParamOccurence] = - MetalsNavigateAST - .pathToExtensionParam(pos.span, extMethods)(using compilatonUnitContext) - .collectFirst { - case v: untpd.ValOrTypeDef => - ExtensionParamOccurence( - v.name, - v.namePos, - v.symbol, - extMethods.methods, - ) - case i: untpd.Ident => - ExtensionParamOccurence( - i.name, - i.sourcePos, - i.symbol, - extMethods.methods, - ) - } - - for - extensionMethodScope <- extensionMethods - occurence <- collectParams(extensionMethodScope) - symbols <- collectAllExtensionParamSymbols( - path.headOption.getOrElse(unit.tpdTree), - occurence, - ) - yield symbols - end seekInExtensionParameters - - private def collectAllExtensionParamSymbols( - tree: tpd.Tree, - occurrence: ExtensionParamOccurence, - ): Option[(Set[Symbol], SourcePosition)] = - occurrence match - case ExtensionParamOccurence(_, namePos, symbol, _) - if symbol != NoSymbol && !symbol.isError && !symbol.owner.is( - Flags.ExtensionMethod - ) => - Some((symbolAlternatives(symbol), namePos)) - case ExtensionParamOccurence(name, namePos, _, methods) => - val symbols = - for - method <- methods.toSet - symbol <- - Interactive.pathTo(tree, method.span) match - case (d: DefDef) :: _ => - d.paramss.flatten.collect { - case param if param.name.decoded == name.decoded => - param.symbol - } - case _ => Set.empty[Symbol] - if (symbol != NoSymbol && !symbol.isError) - withAlt <- symbolAlternatives(symbol) - yield withAlt - if symbols.nonEmpty then Some((symbols, namePos)) else None - end collectAllExtensionParamSymbols - - def result(): List[T] = - params match - case _: OffsetParams => resultWithSought() - case _ => resultAllOccurences().toList - def resultAllOccurences(): Set[T] = def noTreeFilter = (_: Tree) => true def noSoughtFilter = (_: Symbol => Boolean) => true traverseSought(noTreeFilter, noSoughtFilter) - def resultWithSought(): List[T] = - soughtSymbols(path) match - case Some((sought, _)) => - lazy val owners = sought - .flatMap { s => Set(s.owner, s.owner.companionModule) } - .filter(_ != NoSymbol) - lazy val soughtNames: Set[Name] = sought.map(_.name) - - /* - * For comprehensions have two owners, one for the enumerators and one for - * yield. This is a heuristic to find that out. - */ - def isForComprehensionOwner(named: NameTree) = - soughtNames(named.name) && - scala.util - .Try(named.symbol.owner) - .toOption - .exists(_.isAnonymousFunction) && - owners.exists(o => - o.span.exists && o.span.point == named.symbol.owner.span.point - ) - - def soughtOrOverride(sym: Symbol) = - sought(sym) || sym.allOverriddenSymbols.exists(sought(_)) + def resultWithSought(sought: Set[Symbol]): List[T] = + lazy val owners = sought + .flatMap { s => Set(s.owner, s.owner.companionModule) } + .filter(_ != NoSymbol) + lazy val soughtNames: Set[Name] = sought.map(_.name) + + /* + * For comprehensions have two owners, one for the enumerators and one for + * yield. This is a heuristic to find that out. + */ + def isForComprehensionOwner(named: NameTree) = + soughtNames(named.name) && + scala.util + .Try(named.symbol.owner) + .toOption + .exists(_.isAnonymousFunction) && + owners.exists(o => + o.span.exists && o.span.point == named.symbol.owner.span.point + ) - def soughtTreeFilter(tree: Tree): Boolean = - tree match - case ident: Ident - if soughtOrOverride(ident.symbol) || - isForComprehensionOwner(ident) => - true - case sel: Select if soughtOrOverride(sel.symbol) => true - case df: NamedDefTree - if soughtOrOverride(df.symbol) && !df.symbol.isSetter => - true - case imp: Import if owners(imp.expr.symbol) => true - case _ => false + def soughtOrOverride(sym: Symbol) = + sought(sym) || sym.allOverriddenSymbols.exists(sought(_)) - def soughtFilter(f: Symbol => Boolean): Boolean = - sought.exists(f) + def soughtTreeFilter(tree: Tree): Boolean = + tree match + case ident: Ident + if soughtOrOverride(ident.symbol) || + isForComprehensionOwner(ident) => + true + case sel: Select if soughtOrOverride(sel.symbol) => true + case df: NamedDefTree + if soughtOrOverride(df.symbol) && !df.symbol.isSetter => + true + case imp: Import if owners(imp.expr.symbol) => true + case _ => false - traverseSought(soughtTreeFilter, soughtFilter).toList + def soughtFilter(f: Symbol => Boolean): Boolean = + sought.exists(f) - case None => Nil + traverseSought(soughtTreeFilter, soughtFilter).toList + end resultWithSought extension (span: Span) def isCorrect = @@ -493,7 +167,7 @@ abstract class PcCollector[T]( */ case df: NamedDefTree if df.span.isCorrect && df.nameSpan.isCorrect && - filter(df) && !isGeneratedGiven(df) => + filter(df) && !isGeneratedGiven(df, sourceText) => def collectEndMarker = EndMarker.getPosition(df, pos, sourceText).map { collect(EndMarker(df.symbol), _) @@ -615,36 +289,9 @@ abstract class PcCollector[T]( val traverser = new PcCollector.DeepFolderWithParent[Set[T]](collectNamesWithParent) - val all = traverser(Set.empty[T], unit.tpdTree) - all + traverser(Set.empty[T], unit.tpdTree) end traverseSought - // @note (tgodzik) Not sure currently how to get rid of the warning, but looks to correctly - // @nowarn - private def collectTrees(trees: Iterable[Positioned]): Iterable[Tree] = - trees.collect { case t: Tree => - t - } - - // NOTE: Connected to https://github.com/lampepfl/dotty/issues/16771 - // `sel.nameSpan` is calculated incorrectly in (1 + 2).toString - // See test DocumentHighlightSuite.select-parentheses - private def selectNameSpan(sel: Select): Span = - val span = sel.span - if span.exists then - val point = span.point - if sel.name.toTermName == nme.ERROR then Span(point) - else if sel.qualifier.span.start > span.point then // right associative - val realName = sel.name.stripModuleClassSuffix.lastPart - Span(span.start, span.start + realName.length, point) - else Span(point, span.end, point) - else span - - private def namePos(tree: Tree): SourcePosition = - tree match - case sel: Select => sel.sourcePos.withSpan(selectNameSpan(sel)) - case _ => tree.sourcePos - /** * Those have wrong spans and we special case for them. */ @@ -653,7 +300,6 @@ abstract class PcCollector[T]( case _: TypeApply | _: Apply => true case _ => false } - end PcCollector object PcCollector: @@ -709,3 +355,21 @@ object EndMarker: ) end getPosition end EndMarker + +abstract class WithSymbolSearchCollector[T]( + driver: InteractiveDriver, + params: OffsetParams, +) extends WithCompilationUnit(driver, params) + with PcSymbolSearch + with PcCollector[T]: + def result(): List[T] = + soughtSymbols.toList.flatMap { case (sought, _) => + resultWithSought(sought) + } + +abstract class SimpleCollector[T]( + driver: InteractiveDriver, + params: VirtualFileParams, +) extends WithCompilationUnit(driver, params) + with PcCollector[T]: + def result(): List[T] = resultAllOccurences().toList diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcDocumentHighlightProvider.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcDocumentHighlightProvider.scala index 72558934bfd..fbb4d44a0eb 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/PcDocumentHighlightProvider.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcDocumentHighlightProvider.scala @@ -13,7 +13,7 @@ import org.eclipse.lsp4j.DocumentHighlightKind final class PcDocumentHighlightProvider( driver: InteractiveDriver, params: OffsetParams, -) extends PcCollector[DocumentHighlight](driver, params): +) extends WithSymbolSearchCollector[DocumentHighlight](driver, params): def collect( parent: Option[Tree] @@ -30,4 +30,5 @@ final class PcDocumentHighlightProvider( def highlights: List[DocumentHighlight] = result().distinctBy(_.getRange()) + end PcDocumentHighlightProvider diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcInlineValueProviderImpl.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcInlineValueProviderImpl.scala index 5cabd099253..1b36441bb20 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/PcInlineValueProviderImpl.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcInlineValueProviderImpl.scala @@ -17,9 +17,9 @@ import dotty.tools.dotc.util.SourcePosition import org.eclipse.{lsp4j as l} final class PcInlineValueProviderImpl( - val driver: InteractiveDriver, + driver: InteractiveDriver, val params: OffsetParams, -) extends PcCollector[Option[Occurence]](driver, params) +) extends WithSymbolSearchCollector[Option[Occurence]](driver, params) with InlineValueProvider: val position: l.Position = pos.toLsp.getStart() diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala new file mode 100644 index 00000000000..01fca8e1d90 --- /dev/null +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala @@ -0,0 +1,45 @@ +package scala.meta.internal.pc + +import scala.meta.internal.mtags.MtagsEnrichments.* +import scala.meta.pc.OffsetParams +import scala.meta.pc.VirtualFileParams + +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.core.Symbols.Symbol +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourcePosition +import org.eclipse.lsp4j +import org.eclipse.lsp4j.Location + +class PcReferencesProvider( + driver: InteractiveDriver, + params: OffsetParams, + targetFiles: List[VirtualFileParams], + includeDefinition: Boolean, +): + val symbolSearch = new WithCompilationUnit(driver, params) with PcSymbolSearch + def result(): List[Location] = + for + (sought, _) <- symbolSearch.soughtSymbols.toList + params <- targetFiles + collected <- collectForFile(sought, params) + yield collected + + private def collectForFile(sought: Set[Symbol], params: VirtualFileParams) = + new WithCompilationUnit(driver, params) + with PcCollector[Option[lsp4j.Range]]: + def collect(parent: Option[Tree])( + tree: Tree | EndMarker, + toAdjust: SourcePosition, + symbol: Option[Symbol], + ): Option[lsp4j.Range] = + val (pos, _) = toAdjust.adjust(text) + tree match + case (_: DefTree) if !includeDefinition => None + case _: Tree => Some(pos.toLsp) + case _ => None + .resultWithSought(sought) + .flatten + .map(new Location(params.uri().toString(), _)) +end PcReferencesProvider diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcRenameProvider.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcRenameProvider.scala index 61041c68080..2e488449a19 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/PcRenameProvider.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcRenameProvider.scala @@ -15,7 +15,7 @@ final class PcRenameProvider( driver: InteractiveDriver, params: OffsetParams, name: Option[String], -) extends PcCollector[l.TextEdit](driver, params): +) extends WithSymbolSearchCollector[l.TextEdit](driver, params): private val forbiddenMethods = Set("equals", "hashCode", "unapply", "unary_!", "!") def canRenameSymbol(sym: Symbol)(using Context): Boolean = @@ -24,7 +24,7 @@ final class PcRenameProvider( || sym.source.path.isWorksheet) def prepareRename(): Option[l.Range] = - soughtSymbols(path).flatMap((symbols, pos) => + soughtSymbols.flatMap((symbols, pos) => if symbols.forall(canRenameSymbol) then Some(pos.toLsp) else None ) @@ -45,9 +45,8 @@ final class PcRenameProvider( ) end collect - def rename( - ): List[l.TextEdit] = - val (symbols, _) = soughtSymbols(path).getOrElse(Set.empty, pos) + def rename(): List[l.TextEdit] = + val (symbols, _) = soughtSymbols.getOrElse(Set.empty, pos) if symbols.nonEmpty && symbols.forall(canRenameSymbol(_)) then val res = result() diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcSemanticTokensProvider.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcSemanticTokensProvider.scala index baeae4508a3..d6267e4961f 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/PcSemanticTokensProvider.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcSemanticTokensProvider.scala @@ -58,7 +58,7 @@ final class PcSemanticTokensProvider( case _ => !df.rhs.isEmpty case _ => false - object Collector extends PcCollector[Option[Node]](driver, params): + object Collector extends SimpleCollector[Option[Node]](driver, params): override def collect( parent: Option[Tree] )( diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcSymbolSearch.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcSymbolSearch.scala new file mode 100644 index 00000000000..48730c9d2d4 --- /dev/null +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcSymbolSearch.scala @@ -0,0 +1,346 @@ +package scala.meta.internal.pc + +import scala.meta.internal.mtags.MtagsEnrichments.* +import scala.meta.internal.pc.MetalsInteractive.ExtensionMethodCall +import scala.meta.internal.pc.PcSymbolSearch.* + +import dotty.tools.dotc.ast.NavigateAST +import dotty.tools.dotc.ast.Positioned +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.ast.untpd +import dotty.tools.dotc.ast.untpd.ExtMethods +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.NameOps.* +import dotty.tools.dotc.core.Names.* +import dotty.tools.dotc.core.StdNames.* +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.core.Types.* +import dotty.tools.dotc.interactive.Interactive +import dotty.tools.dotc.util.SourcePosition +import dotty.tools.dotc.util.Spans.Span + +trait PcSymbolSearch: + self: WithCompilationUnit => + + private val caseClassSynthetics: Set[Name] = Set(nme.apply, nme.copy) + + lazy val rawPath = + Interactive + .pathTo(driver.openedTrees(uri), pos)(using driver.currentCtx) + .dropWhile(t => // NamedArg anyway doesn't have symbol + t.symbol == NoSymbol && !t.isInstanceOf[NamedArg] || + // same issue https://github.com/lampepfl/dotty/issues/15937 as below + t.isInstanceOf[TypeTree] + ) + + lazy val extensionMethods = + NavigateAST + .untypedPath(pos.span)(using compilatonUnitContext) + .collectFirst { case em @ ExtMethods(_, _) => em } + + lazy val path = rawPath match + // For type it will sometimes go into the wrong tree since TypeTree also contains the same span + // https://github.com/lampepfl/dotty/issues/15937 + case TypeApply(sel: Select, _) :: tail if sel.span.contains(pos.span) => + Interactive.pathTo(sel, pos.span) ::: rawPath + case _ => rawPath + + lazy val soughtSymbols: Option[(Set[Symbol], SourcePosition)] = soughtSymbols( + path + ) + + def soughtSymbols(path: List[Tree]): Option[(Set[Symbol], SourcePosition)] = + val sought = path match + /* reference of an extension paramter + * extension [EF](<>: List[EF]) + * def double(ys: List[EF]) = <> ++ ys + */ + case (id: Ident) :: _ + if id.symbol + .is(Flags.Param) && id.symbol.owner.is(Flags.ExtensionMethod) => + Some(findAllExtensionParamSymbols(id.sourcePos, id.name, id.symbol)) + /** + * Workaround for missing symbol in: + * class A[T](a: T) + * val x = new <>(1) + */ + case t :: (n: New) :: (sel: Select) :: _ + if t.symbol == NoSymbol && sel.symbol.isConstructor => + Some(symbolAlternatives(sel.symbol.owner), namePos(t)) + /** + * Workaround for missing symbol in: + * class A[T](a: T) + * val x = <>[Int](1) + */ + case (sel @ Select(New(t), _)) :: (_: TypeApply) :: _ + if sel.symbol.isConstructor => + Some(symbolAlternatives(sel.symbol.owner), namePos(t)) + + /* simple identifier: + * val a = val@@ue + value + */ + case (id: Ident) :: _ => + Some(symbolAlternatives(id.symbol), id.sourcePos) + /* simple selector: + * object.val@@ue + */ + case (sel: Select) :: _ if selectNameSpan(sel).contains(pos.span) => + Some(symbolAlternatives(sel.symbol), pos.withSpan(sel.nameSpan)) + /* named argument: + * foo(nam@@e = "123") + */ + case (arg: NamedArg) :: (appl: Apply) :: _ => + val realName = arg.name.stripModuleClassSuffix.lastPart + if pos.span.start > arg.span.start && pos.span.end < arg.span.point + realName.length + then + val length = realName.toString.backticked.length() + val pos = arg.sourcePos.withSpan( + arg.span + .withEnd(arg.span.start + length) + .withPoint(arg.span.start) + ) + appl.symbol.paramSymss.flatten.find(_.name == arg.name).map { s => + // if it's a case class we need to look for parameters also + if caseClassSynthetics(s.owner.name) && s.owner.is(Flags.Synthetic) + then + ( + Set( + s, + s.owner.owner.companion.info.member(s.name).symbol, + s.owner.owner.info.member(s.name).symbol, + ) + .filter(_ != NoSymbol), + pos, + ) + else (Set(s), pos) + } + else None + end if + /* all definitions: + * def fo@@o = ??? + * class Fo@@o = ??? + * etc. + */ + case (df: NamedDefTree) :: _ + if df.nameSpan + .contains(pos.span) && !isGeneratedGiven(df, sourceText) => + Some(symbolAlternatives(df.symbol), pos.withSpan(df.nameSpan)) + /* enum cases with params + * enum Foo: + * case B@@ar[A](i: A) + */ + case (df: NamedDefTree) :: Template(_, _, self, _) :: _ + if (df.name == nme.apply || df.name == nme.unapply) && df.nameSpan.isZeroExtent => + Some(symbolAlternatives(self.tpt.symbol), self.sourcePos) + /** + * For traversing annotations: + * @JsonNo@@tification("") + * def params() = ??? + */ + case (df: MemberDef) :: _ if df.span.contains(pos.span) => + val annotTree = df.mods.annotations.find { t => + t.span.contains(pos.span) + } + collectTrees(annotTree).flatMap { t => + soughtSymbols( + Interactive.pathTo(t, pos.span) + ) + }.headOption + + /* Import selectors: + * import scala.util.Tr@@y + */ + case (imp: Import) :: _ if imp.span.contains(pos.span) => + imp + .selector(pos.span) + .map(sym => (symbolAlternatives(sym), sym.sourcePos)) + + /* Workaround for missing span in: + * class MyIntOut(val value: Int) + * object MyIntOut: + * extension (i: MyIntOut) def <> = i.value % 2 == 1 + * + * val a = MyIntOut(1).<> + */ + case ExtensionMethodCall(sym, app) :: _ + if app.span.withStart(app.span.point).contains(pos.span) => + val span = app.span.withStart(app.span.point) + Some(symbolAlternatives(sym), pos.withSpan(span)) + case _ => None + + sought match + case None => seekInExtensionParameters() + case _ => sought + + end soughtSymbols + + // First identify the symbol we are at, comments identify @@ as current cursor position + def symbolAlternatives(sym: Symbol) = + def member(parent: Symbol) = parent.info.member(sym.name).symbol + def primaryConstructorTypeParam(owner: Symbol) = + for + typeParams <- owner.primaryConstructor.paramSymss.headOption + param <- typeParams.find(_.name == sym.name) + if (param.isType) + yield param + def additionalForEnumTypeParam(enumClass: Symbol) = + if enumClass.is(Flags.Enum) then + val enumOwner = + if enumClass.is(Flags.Case) + then + // we check that the type parameter is the one from enum class + // and not an enum case type parameter with the same name + Option.when(member(enumClass).is(Flags.Synthetic))( + enumClass.maybeOwner.companionClass + ) + else Some(enumClass) + enumOwner.toSet.flatMap { enumOwner => + val symsInEnumCases = enumOwner.children.toSet.flatMap(enumCase => + if member(enumCase).is(Flags.Synthetic) + then primaryConstructorTypeParam(enumCase) + else None + ) + val symsInEnumOwner = + primaryConstructorTypeParam(enumOwner).toSet + member(enumOwner) + symsInEnumCases ++ symsInEnumOwner + } + else Set.empty + val all = + if sym.is(Flags.ModuleClass) then + Set(sym, sym.companionModule, sym.companionModule.companion) + else if sym.isClass then + Set(sym, sym.companionModule, sym.companion.moduleClass) + else if sym.is(Flags.Module) then + Set(sym, sym.companionClass, sym.moduleClass) + else if sym.isTerm && (sym.owner.isClass || sym.owner.isConstructor) + then + val info = + if sym.owner.isClass then sym.owner.info else sym.owner.owner.info + Set( + sym, + info.member(sym.asTerm.name.setterName).symbol, + info.member(sym.asTerm.name.getterName).symbol, + ) ++ sym.allOverriddenSymbols.toSet + // type used in primary constructor will not match the one used in the class + else if sym.isTypeParam && sym.owner.isPrimaryConstructor then + Set(sym, member(sym.maybeOwner.maybeOwner)) + ++ additionalForEnumTypeParam(sym.maybeOwner.maybeOwner) + else if sym.isTypeParam then + primaryConstructorTypeParam(sym.maybeOwner).toSet + ++ additionalForEnumTypeParam(sym.maybeOwner) + sym + else Set(sym) + all.filter(s => s != NoSymbol && !s.isError) + end symbolAlternatives + + private def seekInExtensionParameters() = + def collectParams( + extMethods: ExtMethods + ): Option[ExtensionParamOccurence] = + MetalsNavigateAST + .pathToExtensionParam(pos.span, extMethods)(using compilatonUnitContext) + .collectFirst { + case v: untpd.ValOrTypeDef => + ExtensionParamOccurence( + v.name, + v.namePos, + v.symbol, + extMethods.methods, + ) + case i: untpd.Ident => + ExtensionParamOccurence( + i.name, + i.sourcePos, + i.symbol, + extMethods.methods, + ) + } + + for + extensionMethodScope <- extensionMethods + occurence <- collectParams(extensionMethodScope) + symbols <- collectAllExtensionParamSymbols( + path.headOption.getOrElse(unit.tpdTree), + occurence, + ) + yield symbols + end seekInExtensionParameters + + private def collectAllExtensionParamSymbols( + tree: tpd.Tree, + occurrence: ExtensionParamOccurence, + ): Option[(Set[Symbol], SourcePosition)] = + occurrence match + case ExtensionParamOccurence(_, namePos, symbol, _) + if symbol != NoSymbol && !symbol.isError && !symbol.owner.is( + Flags.ExtensionMethod + ) => + Some((symbolAlternatives(symbol), namePos)) + case ExtensionParamOccurence(name, namePos, _, methods) => + val symbols = + for + method <- methods.toSet + symbol <- + Interactive.pathTo(tree, method.span) match + case (d: DefDef) :: _ => + d.paramss.flatten.collect { + case param if param.name.decoded == name.decoded => + param.symbol + } + case _ => Set.empty[Symbol] + if (symbol != NoSymbol && !symbol.isError) + withAlt <- symbolAlternatives(symbol) + yield withAlt + if symbols.nonEmpty then Some((symbols, namePos)) else None + end collectAllExtensionParamSymbols + + private def findAllExtensionParamSymbols( + pos: SourcePosition, + name: Name, + sym: Symbol, + ) = + val symbols = + for + methods <- extensionMethods.map(_.methods) + symbols <- collectAllExtensionParamSymbols( + unit.tpdTree, + ExtensionParamOccurence(name, pos, sym, methods), + ) + yield symbols + symbols.getOrElse((symbolAlternatives(sym), pos)) + end findAllExtensionParamSymbols +end PcSymbolSearch + +object PcSymbolSearch: + // NOTE: Connected to https://github.com/lampepfl/dotty/issues/16771 + // `sel.nameSpan` is calculated incorrectly in (1 + 2).toString + // See test DocumentHighlightSuite.select-parentheses + def selectNameSpan(sel: Select): Span = + val span = sel.span + if span.exists then + val point = span.point + if sel.name.toTermName == nme.ERROR then Span(point) + else if sel.qualifier.span.start > span.point then // right associative + val realName = sel.name.stripModuleClassSuffix.lastPart + Span(span.start, span.start + realName.length, point) + else Span(point, span.end, point) + else span + + // @note (tgodzik) Not sure currently how to get rid of the warning, but looks to correctly + // @nowarn + def collectTrees(trees: Iterable[Positioned]): Iterable[Tree] = + trees.collect { case t: Tree => t } + + def namePos(tree: Tree)(using Context): SourcePosition = + tree match + case sel: Select => sel.sourcePos.withSpan(selectNameSpan(sel)) + case _ => tree.sourcePos + + def isGeneratedGiven(df: NamedDefTree, sourceText: String)(using Context) = + val nameSpan = df.nameSpan + df.symbol.is(Flags.Given) && sourceText.substring( + nameSpan.start, + nameSpan.end, + ) != df.name.toString() +end PcSymbolSearch diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala index 082eae1aab3..d3acbb2e742 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala @@ -168,6 +168,26 @@ case class ScalaPresentationCompiler( PcDocumentHighlightProvider(driver, params).highlights.asJava } + override def references( + params: OffsetParams, + targetFiles: ju.List[VirtualFileParams], + includeDefinition: Boolean, + ): CompletableFuture[ju.List[l.Location]] = + compilerAccess.withNonInterruptableCompiler(Some(params))( + List.empty[l.Location].asJava, + params.token, + ) { access => + val driver = access.compiler() + PcReferencesProvider( + driver, + params, + targetFiles.asScala.toList, + includeDefinition, + ) + .result() + .asJava + } + def shutdown(): Unit = compilerAccess.shutdown() diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/WithCompilationUnit.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/WithCompilationUnit.scala new file mode 100644 index 00000000000..63e25c5450e --- /dev/null +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/WithCompilationUnit.scala @@ -0,0 +1,39 @@ +package scala.meta.internal.pc +import java.nio.file.Paths + +import scala.meta as m + +import scala.meta.internal.metals.CompilerOffsetParams +import scala.meta.internal.mtags.MtagsEnrichments.* +import scala.meta.pc.OffsetParams +import scala.meta.pc.VirtualFileParams + +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourceFile + +class WithCompilationUnit( + val driver: InteractiveDriver, + params: VirtualFileParams, +): + val uri = params.uri() + val filePath = Paths.get(uri) + val sourceText = params.text + val text = sourceText.toCharArray() + val source = + SourceFile.virtual(filePath.toString, sourceText) + driver.run(uri, source) + given ctx: Context = driver.currentCtx + + val unit = driver.currentCtx.run.units.head + val compilatonUnitContext = ctx.fresh.setCompilationUnit(unit) + val offset = params match + case op: OffsetParams => op.offset() + case _ => 0 + val offsetParams = + params match + case op: OffsetParams => op + case _ => + CompilerOffsetParams(params.uri(), params.text(), 0, params.token()) + val pos = driver.sourcePosition(offsetParams) +end WithCompilationUnit diff --git a/tests/mtest/src/main/scala/tests/RangeReplace.scala b/tests/mtest/src/main/scala/tests/RangeReplace.scala index 866d6f0d89e..a1df5bcdd3e 100644 --- a/tests/mtest/src/main/scala/tests/RangeReplace.scala +++ b/tests/mtest/src/main/scala/tests/RangeReplace.scala @@ -11,14 +11,21 @@ trait RangeReplace { def renderHighlightsAsString( code: String, highlights: List[DocumentHighlight] + ): String = renderRangesAsString(code, highlights.map(_.getRange())) + + def renderRangesAsString( + code: String, + highlights: List[Range], + alreadyAddedMarkings: List[(Int, Int)] = Nil, + currentBase: Option[String] = None ): String = { highlights - .foldLeft((code, List.empty[(Int, Int)])) { - case ((base, alreadyAddedMarkings), location) => - replaceInRangeWithAdjustmens( + .foldLeft((currentBase.getOrElse(code), alreadyAddedMarkings)) { + case ((base, alreadyAddedMarkings), range) => + replaceInRangeWithAdjustments( code, base, - location.getRange, + range, alreadyAddedMarkings ) } @@ -31,9 +38,9 @@ trait RangeReplace { prefix: String = "<<", suffix: String = ">>" ): String = - replaceInRangeWithAdjustmens(base, base, range, List(), prefix, suffix)._1 + replaceInRangeWithAdjustments(base, base, range, List(), prefix, suffix)._1 - protected def replaceInRangeWithAdjustmens( + protected def replaceInRangeWithAdjustments( code: String, currentBase: String, range: Range, diff --git a/tests/slow/src/test/scala/tests/feature/PcReferencesLspSuite.scala b/tests/slow/src/test/scala/tests/feature/PcReferencesLspSuite.scala new file mode 100644 index 00000000000..ae6ca1efc57 --- /dev/null +++ b/tests/slow/src/test/scala/tests/feature/PcReferencesLspSuite.scala @@ -0,0 +1,130 @@ +package tests.feature + +import scala.meta.internal.metals +import scala.meta.internal.metals.EmptyCancelToken +import scala.meta.internal.metals.MetalsEnrichments._ + +import munit.TestOptions +import org.eclipse.lsp4j.Location +import org.eclipse.lsp4j.ReferenceContext +import org.eclipse.lsp4j.ReferenceParams +import tests.BaseLspSuite +import tests.BuildInfo +import tests.FileLayout +import tests.RangeReplace + +class PcReferencesLspSuite + extends BaseLspSuite("pc-references") + with RangeReplace { + + check( + "basic2", + """|/a/src/main/scala/O.scala + |object O { + | val <> = 1 + | val k = <> + |} + |/a/src/main/scala/Main.scala + |object Main { + | val g = O.<> + |} + |""".stripMargin, + metals.BuildInfo.scala213, + ) + + check( + "basic3", + """|/a/src/main/scala/O.scala + |object O { + | val <> = 1 + | val k = <> + |} + |/a/src/main/scala/Main.scala + |object Main { + | val g = O.<> + |} + |""".stripMargin, + metals.BuildInfo.scala3, + ) + + def check( + name: TestOptions, + input: String, + scalaVersion: String = BuildInfo.scalaVersion, + ): Unit = + test(name) { + cleanWorkspace() + val files = FileLayout.mapFromString(input) + val defFile = + files.collectFirst { + case (pathStr, content) if content.contains("@@") => pathStr + }.get + + def paramsF = { + val content = files.get(defFile).get + val actualContent = content.replaceAll("<<|>>", "") + val context = new ReferenceContext(true) + server.offsetParams(defFile, actualContent, workspace).map { + case (_, params) => + new ReferenceParams( + params.getTextDocument(), + params.getPosition(), + context, + ) + } + } + + def refFiles = files.keysIterator.map(server.toPath(_)).toList + + val layout = input.replaceAll("<<|>>|@@", "") + + def renderObtained(refs: List[Location]): String = { + val refsMap = refs.groupMap(ref => + ref.getUri().toAbsolutePath.toRelative(workspace).toString + )(_.getRange()) + files + .map { case (pathStr, content) => + val actualContent = content.replaceAll("<<|>>", "") + val withMarkings = + if (pathStr == defFile) { + val index = actualContent.indexOf("@@") + val code = actualContent.replaceAll("@@", "") + renderRangesAsString( + code, + refsMap.getOrElse(pathStr, Nil), + List((index, 2)), + Some(actualContent), + ) + } else { + renderRangesAsString( + actualContent, + refsMap.getOrElse(pathStr, Nil), + ) + } + s"""|/$pathStr + |$withMarkings""".stripMargin + } + .mkString("\n") + } + + for { + _ <- initialize( + s"""|/metals.json + |{ + | "a" : { + | "scalaVersion": "$scalaVersion" + | } + |} + |$layout""".stripMargin + ) + _ <- server.didOpen(defFile) + params <- paramsF + refs <- server.server.compilers.references( + params, + refFiles, + EmptyCancelToken, + ) + _ = assertNoDiff(renderObtained(refs), input) + } yield () + } +} diff --git a/tests/slow/src/test/scala/tests/gradle/GradleLspSuite.scala b/tests/slow/src/test/scala/tests/gradle/GradleLspSuite.scala index 0b0bddf9160..1aaa6c12fee 100644 --- a/tests/slow/src/test/scala/tests/gradle/GradleLspSuite.scala +++ b/tests/slow/src/test/scala/tests/gradle/GradleLspSuite.scala @@ -502,8 +502,9 @@ class GradleLspSuite extends BaseImportSuite("gradle-import") { """.stripMargin, ) // we should still have references despite fatal warning + refs <- server.workspaceReferences() _ = assertNoDiff( - server.workspaceReferences().references.map(_.symbol).mkString("\n"), + refs.references.map(_.symbol).mkString("\n"), """|_empty_/A. |_empty_/A.B. |_empty_/Warning. diff --git a/tests/slow/src/test/scala/tests/maven/MavenLspSuite.scala b/tests/slow/src/test/scala/tests/maven/MavenLspSuite.scala index 80f025c01f7..d1d124ad00b 100644 --- a/tests/slow/src/test/scala/tests/maven/MavenLspSuite.scala +++ b/tests/slow/src/test/scala/tests/maven/MavenLspSuite.scala @@ -255,8 +255,11 @@ class MavenLspSuite extends BaseImportSuite("maven-import") { """.stripMargin, ) // we should still have references despite fatal warning + refs <- server + .workspaceReferences() + .map(_.references.map(_.symbol).mkString("\n")) _ = assertNoDiff( - server.workspaceReferences().references.map(_.symbol).mkString("\n"), + refs, """|_empty_/A. |_empty_/A.B. |_empty_/Warning. diff --git a/tests/slow/src/test/scala/tests/mill/MillLspSuite.scala b/tests/slow/src/test/scala/tests/mill/MillLspSuite.scala index 01b87b31b47..ad558b7ed74 100644 --- a/tests/slow/src/test/scala/tests/mill/MillLspSuite.scala +++ b/tests/slow/src/test/scala/tests/mill/MillLspSuite.scala @@ -174,8 +174,9 @@ class MillLspSuite extends BaseImportSuite("mill-import") { """.stripMargin, ) // we should still have references despite fatal warning + refs <- server.workspaceReferences() _ = assertNoDiff( - server.workspaceReferences().references.map(_.symbol).mkString("\n"), + refs.references.map(_.symbol).mkString("\n"), """|_empty_/A. |_empty_/A.B. |_empty_/Warning. diff --git a/tests/slow/src/test/scala/tests/sbt/SbtBloopLspSuite.scala b/tests/slow/src/test/scala/tests/sbt/SbtBloopLspSuite.scala index 3d3462e79c1..90887adbf17 100644 --- a/tests/slow/src/test/scala/tests/sbt/SbtBloopLspSuite.scala +++ b/tests/slow/src/test/scala/tests/sbt/SbtBloopLspSuite.scala @@ -498,8 +498,9 @@ class SbtBloopLspSuite """.stripMargin, ) // we should still have references despite fatal warning + refs <- server.workspaceReferences() _ = assertNoDiff( - server.workspaceReferences().references.map(_.symbol).mkString("\n"), + refs.references.map(_.symbol).mkString("\n"), """|_empty_/A. |_empty_/A.B. |_empty_/Warning. diff --git a/tests/unit/src/main/scala/tests/TestRanges.scala b/tests/unit/src/main/scala/tests/TestRanges.scala index 6bc359c5500..e07c96a6eed 100644 --- a/tests/unit/src/main/scala/tests/TestRanges.scala +++ b/tests/unit/src/main/scala/tests/TestRanges.scala @@ -20,7 +20,7 @@ object TestRanges extends RangeReplace { } yield file -> validLocations .foldLeft((code, List.empty[(Int, Int)])) { case ((base, alreadyAddedMarkings), location) => - replaceInRangeWithAdjustmens( + replaceInRangeWithAdjustments( code, base, location.getRange, diff --git a/tests/unit/src/main/scala/tests/TestingServer.scala b/tests/unit/src/main/scala/tests/TestingServer.scala index ed664d0c72f..e7f8fc46c45 100644 --- a/tests/unit/src/main/scala/tests/TestingServer.scala +++ b/tests/unit/src/main/scala/tests/TestingServer.scala @@ -386,8 +386,7 @@ final case class TestingServer( def assertReferenceDefinitionBijection()(implicit loc: munit.Location - ): Unit = { - val compare = workspaceReferences() + ): Future[Unit] = workspaceReferences().map { compare => assert(compare.definition.nonEmpty, "Definitions should not be empty") assert(compare.references.nonEmpty, "References should not be empty") Assertions.assertNoDiff( @@ -398,13 +397,11 @@ final case class TestingServer( def assertReferenceDefinitionDiff( expectedDiff: String - )(implicit loc: munit.Location): Unit = { - Assertions.assertNoDiff( - workspaceReferences().diff, - expectedDiff, + )(implicit loc: munit.Location): Future[Unit] = + workspaceReferences().map(refs => + Assertions.assertNoDiff(refs.diff, expectedDiff) ) - } - def workspaceReferences(): WorkspaceSymbolReferences = { + def workspaceReferences(): Future[WorkspaceSymbolReferences] = { val inverse = mutable.Map.empty[SymbolReference, mutable.ListBuffer[Location]] val inputsCache = mutable.Map.empty[String, Input] @@ -451,27 +448,35 @@ final case class TestingServer( } val definition = Seq.newBuilder[SymbolReference] val references = Seq.newBuilder[SymbolReference] - for { - (ref, expectedLocations) <- inverse.toSeq.sortBy(_._1.symbol) - } { - val params = new ReferenceParams( - new TextDocumentIdentifier( - ref.location.getUri - ), - ref.location.getRange.getStart, - new ReferenceContext(true), + val resultFuture: Future[Unit] = + Future + .sequence( + for { + (ref, expectedLocations) <- inverse.toSeq.sortBy(_._1.symbol) + } yield { + val params = new ReferenceParams( + new TextDocumentIdentifier( + ref.location.getUri + ), + ref.location.getRange.getStart, + new ReferenceContext(true), + ) + server.referencesResult(params).map { obtainedLocations => + references ++= obtainedLocations.flatMap { result => + result.locations.map { l => + newRef(result.symbol, l) + } + } + definition ++= expectedLocations.map(l => newRef(ref.symbol, l)) + } + } + ) + .ignoreValue + resultFuture.map(_ => + WorkspaceSymbolReferences( + references.result().distinct, + definition.result().distinct, ) - val obtainedLocations = server.referencesResult(params) - references ++= obtainedLocations.flatMap { result => - result.locations.map { l => - newRef(result.symbol, l) - } - } - definition ++= expectedLocations.map(l => newRef(ref.symbol, l)) - } - WorkspaceSymbolReferences( - references.result().distinct, - definition.result().distinct, ) } @@ -1172,7 +1177,7 @@ final case class TestingServer( } } - private def offsetParams( + def offsetParams( filename: String, original: String, root: AbsolutePath, diff --git a/tests/unit/src/test/scala/tests/FingerprintsLspSuite.scala b/tests/unit/src/test/scala/tests/FingerprintsLspSuite.scala index adb8354bce0..c9fc508b272 100644 --- a/tests/unit/src/test/scala/tests/FingerprintsLspSuite.scala +++ b/tests/unit/src/test/scala/tests/FingerprintsLspSuite.scala @@ -67,8 +67,9 @@ class FingerprintsLspSuite extends BaseLspSuite("fingerprints") { | ^^^^^^^^ |""".stripMargin, ) + workspaceRefs <- server.workspaceReferences() _ = assertNoDiff( - newServer.workspaceReferences().referencesFormat, + workspaceRefs.referencesFormat, """|============= |= a/Adresses# |=============