diff --git a/presentation-compiler/src/main/dotty/tools/pc/PcCollector.scala b/presentation-compiler/src/main/dotty/tools/pc/PcCollector.scala index c447123c8725..5de80cda4ddf 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/PcCollector.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/PcCollector.scala @@ -2,6 +2,7 @@ package dotty.tools.pc import java.nio.file.Paths +import dotty.tools.pc.PcSymbolSearch.* import scala.meta.internal.metals.CompilerOffsetParams import scala.meta.pc.OffsetParams import scala.meta.pc.VirtualFileParams @@ -28,363 +29,59 @@ import dotty.tools.dotc.util.SourcePosition import dotty.tools.dotc.util.Spans.Span import dotty.tools.pc.utils.InteractiveEnrichments.* -abstract class PcCollector[T]( - driver: InteractiveDriver, - params: VirtualFileParams -): - private val caseClassSynthetics: Set[Name] = Set(nme.apply, nme.copy) - val uri = params.uri().nn - val filePath = Paths.get(uri).nn - val sourceText = params.text().nn - val text = sourceText.toCharArray().nn - val source = - SourceFile.virtual(filePath.toString(), sourceText) - driver.run(uri, source) - given ctx: Context = driver.currentCtx - - val unit = driver.currentCtx.run.nn.units.head - val compilatonUnitContext = ctx.fresh.setCompilationUnit(unit) - val offset = params match - case op: OffsetParams => op.offset() - case _ => 0 - val offsetParams = - params match - case op: OffsetParams => op - case _ => CompilerOffsetParams(uri, sourceText, 0, params.token().nn) - val pos = driver.sourcePosition(offsetParams) - val rawPath = - Interactive - .pathTo(driver.openedTrees(uri), pos)(using driver.currentCtx) - .dropWhile(t => // NamedArg anyway doesn't have symbol - t.symbol == NoSymbol && !t.isInstanceOf[NamedArg] || - // same issue https://github.com/scala/scala3/issues/15937 as below - t.isInstanceOf[TypeTree] - ) - - val path = rawPath match - // For type it will sometimes go into the wrong tree since TypeTree also contains the same span - // https://github.com/scala/scala3/issues/15937 - case TypeApply(sel: Select, _) :: tail if sel.span.contains(pos.span) => - Interactive.pathTo(sel, pos.span) ::: rawPath - case _ => rawPath +trait PcCollector[T]: + self: WithCompilationUnit => def collect( parent: Option[Tree] )(tree: Tree| EndMarker, pos: SourcePosition, symbol: Option[Symbol]): T - def symbolAlternatives(sym: Symbol) = - def member(parent: Symbol) = parent.info.member(sym.name).symbol - def primaryConstructorTypeParam(owner: Symbol) = - for - typeParams <- owner.primaryConstructor.paramSymss.headOption - param <- typeParams.find(_.name == sym.name) - if (param.isType) - yield param - def additionalForEnumTypeParam(enumClass: Symbol) = - if enumClass.is(Flags.Enum) then - val enumOwner = - if enumClass.is(Flags.Case) - then - Option.when(member(enumClass).is(Flags.Synthetic))( - enumClass.maybeOwner.companionClass - ) - else Some(enumClass) - enumOwner.toSet.flatMap { enumOwner => - val symsInEnumCases = enumOwner.children.toSet.flatMap(enumCase => - if member(enumCase).is(Flags.Synthetic) - then primaryConstructorTypeParam(enumCase) - else None - ) - val symsInEnumOwner = - primaryConstructorTypeParam(enumOwner).toSet + member(enumOwner) - symsInEnumCases ++ symsInEnumOwner - } - else Set.empty - val all = - if sym.is(Flags.ModuleClass) then - Set(sym, sym.companionModule, sym.companionModule.companion) - else if sym.isClass then - Set(sym, sym.companionModule, sym.companion.moduleClass) - else if sym.is(Flags.Module) then - Set(sym, sym.companionClass, sym.moduleClass) - else if sym.isTerm && (sym.owner.isClass || sym.owner.isConstructor) - then - val info = - if sym.owner.isClass then sym.owner.info else sym.owner.owner.info - Set( - sym, - info.member(sym.asTerm.name.setterName).symbol, - info.member(sym.asTerm.name.getterName).symbol - ) ++ sym.allOverriddenSymbols.toSet - // type used in primary constructor will not match the one used in the class - else if sym.isTypeParam && sym.owner.isPrimaryConstructor then - Set(sym, member(sym.maybeOwner.maybeOwner)) - ++ additionalForEnumTypeParam(sym.maybeOwner.maybeOwner) - else if sym.isTypeParam then - primaryConstructorTypeParam(sym.maybeOwner).toSet - ++ additionalForEnumTypeParam(sym.maybeOwner) + sym - else Set(sym) - all.filter(s => s != NoSymbol && !s.isError) - end symbolAlternatives - - private def isGeneratedGiven(df: NamedDefTree)(using Context) = - val nameSpan = df.nameSpan - df.symbol.is(Flags.Given) && sourceText.substring( - nameSpan.start, - nameSpan.end - ) != df.name.toString() - - // First identify the symbol we are at, comments identify @@ as current cursor position - def soughtSymbols(path: List[Tree]): Option[(Set[Symbol], SourcePosition)] = - val sought = path match - /* reference of an extension paramter - * extension [EF](<>: List[EF]) - * def double(ys: List[EF]) = <> ++ ys - */ - case (id: Ident) :: _ - if id.symbol - .is(Flags.Param) && id.symbol.owner.is(Flags.ExtensionMethod) => - Some(findAllExtensionParamSymbols(id.sourcePos, id.name, id.symbol)) - /** - * Workaround for missing symbol in: - * class A[T](a: T) - * val x = new <>(1) - */ - case t :: (n: New) :: (sel: Select) :: _ - if t.symbol == NoSymbol && sel.symbol.isConstructor => - Some(symbolAlternatives(sel.symbol.owner), namePos(t)) - /** - * Workaround for missing symbol in: - * class A[T](a: T) - * val x = <>[Int](1) - */ - case (sel @ Select(New(t), _)) :: (_: TypeApply) :: _ - if sel.symbol.isConstructor => - Some(symbolAlternatives(sel.symbol.owner), namePos(t)) - /* simple identifier: - * val a = val@@ue + value - */ - case (id: Ident) :: _ => - Some(symbolAlternatives(id.symbol), id.sourcePos) - /* simple selector: - * object.val@@ue - */ - case (sel: Select) :: _ if selectNameSpan(sel).contains(pos.span) => - Some(symbolAlternatives(sel.symbol), pos.withSpan(sel.nameSpan)) - /* named argument: - * foo(nam@@e = "123") - */ - case (arg: NamedArg) :: (appl: Apply) :: _ => - val realName = arg.name.stripModuleClassSuffix.lastPart - if pos.span.start > arg.span.start && pos.span.end < arg.span.point + realName.length - then - val length = realName.toString.backticked.length() - val pos = arg.sourcePos.withSpan( - arg.span - .withEnd(arg.span.start + length) - .withPoint(arg.span.start) - ) - appl.symbol.paramSymss.flatten.find(_.name == arg.name).map { s => - // if it's a case class we need to look for parameters also - if caseClassSynthetics(s.owner.name) && s.owner.is(Flags.Synthetic) - then - ( - Set( - s, - s.owner.owner.companion.info.member(s.name).symbol, - s.owner.owner.info.member(s.name).symbol - ) - .filter(_ != NoSymbol), - pos, - ) - else (Set(s), pos) - } - else None - end if - /* all definitions: - * def fo@@o = ??? - * class Fo@@o = ??? - * etc. - */ - case (df: NamedDefTree) :: _ - if df.nameSpan.contains(pos.span) && !isGeneratedGiven(df) => - Some(symbolAlternatives(df.symbol), pos.withSpan(df.nameSpan)) - /* enum cases with params - * enum Foo: - * case B@@ar[A](i: A) - */ - case (df: NamedDefTree) :: Template(_, _, self, _) :: _ - if (df.name == nme.apply || df.name == nme.unapply) && df.nameSpan.isZeroExtent => - Some(symbolAlternatives(self.tpt.symbol), self.sourcePos) - /** - * For traversing annotations: - * @JsonNo@@tification("") - * def params() = ??? - */ - case (df: MemberDef) :: _ if df.span.contains(pos.span) => - val annotTree = df.mods.annotations.find { t => - t.span.contains(pos.span) - } - collectTrees(annotTree).flatMap { t => - soughtSymbols( - Interactive.pathTo(t, pos.span) - ) - }.headOption - - /* Import selectors: - * import scala.util.Tr@@y - */ - case (imp: Import) :: _ if imp.span.contains(pos.span) => - imp - .selector(pos.span) - .map(sym => (symbolAlternatives(sym), sym.sourcePos)) - - case _ => None - - sought match - case None => seekInExtensionParameters() - case _ => sought - - end soughtSymbols - - lazy val extensionMethods = - NavigateAST - .untypedPath(pos.span)(using compilatonUnitContext) - .collectFirst { case em @ ExtMethods(_, _) => em } - - private def findAllExtensionParamSymbols( - pos: SourcePosition, - name: Name, - sym: Symbol - ) = - val symbols = - for - methods <- extensionMethods.map(_.methods) - symbols <- collectAllExtensionParamSymbols( - unit.tpdTree, - ExtensionParamOccurence(name, pos, sym, methods) - ) - yield symbols - symbols.getOrElse((symbolAlternatives(sym), pos)) - end findAllExtensionParamSymbols - - private def seekInExtensionParameters() = - def collectParams( - extMethods: ExtMethods - ): Option[ExtensionParamOccurence] = - NavigateAST - .pathTo(pos.span, extMethods.paramss.flatten)(using - compilatonUnitContext - ) - .collectFirst { - case v: untpd.ValOrTypeDef => - ExtensionParamOccurence( - v.name, - v.namePos, - v.symbol, - extMethods.methods - ) - case i: untpd.Ident => - ExtensionParamOccurence( - i.name, - i.sourcePos, - i.symbol, - extMethods.methods - ) - } - - for - extensionMethodScope <- extensionMethods - occurrence <- collectParams(extensionMethodScope) - symbols <- collectAllExtensionParamSymbols( - path.headOption.getOrElse(unit.tpdTree), - occurrence - ) - yield symbols - end seekInExtensionParameters - - private def collectAllExtensionParamSymbols( - tree: tpd.Tree, - occurrence: ExtensionParamOccurence - ): Option[(Set[Symbol], SourcePosition)] = - occurrence match - case ExtensionParamOccurence(_, namePos, symbol, _) - if symbol != NoSymbol && !symbol.isError && !symbol.owner.is( - Flags.ExtensionMethod - ) => - Some((symbolAlternatives(symbol), namePos)) - case ExtensionParamOccurence(name, namePos, _, methods) => - val symbols = - for - method <- methods.toSet - symbol <- - Interactive.pathTo(tree, method.span) match - case (d: DefDef) :: _ => - d.paramss.flatten.collect { - case param if param.name.decoded == name.decoded => - param.symbol - } - case _ => Set.empty[Symbol] - if (symbol != NoSymbol && !symbol.isError) - withAlt <- symbolAlternatives(symbol) - yield withAlt - if symbols.nonEmpty then Some((symbols, namePos)) else None - end collectAllExtensionParamSymbols - - def result(): List[T] = - params match - case _: OffsetParams => resultWithSought() - case _ => resultAllOccurences().toList - def resultAllOccurences(): Set[T] = def noTreeFilter = (_: Tree) => true def noSoughtFilter = (_: Symbol => Boolean) => true traverseSought(noTreeFilter, noSoughtFilter) - def resultWithSought(): List[T] = - soughtSymbols(path) match - case Some((sought, _)) => - lazy val owners = sought - .flatMap { s => Set(s.owner, s.owner.companionModule) } - .filter(_ != NoSymbol) - lazy val soughtNames: Set[Name] = sought.map(_.name) - - /* - * For comprehensions have two owners, one for the enumerators and one for - * yield. This is a heuristic to find that out. - */ - def isForComprehensionOwner(named: NameTree) = - soughtNames(named.name) && - scala.util - .Try(named.symbol.owner) - .toOption - .exists(_.isAnonymousFunction) && - owners.exists(o => - o.span.exists && o.span.point == named.symbol.owner.span.point - ) - - def soughtOrOverride(sym: Symbol) = - sought(sym) || sym.allOverriddenSymbols.exists(sought(_)) + def resultWithSought(sought: Set[Symbol]): List[T] = + lazy val owners = sought + .flatMap { s => Set(s.owner, s.owner.companionModule) } + .filter(_ != NoSymbol) + lazy val soughtNames: Set[Name] = sought.map(_.name) + + /* + * For comprehensions have two owners, one for the enumerators and one for + * yield. This is a heuristic to find that out. + */ + def isForComprehensionOwner(named: NameTree) = + soughtNames(named.name) && + scala.util + .Try(named.symbol.owner) + .toOption + .exists(_.isAnonymousFunction) && + owners.exists(o => + o.span.exists && o.span.point == named.symbol.owner.span.point + ) - def soughtTreeFilter(tree: Tree): Boolean = - tree match - case ident: Ident - if soughtOrOverride(ident.symbol) || - isForComprehensionOwner(ident) => - true - case sel: Select if soughtOrOverride(sel.symbol) => true - case df: NamedDefTree - if soughtOrOverride(df.symbol) && !df.symbol.isSetter => - true - case imp: Import if owners(imp.expr.symbol) => true - case _ => false + def soughtOrOverride(sym: Symbol) = + sought(sym) || sym.allOverriddenSymbols.exists(sought(_)) - def soughtFilter(f: Symbol => Boolean): Boolean = - sought.exists(f) + def soughtTreeFilter(tree: Tree): Boolean = + tree match + case ident: Ident + if soughtOrOverride(ident.symbol) || + isForComprehensionOwner(ident) => + true + case sel: Select if soughtOrOverride(sel.symbol) => true + case df: NamedDefTree + if soughtOrOverride(df.symbol) && !df.symbol.isSetter => + true + case imp: Import if owners(imp.expr.symbol) => true + case _ => false - traverseSought(soughtTreeFilter, soughtFilter).toList + def soughtFilter(f: Symbol => Boolean): Boolean = + sought.exists(f) - case None => Nil + traverseSought(soughtTreeFilter, soughtFilter).toList + end resultWithSought extension (span: Span) def isCorrect = @@ -453,7 +150,7 @@ abstract class PcCollector[T]( */ case df: NamedDefTree if df.span.isCorrect && df.nameSpan.isCorrect && - filter(df) && !isGeneratedGiven(df) => + filter(df) && !isGeneratedGiven(df, sourceText) => def collectEndMarker = EndMarker.getPosition(df, pos, sourceText).map: collect(EndMarker(df.symbol), _) @@ -572,35 +269,9 @@ abstract class PcCollector[T]( val traverser = new PcCollector.DeepFolderWithParent[Set[T]](collectNamesWithParent) - val all = traverser(Set.empty[T], unit.tpdTree) - all + traverser(Set.empty[T], unit.tpdTree) end traverseSought - // @note (tgodzik) Not sure currently how to get rid of the warning, but looks to correctly - // @nowarn - private def collectTrees(trees: Iterable[Positioned]): Iterable[Tree] = - trees.collect { case t: Tree => - t - } - - // NOTE: Connected to https://github.com/scala/scala3/issues/16771 - // `sel.nameSpan` is calculated incorrectly in (1 + 2).toString - // See test DocumentHighlightSuite.select-parentheses - private def selectNameSpan(sel: Select): Span = - val span = sel.span - if span.exists then - val point = span.point - if sel.name.toTermName == nme.ERROR then Span(point) - else if sel.qualifier.span.start > span.point then // right associative - val realName = sel.name.stripModuleClassSuffix.lastPart - Span(span.start, span.start + realName.length, point) - else Span(point, span.end, point) - else span - - private def namePos(tree: Tree): SourcePosition = - tree match - case sel: Select => sel.sourcePos.withSpan(selectNameSpan(sel)) - case _ => tree.sourcePos end PcCollector object PcCollector: @@ -656,3 +327,21 @@ object EndMarker: ) end getPosition end EndMarker + +abstract class WithSymbolSearchCollector[T]( + driver: InteractiveDriver, + params: OffsetParams, +) extends WithCompilationUnit(driver, params) + with PcSymbolSearch + with PcCollector[T]: + def result(): List[T] = + soughtSymbols.toList.flatMap { case (sought, _) => + resultWithSought(sought) + } + +abstract class SimpleCollector[T]( + driver: InteractiveDriver, + params: VirtualFileParams, +) extends WithCompilationUnit(driver, params) + with PcCollector[T]: + def result(): List[T] = resultAllOccurences().toList diff --git a/presentation-compiler/src/main/dotty/tools/pc/PcDocumentHighlightProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/PcDocumentHighlightProvider.scala index d9b94ebb82a3..0c1af215b7f7 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/PcDocumentHighlightProvider.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/PcDocumentHighlightProvider.scala @@ -14,7 +14,7 @@ import org.eclipse.lsp4j.DocumentHighlightKind final class PcDocumentHighlightProvider( driver: InteractiveDriver, params: OffsetParams -) extends PcCollector[DocumentHighlight](driver, params): +) extends WithSymbolSearchCollector[DocumentHighlight](driver, params): def collect( parent: Option[Tree] diff --git a/presentation-compiler/src/main/dotty/tools/pc/PcInlineValueProviderImpl.scala b/presentation-compiler/src/main/dotty/tools/pc/PcInlineValueProviderImpl.scala index 38b5e8d0069b..bbba44d0d84f 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/PcInlineValueProviderImpl.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/PcInlineValueProviderImpl.scala @@ -22,9 +22,9 @@ import dotty.tools.pc.utils.InteractiveEnrichments.* import org.eclipse.lsp4j as l final class PcInlineValueProviderImpl( - val driver: InteractiveDriver, + driver: InteractiveDriver, val params: OffsetParams -) extends PcCollector[Option[Occurence]](driver, params) +) extends WithSymbolSearchCollector[Option[Occurence]](driver, params) with InlineValueProvider: val position: l.Position = pos.toLsp.getStart().nn diff --git a/presentation-compiler/src/main/dotty/tools/pc/PcReferencesProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/PcReferencesProvider.scala new file mode 100644 index 000000000000..8d22ce320eee --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/PcReferencesProvider.scala @@ -0,0 +1,67 @@ +package dotty.tools.pc + +import scala.language.unsafeNulls + +import scala.jdk.CollectionConverters.* + +import scala.meta.internal.metals.CompilerOffsetParams +import scala.meta.pc.ReferencesRequest +import scala.meta.pc.ReferencesResult + +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourcePosition +import org.eclipse.lsp4j +import org.eclipse.lsp4j.Location +import dotty.tools.pc.utils.InteractiveEnrichments.* +import scala.meta.internal.pc.PcReferencesResult + +class PcReferencesProvider( + driver: InteractiveDriver, + request: ReferencesRequest, +) extends WithCompilationUnit(driver, request.file()) with PcCollector[Option[(String, Option[lsp4j.Range])]]: + + private def soughtSymbols = + if(request.offsetOrSymbol().isLeft()) { + val offsetParams = CompilerOffsetParams( + request.file().uri(), + request.file().text(), + request.offsetOrSymbol().getLeft() + ) + val symbolSearch = new WithCompilationUnit(driver, offsetParams) with PcSymbolSearch + symbolSearch.soughtSymbols.map(_._1) + } else { + SymbolProvider.compilerSymbol(request.offsetOrSymbol().getRight()).map(symbolAlternatives(_)) + } + + def collect(parent: Option[Tree])( + tree: Tree | EndMarker, + toAdjust: SourcePosition, + symbol: Option[Symbol], + ): Option[(String, Option[lsp4j.Range])] = + val (pos, _) = toAdjust.adjust(text) + tree match + case t: DefTree if !request.includeDefinition() => + val sym = symbol.getOrElse(t.symbol) + Some(SemanticdbSymbols.symbolName(sym), None) + case t: Tree => + val sym = symbol.getOrElse(t.symbol) + Some(SemanticdbSymbols.symbolName(sym), Some(pos.toLsp)) + case _ => None + + def references(): List[ReferencesResult] = + soughtSymbols match + case Some(sought) if sought.nonEmpty => + resultWithSought(sought) + .flatten + .groupMap(_._1) { case (_, optRange) => + optRange.map(new Location(request.file().uri().toString(), _)) + } + .map { case (symbol, locs) => + PcReferencesResult(symbol, locs.flatten.asJava) + } + .toList + case _ => Nil +end PcReferencesProvider \ No newline at end of file diff --git a/presentation-compiler/src/main/dotty/tools/pc/PcRenameProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/PcRenameProvider.scala index 94482767f917..666ccf9c614f 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/PcRenameProvider.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/PcRenameProvider.scala @@ -16,7 +16,7 @@ final class PcRenameProvider( driver: InteractiveDriver, params: OffsetParams, name: Option[String] -) extends PcCollector[l.TextEdit](driver, params): +) extends WithSymbolSearchCollector[l.TextEdit](driver, params): private val forbiddenMethods = Set("equals", "hashCode", "unapply", "unary_!", "!") def canRenameSymbol(sym: Symbol)(using Context): Boolean = @@ -25,7 +25,7 @@ final class PcRenameProvider( || sym.source.path.isWorksheet) def prepareRename(): Option[l.Range] = - soughtSymbols(path).flatMap((symbols, pos) => + soughtSymbols.flatMap((symbols, pos) => if symbols.forall(canRenameSymbol) then Some(pos.toLsp) else None ) @@ -42,13 +42,10 @@ final class PcRenameProvider( ) end collect - def rename( - ): List[l.TextEdit] = - val (symbols, _) = soughtSymbols(path).getOrElse(Set.empty, pos) + def rename(): List[l.TextEdit] = + val (symbols, _) = soughtSymbols.getOrElse(Set.empty, pos) if symbols.nonEmpty && symbols.forall(canRenameSymbol(_)) - then - val res = result() - res + then result() else Nil end rename end PcRenameProvider diff --git a/presentation-compiler/src/main/dotty/tools/pc/PcSemanticTokensProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/PcSemanticTokensProvider.scala index a5332f1e4ff6..216d9318197b 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/PcSemanticTokensProvider.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/PcSemanticTokensProvider.scala @@ -60,7 +60,7 @@ final class PcSemanticTokensProvider( case _ => !df.rhs.isEmpty case _ => false - object Collector extends PcCollector[Option[Node]](driver, params): + object Collector extends SimpleCollector[Option[Node]](driver, params): override def collect( parent: Option[Tree] )(tree: Tree | EndMarker, pos: SourcePosition, symbol: Option[Symbol]): Option[Node] = diff --git a/presentation-compiler/src/main/dotty/tools/pc/PcSymbolSearch.scala b/presentation-compiler/src/main/dotty/tools/pc/PcSymbolSearch.scala new file mode 100644 index 000000000000..fd3d74f16c16 --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/PcSymbolSearch.scala @@ -0,0 +1,275 @@ +package dotty.tools.pc + +import dotty.tools.pc.PcSymbolSearch.* + +import dotty.tools.dotc.ast.NavigateAST +import dotty.tools.dotc.ast.Positioned +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.ast.untpd +import dotty.tools.dotc.ast.untpd.ExtMethods +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.NameOps.* +import dotty.tools.dotc.core.Names.* +import dotty.tools.dotc.core.StdNames.* +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.core.Types.* +import dotty.tools.dotc.interactive.Interactive +import dotty.tools.dotc.util.SourcePosition +import dotty.tools.dotc.util.Spans.Span +import dotty.tools.pc.utils.InteractiveEnrichments.* + +trait PcSymbolSearch: + self: WithCompilationUnit => + + private val caseClassSynthetics: Set[Name] = Set(nme.apply, nme.copy) + + lazy val rawPath = + Interactive + .pathTo(driver.openedTrees(uri), pos)(using driver.currentCtx) + .dropWhile(t => // NamedArg anyway doesn't have symbol + t.symbol == NoSymbol && !t.isInstanceOf[NamedArg] || + // same issue https://github.com/lampepfl/dotty/issues/15937 as below + t.isInstanceOf[TypeTree] + ) + + lazy val extensionMethods = + NavigateAST + .untypedPath(pos.span)(using compilatonUnitContext) + .collectFirst { case em @ ExtMethods(_, _) => em } + + lazy val path = rawPath match + // For type it will sometimes go into the wrong tree since TypeTree also contains the same span + // https://github.com/lampepfl/dotty/issues/15937 + case TypeApply(sel: Select, _) :: tail if sel.span.contains(pos.span) => + Interactive.pathTo(sel, pos.span) ::: rawPath + case _ => rawPath + + lazy val soughtSymbols: Option[(Set[Symbol], SourcePosition)] = + soughtSymbols(path) + + def soughtSymbols(path: List[Tree]): Option[(Set[Symbol], SourcePosition)] = + val sought = path match + /* reference of an extension paramter + * extension [EF](<>: List[EF]) + * def double(ys: List[EF]) = <> ++ ys + */ + case (id: Ident) :: _ + if id.symbol + .is(Flags.Param) && id.symbol.owner.is(Flags.ExtensionMethod) => + Some(findAllExtensionParamSymbols(id.sourcePos, id.name, id.symbol)) + /** + * Workaround for missing symbol in: + * class A[T](a: T) + * val x = new <>(1) + */ + case t :: (n: New) :: (sel: Select) :: _ + if t.symbol == NoSymbol && sel.symbol.isConstructor => + Some(symbolAlternatives(sel.symbol.owner), namePos(t)) + /** + * Workaround for missing symbol in: + * class A[T](a: T) + * val x = <>[Int](1) + */ + case (sel @ Select(New(t), _)) :: (_: TypeApply) :: _ + if sel.symbol.isConstructor => + Some(symbolAlternatives(sel.symbol.owner), namePos(t)) + /* simple identifier: + * val a = val@@ue + value + */ + case (id: Ident) :: _ => + Some(symbolAlternatives(id.symbol), id.sourcePos) + /* simple selector: + * object.val@@ue + */ + case (sel: Select) :: _ if selectNameSpan(sel).contains(pos.span) => + Some(symbolAlternatives(sel.symbol), pos.withSpan(sel.nameSpan)) + /* named argument: + * foo(nam@@e = "123") + */ + case (arg: NamedArg) :: (appl: Apply) :: _ => + val realName = arg.name.stripModuleClassSuffix.lastPart + if pos.span.start > arg.span.start && pos.span.end < arg.span.point + realName.length + then + val length = realName.toString.backticked.length() + val pos = arg.sourcePos.withSpan( + arg.span + .withEnd(arg.span.start + length) + .withPoint(arg.span.start) + ) + appl.symbol.paramSymss.flatten.find(_.name == arg.name).map { s => + // if it's a case class we need to look for parameters also + if caseClassSynthetics(s.owner.name) && s.owner.is(Flags.Synthetic) + then + ( + Set( + s, + s.owner.owner.companion.info.member(s.name).symbol, + s.owner.owner.info.member(s.name).symbol + ) + .filter(_ != NoSymbol), + pos, + ) + else (Set(s), pos) + } + else None + end if + /* all definitions: + * def fo@@o = ??? + * class Fo@@o = ??? + * etc. + */ + case (df: NamedDefTree) :: _ + if df.nameSpan.contains(pos.span) && !isGeneratedGiven(df, sourceText) => + Some(symbolAlternatives(df.symbol), pos.withSpan(df.nameSpan)) + /* enum cases with params + * enum Foo: + * case B@@ar[A](i: A) + */ + case (df: NamedDefTree) :: Template(_, _, self, _) :: _ + if (df.name == nme.apply || df.name == nme.unapply) && df.nameSpan.isZeroExtent => + Some(symbolAlternatives(self.tpt.symbol), self.sourcePos) + /** + * For traversing annotations: + * @JsonNo@@tification("") + * def params() = ??? + */ + case (df: MemberDef) :: _ if df.span.contains(pos.span) => + val annotTree = df.mods.annotations.find { t => + t.span.contains(pos.span) + } + collectTrees(annotTree).flatMap { t => + soughtSymbols( + Interactive.pathTo(t, pos.span) + ) + }.headOption + + /* Import selectors: + * import scala.util.Tr@@y + */ + case (imp: Import) :: _ if imp.span.contains(pos.span) => + imp + .selector(pos.span) + .map(sym => (symbolAlternatives(sym), sym.sourcePos)) + + case _ => None + + sought match + case None => seekInExtensionParameters() + case _ => sought + + end soughtSymbols + + private def seekInExtensionParameters() = + def collectParams( + extMethods: ExtMethods + ): Option[ExtensionParamOccurence] = + NavigateAST + .pathTo(pos.span, extMethods.paramss.flatten)(using + compilatonUnitContext + ) + .collectFirst { + case v: untpd.ValOrTypeDef => + ExtensionParamOccurence( + v.name, + v.namePos, + v.symbol, + extMethods.methods + ) + case i: untpd.Ident => + ExtensionParamOccurence( + i.name, + i.sourcePos, + i.symbol, + extMethods.methods + ) + } + + for + extensionMethodScope <- extensionMethods + occurrence <- collectParams(extensionMethodScope) + symbols <- collectAllExtensionParamSymbols( + path.headOption.getOrElse(unit.tpdTree), + occurrence + ) + yield symbols + end seekInExtensionParameters + + private def collectAllExtensionParamSymbols( + tree: tpd.Tree, + occurrence: ExtensionParamOccurence, + ): Option[(Set[Symbol], SourcePosition)] = + occurrence match + case ExtensionParamOccurence(_, namePos, symbol, _) + if symbol != NoSymbol && !symbol.isError && !symbol.owner.is( + Flags.ExtensionMethod + ) => + Some((symbolAlternatives(symbol), namePos)) + case ExtensionParamOccurence(name, namePos, _, methods) => + val symbols = + for + method <- methods.toSet + symbol <- + Interactive.pathTo(tree, method.span) match + case (d: DefDef) :: _ => + d.paramss.flatten.collect { + case param if param.name.decoded == name.decoded => + param.symbol + } + case _ => Set.empty[Symbol] + if (symbol != NoSymbol && !symbol.isError) + withAlt <- symbolAlternatives(symbol) + yield withAlt + if symbols.nonEmpty then Some((symbols, namePos)) else None + end collectAllExtensionParamSymbols + + private def findAllExtensionParamSymbols( + pos: SourcePosition, + name: Name, + sym: Symbol, + ) = + val symbols = + for + methods <- extensionMethods.map(_.methods) + symbols <- collectAllExtensionParamSymbols( + unit.tpdTree, + ExtensionParamOccurence(name, pos, sym, methods), + ) + yield symbols + symbols.getOrElse((symbolAlternatives(sym), pos)) + end findAllExtensionParamSymbols +end PcSymbolSearch + +object PcSymbolSearch: + // NOTE: Connected to https://github.com/lampepfl/dotty/issues/16771 + // `sel.nameSpan` is calculated incorrectly in (1 + 2).toString + // See test DocumentHighlightSuite.select-parentheses + def selectNameSpan(sel: Select): Span = + val span = sel.span + if span.exists then + val point = span.point + if sel.name.toTermName == nme.ERROR then Span(point) + else if sel.qualifier.span.start > span.point then // right associative + val realName = sel.name.stripModuleClassSuffix.lastPart + Span(span.start, span.start + realName.length, point) + else Span(point, span.end, point) + else span + + def collectTrees(trees: Iterable[Positioned]): Iterable[Tree] = + trees.collect { case t: Tree => t } + + def namePos(tree: Tree)(using Context): SourcePosition = + tree match + case sel: Select => sel.sourcePos.withSpan(selectNameSpan(sel)) + case _ => tree.sourcePos + + def isGeneratedGiven(df: NamedDefTree, sourceText: String)(using Context) = + val nameSpan = df.nameSpan + df.symbol.is(Flags.Given) && sourceText.substring( + nameSpan.start, + nameSpan.end, + ) != df.name.toString() + +end PcSymbolSearch + diff --git a/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala b/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala index 86aa895cb4fc..ad8ac02ec811 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala @@ -36,7 +36,7 @@ import dotty.tools.pc.buildinfo.BuildInfo import org.eclipse.lsp4j.DocumentHighlight import org.eclipse.lsp4j.TextEdit import org.eclipse.lsp4j as l -import scala.meta.internal.pc.SymbolInformationProvider +import dotty.tools.pc.SymbolInformationProvider case class ScalaPresentationCompiler( buildTargetIdentifier: String = "", @@ -178,6 +178,19 @@ case class ScalaPresentationCompiler( PcDocumentHighlightProvider(driver, params).highlights.asJava } + override def references( + params: ReferencesRequest + ): CompletableFuture[ju.List[ReferencesResult]] = + compilerAccess.withNonInterruptableCompiler(Some(params.file()))( + List.empty[ReferencesResult].asJava, + params.file().token, + ) { access => + val driver = access.compiler() + PcReferencesProvider(driver, params) + .references() + .asJava + } + def shutdown(): Unit = compilerAccess.shutdown() diff --git a/presentation-compiler/src/main/dotty/tools/pc/SymbolInformationProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/SymbolInformationProvider.scala index 0743361f255d..18d6a4ec8621 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/SymbolInformationProvider.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/SymbolInformationProvider.scala @@ -1,4 +1,4 @@ -package scala.meta.internal.pc +package dotty.tools.pc import scala.util.control.NonFatal @@ -15,58 +15,19 @@ import dotty.tools.dotc.core.Symbols.* import dotty.tools.pc.utils.InteractiveEnrichments.deepDealias import dotty.tools.pc.SemanticdbSymbols import dotty.tools.pc.utils.InteractiveEnrichments.allSymbols +import dotty.tools.pc.utils.InteractiveEnrichments.stripBackticks +import scala.meta.internal.pc.PcSymbolInformation +import scala.meta.internal.pc.SymbolInfo class SymbolInformationProvider(using Context): - private def toSymbols( - pkg: String, - parts: List[(String, Boolean)], - ): List[Symbol] = - def loop( - owners: List[Symbol], - parts: List[(String, Boolean)], - ): List[Symbol] = - parts match - case (head, isClass) :: tl => - val foundSymbols = - owners.flatMap { owner => - val next = - if isClass then owner.info.member(typeName(head)) - else owner.info.member(termName(head)) - next.allSymbols - } - if foundSymbols.nonEmpty then loop(foundSymbols, tl) - else Nil - case Nil => owners - - val pkgSym = - if pkg == "_empty_" then requiredPackage(nme.EMPTY_PACKAGE) - else requiredPackage(pkg) - loop(List(pkgSym), parts) - end toSymbols def info(symbol: String): Option[PcSymbolInformation] = - val index = symbol.lastIndexOf("/") - val pkg = normalizePackage(symbol.take(index + 1)) - - def loop( - symbol: String, - acc: List[(String, Boolean)], - ): List[(String, Boolean)] = - if symbol.isEmpty() then acc.reverse - else - val newSymbol = symbol.takeWhile(c => c != '.' && c != '#') - val rest = symbol.drop(newSymbol.size) - loop(rest.drop(1), (newSymbol, rest.headOption.exists(_ == '#')) :: acc) - val names = - loop(symbol.drop(index + 1).takeWhile(_ != '('), List.empty) - - val foundSymbols = - try toSymbols(pkg, names) - catch case NonFatal(e) => Nil + val foundSymbols = SymbolProvider.compilerSymbols(symbol) val (searchedSymbol, alternativeSymbols) = - foundSymbols.partition: compilerSymbol => + foundSymbols.partition(compilerSymbol => SemanticdbSymbols.symbolName(compilerSymbol) == symbol + ) searchedSymbol match case Nil => None @@ -115,8 +76,50 @@ class SymbolInformationProvider(using Context): else if sym.is(Flags.TypeParam) then PcSymbolKind.TYPE_PARAMETER else if sym.isType then PcSymbolKind.TYPE else PcSymbolKind.UNKNOWN_KIND +end SymbolInformationProvider + +object SymbolProvider: + + def compilerSymbol(symbol: String)(using Context): Option[Symbol] = + compilerSymbols(symbol).find(sym => SemanticdbSymbols.symbolName(sym) == symbol) + + def compilerSymbols(symbol: String)(using Context): List[Symbol] = + try toSymbols(SymbolInfo.getPartsFromSymbol(symbol)) + catch case NonFatal(e) => Nil private def normalizePackage(pkg: String): String = pkg.replace("/", ".").nn.stripSuffix(".") -end SymbolInformationProvider + private def toSymbols(info: SymbolInfo.SymbolParts)(using Context): List[Symbol] = + def collectSymbols(denotation: Denotation): List[Symbol] = + denotation match + case MultiDenotation(denot1, denot2) => + collectSymbols(denot1) ++ collectSymbols(denot2) + case denot => List(denot.symbol) + + def loop( + owners: List[Symbol], + parts: List[(String, Boolean)], + ): List[Symbol] = + parts match + case (head, isClass) :: tl => + val foundSymbols = + owners.flatMap { owner => + val name = head.stripBackticks + val next = + if isClass then owner.info.member(typeName(name)) + else owner.info.member(termName(name)) + collectSymbols(next).filter(_.exists) + } + if foundSymbols.nonEmpty then loop(foundSymbols, tl) + else Nil + case Nil => owners + + val pkgSym = + if info.packagePart == "_empty_/" then requiredPackage(nme.EMPTY_PACKAGE) + else requiredPackage(normalizePackage(info.packagePart)) + val found = loop(List(pkgSym), info.names) + info.paramName match + case Some(name) => found.flatMap(_.paramSymss.flatten.find(_.showName == name)) + case _ => found + end toSymbols diff --git a/presentation-compiler/src/main/dotty/tools/pc/WithCompilationUnit.scala b/presentation-compiler/src/main/dotty/tools/pc/WithCompilationUnit.scala new file mode 100644 index 000000000000..8110db269b3b --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/WithCompilationUnit.scala @@ -0,0 +1,105 @@ +package dotty.tools.pc + +import scala.language.unsafeNulls + +import java.nio.file.Paths + +import scala.meta as m + +import scala.meta.internal.metals.CompilerOffsetParams +import scala.meta.pc.OffsetParams +import scala.meta.pc.VirtualFileParams + +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.NameOps.* +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourceFile +import dotty.tools.pc.utils.InteractiveEnrichments.* + +class WithCompilationUnit( + val driver: InteractiveDriver, + params: VirtualFileParams, +): + val uri = params.uri() + val filePath = Paths.get(uri) + val sourceText = params.text + val text = sourceText.toCharArray() + val source = + SourceFile.virtual(filePath.toString, sourceText) + driver.run(uri, source) + given ctx: Context = driver.currentCtx + + private val run = driver.currentCtx.run + val unit = run.units.head + val compilatonUnitContext = ctx.fresh.setCompilationUnit(unit) + val offset = params match + case op: OffsetParams => op.offset() + case _ => 0 + val offsetParams = + params match + case op: OffsetParams => op + case _ => + CompilerOffsetParams(params.uri(), params.text(), 0, params.token()) + val pos = driver.sourcePosition(offsetParams) + + // First identify the symbol we are at, comments identify @@ as current cursor position + def symbolAlternatives(sym: Symbol)(using Context) = + def member(parent: Symbol) = parent.info.member(sym.name).symbol + def primaryConstructorTypeParam(owner: Symbol) = + for + typeParams <- owner.primaryConstructor.paramSymss.headOption + param <- typeParams.find(_.name == sym.name) + if (param.isType) + yield param + def additionalForEnumTypeParam(enumClass: Symbol) = + if enumClass.is(Flags.Enum) then + val enumOwner = + if enumClass.is(Flags.Case) + then + // we check that the type parameter is the one from enum class + // and not an enum case type parameter with the same name + Option.when(member(enumClass).is(Flags.Synthetic))( + enumClass.maybeOwner.companionClass + ) + else Some(enumClass) + enumOwner.toSet.flatMap { enumOwner => + val symsInEnumCases = enumOwner.children.toSet.flatMap(enumCase => + if member(enumCase).is(Flags.Synthetic) + then primaryConstructorTypeParam(enumCase) + else None + ) + val symsInEnumOwner = + primaryConstructorTypeParam(enumOwner).toSet + member(enumOwner) + symsInEnumCases ++ symsInEnumOwner + } + else Set.empty + val all = + if sym.is(Flags.ModuleClass) then + Set(sym, sym.companionModule, sym.companionModule.companion) + else if sym.isClass then + Set(sym, sym.companionModule, sym.companion.moduleClass) + else if sym.is(Flags.Module) then + Set(sym, sym.companionClass, sym.moduleClass) + else if sym.isTerm && (sym.owner.isClass || sym.owner.isConstructor) + then + val info = + if sym.owner.isClass then sym.owner.info else sym.owner.owner.info + Set( + sym, + info.member(sym.asTerm.name.setterName).symbol, + info.member(sym.asTerm.name.getterName).symbol, + ) ++ sym.allOverriddenSymbols.toSet + // type used in primary constructor will not match the one used in the class + else if sym.isTypeParam && sym.owner.isPrimaryConstructor then + Set(sym, member(sym.maybeOwner.maybeOwner)) + ++ additionalForEnumTypeParam(sym.maybeOwner.maybeOwner) + else if sym.isTypeParam then + primaryConstructorTypeParam(sym.maybeOwner).toSet + ++ additionalForEnumTypeParam(sym.maybeOwner) + sym + else Set(sym) + all.filter(s => s != NoSymbol && !s.isError) + end symbolAlternatives + +end WithCompilationUnit diff --git a/presentation-compiler/test/dotty/tools/pc/utils/DefSymbolCollector.scala b/presentation-compiler/test/dotty/tools/pc/utils/DefSymbolCollector.scala index 0171d2a0d76d..a37801b3c48c 100644 --- a/presentation-compiler/test/dotty/tools/pc/utils/DefSymbolCollector.scala +++ b/presentation-compiler/test/dotty/tools/pc/utils/DefSymbolCollector.scala @@ -7,13 +7,13 @@ import dotty.tools.dotc.ast.{Trees, tpd} import dotty.tools.dotc.core.Symbols.* import dotty.tools.dotc.interactive.InteractiveDriver import dotty.tools.dotc.util.SourcePosition -import dotty.tools.pc.PcCollector +import dotty.tools.pc.SimpleCollector import dotty.tools.pc.EndMarker final class DefSymbolCollector( driver: InteractiveDriver, params: VirtualFileParams -) extends PcCollector[Option[Symbol]](driver, params): +) extends SimpleCollector[Option[Symbol]](driver, params): def collect(parent: Option[Tree])( tree: Tree | EndMarker, diff --git a/project/Build.scala b/project/Build.scala index c1a8800421a6..82417df75756 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -1351,12 +1351,13 @@ object Build { BuildInfoPlugin.buildInfoDefaultSettings def presentationCompilerSettings(implicit mode: Mode) = { - val mtagsVersion = "1.3.1" + val mtagsVersion = "1.3.2" Seq( libraryDependencies ++= Seq( "org.lz4" % "lz4-java" % "1.8.0", "io.get-coursier" % "interface" % "1.0.18", "org.scalameta" % "mtags-interfaces" % mtagsVersion, + "com.google.guava" % "guava" % "33.2.1-jre" ), libraryDependencies += ("org.scalameta" % "mtags-shared_2.13.12" % mtagsVersion % SourceDeps), ivyConfigurations += SourceDeps.hide,