diff --git a/metals/src/main/scala/scala/meta/internal/metals/Indexer.scala b/metals/src/main/scala/scala/meta/internal/metals/Indexer.scala index 7b735d91d93..48403984ee8 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/Indexer.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/Indexer.scala @@ -554,7 +554,7 @@ final case class Indexer( input, dialect, includeMembers = true, - includeIdentifiers = true, + collectIdentifiers = true, ) { case SemanticdbDefinition(info, occ, owner) => if (info.isExtension) { occ.range.foreach { range => diff --git a/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala b/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala index 87c2e481900..144320c180a 100644 --- a/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala +++ b/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala @@ -240,7 +240,7 @@ final class RenameProvider( List(implReferences, currentReferences, companionRefs) ) .map( - _.reduce(_ ++ _) ++ definitionLocation + _.flatten ++ definitionLocation ) Future .sequence(allReferences) diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala index f52409795f0..935f5a54885 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala @@ -55,7 +55,7 @@ class PcReferencesProvider( ): Option[(String, lsp4j.Range)] = val (pos, _) = toAdjust.adjust(text) tree match - case (_: DefTree) if !includeDefinition => None + case _: DefTree if !includeDefinition => None case t: Tree => val sym = symbol.getOrElse(t.symbol) Some(SemanticdbSymbols.symbolName(sym), pos.toLsp) diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcRenameProvider.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcRenameProvider.scala index 2e488449a19..d9a5c5bfa5e 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/PcRenameProvider.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcRenameProvider.scala @@ -48,9 +48,7 @@ final class PcRenameProvider( def rename(): List[l.TextEdit] = val (symbols, _) = soughtSymbols.getOrElse(Set.empty, pos) if symbols.nonEmpty && symbols.forall(canRenameSymbol(_)) - then - val res = result() - res + then result() else Nil end rename end PcRenameProvider diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcSymbolSearch.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcSymbolSearch.scala index 34632a1a45c..bccc65e3dc6 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/PcSymbolSearch.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcSymbolSearch.scala @@ -47,9 +47,8 @@ trait PcSymbolSearch: Interactive.pathTo(sel, pos.span) ::: rawPath case _ => rawPath - lazy val soughtSymbols: Option[(Set[Symbol], SourcePosition)] = soughtSymbols( - path - ) + lazy val soughtSymbols: Option[(Set[Symbol], SourcePosition)] = + soughtSymbols(path) def soughtSymbols(path: List[Tree]): Option[(Set[Symbol], SourcePosition)] = val sought = path match @@ -174,64 +173,6 @@ trait PcSymbolSearch: end soughtSymbols - // First identify the symbol we are at, comments identify @@ as current cursor position - def symbolAlternatives(sym: Symbol) = - def member(parent: Symbol) = parent.info.member(sym.name).symbol - def primaryConstructorTypeParam(owner: Symbol) = - for - typeParams <- owner.primaryConstructor.paramSymss.headOption - param <- typeParams.find(_.name == sym.name) - if (param.isType) - yield param - def additionalForEnumTypeParam(enumClass: Symbol) = - if enumClass.is(Flags.Enum) then - val enumOwner = - if enumClass.is(Flags.Case) - then - // we check that the type parameter is the one from enum class - // and not an enum case type parameter with the same name - Option.when(member(enumClass).is(Flags.Synthetic))( - enumClass.maybeOwner.companionClass - ) - else Some(enumClass) - enumOwner.toSet.flatMap { enumOwner => - val symsInEnumCases = enumOwner.children.toSet.flatMap(enumCase => - if member(enumCase).is(Flags.Synthetic) - then primaryConstructorTypeParam(enumCase) - else None - ) - val symsInEnumOwner = - primaryConstructorTypeParam(enumOwner).toSet + member(enumOwner) - symsInEnumCases ++ symsInEnumOwner - } - else Set.empty - val all = - if sym.is(Flags.ModuleClass) then - Set(sym, sym.companionModule, sym.companionModule.companion) - else if sym.isClass then - Set(sym, sym.companionModule, sym.companion.moduleClass) - else if sym.is(Flags.Module) then - Set(sym, sym.companionClass, sym.moduleClass) - else if sym.isTerm && (sym.owner.isClass || sym.owner.isConstructor) - then - val info = - if sym.owner.isClass then sym.owner.info else sym.owner.owner.info - Set( - sym, - info.member(sym.asTerm.name.setterName).symbol, - info.member(sym.asTerm.name.getterName).symbol, - ) ++ sym.allOverriddenSymbols.toSet - // type used in primary constructor will not match the one used in the class - else if sym.isTypeParam && sym.owner.isPrimaryConstructor then - Set(sym, member(sym.maybeOwner.maybeOwner)) - ++ additionalForEnumTypeParam(sym.maybeOwner.maybeOwner) - else if sym.isTypeParam then - primaryConstructorTypeParam(sym.maybeOwner).toSet - ++ additionalForEnumTypeParam(sym.maybeOwner) + sym - else Set(sym) - all.filter(s => s != NoSymbol && !s.isError) - end symbolAlternatives - private def seekInExtensionParameters() = def collectParams( extMethods: ExtMethods diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/WithCompilationUnit.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/WithCompilationUnit.scala index 63e25c5450e..72a4de82ed8 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/WithCompilationUnit.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/WithCompilationUnit.scala @@ -9,6 +9,9 @@ import scala.meta.pc.OffsetParams import scala.meta.pc.VirtualFileParams import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.NameOps.* +import dotty.tools.dotc.core.Symbols.* import dotty.tools.dotc.interactive.InteractiveDriver import dotty.tools.dotc.util.SourceFile @@ -36,4 +39,63 @@ class WithCompilationUnit( case _ => CompilerOffsetParams(params.uri(), params.text(), 0, params.token()) val pos = driver.sourcePosition(offsetParams) + + // First identify the symbol we are at, comments identify @@ as current cursor position + def symbolAlternatives(sym: Symbol)(using Context) = + def member(parent: Symbol) = parent.info.member(sym.name).symbol + def primaryConstructorTypeParam(owner: Symbol) = + for + typeParams <- owner.primaryConstructor.paramSymss.headOption + param <- typeParams.find(_.name == sym.name) + if (param.isType) + yield param + def additionalForEnumTypeParam(enumClass: Symbol) = + if enumClass.is(Flags.Enum) then + val enumOwner = + if enumClass.is(Flags.Case) + then + // we check that the type parameter is the one from enum class + // and not an enum case type parameter with the same name + Option.when(member(enumClass).is(Flags.Synthetic))( + enumClass.maybeOwner.companionClass + ) + else Some(enumClass) + enumOwner.toSet.flatMap { enumOwner => + val symsInEnumCases = enumOwner.children.toSet.flatMap(enumCase => + if member(enumCase).is(Flags.Synthetic) + then primaryConstructorTypeParam(enumCase) + else None + ) + val symsInEnumOwner = + primaryConstructorTypeParam(enumOwner).toSet + member(enumOwner) + symsInEnumCases ++ symsInEnumOwner + } + else Set.empty + val all = + if sym.is(Flags.ModuleClass) then + Set(sym, sym.companionModule, sym.companionModule.companion) + else if sym.isClass then + Set(sym, sym.companionModule, sym.companion.moduleClass) + else if sym.is(Flags.Module) then + Set(sym, sym.companionClass, sym.moduleClass) + else if sym.isTerm && (sym.owner.isClass || sym.owner.isConstructor) + then + val info = + if sym.owner.isClass then sym.owner.info else sym.owner.owner.info + Set( + sym, + info.member(sym.asTerm.name.setterName).symbol, + info.member(sym.asTerm.name.getterName).symbol, + ) ++ sym.allOverriddenSymbols.toSet + // type used in primary constructor will not match the one used in the class + else if sym.isTypeParam && sym.owner.isPrimaryConstructor then + Set(sym, member(sym.maybeOwner.maybeOwner)) + ++ additionalForEnumTypeParam(sym.maybeOwner.maybeOwner) + else if sym.isTypeParam then + primaryConstructorTypeParam(sym.maybeOwner).toSet + ++ additionalForEnumTypeParam(sym.maybeOwner) + sym + else Set(sym) + all.filter(s => s != NoSymbol && !s.isError) + end symbolAlternatives + end WithCompilationUnit diff --git a/mtags/src/main/scala/scala/meta/internal/metals/SemanticdbDefinition.scala b/mtags/src/main/scala/scala/meta/internal/metals/SemanticdbDefinition.scala index cfa3b8cd711..640509fb559 100644 --- a/mtags/src/main/scala/scala/meta/internal/metals/SemanticdbDefinition.scala +++ b/mtags/src/main/scala/scala/meta/internal/metals/SemanticdbDefinition.scala @@ -50,14 +50,14 @@ object SemanticdbDefinition { input, dialect, includeMembers, - includeIdentifiers = false + collectIdentifiers = false )(fn) def foreachWithReturnMtags( input: Input.VirtualFile, dialect: Dialect, includeMembers: Boolean, - includeIdentifiers: Boolean + collectIdentifiers: Boolean )( fn: SemanticdbDefinition => Unit )(implicit rc: ReportContext): Option[MtagsIndexer] = { @@ -68,7 +68,7 @@ object SemanticdbDefinition { includeInnerClasses = true, includeMembers = includeMembers, dialect, - includeIdentifiers = includeIdentifiers + collectIdentifiers = collectIdentifiers ) { override def visitOccurrence( occ: SymbolOccurrence, diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala b/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala index e0d00044ccb..e12dec18141 100644 --- a/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala +++ b/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala @@ -45,7 +45,7 @@ class ScalaToplevelMtags( includeInnerClasses: Boolean, includeMembers: Boolean, dialect: Dialect, - includeIdentifiers: Boolean = false + collectIdentifiers: Boolean = false )(implicit rc: ReportContext) extends MtagsIndexer { @@ -57,7 +57,7 @@ class ScalaToplevelMtags( implicit class XtensionScanner(scanner: LegacyScanner) { def mtagsNextToken(): Any = { scanner.nextToken() - if (includeIdentifiers) + if (collectIdentifiers) scanner.curr.token match { case IDENTIFIER => identifiers += scanner.curr.name case _ => @@ -508,7 +508,7 @@ class ScalaToplevelMtags( else nextExpectTemplate ) case IMPLICIT => - scanner.nextToken() + scanner.mtagsNextToken() loop( indent, isAfterNewline,