diff --git a/metals-bench/src/main/scala/bench/MetalsBench.scala b/metals-bench/src/main/scala/bench/MetalsBench.scala index b32c5ef6d7d..5715745a5de 100644 --- a/metals-bench/src/main/scala/bench/MetalsBench.scala +++ b/metals-bench/src/main/scala/bench/MetalsBench.scala @@ -99,6 +99,20 @@ class MetalsBench { } } + @Benchmark + @BenchmarkMode(Array(Mode.SingleShotTime)) + def typeHierarchyIndex(): Unit = { + scalaDependencySources.inputs.foreach { input => + implicit val rc: ReportContext = EmptyReportContext + new ScalaToplevelMtags( + input, + includeInnerClasses = true, + includeMembers = false, + dialects.Scala213, + ).index() + } + } + @Benchmark @BenchmarkMode(Array(Mode.SingleShotTime)) def scalaTokenize(): Unit = { diff --git a/metals/src/main/resources/db/migration/V5__Jar_type_hierarchy.sql b/metals/src/main/resources/db/migration/V5__Jar_type_hierarchy.sql new file mode 100644 index 00000000000..080bd9dcff5 --- /dev/null +++ b/metals/src/main/resources/db/migration/V5__Jar_type_hierarchy.sql @@ -0,0 +1,14 @@ +-- Type hierarchy information, e.g. symbol: "a/MyException#", extended_name: "Exception" +create table type_hierarchy( + symbol varchar not null, + parent_name varchar not null, + path varchar not null, + jar int, + is_resolved bit, + foreign key (jar) references indexed_jar (id) on delete cascade +); + +create index type_hierarchy_jar on type_hierarchy(jar); + +alter table indexed_jar +add type_hierarchy_indexed bit diff --git a/metals/src/main/scala/scala/meta/internal/implementation/GlobalClassTable.scala b/metals/src/main/scala/scala/meta/internal/implementation/GlobalClassTable.scala deleted file mode 100644 index b50c2c074d0..00000000000 --- a/metals/src/main/scala/scala/meta/internal/implementation/GlobalClassTable.scala +++ /dev/null @@ -1,99 +0,0 @@ -package scala.meta.internal.implementation -import java.nio.file.Path - -import scala.collection.concurrent.TrieMap -import scala.collection.mutable - -import scala.meta.internal.metals.BuildTargets -import scala.meta.internal.semanticdb.SymbolInformation -import scala.meta.internal.symtab.GlobalSymbolTable -import scala.meta.io.AbsolutePath -import scala.meta.io.Classpath - -import ch.epfl.scala.bsp4j.BuildTargetIdentifier - -final class GlobalClassTable( - buildTargets: BuildTargets -) { - - import ImplementationProvider._ - - type ImplementationCache = Map[Path, Map[String, Set[ClassLocation]]] - - private val buildTargetsIndexes = - TrieMap.empty[BuildTargetIdentifier, GlobalSymbolTable] - - def globalContextFor( - source: AbsolutePath, - implementationsInPath: ImplementationCache, - ): Option[InheritanceContext] = { - for { - symtab <- globalSymbolTableFor(source) - } yield { - calculateIndex(symtab, implementationsInPath) - } - } - - def globalSymbolTableFor( - source: AbsolutePath - ): Option[GlobalSymbolTable] = - synchronized { - for { - buildTargetId <- buildTargets.inverseSources(source) - jarClasspath <- buildTargets.targetJarClasspath(buildTargetId) - classpath = new Classpath(jarClasspath) - } yield { - buildTargetsIndexes.getOrElseUpdate( - buildTargetId, - GlobalSymbolTable(classpath, includeJdk = true), - ) - } - } - - private def calculateIndex( - symTab: GlobalSymbolTable, - implementationsInPath: ImplementationCache, - ): InheritanceContext = { - val context = InheritanceContext.fromDefinitions( - symTab.safeInfo, - implementationsInPath.toMap, - ) - val symbolsInformation = for { - classSymbol <- context.allClassSymbols - classInfo <- symTab.safeInfo(classSymbol) - } yield classInfo - - calculateInheritance(symbolsInformation, context, symTab) - - } - - private def calculateInheritance( - classpathClassInfos: Set[SymbolInformation], - context: InheritanceContext, - symTab: GlobalSymbolTable, - ): InheritanceContext = { - val results = new mutable.ListBuffer[(String, ClassLocation)] - val calculated = mutable.Set.empty[String] - var infos = classpathClassInfos - - while (infos.nonEmpty) { - calculated ++= infos.map(_.symbol) - - val allParents = infos.flatMap { info => - ImplementationProvider.parentsFromSignature( - info.symbol, - info.signature, - None, - ) - } - results ++= allParents - infos = (allParents.map(_._1) -- calculated).flatMap(symTab.safeInfo) - } - - val inheritance = results.groupBy(_._1).map { case (symbol, locations) => - symbol -> locations.map(_._2).toSet - } - context.withClasspathContext(inheritance) - } - -} diff --git a/metals/src/main/scala/scala/meta/internal/implementation/ImplementationProvider.scala b/metals/src/main/scala/scala/meta/internal/implementation/ImplementationProvider.scala index 66940c8604a..5c4a9d40648 100644 --- a/metals/src/main/scala/scala/meta/internal/implementation/ImplementationProvider.scala +++ b/metals/src/main/scala/scala/meta/internal/implementation/ImplementationProvider.scala @@ -1,5 +1,6 @@ package scala.meta.internal.implementation +import java.nio.charset.StandardCharsets import java.nio.file.Path import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentLinkedQueue @@ -7,33 +8,43 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.mutable import scala.concurrent.ExecutionContext import scala.concurrent.Future -import scala.util.control.NonFatal +import scala.meta.internal.io.FileIO import scala.meta.internal.metals.Buffers import scala.meta.internal.metals.BuildTargets +import scala.meta.internal.metals.Compilers import scala.meta.internal.metals.DefinitionProvider import scala.meta.internal.metals.MetalsEnrichments._ import scala.meta.internal.metals.ReportContext import scala.meta.internal.metals.ScalaVersionSelector +import scala.meta.internal.metals.ScalaVersions import scala.meta.internal.metals.SemanticdbFeatureProvider import scala.meta.internal.mtags.GlobalSymbolIndex +import scala.meta.internal.mtags.IndexingResult import scala.meta.internal.mtags.Mtags +import scala.meta.internal.mtags.OverriddenSymbol +import scala.meta.internal.mtags.ResolvedOverriddenSymbol import scala.meta.internal.mtags.Semanticdbs import scala.meta.internal.mtags.SymbolDefinition +import scala.meta.internal.mtags.UnresolvedOverriddenSymbol import scala.meta.internal.mtags.{Symbol => MSymbol} import scala.meta.internal.parsing.Trees +import scala.meta.internal.pc.PcSymbolInformation +import scala.meta.internal.search.SymbolHierarchyOps._ import scala.meta.internal.semanticdb.ClassSignature +import scala.meta.internal.semanticdb.Scala.Descriptor.Method import scala.meta.internal.semanticdb.Scala._ import scala.meta.internal.semanticdb.Signature import scala.meta.internal.semanticdb.SymbolInformation -import scala.meta.internal.semanticdb.SymbolOccurrence import scala.meta.internal.semanticdb.TextDocument import scala.meta.internal.semanticdb.TextDocuments import scala.meta.internal.semanticdb.TypeRef import scala.meta.internal.semanticdb.TypeSignature -import scala.meta.internal.symtab.GlobalSymbolTable import scala.meta.io.AbsolutePath +import scala.meta.pc.PcSymbolKind +import scala.meta.pc.PcSymbolProperty +import ch.epfl.scala.bsp4j.BuildTargetIdentifier import org.eclipse.lsp4j.Location import org.eclipse.lsp4j.TextDocumentPositionParams @@ -41,18 +52,19 @@ final class ImplementationProvider( semanticdbs: Semanticdbs, workspace: AbsolutePath, index: GlobalSymbolIndex, - buildTargets: BuildTargets, buffer: Buffers, definitionProvider: DefinitionProvider, trees: Trees, scalaVersionSelector: ScalaVersionSelector, + compilers: Compilers, + buildTargets: BuildTargets, )(implicit ec: ExecutionContext, rc: ReportContext) extends SemanticdbFeatureProvider { import ImplementationProvider._ - - private val globalTable = new GlobalClassTable(buildTargets) private val implementationsInPath = new ConcurrentHashMap[Path, Map[String, Set[ClassLocation]]] + private val implementationsInDependencySources = + new ConcurrentHashMap[String, Set[ClassLocation]] override def reset(): Unit = { implementationsInPath.clear() @@ -70,6 +82,44 @@ final class ImplementationProvider( ) } + def addTypeHierarchy(results: List[IndexingResult]): Unit = for { + IndexingResult(path, _, overrides) <- results + (overridesSymbol, overriddenSymbols) <- overrides + overridden <- overriddenSymbols + } addTypeHierarchyElement(path, overridesSymbol, overridden) + + def addTypeHierarchyElements( + elements: List[(AbsolutePath, String, OverriddenSymbol)] + ): Unit = elements.foreach { case (path, overridesSymbol, overridden) => + addTypeHierarchyElement(path, overridesSymbol, overridden) + } + + private def addTypeHierarchyElement( + path: AbsolutePath, + overridesSymbol: String, + overridden: OverriddenSymbol, + ): Unit = { + def createUpdate( + newSymbol: ClassLocation + ): (String, Set[ClassLocation]) => Set[ClassLocation] = { + case (_, null) => Set(newSymbol) + case (_, previous) => previous + newSymbol + } + overridden match { + case ResolvedOverriddenSymbol(symbol) => + val update = createUpdate( + ClassLocation(overridesSymbol, Some(path.toNIO)) + ) + implementationsInDependencySources.compute(symbol, update(_, _)) + case UnresolvedOverriddenSymbol(name) => + val update = + createUpdate( + ClassLocation(overridesSymbol, Some(path.toNIO)) + ) + implementationsInDependencySources.compute(name, update(_, _)) + } + } + private def computeInheritance( documents: TextDocuments ): Map[String, Set[ClassLocation]] = { @@ -92,20 +142,6 @@ final class ImplementationProvider( } } - def defaultSymbolSearch( - anyWorkspacePath: AbsolutePath, - textDocument: TextDocument, - ): String => Option[SymbolInformation] = { - lazy val global = - globalTable.globalSymbolTableFor(anyWorkspacePath) - symbol => { - textDocument.symbols - .find(_.symbol == symbol) - .orElse(findSymbolInformation(symbol)) - .orElse(global.flatMap(_.safeInfo(symbol))) - } - } - def implementations( params: TextDocumentPositionParams ): Future[List[Location]] = { @@ -122,165 +158,114 @@ final class ImplementationProvider( // 1. Search locally for symbol // 2. Search inside workspace // 3. Search classpath via GlobalSymbolTable - val symbolSearch = defaultSymbolSearch(source, currentDocument) val sym = symbolOccurrence.symbol val dealiased = - if (sym.desc.isType) - dealiasClass(sym, symbolSearch) - else sym - - val definitionDocument = - if (currentDocument.definesSymbol(dealiased)) { - Some(currentDocument) - } else { - findSemanticDbForSymbol(dealiased) - } - - val inheritanceContext = definitionDocument match { - // symbol is not in workspace, we only search classpath for it - case None => - globalTable.globalContextFor( - source, - implementationsInPath.asScala.toMap, + if (sym.desc.isType) { + symbolInfo(currentDocument, source, sym).map( + _.map(_.dealiasedSymbol).getOrElse(sym) ) - // symbol is in workspace, - // we might need to search different places for related symbols - case Some(_) => - Some( - InheritanceContext.fromDefinitions( - symbolSearch, - implementationsInPath.asScala.toMap, + } else Future.successful(sym) + dealiased.flatMap { dealisedSymbol => + val isWorkspaceSymbol = + (source.isWorkspaceSource(workspace) && + currentDocument.definesSymbol(dealisedSymbol)) || + findSymbolDefinition(dealisedSymbol).exists( + _.path.isWorkspaceSource(workspace) ) + + val workspaceInheritanceContext: InheritanceContext = + InheritanceContext.fromDefinitions( + implementationsInPath.asScala.toMap ) - } - symbolLocationsFromContext( - dealiased, - source, - inheritanceContext, - ) - } - Future.sequence(locations).map { - _.flatten.toList - } - } - def topMethodParents( - symbol: String, - textDocument: TextDocument, - ): Seq[Location] = { - - def findClassInfo(owner: String) = { - if (owner.nonEmpty) { - findSymbol(textDocument, owner) - } else { - textDocument.symbols.find { sym => - sym.signature match { - case sig: ClassSignature => - sig.declarations.exists(_.symlinks.contains(symbol)) - case _ => false - } - } - } - } + val inheritanceContext: InheritanceContext = + if (isWorkspaceSymbol) workspaceInheritanceContext + else + // symbol is not defined in the workspace, we search both workspace and dependencies for it + workspaceInheritanceContext + .toGlobal( + compilers, + implementationsInDependencySources.asScala.toMap, + source, + ) - val results = for { - currentInfo <- findSymbol(textDocument, symbol) - if !isClassLike(currentInfo) - classInfo <- findClassInfo(symbol.owner) - } yield { - classInfo.signature match { - case sig: ClassSignature => - methodInParentSignature(sig, currentInfo, sig) - case _ => Nil + symbolLocationsFromContext( + dealisedSymbol, + currentDocument, + source, + inheritanceContext, + ) } } - results.getOrElse(Seq.empty) - } - - private def methodInParentSignature( - currentClassSig: ClassSignature, - bottomSymbol: SymbolInformation, - bottomClassSig: ClassSignature, - ): Seq[Location] = { - currentClassSig.parents.flatMap { - case parentSym: TypeRef => - val parentTextDocument = findSemanticDbForSymbol(parentSym.symbol) - def search(symbol: String) = - parentTextDocument.flatMap(findSymbol(_, symbol)) - search(parentSym.symbol).map(_.signature) match { - case Some(parenClassSig: ClassSignature) => - val fromParent = methodInParentSignature( - parenClassSig, - bottomSymbol, - bottomClassSig, - ) - if (fromParent.isEmpty) { - locationFromClass( - bottomSymbol, - parenClassSig, - search, - parentTextDocument, - ) - } else { - fromParent - } - case _ => Nil - } - case _ => Nil + Future.sequence(locations).map { + _.flatten.toList } } - private def locationFromClass( - bottomSymbolInformation: SymbolInformation, - parentClassSig: ClassSignature, - search: String => Option[SymbolInformation], - parentTextDocument: Option[TextDocument], - ): Option[Location] = { - val matchingSymbol = MethodImplementation.findParentSymbol( - bottomSymbolInformation, - parentClassSig, - search, - ) - for { - symbol <- matchingSymbol - parentDoc <- parentTextDocument - source = workspace.resolve(parentDoc.uri) - implOccurrence <- findDefOccurrence( - parentDoc, - symbol, - source, - ) - range <- implOccurrence.range - distance = buffer.tokenEditDistance( - source, - parentDoc.text, - trees, - ) - revised <- distance.toRevised(range.toLsp) - } yield new Location(source.toNIO.toUri().toString(), revised) - } - private def symbolLocationsFromContext( - symbol: String, + dealiased: String, + textDocument: TextDocument, source: AbsolutePath, - inheritanceContext: Option[InheritanceContext], + inheritanceContext: InheritanceContext, ): Future[Seq[Location]] = { def findImplementationSymbol( - parentSymbolInfo: SymbolInformation, - implDocument: TextDocument, + info: PcSymbolInformation, + classSymbol: String, + textDocument: TextDocument, implReal: ClassLocation, - ): Option[String] = { - if (isClassLike(parentSymbolInfo)) - Some(implReal.symbol) + source: AbsolutePath, + ): Option[Future[String]] = { + if (classLikeKinds(info.kind)) Some(Future(implReal.symbol)) else { - val symbolSearch = defaultSymbolSearch(source, implDocument) - MethodImplementation.findInherited( - parentSymbolInfo, - implReal, - symbolSearch, - ) + def tryFromDoc = + for { + classInfo <- findSymbol(textDocument, implReal.symbol) + declarations <- classInfo.signature match { + case ClassSignature(_, _, _, declarations) => declarations + case _ => None + } + found <- declarations.symlinks.collectFirst { sym => + findSymbol(textDocument, sym) match { + case Some(implInfo) + if implInfo.overriddenSymbols.contains(info.symbol) => + sym + } + } + } yield Future.successful(found) + def pcSearch = { + val symbol = { + val inferredSymbol = + s"${implReal.symbol}${info.symbol.stripPrefix(classSymbol)}" + inferredSymbol.desc match { + case Method(value, _) => new Method(value, "()").toString() + case _ => inferredSymbol + } + } + def overridesSym(info: PcSymbolInformation) = + info.overriddenSymbols.contains(info.symbol) + compilers.info(source, symbol).flatMap { + case Some(info) if overridesSym(info) => + Future.successful(info.symbol) + case Some(info) => + // look if one of the alternatives overrides `info.symbol` + Future + .sequence( + info.alternativeSymbols.map(compilers.info(source, _)) + ) + .map { + _.collectFirst { + case Some(info) if overridesSym(info) => info.symbol + }.getOrElse(symbol) + } + case None => Future.successful(symbol) + } + } + tryFromDoc.orElse { + if (implReal.symbol.isLocal) None + else Some(pcSearch) + } } } @@ -289,215 +274,241 @@ final class ImplementationProvider( def findImplementationLocations( files: Set[Path], locationsByFile: Map[Path, Set[ClassLocation]], - parentSymbol: SymbolInformation, - ) = - Future { + parentSymbol: PcSymbolInformation, + classSymbol: String, + buildTarget: BuildTargetIdentifier, + ) = Future.sequence({ + for { + file <- files + locations = locationsByFile(file) + implPath = AbsolutePath(file) + if (buildTargets.belongsToBuildTarget(buildTarget, implPath)) + implDocument <- findSemanticdb(implPath).toList + } yield { for { - file <- files - locations = locationsByFile(file) - implPath = AbsolutePath(file) - implDocument <- findSemanticdb(implPath).toIterable - distance = buffer.tokenEditDistance( - implPath, - implDocument.text, - trees, - ) - implLocation <- locations - implSymbol <- findImplementationSymbol( - parentSymbol, - implDocument, - implLocation, - ) - if !findSymbol(implDocument, implSymbol).exists( - // we should not show types if we are looking for a class implementations - sym => sym.isType && !parentSymbol.isType - ) - implOccurrence <- findDefOccurrence( - implDocument, - implSymbol, - source, + symbols <- Future.sequence( + locations.flatMap( + findImplementationSymbol( + parentSymbol, + classSymbol, + implDocument, + _, + source, + ) + ) ) - range <- implOccurrence.range - revised <- distance.toRevised(range.toLsp) - } { allLocations.add(new Location(file.toUri.toString, revised)) } + } yield { + for { + sym <- symbols + symInfo <- implDocument.symbols.find(_.symbol == sym) + if (!symInfo.isType || parentSymbol.kind == PcSymbolKind.TYPE) + implOccurrence <- findDefOccurrence( + implDocument, + sym, + implPath, + scalaVersionSelector, + ).toList + range <- implOccurrence.range + revised <- + if (implPath.isJarFileSystem) { + Some(range.toLsp) + } else { + val distance = buffer.tokenEditDistance( + implPath, + implDocument.text, + trees, + ) + distance.toRevised(range.toLsp) + } + } { allLocations.add(new Location(file.toUri.toString, revised)) } + } } + }) lazy val cores = Runtime.getRuntime().availableProcessors() - val splitJobs = for { - classContext <- inheritanceContext.toIterable - parentSymbol <- classContext.findSymbol(symbol).toIterable - symbolClass <- classFromSymbol(parentSymbol, classContext.findSymbol) - locationsByFile = findImplementation( - symbolClass.symbol, - classContext, - source.toNIO, - ) - files <- locationsByFile.keySet.grouped( - Math.max(locationsByFile.size / cores, 1) - ) - } yield findImplementationLocations(files, locationsByFile, parentSymbol) - Future.sequence(splitJobs).map { _ => - allLocations.asScala.toSeq - } + val splitJobs = + symbolInfo(textDocument, source, dealiased).flatMap { optSymbolInfo => + (for { + symbolInfo <- optSymbolInfo + symbolClass <- classFromSymbol(symbolInfo) + target <- buildTargets.inverseSources(source) + } yield { + for { + locationsByFile <- findImplementation( + symbolClass, + inheritanceContext, + source.toNIO, + ) + files = locationsByFile.keySet.grouped( + Math.max(locationsByFile.size / cores, 1) + ) + collected <- + Future.sequence( + files.map( + findImplementationLocations( + _, + locationsByFile, + symbolInfo, + symbolClass, + target, + ) + ) + ) + } yield collected + }).getOrElse(Future.successful(Iterator.empty)) + } + splitJobs.map(_ => allLocations.asScala.toSeq) } private def findSemanticdb(fileSource: AbsolutePath): Option[TextDocument] = { if (fileSource.isJarFileSystem) - None + Some(semanticdbForJarFile(fileSource)) else semanticdbs .textDocument(fileSource) .documentIncludingStale } + private def semanticdbForJarFile(fileSource: AbsolutePath) = { + val dialect = ScalaVersions.dialectForDependencyJar(fileSource.filename) + FileIO.slurp(fileSource, StandardCharsets.UTF_8) + val textDocument = Mtags.index(fileSource, dialect) + textDocument + } + private def findImplementation( symbol: String, classContext: InheritanceContext, file: Path, - ): Map[Path, Set[ClassLocation]] = { - - def loop(symbol: String, currentPath: Option[Path]): Set[ClassLocation] = { + ): Future[Map[Path, Set[ClassLocation]]] = { + val visited = mutable.Set.empty[String] + + def loop( + symbol: String, + currentPath: Option[Path], + ): Future[Set[ClassLocation]] = { + visited.add(symbol) + scribe.debug(s"searching for implementations for symbol $symbol") val directImplementations = - classContext.getLocations(symbol).filterNot { loc => - // we are not interested in local symbols from outside the workspace - (loc.symbol.isLocal && loc.file.isEmpty) || - // for local symbols, inheritance should only be picked up in the same file - // otherwise there can be a name collision between files - // local1' is from file A, local2 extends local1'' - // but both local2 and local1'' are from file B - // clearly, local2 shouldn't be considered for local1' - (symbol.isLocal && loc.symbol.isLocal && loc.file != currentPath) - } - directImplementations ++ directImplementations - .flatMap { loc => loop(loc.symbol, loc.file) } + classContext + .getLocations(symbol) + .map(_.filterNot { loc => + // we are not interested in local symbols from outside the workspace + (loc.symbol.isLocal && loc.file.isEmpty) || + // for local symbols, inheritance should only be picked up in the same file + // otherwise there can be a name collision between files + // local1' is from file A, local2 extends local1'' + // but both local2 and local1'' are from file B + // clearly, local2 shouldn't be considered for local1' + (symbol.isLocal && loc.symbol.isLocal && loc.file != currentPath) + }) + directImplementations.flatMap { directImplementations => + Future + .sequence( + directImplementations + .withFilter(loc => + (!visited( + loc.symbol + ) && loc.symbol.desc.isType) || loc.symbol.isLocal + ) + .map { loc => + loop(loc.symbol, loc.file) + } + ) + .map(rec => directImplementations ++ rec.flatten) + } } - loop(symbol, Some(file)).groupBy(_.file).collect { + loop(symbol, Some(file)).map(_.groupBy(_.file).collect { case (Some(path), locs) => path -> locs - } - } - - private def findSymbolInformation( - symbol: String - ): Option[SymbolInformation] = { - findSemanticDbForSymbol(symbol).flatMap(findSymbol(_, symbol)) - } - - def findSemanticDbWithPathForSymbol( - symbol: String - ): Option[TextDocumentWithPath] = { - for { - symbolDefinition <- findSymbolDefinition(symbol) - document <- findSemanticdb(symbolDefinition.path) - } yield TextDocumentWithPath(document, symbolDefinition.path) + }) } private def findSymbolDefinition(symbol: String): Option[SymbolDefinition] = { index.definition(MSymbol(symbol)) } - private def findSemanticDbForSymbol(symbol: String): Option[TextDocument] = { - for { - symbolDefinition <- findSymbolDefinition(symbol) - document <- findSemanticdb(symbolDefinition.path) - } yield { - document - } - } + private def classFromSymbol(info: PcSymbolInformation): Option[String] = + if (classLikeKinds(info.kind)) Some(info.dealiasedSymbol) + else info.classOwner - private def classFromSymbol( - info: SymbolInformation, - findSymbol: String => Option[SymbolInformation], - ): Iterable[SymbolInformation] = { - val classInfo = if (isClassLike(info)) { - Some(info) - } else { - findSymbol(info.symbol.owner) - .filter(info => isClassLike(info)) - } - classInfo.map(inf => dealiasClass(inf, findSymbol)) - } - - private def findDefOccurrence( - semanticDb: TextDocument, - symbol: String, + private def symbolInfo( + textDocument: TextDocument, source: AbsolutePath, - ): Option[SymbolOccurrence] = { - def isDefinitionOccurrence(occ: SymbolOccurrence) = - occ.role.isDefinition && occ.symbol == symbol - - semanticDb.occurrences - .find(isDefinitionOccurrence) - .orElse( - Mtags - .allToplevels(source.toInput, scalaVersionSelector.getDialect(source)) - .occurrences - .find(isDefinitionOccurrence) - ) - } -} - -object ImplementationProvider { - - implicit class XtensionGlobalSymbolTable(symtab: GlobalSymbolTable) { - def safeInfo(symbol: String): Option[SymbolInformation] = - try { - symtab.info(symbol) - } catch { - case NonFatal(_) => None - } - } - - def dealiasClass( symbol: String, - findSymbol: String => Option[SymbolInformation], - ): String = { - if (symbol.desc.isType) { - findSymbol(symbol) - .map { inf => - val isAbstractType = inf.isAbstract && inf.isType - // abstract type will always have Any as upper bound - if (isAbstractType) symbol - else dealiasClass(inf, findSymbol).symbol - + ): Future[Option[PcSymbolInformation]] = + if (symbol.isLocal) { + (for { + info <- findSymbol(textDocument, symbol) + } yield { + info.signature match { + case typeSig: TypeSignature => + typeSig.upperBound match { + case tr: TypeRef => + symbolInfo(textDocument, source, tr.symbol).map( + _.map(_.copy(symbol = symbol)) + ) + case _ => Future.successful(None) + } + case _ => Future.successful(Some(toPcSymbolInfo(textDocument, info))) } - .getOrElse(symbol) - } else { - symbol - } - } + }).getOrElse(Future.successful(None)) + } else compilers.info(source, symbol) - def dealiasClass( + private def toPcSymbolInfo( + textDocument: TextDocument, info: SymbolInformation, - findSymbol: String => Option[SymbolInformation], - ): SymbolInformation = { - if (info.isType) { + ): PcSymbolInformation = { + val parents = info.signature match { - case ts: TypeSignature => - ts.upperBound match { - case tr: TypeRef => - findSymbol(tr.symbol) - .map(dealiasClass(_, findSymbol)) - .getOrElse(info) - case _ => - info - } - case _ => info + case ClassSignature(_, parents, _, _) => + parents.collect { case t: TypeRef => t.symbol }.toList + case _ => Nil } - } else { - info + + val classOwnerInfoOpt = + textDocument.symbols.collectFirst { classInfo => + classInfo.signature match { + case ClassSignature(_, _, _, declarations) + if declarations.exists(_.symlinks.contains(info.symbol)) => + classInfo + } + } + + def getMethodPrefix(symbol: String) = symbol.desc match { + case Method(searchedSym, _) => Some(searchedSym) + case _ => None } - } - private def findSymbol( - semanticDb: TextDocument, - symbol: String, - ): Option[SymbolInformation] = { - semanticDb.symbols - .find(sym => sym.symbol == symbol) + val alternativeSymbols = + for { + classOwnerInfo <- classOwnerInfoOpt.toList + searchedSym <- getMethodPrefix(info.symbol).toList + decl <- classOwnerInfo.signature match { + case ClassSignature(_, _, _, declarations) => declarations + case _ => Nil + } + sym <- decl.symlinks + if (sym != searchedSym && getMethodPrefix(sym).contains(searchedSym)) + } yield sym + + PcSymbolInformation( + symbol = info.symbol, + kind = PcSymbolKind.values + .find(_.getValue == info.kind.value) + .getOrElse(PcSymbolKind.UNKNOWN_KIND), + parents = parents, + dealiasedSymbol = info.symbol, + classOwner = classOwnerInfoOpt.map(_.symbol), + alternativeSymbols = alternativeSymbols.toList, + overriddenSymbols = info.overriddenSymbols.toList, + properties = if (info.isAbstract) List(PcSymbolProperty.ABSTRACT) else Nil, + ) } +} +object ImplementationProvider { def parentsFromSignature( symbol: String, signature: Signature, @@ -535,7 +546,11 @@ object ImplementationProvider { } } - def isClassLike(info: SymbolInformation): Boolean = - info.isObject || info.isClass || info.isTrait || info.isInterface + val classLikeKinds: Set[PcSymbolKind] = Set( + PcSymbolKind.OBJECT, + PcSymbolKind.CLASS, + PcSymbolKind.TRAIT, + PcSymbolKind.INTERFACE, + ) } diff --git a/metals/src/main/scala/scala/meta/internal/implementation/InheritanceContext.scala b/metals/src/main/scala/scala/meta/internal/implementation/InheritanceContext.scala index 0b112fe7a62..5b8b14d088e 100644 --- a/metals/src/main/scala/scala/meta/internal/implementation/InheritanceContext.scala +++ b/metals/src/main/scala/scala/meta/internal/implementation/InheritanceContext.scala @@ -3,40 +3,75 @@ package scala.meta.internal.implementation import java.nio.file.Path import scala.collection.mutable +import scala.concurrent.ExecutionContext +import scala.concurrent.Future -import scala.meta.internal.semanticdb.SymbolInformation +import scala.meta.internal.metals.Compilers +import scala.meta.internal.semanticdb.Scala._ +import scala.meta.io.AbsolutePath -case class InheritanceContext( - findSymbol: String => Option[SymbolInformation], - private val inheritance: Map[String, Set[ClassLocation]], -) { +class InheritanceContext(inheritance: Map[String, Set[ClassLocation]]) { def allClassSymbols = inheritance.keySet - def getLocations(symbol: String): Set[ClassLocation] = { + def getLocations(symbol: String)(implicit + ec: ExecutionContext + ): Future[Set[ClassLocation]] = + Future.successful(getWorkspaceLocations(symbol)) + + protected def getWorkspaceLocations(symbol: String): Set[ClassLocation] = inheritance.getOrElse(symbol, Set.empty) - } - def withClasspathContext( - classpathInheritance: Map[String, Set[ClassLocation]] - ): InheritanceContext = { - val newInheritance = mutable.Map(inheritance.toSeq: _*) - for { (symbol, locations) <- classpathInheritance } { - val newLocations = - newInheritance.getOrElse(symbol, Set.empty) ++ locations - newInheritance += symbol -> newLocations - } - this.copy( - inheritance = newInheritance.toMap - ) + def toGlobal( + compilers: Compilers, + implementationsInDependencySources: Map[String, Set[ClassLocation]], + source: AbsolutePath, + ) = new GlobalInheritanceContext( + compilers, + implementationsInDependencySources, + inheritance, + source, + ) +} + +class GlobalInheritanceContext( + compilers: Compilers, + implementationsInDependencySources: Map[String, Set[ClassLocation]], + localInheritance: Map[String, Set[ClassLocation]], + source: AbsolutePath, +) extends InheritanceContext(localInheritance) { + override def getLocations( + symbol: String + )(implicit ec: ExecutionContext): Future[Set[ClassLocation]] = { + val workspaceImplementations = getWorkspaceLocations(symbol) + // for enum class we resolve all cases as implementations while indexing + val enumCasesImplementations = + implementationsInDependencySources.getOrElse(symbol, Set.empty) + val shortName = symbol.desc.name.value + val resolveGlobal = + implementationsInDependencySources + .getOrElse(shortName, Set.empty) + .collect { case loc @ ClassLocation(sym, _) => + compilers.info(source, sym).map { + case Some(symInfo) if symInfo.parents.contains(symbol) => Some(loc) + case Some(symInfo) + if symInfo.dealiasedSymbol == symbol && symInfo.symbol != symbol => + Some(loc) + case _ => None + } + } + Future + .sequence(resolveGlobal) + .map { globalImplementations => + workspaceImplementations ++ globalImplementations.flatten ++ enumCasesImplementations + } } } object InheritanceContext { def fromDefinitions( - findSymbol: String => Option[SymbolInformation], - localDefinitions: Map[Path, Map[String, Set[ClassLocation]]], + localDefinitions: Map[Path, Map[String, Set[ClassLocation]]] ): InheritanceContext = { val inheritance = mutable.Map .empty[String, Set[ClassLocation]] @@ -47,6 +82,6 @@ object InheritanceContext { val updated = inheritance.getOrElse(symbol, Set.empty) ++ locations inheritance += symbol -> updated } - InheritanceContext(findSymbol, inheritance.toMap) + new InheritanceContext(inheritance.toMap) } } diff --git a/metals/src/main/scala/scala/meta/internal/implementation/Supermethods.scala b/metals/src/main/scala/scala/meta/internal/implementation/Supermethods.scala index 94d1dce1154..2bef1344c08 100644 --- a/metals/src/main/scala/scala/meta/internal/implementation/Supermethods.scala +++ b/metals/src/main/scala/scala/meta/internal/implementation/Supermethods.scala @@ -11,6 +11,7 @@ import scala.meta.internal.metals.ReportContext import scala.meta.internal.metals.clients.language.MetalsLanguageClient import scala.meta.internal.metals.clients.language.MetalsQuickPickItem import scala.meta.internal.metals.clients.language.MetalsQuickPickParams +import scala.meta.internal.search.SymbolHierarchyOps import scala.meta.internal.semanticdb.SymbolInformation import scala.meta.io.AbsolutePath @@ -22,7 +23,7 @@ import org.eclipse.lsp4j.TextDocumentPositionParams class Supermethods( client: MetalsLanguageClient, definitionProvider: DefinitionProvider, - implementationProvider: ImplementationProvider, + symbolHierarchyOps: SymbolHierarchyOps, )(implicit ec: ExecutionContext, reports: ReportContext, @@ -76,7 +77,7 @@ class Supermethods( filePath, params.getPosition(), ) - findSymbol = implementationProvider.defaultSymbolSearch( + findSymbol = symbolHierarchyOps.defaultSymbolSearch( filePath, textDocument, ) @@ -131,7 +132,7 @@ class Supermethods( filePath, position, ) - findSymbol = implementationProvider.defaultSymbolSearch( + findSymbol = symbolHierarchyOps.defaultSymbolSearch( filePath, textDocument, ) diff --git a/metals/src/main/scala/scala/meta/internal/metals/BuildTargets.scala b/metals/src/main/scala/scala/meta/internal/metals/BuildTargets.scala index 8be3bb12cff..3ef24181868 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/BuildTargets.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/BuildTargets.scala @@ -375,6 +375,15 @@ final class BuildTargets private ( } } + def belongsToBuildTarget( + target: BuildTargetIdentifier, + path: AbsolutePath, + ): Boolean = { + val possibleBuildTargets = + buildTargetTransitiveDependencies(target).toSet + target + inverseSourcesAll(path).exists(possibleBuildTargets(_)) + } + def inferBuildTarget( source: AbsolutePath ): Option[BuildTargetIdentifier] = diff --git a/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala b/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala index be8b2ffb014..127cff2df38 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala @@ -24,6 +24,7 @@ import scala.meta.internal.parsing.Trees import scala.meta.internal.pc.EmptySymbolSearch import scala.meta.internal.pc.JavaPresentationCompiler import scala.meta.internal.pc.LogMessages +import scala.meta.internal.pc.PcSymbolInformation import scala.meta.internal.pc.ScalaPresentationCompiler import scala.meta.internal.worksheets.WorksheetPcData import scala.meta.internal.worksheets.WorksheetProvider @@ -774,6 +775,18 @@ class Compilers( definition(params = params, token = token, findTypeDef = true) } + def info( + path: AbsolutePath, + symbol: String, + ): Future[Option[PcSymbolInformation]] = { + loadCompiler(path, forceScala = true) + .map( + _.info(symbol).asScala + .map(_.asScala.map(PcSymbolInformation.from)) + ) + .getOrElse(Future(None)) + } + private def definition( params: TextDocumentPositionParams, token: CancelToken, @@ -829,8 +842,16 @@ class Compilers( }.getOrElse(Future.successful(Nil.asJava)) } + /** + * Gets presentation compiler for a file. + * @param path for which presentation compiler should be loaded, + * resolves build target based on this file + * @param forceScala if should use Scala pc for `.java` files that are in a Scala build target, + * useful when Scala pc can handle Java files and Java pc implementation of a feature is missing + */ def loadCompiler( - path: AbsolutePath + path: AbsolutePath, + forceScala: Boolean = false, ): Option[PresentationCompiler] = { def fromBuildTarget: Option[PresentationCompiler] = { @@ -841,6 +862,8 @@ class Compilers( case None => Some(fallbackCompiler(path)) case Some(value) => if (path.isScalaFilename) loadCompiler(value) + else if (path.isJavaFilename && forceScala) + loadCompiler(value).orElse(loadJavaCompiler(value)) else if (path.isJavaFilename) loadJavaCompiler(value) else None } diff --git a/metals/src/main/scala/scala/meta/internal/metals/DefinitionProvider.scala b/metals/src/main/scala/scala/meta/internal/metals/DefinitionProvider.scala index 4cbf3f2c0f7..0850e5694d4 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/DefinitionProvider.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/DefinitionProvider.scala @@ -8,6 +8,7 @@ import scala.concurrent.Future import scala.meta.Term import scala.meta.Type import scala.meta.inputs.Input +import scala.meta.inputs.Position.Range import scala.meta.internal.metals.MetalsEnrichments._ import scala.meta.internal.mtags.GlobalSymbolIndex import scala.meta.internal.mtags.Mtags @@ -33,6 +34,7 @@ import org.eclipse.lsp4j.Location import org.eclipse.lsp4j.Position import org.eclipse.lsp4j.SymbolInformation import org.eclipse.lsp4j.SymbolKind +import org.eclipse.lsp4j.TextDocumentIdentifier import org.eclipse.lsp4j.TextDocumentPositionParams /** @@ -120,6 +122,23 @@ final class DefinitionProvider( } } + def definition( + path: AbsolutePath, + pos: Int, + ): Future[DefinitionResult] = { + val text = path.readText + val input = new Input.VirtualFile(path.toURI.toString(), text) + val range = Range(input, pos, pos) + definition( + path, + new TextDocumentPositionParams( + new TextDocumentIdentifier(path.toURI.toString()), + range.toLsp.getStart(), + ), + EmptyCancelToken, + ) + } + /** * Tries to find an identifier token at the current position * to use it for symbol search. This is the last possibility for diff --git a/metals/src/main/scala/scala/meta/internal/metals/Indexer.scala b/metals/src/main/scala/scala/meta/internal/metals/Indexer.scala index 28edbf47c06..0e4673aa516 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/Indexer.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/Indexer.scala @@ -24,11 +24,13 @@ import scala.meta.internal.builds.BuildTool import scala.meta.internal.builds.BuildTools import scala.meta.internal.builds.Digest.Status import scala.meta.internal.builds.WorkspaceReload +import scala.meta.internal.implementation.ImplementationProvider import scala.meta.internal.metals.MetalsEnrichments._ import scala.meta.internal.metals.clients.language.DelegatingLanguageClient import scala.meta.internal.metals.clients.language.ForwardingMetalsBuildClient import scala.meta.internal.metals.debug.BuildTargetClasses import scala.meta.internal.metals.watcher.FileWatcher +import scala.meta.internal.mtags.IndexingResult import scala.meta.internal.mtags.OnDemandSymbolIndex import scala.meta.internal.semanticdb.Scala._ import scala.meta.internal.tvp.FolderTreeViewProvider @@ -80,6 +82,7 @@ final case class Indexer( scalaVersionSelector: ScalaVersionSelector, sourceMapper: SourceMapper, workspaceFolder: AbsolutePath, + implementationProvider: ImplementationProvider, )(implicit rc: ReportContext) { private implicit def ec: ExecutionContextExecutorService = executionContext @@ -479,6 +482,7 @@ final case class Indexer( ) ) .getOrElse(Scala213) + definitionIndex.addSourceDirectory(path, dialect) } else { scribe.warn(s"unexpected dependency: $path") @@ -608,14 +612,36 @@ final case class Indexer( * @param path JAR path */ private def addSourceJarSymbols(path: AbsolutePath): Unit = { + val dialect = ScalaVersions.dialectForDependencyJar(path.filename) + def indexJar() = { + val indexResult = definitionIndex.addSourceJar(path, dialect) + val toplevels = indexResult.flatMap { + case IndexingResult(path, toplevels, _) => + toplevels.map((_, path)) + } + val overrides = indexResult.flatMap { + case IndexingResult(path, _, list) => + list.flatMap { case (symbol, overridden) => + overridden.map((path, symbol, _)) + } + } + implementationProvider.addTypeHierarchyElements(overrides) + (toplevels, overrides) + } + tables.jarSymbols.getTopLevels(path) match { case Some(toplevels) => - val dialect = ScalaVersions.dialectForDependencyJar(path.filename) - definitionIndex.addIndexedSourceJar(path, toplevels, dialect) + tables.jarSymbols.getTypeHierarchy(path) match { + case Some(overrides) => + definitionIndex.addIndexedSourceJar(path, toplevels, dialect) + implementationProvider.addTypeHierarchyElements(overrides) + case None => + val (_, overrides) = indexJar() + tables.jarSymbols.addTypeHierarchyInfo(path, overrides) + } case None => - val dialect = ScalaVersions.dialectForDependencyJar(path.filename) - val toplevels = definitionIndex.addSourceJar(path, dialect) - tables.jarSymbols.putTopLevels(path, toplevels) + val (toplevels, overrides) = indexJar() + tables.jarSymbols.putJarIndexingInfo(path, toplevels, overrides) } } diff --git a/metals/src/main/scala/scala/meta/internal/metals/InteractiveSemanticdbs.scala b/metals/src/main/scala/scala/meta/internal/metals/InteractiveSemanticdbs.scala index e49bc9754e0..34c96bcba45 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/InteractiveSemanticdbs.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/InteractiveSemanticdbs.scala @@ -165,8 +165,12 @@ final class InteractiveSemanticdbs( .getOrElse(compilers().fallbackCompiler(source)) val (prependedLinesSize, modifiedText) = - buildTargets - .sbtAutoImports(source) + Option + .when(source.isSbt)( + buildTargets + .sbtAutoImports(source) + ) + .flatten .fold((0, text))(imports => (imports.size, SbtBuildTool.prependAutoImports(text, imports)) ) diff --git a/metals/src/main/scala/scala/meta/internal/metals/JarTopLevels.scala b/metals/src/main/scala/scala/meta/internal/metals/JarTopLevels.scala index 6fae84713fd..bdecc52aec6 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/JarTopLevels.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/JarTopLevels.scala @@ -8,10 +8,12 @@ import java.sql.Statement import java.util.zip.ZipError import java.util.zip.ZipException -import scala.meta.internal.io.PlatformFileIO import scala.meta.internal.metals.JdbcEnrichments._ import scala.meta.internal.metals.MetalsEnrichments._ import scala.meta.internal.mtags.MD5 +import scala.meta.internal.mtags.OverriddenSymbol +import scala.meta.internal.mtags.ResolvedOverriddenSymbol +import scala.meta.internal.mtags.UnresolvedOverriddenSymbol import scala.meta.io.AbsolutePath /** @@ -28,17 +30,9 @@ final class JarTopLevels(conn: () => Connection) { * @return the top level Scala symbols in the jar */ def getTopLevels( - path: AbsolutePath + jar: AbsolutePath ): Option[List[(String, AbsolutePath)]] = try { - val fs = path.jarPath - .map(jarPath => - PlatformFileIO.newFileSystem( - jarPath.toURI, - new java.util.HashMap[String, String](), - ) - ) - .getOrElse(PlatformFileIO.newJarFileSystem(path, create = false)) val toplevels = List.newBuilder[(String, AbsolutePath)] conn() .query( @@ -47,20 +41,59 @@ final class JarTopLevels(conn: () => Connection) { |left join toplevel_symbol ts |on ij.id=ts.jar |where ij.md5=?""".stripMargin - ) { _.setString(1, getMD5Digest(path)) } { rs => + ) { _.setString(1, getMD5Digest(jar)) } { rs => if (rs.getString(1) != null && rs.getString(2) != null) { val symbol = rs.getString(1) - val path = AbsolutePath(fs.getPath(rs.getString(2))) + val path = toPath(jar, rs.getString(2)) toplevels += (symbol -> path) } } .headOption .map(_ => toplevels.result) } catch { - case _: ZipError | _: ZipException => + case error @ (_: ZipError | _: ZipException) => + scribe.warn(s"corrupted jar $jar: $error") + None + } + + def getTypeHierarchy( + jar: AbsolutePath + ): Option[List[(AbsolutePath, String, OverriddenSymbol)]] = + try { + val toplevels = List.newBuilder[(AbsolutePath, String, OverriddenSymbol)] + conn() + .query( + """select th.symbol, th.parent_name, th.path, th.is_resolved + |from indexed_jar ij + |left join type_hierarchy th + |on ij.id=th.jar + |where ij.type_hierarchy_indexed=true and ij.md5=?""".stripMargin + ) { _.setString(1, getMD5Digest(jar)) } { rs => + if ( + rs.getString(1) != null && rs + .getString(2) != null && rs.getString(4) != null + ) { + val symbol = rs.getString(1) + val parentName = rs.getString(2) + val path = toPath(jar, rs.getString(3)) + val isResolved = rs.getBoolean(4) + val overridden = + if (isResolved) ResolvedOverriddenSymbol(parentName) + else UnresolvedOverriddenSymbol(parentName) + toplevels += ((path, symbol, overridden)) + } + } + .headOption + .map(_ => toplevels.result) + } catch { + case error @ (_: ZipError | _: ZipException) => + scribe.warn(s"corrupted jar $jar: $error") None } + private def toPath(jar: AbsolutePath, path: String) = + ("jar:" ++ jar.toNIO.toUri.toString() ++ "!" ++ path).toAbsolutePath + /** * Stores the top level symbols for the Jar * @@ -68,21 +101,23 @@ final class JarTopLevels(conn: () => Connection) { * @param toplevels toplevel symbols in the jar * @return the number of toplevel symbols inserted */ - def putTopLevels( + def putJarIndexingInfo( path: AbsolutePath, toplevels: List[(String, AbsolutePath)], + type_hierarchy: List[(AbsolutePath, String, OverriddenSymbol)], ): Int = { - if (toplevels.isEmpty) 0 + if (toplevels.isEmpty && type_hierarchy.isEmpty) 0 else { // Add jar to H2 var jarStmt: PreparedStatement = null val jar = try { jarStmt = conn().prepareStatement( - s"insert into indexed_jar (md5) values (?)", + s"insert into indexed_jar (md5, type_hierarchy_indexed) values (?, ?)", Statement.RETURN_GENERATED_KEYS, ) jarStmt.setString(1, getMD5Digest(path)) + jarStmt.setBoolean(2, true) jarStmt.executeUpdate() val rs = jarStmt.getGeneratedKeys rs.next() @@ -90,7 +125,42 @@ final class JarTopLevels(conn: () => Connection) { } finally { if (jarStmt != null) jarStmt.close() } + putToplevels(jar, toplevels) + putTypeHierarchyInfo(jar, type_hierarchy) + } + } + + def addTypeHierarchyInfo( + path: AbsolutePath, + type_hierarchy: List[(AbsolutePath, String, OverriddenSymbol)], + ): Int = { + var jarStmt: PreparedStatement = null + val jar = + try { + val digest = getMD5Digest(path) + jarStmt = conn().prepareStatement( + s"update indexed_jar set type_hierarchy_indexed = true where (md5) = (?)" + ) + jarStmt.setString(1, digest) + jarStmt.executeUpdate() + conn() + .query( + """select id + |from indexed_jar + |where md5=?""".stripMargin + ) { _.setString(1, digest) } { _.getInt(1) } + .head + } finally { + if (jarStmt != null) jarStmt.close() + } + putTypeHierarchyInfo(jar, type_hierarchy) + } + + def putToplevels( + jar: Int, + toplevels: List[(String, AbsolutePath)], + ): Int = + if (toplevels.nonEmpty) { // Add symbols for jar to H2 var symbolStmt: PreparedStatement = null try { @@ -108,8 +178,40 @@ final class JarTopLevels(conn: () => Connection) { } finally { if (symbolStmt != null) symbolStmt.close() } - } - } + } else 0 + + private def putTypeHierarchyInfo( + jar: Int, + type_hierarchy: List[(AbsolutePath, String, OverriddenSymbol)], + ): Int = + if (type_hierarchy.nonEmpty) { + // Add symbols for jar to H2 + var symbolStmt: PreparedStatement = null + try { + symbolStmt = conn().prepareStatement( + s"insert into type_hierarchy (symbol, parent_name, path, jar, is_resolved) values (?, ?, ?, ?, ?)" + ) + type_hierarchy.foreach { case (path, symbol, overridden) => + symbolStmt.setString(1, symbol) + overridden match { + case ResolvedOverriddenSymbol(name) => + symbolStmt.setString(2, name) + symbolStmt.setInt(3, 0) + symbolStmt.setBoolean(5, true) + case UnresolvedOverriddenSymbol(name) => + symbolStmt.setString(2, name) + symbolStmt.setBoolean(5, false) + } + symbolStmt.setString(3, path.toString()) + symbolStmt.setInt(4, jar) + symbolStmt.addBatch() + } + // Return number of rows inserted + symbolStmt.executeBatch().sum + } finally { + if (symbolStmt != null) symbolStmt.close() + } + } else 0 /** * Delete the jars that are not used and their top level symbols @@ -124,7 +226,17 @@ final class JarTopLevels(conn: () => Connection) { } { _ => () } } - private def getMD5Digest(path: AbsolutePath) = { + def clearAll(): Unit = { + val statement1 = conn().prepareStatement("truncate table toplevel_symbol") + statement1.execute() + val statement2 = + conn().prepareStatement("truncate table type_hierarchy_jar") + statement2.execute() + val statement3 = conn().prepareStatement("delete from indexed_jar") + statement3.execute() + } + + def getMD5Digest(path: AbsolutePath): String = { val attributes = Files .getFileAttributeView(path.toNIO, classOf[BasicFileAttributeView]) .readAttributes() @@ -134,11 +246,4 @@ final class JarTopLevels(conn: () => Connection) { .toMillis + ":" + attributes.size() ) } - - def clearAll(): Unit = { - val statement1 = conn().prepareStatement("truncate table toplevel_symbol") - statement1.execute() - val statement2 = conn().prepareStatement("delete from indexed_jar") - statement2.execute() - } } diff --git a/metals/src/main/scala/scala/meta/internal/metals/MetalsEnrichments.scala b/metals/src/main/scala/scala/meta/internal/metals/MetalsEnrichments.scala index 4baeccf62e4..17a63e4b4d4 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/MetalsEnrichments.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/MetalsEnrichments.scala @@ -1268,6 +1268,14 @@ object MetalsEnrichments } } + implicit class XtensionLocation(location: l.Location) { + def toTextDocumentPositionParams = + new l.TextDocumentPositionParams( + new l.TextDocumentIdentifier(location.getUri()), + location.getRange().getStart(), + ) + } + /** * Strips ANSI colors. * As long as the color codes are valid this should correctly strip diff --git a/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala b/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala index 948f119865c..8f8667ed3e8 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala @@ -78,6 +78,7 @@ import scala.meta.internal.parsing.FoldingRangeProvider import scala.meta.internal.parsing.TokenEditDistance import scala.meta.internal.parsing.Trees import scala.meta.internal.rename.RenameProvider +import scala.meta.internal.search.SymbolHierarchyOps import scala.meta.internal.semver.SemVer import scala.meta.internal.tvp._ import scala.meta.internal.worksheets.DecorationWorksheetPublisher @@ -517,24 +518,6 @@ class MetalsLspService( ) } - private val implementationProvider: ImplementationProvider = - new ImplementationProvider( - semanticdbs, - folder, - definitionIndex, - buildTargets, - buffers, - definitionProvider, - trees, - scalaVersionSelector, - ) - - private val supermethods: Supermethods = new Supermethods( - languageClient, - definitionProvider, - implementationProvider, - ) - private val referencesProvider: ReferenceProvider = new ReferenceProvider( folder, semanticdbs, @@ -544,16 +527,6 @@ class MetalsLspService( buildTargets, ) - private val semanticDBIndexer: SemanticdbIndexer = new SemanticdbIndexer( - List( - referencesProvider, - implementationProvider, - testProvider, - ), - buildTargets, - folder, - ) - private val formattingProvider: FormattingProvider = new FormattingProvider( folder, buffers, @@ -566,26 +539,6 @@ class MetalsLspService( buildTargets, ) - private val javaFormattingProvider: JavaFormattingProvider = - new JavaFormattingProvider( - buffers, - () => userConfig, - buildTargets, - ) - - private val callHierarchyProvider: CallHierarchyProvider = - new CallHierarchyProvider( - folder, - semanticdbs, - definitionProvider, - referencesProvider, - clientConfig.icons, - () => compilers, - trees, - buildTargets, - supermethods, - ) - private val javaHighlightProvider: JavaDocumentHighlightProvider = new JavaDocumentHighlightProvider( definitionProvider, @@ -665,9 +618,70 @@ class MetalsLspService( ) ) + private val javaFormattingProvider: JavaFormattingProvider = + new JavaFormattingProvider( + buffers, + () => userConfig, + buildTargets, + ) + + private val implementationProvider: ImplementationProvider = + new ImplementationProvider( + semanticdbs, + folder, + definitionIndex, + buffers, + definitionProvider, + trees, + scalaVersionSelector, + compilers, + buildTargets, + ) + + private val symbolHierarchyOps: SymbolHierarchyOps = + new SymbolHierarchyOps( + folder, + buildTargets, + semanticdbs, + definitionIndex, + scalaVersionSelector, + buffers, + trees, + ) + + private val supermethods: Supermethods = new Supermethods( + languageClient, + definitionProvider, + symbolHierarchyOps, + ) + + private val semanticDBIndexer: SemanticdbIndexer = new SemanticdbIndexer( + List( + referencesProvider, + implementationProvider, + testProvider, + ), + buildTargets, + folder, + ) + + private val callHierarchyProvider: CallHierarchyProvider = + new CallHierarchyProvider( + folder, + semanticdbs, + definitionProvider, + referencesProvider, + clientConfig.icons, + () => compilers, + trees, + buildTargets, + supermethods, + ) + private val renameProvider: RenameProvider = new RenameProvider( referencesProvider, implementationProvider, + symbolHierarchyOps, definitionProvider, folder, languageClient, @@ -2418,6 +2432,7 @@ class MetalsLspService( scalaVersionSelector, sourceMapper, folder, + implementationProvider, ) private def checkRunningBloopVersion(bspServerVersion: String) = { diff --git a/metals/src/main/scala/scala/meta/internal/metals/Tables.scala b/metals/src/main/scala/scala/meta/internal/metals/Tables.scala index b51a7d9680a..7b8f10f8f35 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/Tables.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/Tables.scala @@ -150,7 +150,8 @@ final class Tables( } private def tryUrl(url: String): Connection = { - val flyway = Flyway.configure.dataSource(url, user, null).load() + val flyway = + Flyway.configure.dataSource(url, user, null).cleanDisabled(false).load() migrateOrRestart(flyway) DriverManager.getConnection(url, user, null) } diff --git a/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala b/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala index 50ca3f5e634..dd4159e291d 100644 --- a/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala +++ b/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala @@ -20,6 +20,7 @@ import scala.meta.internal.metals.TextEdits import scala.meta.internal.metals.clients.language.MetalsLanguageClient import scala.meta.internal.parsing.Trees import scala.meta.internal.pc.Identifier +import scala.meta.internal.search.SymbolHierarchyOps import scala.meta.internal.semanticdb.Scala._ import scala.meta.internal.semanticdb.SelectTree import scala.meta.internal.semanticdb.SymbolOccurrence @@ -50,6 +51,7 @@ import org.eclipse.lsp4j.{Range => LSPRange} final class RenameProvider( referenceProvider: ReferenceProvider, implementationProvider: ImplementationProvider, + symbolHierarchyOps: SymbolHierarchyOps, definitionProvider: DefinitionProvider, workspace: AbsolutePath, client: MetalsLanguageClient, @@ -177,7 +179,7 @@ final class RenameProvider( path: AbsolutePath, textDocument: TextDocument, ) = - !symbol.desc.isType && !(symbol.isLocal && implementationProvider + !symbol.desc.isType && !(symbol.isLocal && symbolHierarchyOps .defaultSymbolSearch(path, textDocument)(symbol) .exists(info => info.isTrait || info.isClass)) @@ -192,7 +194,7 @@ final class RenameProvider( isWorkspaceSymbol(occurence.symbol, definitionPath) && isNotRenamedSymbol(semanticDb, occurence) parentSymbols = - implementationProvider + symbolHierarchyOps .topMethodParents(occurence.symbol, defSemanticdb) txtParams <- { if (parentSymbols.nonEmpty) parentSymbols.map(toTextParams) diff --git a/metals/src/main/scala/scala/meta/internal/search/GlobalClassTable.scala b/metals/src/main/scala/scala/meta/internal/search/GlobalClassTable.scala new file mode 100644 index 00000000000..f74aad8e946 --- /dev/null +++ b/metals/src/main/scala/scala/meta/internal/search/GlobalClassTable.scala @@ -0,0 +1,33 @@ +package scala.meta.internal.search +import scala.collection.concurrent.TrieMap + +import scala.meta.internal.metals.BuildTargets +import scala.meta.internal.symtab.GlobalSymbolTable +import scala.meta.io.AbsolutePath +import scala.meta.io.Classpath + +import ch.epfl.scala.bsp4j.BuildTargetIdentifier + +final class GlobalClassTable( + buildTargets: BuildTargets +) { + private val buildTargetsIndexes = + TrieMap.empty[BuildTargetIdentifier, GlobalSymbolTable] + + def globalSymbolTableFor( + source: AbsolutePath + ): Option[GlobalSymbolTable] = + synchronized { + for { + buildTargetId <- buildTargets.inverseSources(source) + jarClasspath <- buildTargets.targetJarClasspath(buildTargetId) + classpath = new Classpath(jarClasspath) + } yield { + buildTargetsIndexes.getOrElseUpdate( + buildTargetId, + GlobalSymbolTable(classpath, includeJdk = true), + ) + } + } + +} diff --git a/metals/src/main/scala/scala/meta/internal/search/SymbolHierarchyOps.scala b/metals/src/main/scala/scala/meta/internal/search/SymbolHierarchyOps.scala new file mode 100644 index 00000000000..dd797ffe8cb --- /dev/null +++ b/metals/src/main/scala/scala/meta/internal/search/SymbolHierarchyOps.scala @@ -0,0 +1,227 @@ +package scala.meta.internal.search + +import scala.util.control.NonFatal + +import scala.meta.internal.implementation.MethodImplementation +import scala.meta.internal.implementation.TextDocumentWithPath +import scala.meta.internal.metals.Buffers +import scala.meta.internal.metals.BuildTargets +import scala.meta.internal.metals.MetalsEnrichments._ +import scala.meta.internal.metals.ScalaVersionSelector +import scala.meta.internal.mtags.GlobalSymbolIndex +import scala.meta.internal.mtags.Mtags +import scala.meta.internal.mtags.Semanticdbs +import scala.meta.internal.mtags.SymbolDefinition +import scala.meta.internal.mtags.{Symbol => MSymbol} +import scala.meta.internal.parsing.Trees +import scala.meta.internal.search.SymbolHierarchyOps._ +import scala.meta.internal.semanticdb.ClassSignature +import scala.meta.internal.semanticdb.Scala._ +import scala.meta.internal.semanticdb.SymbolInformation +import scala.meta.internal.semanticdb.SymbolOccurrence +import scala.meta.internal.semanticdb.TextDocument +import scala.meta.internal.semanticdb.TypeRef +import scala.meta.internal.symtab.GlobalSymbolTable +import scala.meta.io.AbsolutePath + +import org.eclipse.lsp4j.Location + +class SymbolHierarchyOps( + workspace: AbsolutePath, + buildTargets: BuildTargets, + semanticdbs: Semanticdbs, + index: GlobalSymbolIndex, + scalaVersionSelector: ScalaVersionSelector, + buffer: Buffers, + trees: Trees, +) { + private val globalTable = new GlobalClassTable(buildTargets) + def defaultSymbolSearch( + anyWorkspacePath: AbsolutePath, + textDocument: TextDocument, + ): String => Option[SymbolInformation] = { + lazy val global = + globalTable.globalSymbolTableFor(anyWorkspacePath) + symbol => { + textDocument.symbols + .find(_.symbol == symbol) + .orElse(findSymbolInformation(symbol)) + .orElse(global.flatMap(_.safeInfo(symbol))) + } + } + + private def findSymbolInformation( + symbol: String + ): Option[SymbolInformation] = { + findSemanticDbForSymbol(symbol).flatMap(findSymbol(_, symbol)) + } + + def findSemanticDbWithPathForSymbol( + symbol: String + ): Option[TextDocumentWithPath] = { + for { + symbolDefinition <- findSymbolDefinition(symbol) + document <- findSemanticdb(symbolDefinition.path) + } yield TextDocumentWithPath(document, symbolDefinition.path) + } + + private def findSemanticdb(fileSource: AbsolutePath): Option[TextDocument] = { + if (fileSource.isJarFileSystem) None + else + semanticdbs + .textDocument(fileSource) + .documentIncludingStale + } + + private def findSymbolDefinition(symbol: String): Option[SymbolDefinition] = { + index.definition(MSymbol(symbol)) + } + + private def findSemanticDbForSymbol(symbol: String): Option[TextDocument] = { + for { + symbolDefinition <- findSymbolDefinition(symbol) + document <- findSemanticdb(symbolDefinition.path) + } yield { + document + } + } + + def topMethodParents( + symbol: String, + textDocument: TextDocument, + ): Seq[Location] = { + + def findClassInfo(owner: String) = { + if (owner.nonEmpty) { + findSymbol(textDocument, owner) + } else { + textDocument.symbols.find { sym => + sym.signature match { + case sig: ClassSignature => + sig.declarations.exists(_.symlinks.contains(symbol)) + case _ => false + } + } + } + } + + val results = for { + currentInfo <- findSymbol(textDocument, symbol) + if !isClassLike(currentInfo) + classInfo <- findClassInfo(symbol.owner) + } yield { + classInfo.signature match { + case sig: ClassSignature => + methodInParentSignature(sig, currentInfo, sig) + case _ => Nil + } + } + results.getOrElse(Seq.empty) + } + + private def methodInParentSignature( + currentClassSig: ClassSignature, + bottomSymbol: SymbolInformation, + bottomClassSig: ClassSignature, + ): Seq[Location] = { + currentClassSig.parents.flatMap { + case parentSym: TypeRef => + val parentTextDocument = findSemanticDbForSymbol(parentSym.symbol) + def search(symbol: String) = + parentTextDocument.flatMap(findSymbol(_, symbol)) + search(parentSym.symbol).map(_.signature) match { + case Some(parenClassSig: ClassSignature) => + val fromParent = methodInParentSignature( + parenClassSig, + bottomSymbol, + bottomClassSig, + ) + if (fromParent.isEmpty) { + locationFromClass( + bottomSymbol, + parenClassSig, + search, + parentTextDocument, + ) + } else { + fromParent + } + case _ => Nil + } + + case _ => Nil + } + } + + private def locationFromClass( + bottomSymbolInformation: SymbolInformation, + parentClassSig: ClassSignature, + search: String => Option[SymbolInformation], + parentTextDocument: Option[TextDocument], + ): Option[Location] = { + val matchingSymbol = MethodImplementation.findParentSymbol( + bottomSymbolInformation, + parentClassSig, + search, + ) + for { + symbol <- matchingSymbol + parentDoc <- parentTextDocument + source = workspace.resolve(parentDoc.uri) + implOccurrence <- findDefOccurrence( + parentDoc, + symbol, + source, + scalaVersionSelector, + ) + range <- implOccurrence.range + distance = buffer.tokenEditDistance( + source, + parentDoc.text, + trees, + ) + revised <- distance.toRevised(range.toLsp) + } yield new Location(source.toNIO.toUri().toString(), revised) + } +} + +object SymbolHierarchyOps { + def findSymbol( + semanticDb: TextDocument, + symbol: String, + ): Option[SymbolInformation] = { + semanticDb.symbols + .find(sym => sym.symbol == symbol) + } + + implicit class XtensionGlobalSymbolTable(symtab: GlobalSymbolTable) { + def safeInfo(symbol: String): Option[SymbolInformation] = + try { + symtab.info(symbol) + } catch { + case NonFatal(_) => None + } + } + + def isClassLike(info: SymbolInformation): Boolean = + info.isObject || info.isClass || info.isTrait || info.isInterface + + def findDefOccurrence( + semanticDb: TextDocument, + symbol: String, + source: AbsolutePath, + scalaVersionSelector: ScalaVersionSelector, + ): Option[SymbolOccurrence] = { + def isDefinitionOccurrence(occ: SymbolOccurrence) = + occ.role.isDefinition && occ.symbol == symbol + + semanticDb.occurrences + .find(isDefinitionOccurrence) + .orElse( + Mtags + .allToplevels(source.toInput, scalaVersionSelector.getDialect(source)) + .occurrences + .find(isDefinitionOccurrence) + ) + } +} diff --git a/mtags-interfaces/src/main/java/scala/meta/pc/PcSymbolInformation.java b/mtags-interfaces/src/main/java/scala/meta/pc/PcSymbolInformation.java new file mode 100644 index 00000000000..43e57233420 --- /dev/null +++ b/mtags-interfaces/src/main/java/scala/meta/pc/PcSymbolInformation.java @@ -0,0 +1,15 @@ +package scala.meta.pc; + +import java.util.List; + +public interface PcSymbolInformation { + String symbol(); + PcSymbolKind kind(); + List parents(); + String dealiasedSymbol(); + String classOwner(); + List overriddenSymbols(); + // overloaded methods + List alternativeSymbols(); + List properties(); +} diff --git a/mtags-interfaces/src/main/java/scala/meta/pc/PcSymbolKind.java b/mtags-interfaces/src/main/java/scala/meta/pc/PcSymbolKind.java new file mode 100644 index 00000000000..345234a108b --- /dev/null +++ b/mtags-interfaces/src/main/java/scala/meta/pc/PcSymbolKind.java @@ -0,0 +1,28 @@ +package scala.meta.pc; + +public enum PcSymbolKind { + UNKNOWN_KIND(0), + METHOD(3), + MACRO(6), + TYPE(7), + PARAMETER(8), + TYPE_PARAMETER(9), + OBJECT(10), + PACKAGE(11), + PACKAGE_OBJECT(12), + CLASS(13), + TRAIT(14), + SELF_PARAMETER(17), + INTERFACE(18), + LOCAL(19), + FIELD(20), + CONSTRUCTOR(21); + + private int value; + + public int getValue(){return value;} + + private PcSymbolKind (int value) { + this.value = value; + } +} diff --git a/mtags-interfaces/src/main/java/scala/meta/pc/PcSymbolProperty.java b/mtags-interfaces/src/main/java/scala/meta/pc/PcSymbolProperty.java new file mode 100644 index 00000000000..0c85d0043ce --- /dev/null +++ b/mtags-interfaces/src/main/java/scala/meta/pc/PcSymbolProperty.java @@ -0,0 +1,5 @@ +package scala.meta.pc; + +public enum PcSymbolProperty { + ABSTRACT; +} diff --git a/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java b/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java index eab3c2f82de..734cfc26e91 100644 --- a/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java +++ b/mtags-interfaces/src/main/java/scala/meta/pc/PresentationCompiler.java @@ -179,6 +179,10 @@ public CompletableFuture> syntheticDecorations(Synthet return CompletableFuture.completedFuture(Collections.emptyList()); } + public CompletableFuture> info(String symbol) { + return CompletableFuture.completedFuture(Optional.empty()); + } + /** * File was closed. */ diff --git a/mtags-shared/src/main/scala/scala/meta/internal/metals/ReportContext.scala b/mtags-shared/src/main/scala/scala/meta/internal/metals/ReportContext.scala index cba45550521..0f8131a1c29 100644 --- a/mtags-shared/src/main/scala/scala/meta/internal/metals/ReportContext.scala +++ b/mtags-shared/src/main/scala/scala/meta/internal/metals/ReportContext.scala @@ -6,6 +6,7 @@ import java.nio.file.Paths import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicReference +import scala.util.Try import scala.util.matching.Regex import scala.meta.internal.metals.utils.LimitedFilesManager @@ -139,9 +140,11 @@ class StdReporter( } yield duplicate optDuplicate.orElse { - path.createDirectories() - path.writeText(sanitize(report.fullText(withIdAndSummary = true))) - Some(path) + Try { + path.createDirectories() + path.writeText(sanitize(report.fullText(withIdAndSummary = true))) + path + }.toOption } } diff --git a/mtags-shared/src/main/scala/scala/meta/internal/pc/PcSymbolInformation.scala b/mtags-shared/src/main/scala/scala/meta/internal/pc/PcSymbolInformation.scala new file mode 100644 index 00000000000..aa35cbabc47 --- /dev/null +++ b/mtags-shared/src/main/scala/scala/meta/internal/pc/PcSymbolInformation.scala @@ -0,0 +1,57 @@ +package scala.meta.internal.pc + +import java.{util => ju} + +import scala.meta.internal.jdk.CollectionConverters._ +import scala.meta.pc +import scala.meta.pc.PcSymbolProperty +import scala.meta.pc.{PcSymbolInformation => IPcSymbolInformation} + +case class PcSymbolInformation( + symbol: String, + kind: pc.PcSymbolKind, + parents: List[String], + dealiasedSymbol: String, + classOwner: Option[String], + overriddenSymbols: List[String], + alternativeSymbols: List[String], + properties: List[PcSymbolProperty] +) { + def asJava: PcSymbolInformationJava = + PcSymbolInformationJava( + symbol, + kind, + parents.asJava, + dealiasedSymbol, + classOwner.getOrElse(""), + overriddenSymbols.asJava, + alternativeSymbols.asJava, + properties.asJava + ) +} + +case class PcSymbolInformationJava( + symbol: String, + kind: pc.PcSymbolKind, + parents: ju.List[String], + dealiasedSymbol: String, + classOwner: String, + overriddenSymbols: ju.List[String], + alternativeSymbols: ju.List[String], + properties: ju.List[PcSymbolProperty] +) extends IPcSymbolInformation + +object PcSymbolInformation { + def from(info: IPcSymbolInformation): PcSymbolInformation = + PcSymbolInformation( + info.symbol(), + info.kind(), + info.parents().asScala.toList, + info.dealiasedSymbol(), + if (info.classOwner().nonEmpty) Some(info.classOwner()) + else None, + info.overriddenSymbols().asScala.toList, + info.alternativeSymbols().asScala.toList, + info.properties().asScala.toList + ) +} diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/ScalaPresentationCompiler.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/ScalaPresentationCompiler.scala index 067ca0f4db2..7fad8e3fc10 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/ScalaPresentationCompiler.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/ScalaPresentationCompiler.scala @@ -39,6 +39,7 @@ import scala.meta.pc.PresentationCompilerConfig import scala.meta.pc.RangeParams import scala.meta.pc.SymbolSearch import scala.meta.pc.VirtualFileParams +import scala.meta.pc.{PcSymbolInformation => IPcSymbolInformation} import org.eclipse.lsp4j.CompletionItem import org.eclipse.lsp4j.CompletionList @@ -349,6 +350,21 @@ case class ScalaPresentationCompiler( ) { pc => new PcDefinitionProvider(pc.compiler(), params).definition() } } + override def info( + symbol: String + ): CompletableFuture[Optional[IPcSymbolInformation]] = { + compilerAccess.withNonInterruptableCompiler[Optional[IPcSymbolInformation]]( + None + )( + Optional.empty(), + EmptyCancelToken + ) { pc => + val result: Option[IPcSymbolInformation] = + pc.compiler().info(symbol).map(_.asJava) + result.asJava + } + } + override def typeDefinition( params: OffsetParams ): CompletableFuture[DefinitionResult] = { diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/WorkspaceSymbolSearch.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/WorkspaceSymbolSearch.scala index 97a5a26d234..9f4ab5e53bb 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/WorkspaceSymbolSearch.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/WorkspaceSymbolSearch.scala @@ -4,16 +4,98 @@ import java.nio.file.Path import scala.util.control.NonFatal +import scala.meta.pc.PcSymbolKind +import scala.meta.pc.PcSymbolProperty import scala.meta.pc.SymbolSearchVisitor import org.eclipse.{lsp4j => l} trait WorkspaceSymbolSearch { this: MetalsGlobal => + def info(symbol: String): Option[PcSymbolInformation] = { + val index = symbol.lastIndexOf("/") + val pkgString = symbol.take(index + 1) + val pkg = packageSymbolFromString(pkgString) + + def loop( + symbol: String, + acc: List[(String, Boolean)] + ): List[(String, Boolean)] = + if (symbol.isEmpty()) acc.reverse + else { + val newSymbol = symbol.takeWhile(c => c != '.' && c != '#') + val rest = symbol.drop(newSymbol.size) + loop(rest.drop(1), (newSymbol, rest.headOption.exists(_ == '#')) :: acc) + } + + val names = + loop(symbol.drop(index + 1).takeWhile(_ != '('), List.empty) + + val compilerSymbols = names.foldLeft(pkg.toList) { + case (owners, (name, isClass)) => + owners.flatMap { owner => + val foundChild = + if (isClass) owner.info.member(TypeName(name)) + else owner.info.member(TermName(name)) + if (foundChild.exists) { + foundChild.info match { + case OverloadedType(_, alts) => alts + case _ => List(foundChild) + } + } else Nil + } + } + + val (searchedSymbol, alternativeSymbols) = + compilerSymbols.partition(compilerSymbol => + semanticdbSymbol(compilerSymbol) == symbol + ) + + searchedSymbol match { + case compilerSymbol :: _ => + Some( + PcSymbolInformation( + symbol = symbol, + kind = getSymbolKind(compilerSymbol), + parents = compilerSymbol.parentSymbols.map(semanticdbSymbol), + dealiasedSymbol = semanticdbSymbol(compilerSymbol.dealiased), + classOwner = compilerSymbol.ownerChain + .find(c => c.isClass || c.isModule) + .map(semanticdbSymbol), + overriddenSymbols = compilerSymbol.overrides.map(semanticdbSymbol), + alternativeSymbols = alternativeSymbols.map(semanticdbSymbol), + properties = + if ( + compilerSymbol.isAbstractClass || compilerSymbol.isAbstractType + ) + List(PcSymbolProperty.ABSTRACT) + else Nil + ) + ) + case _ => None + } + } + + private def getSymbolKind(sym: Symbol): PcSymbolKind = + if (sym.isJavaInterface) PcSymbolKind.INTERFACE + else if (sym.isTrait) PcSymbolKind.TRAIT + else if (sym.isConstructor) PcSymbolKind.CONSTRUCTOR + else if (sym.isPackageObject) PcSymbolKind.PACKAGE_OBJECT + else if (sym.isClass) PcSymbolKind.CLASS + else if (sym.isMacro) PcSymbolKind.MACRO + else if (sym.isLocalToBlock) PcSymbolKind.LOCAL + else if (sym.isMethod) PcSymbolKind.METHOD + else if (sym.isParameter) PcSymbolKind.PARAMETER + else if (sym.hasPackageFlag) PcSymbolKind.PACKAGE + else if (sym.isTypeParameter) PcSymbolKind.TYPE_PARAMETER + else if (sym.isType) PcSymbolKind.TYPE + else PcSymbolKind.UNKNOWN_KIND + class CompilerSearchVisitor( context: Context, visitMember: Symbol => Boolean ) extends SymbolSearchVisitor { + def visit(top: SymbolSearchCandidate): Int = { var added = 0 for { diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/completions/Completions.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/completions/Completions.scala index 3a586d82c0c..a0671887dcb 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/completions/Completions.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/completions/Completions.scala @@ -76,19 +76,21 @@ trait Completions { this: MetalsGlobal => val packageSymbols: mutable.Map[String, Option[Symbol]] = mutable.Map.empty[String, Option[Symbol]] - def packageSymbolFromString(symbol: String): Option[Symbol] = { - packageSymbols.getOrElseUpdate( - symbol, { - val fqn = symbol.stripSuffix("/").replace('/', '.') - try { - Some(rootMirror.staticPackage(fqn)) - } catch { - case NonFatal(_) => - None + def packageSymbolFromString(symbol: String): Option[Symbol] = + if (symbol == "_empty_/") Some(rootMirror.EmptyPackage) + else { + packageSymbols.getOrElseUpdate( + symbol, { + val fqn = symbol.stripSuffix("/").replace('/', '.') + try { + Some(rootMirror.staticPackage(fqn)) + } catch { + case NonFatal(_) => + None + } } - } - ) - } + ) + } /** * Returns a high number for less relevant symbols and low number for relevant numbers. diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala index 3d5c7c73e41..0fe3963e477 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/ScalaPresentationCompiler.scala @@ -24,6 +24,7 @@ import scala.meta.internal.mtags.MtagsEnrichments.given import scala.meta.internal.pc.completions.CompletionProvider import scala.meta.internal.pc.completions.OverrideCompletions import scala.meta.pc.* +import scala.meta.pc.{PcSymbolInformation as IPcSymbolInformation} import dotty.tools.dotc.reporting.StoreReporter import org.eclipse.lsp4j.DocumentHighlight @@ -187,6 +188,21 @@ case class ScalaPresentationCompiler( def diagnosticsForDebuggingPurposes(): ju.List[String] = List[String]().asJava + override def info( + symbol: String + ): CompletableFuture[Optional[IPcSymbolInformation]] = + compilerAccess.withNonInterruptableCompiler[Optional[IPcSymbolInformation]]( + None + )( + Optional.empty(), + EmptyCancelToken, + ) { access => + SymbolInformationProvider(using access.compiler().currentCtx) + .info(symbol) + .map(_.asJava) + .asJava + } + def semanticdbTextDocument( filename: URI, code: String, diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/SymbolInformationProvider.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/SymbolInformationProvider.scala new file mode 100644 index 00000000000..f45eb12c98f --- /dev/null +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/SymbolInformationProvider.scala @@ -0,0 +1,126 @@ +package scala.meta.internal.pc + +import scala.util.control.NonFatal + +import scala.meta.internal.mtags.MtagsEnrichments.metalsDealias +import scala.meta.pc.PcSymbolKind +import scala.meta.pc.PcSymbolProperty + +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Denotations.Denotation +import dotty.tools.dotc.core.Denotations.MultiDenotation +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.Names.* +import dotty.tools.dotc.core.StdNames.nme +import dotty.tools.dotc.core.Symbols.* + +class SymbolInformationProvider(using Context): + private def toSymbols( + pkg: String, + parts: List[(String, Boolean)], + ): List[Symbol] = + def collectSymbols(denotation: Denotation): List[Symbol] = + denotation match + case MultiDenotation(denot1, denot2) => + collectSymbols(denot1) ++ collectSymbols(denot2) + case denot => List(denot.symbol) + + def loop( + owners: List[Symbol], + parts: List[(String, Boolean)], + ): List[Symbol] = + parts match + case (head, isClass) :: tl => + val foundSymbols = + owners.flatMap { owner => + val next = + if isClass then owner.info.member(typeName(head)) + else owner.info.member(termName(head)) + collectSymbols(next).filter(_.exists) + } + if foundSymbols.nonEmpty then loop(foundSymbols, tl) + else Nil + case Nil => owners + + val pkgSym = + if pkg == "_empty_" then requiredPackage(nme.EMPTY_PACKAGE) + else requiredPackage(pkg) + loop(List(pkgSym), parts) + end toSymbols + + def info(symbol: String): Option[PcSymbolInformation] = + val index = symbol.lastIndexOf("/") + val pkg = normalizePackage(symbol.take(index + 1)) + + def loop( + symbol: String, + acc: List[(String, Boolean)], + ): List[(String, Boolean)] = + if symbol.isEmpty() then acc.reverse + else + val newSymbol = symbol.takeWhile(c => c != '.' && c != '#') + val rest = symbol.drop(newSymbol.size) + loop(rest.drop(1), (newSymbol, rest.headOption.exists(_ == '#')) :: acc) + val names = + loop(symbol.drop(index + 1).takeWhile(_ != '('), List.empty) + + val foundSymbols = + try toSymbols(pkg, names) + catch case NonFatal(e) => Nil + + val (searchedSymbol, alternativeSymbols) = + foundSymbols.partition(compilerSymbol => + SemanticdbSymbols.symbolName(compilerSymbol) == symbol + ) + + searchedSymbol match + case Nil => None + case sym :: _ => + val classSym = if sym.isClass then sym else sym.moduleClass + val parents = + if classSym.isClass + then classSym.asClass.parentSyms.map(SemanticdbSymbols.symbolName) + else Nil + val dealisedSymbol = + if sym.isAliasType then sym.info.metalsDealias.typeSymbol else sym + val classOwner = + sym.ownersIterator.drop(1).find(s => s.isClass || s.is(Flags.Module)) + val overridden = sym.denot.allOverriddenSymbols.toList + + val pcSymbolInformation = + PcSymbolInformation( + symbol = SemanticdbSymbols.symbolName(sym), + kind = getSymbolKind(sym), + parents = parents, + dealiasedSymbol = SemanticdbSymbols.symbolName(dealisedSymbol), + classOwner = classOwner.map(SemanticdbSymbols.symbolName), + overriddenSymbols = overridden.map(SemanticdbSymbols.symbolName), + alternativeSymbols = + alternativeSymbols.map(SemanticdbSymbols.symbolName), + properties = + if sym.is(Flags.Abstract) then List(PcSymbolProperty.ABSTRACT) + else Nil, + ) + + Some(pcSymbolInformation) + end match + end info + + private def getSymbolKind(sym: Symbol): PcSymbolKind = + if sym.isAllOf(Flags.JavaInterface) then PcSymbolKind.INTERFACE + else if sym.is(Flags.Trait) then PcSymbolKind.TRAIT + else if sym.isConstructor then PcSymbolKind.CONSTRUCTOR + else if sym.isPackageObject then PcSymbolKind.PACKAGE_OBJECT + else if sym.isClass then PcSymbolKind.CLASS + else if sym.is(Flags.Macro) then PcSymbolKind.MACRO + else if sym.is(Flags.Local) then PcSymbolKind.LOCAL + else if sym.is(Flags.Method) then PcSymbolKind.METHOD + else if sym.is(Flags.Param) then PcSymbolKind.PARAMETER + else if sym.is(Flags.Package) then PcSymbolKind.PACKAGE + else if sym.is(Flags.TypeParam) then PcSymbolKind.TYPE_PARAMETER + else if sym.isType then PcSymbolKind.TYPE + else PcSymbolKind.UNKNOWN_KIND + + private def normalizePackage(pkg: String): String = + pkg.replace("/", ".").stripSuffix(".") +end SymbolInformationProvider diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/GlobalSymbolIndex.scala b/mtags/src/main/scala/scala/meta/internal/mtags/GlobalSymbolIndex.scala index be54dca8f34..27cc008ba38 100644 --- a/mtags/src/main/scala/scala/meta/internal/mtags/GlobalSymbolIndex.scala +++ b/mtags/src/main/scala/scala/meta/internal/mtags/GlobalSymbolIndex.scala @@ -55,7 +55,7 @@ trait GlobalSymbolIndex { file: AbsolutePath, sourceDirectory: Option[AbsolutePath], dialect: Dialect - ): List[String] + ): Option[IndexingResult] /** * Index a jar or zip file containing Scala and Java source files. @@ -88,7 +88,7 @@ trait GlobalSymbolIndex { def addSourceJar( jar: AbsolutePath, dialect: Dialect - ): List[(String, AbsolutePath)] + ): List[IndexingResult] /** * The same as `addSourceJar` except for directories @@ -96,7 +96,7 @@ trait GlobalSymbolIndex { def addSourceDirectory( dir: AbsolutePath, dialect: Dialect - ): List[(String, AbsolutePath)] + ): List[IndexingResult] } @@ -112,3 +112,9 @@ case class SymbolDefinition( def isExact: Boolean = querySymbol == definitionSymbol } + +case class IndexingResult( + path: AbsolutePath, + topLevels: List[String], + overrides: List[(String, List[OverriddenSymbol])] +) diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/Mtags.scala b/mtags/src/main/scala/scala/meta/internal/mtags/Mtags.scala index 3a80eca84bc..e52b9fcfa11 100644 --- a/mtags/src/main/scala/scala/meta/internal/mtags/Mtags.scala +++ b/mtags/src/main/scala/scala/meta/internal/mtags/Mtags.scala @@ -49,6 +49,35 @@ final class Mtags(implicit rc: ReportContext) { } } + def indexWithOverrides( + path: AbsolutePath, + dialect: Dialect = dialects.Scala213, + includeMembers: Boolean = false + ): (TextDocument, MtagsIndexer.AllOverrides) = { + val input = path.toInput + val language = input.toLanguage + if (language.isJava || language.isScala) { + val mtags = + if (language.isJava) + new JavaToplevelMtags(input) + else + new ScalaToplevelMtags( + input, + includeInnerClasses = true, + includeMembers, + dialect + ) + addLines(language, input.text) + val doc = + Mtags.stdLibPatches.patchDocument( + path, + mtags.index() + ) + val overrides = mtags.overrides() + (doc, overrides) + } else (TextDocument(), Nil) + } + def topLevelSymbols( path: AbsolutePath, dialect: Dialect = dialects.Scala213 @@ -115,7 +144,7 @@ object Mtags { input: Input.VirtualFile, dialect: Dialect, includeMembers: Boolean = true - )(implicit rc: ReportContext = EmptyReportContext): TextDocument = { + )(implicit rc: ReportContext = EmptyReportContext): TextDocument = input.toLanguage match { case Language.JAVA => new JavaMtags(input, includeMembers = true).index() @@ -126,7 +155,6 @@ object Mtags { case _ => TextDocument() } - } def toplevels( path: AbsolutePath, @@ -135,6 +163,16 @@ object Mtags { new Mtags().toplevels(path, dialect) } + def indexWithOverrides( + path: AbsolutePath, + dialect: Dialect = dialects.Scala213, + includeMembers: Boolean = false + )(implicit + rc: ReportContext = EmptyReportContext + ): (TextDocument, MtagsIndexer.AllOverrides) = { + new Mtags().indexWithOverrides(path, dialect, includeMembers) + } + def topLevelSymbols( path: AbsolutePath, dialect: Dialect = dialects.Scala213 @@ -175,4 +213,5 @@ object Mtags { } } + } diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/MtagsIndexer.scala b/mtags/src/main/scala/scala/meta/internal/mtags/MtagsIndexer.scala index 590a7b3a1f1..92694a7e703 100644 --- a/mtags/src/main/scala/scala/meta/internal/mtags/MtagsIndexer.scala +++ b/mtags/src/main/scala/scala/meta/internal/mtags/MtagsIndexer.scala @@ -15,6 +15,8 @@ trait MtagsIndexer { def language: Language def indexRoot(): Unit def input: Input.VirtualFile + // should only be called after `index`/`indexRoot` + def overrides(): MtagsIndexer.AllOverrides = Nil def index(): s.TextDocument = { indexRoot() s.TextDocument( @@ -170,3 +172,7 @@ trait MtagsIndexer { Symbols.Global(Symbols.RootPackage, signature) else Symbols.Global(currentOwner, signature) } + +object MtagsIndexer { + type AllOverrides = List[(String, List[OverriddenSymbol])] +} diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/OnDemandSymbolIndex.scala b/mtags/src/main/scala/scala/meta/internal/mtags/OnDemandSymbolIndex.scala index 8f5beeca2c2..a7b47aa3ba3 100644 --- a/mtags/src/main/scala/scala/meta/internal/mtags/OnDemandSymbolIndex.scala +++ b/mtags/src/main/scala/scala/meta/internal/mtags/OnDemandSymbolIndex.scala @@ -65,7 +65,7 @@ final class OnDemandSymbolIndex( override def addSourceDirectory( dir: AbsolutePath, dialect: Dialect - ): List[(String, AbsolutePath)] = + ): List[IndexingResult] = tryRun( dir, List.empty, @@ -77,7 +77,7 @@ final class OnDemandSymbolIndex( override def addSourceJar( jar: AbsolutePath, dialect: Dialect - ): List[(String, AbsolutePath)] = + ): List[IndexingResult] = tryRun( jar, List.empty, { @@ -109,10 +109,10 @@ final class OnDemandSymbolIndex( source: AbsolutePath, sourceDirectory: Option[AbsolutePath], dialect: Dialect - ): List[String] = + ): Option[IndexingResult] = tryRun( source, - List.empty, { + None, { indexedSources += 1 getOrCreateBucket(dialect).addSourceFile(source, sourceDirectory) } diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/OverriddenSymbol.scala b/mtags/src/main/scala/scala/meta/internal/mtags/OverriddenSymbol.scala new file mode 100644 index 00000000000..9f75858840c --- /dev/null +++ b/mtags/src/main/scala/scala/meta/internal/mtags/OverriddenSymbol.scala @@ -0,0 +1,5 @@ +package scala.meta.internal.mtags + +sealed trait OverriddenSymbol +case class UnresolvedOverriddenSymbol(name: String) extends OverriddenSymbol +case class ResolvedOverriddenSymbol(symbol: String) extends OverriddenSymbol diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala b/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala index 46c4ba40a72..2e7484e05ed 100644 --- a/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala +++ b/mtags/src/main/scala/scala/meta/internal/mtags/ScalaToplevelMtags.scala @@ -48,6 +48,14 @@ class ScalaToplevelMtags( )(implicit rc: ReportContext) extends MtagsIndexer { + override def overrides(): List[(String, List[OverriddenSymbol])] = + overridden.result + + private val overridden = List.newBuilder[(String, List[OverriddenSymbol])] + + private def addOverridden(symbols: List[OverriddenSymbol]) = + overridden += ((currentOwner, symbols)) + import ScalaToplevelMtags._ override def language: Language = Language.SCALA @@ -89,7 +97,7 @@ class ScalaToplevelMtags( isImplicit = isImplicit ) ) - def newExpectCaseClassTemplate: Some[ExpectTemplate] = + def newExpectCaseClassTemplate(): Some[ExpectTemplate] = Some( ExpectTemplate( indent, @@ -213,7 +221,7 @@ class ScalaToplevelMtags( emitMember(isPackageObject = false, owner) val template = expectTemplate match { case Some(expect) if expect.isCaseClassConstructor => - newExpectCaseClassTemplate + newExpectCaseClassTemplate() case Some(expect) => newExpectClassTemplate(expect.isImplicit) case _ => @@ -280,7 +288,7 @@ class ScalaToplevelMtags( currRegion.withTermOwner(owner), expectTemplate ) - case DEF | VAL | VAR | GIVEN | TYPE + case DEF | VAL | VAR | GIVEN if expectTemplate.map(!_.isExtension).getOrElse(true) => if (needEmitTermMember()) { withOwner(currRegion.termOwner) { @@ -288,6 +296,13 @@ class ScalaToplevelMtags( } } else scanner.nextToken() loop(indent, isAfterNewline = false, currRegion, newExpectIgnoreBody) + case TYPE if expectTemplate.map(!_.isExtension).getOrElse(true) => + if (needEmitMember(currRegion) && !prevWasDot) { + withOwner(currRegion.termOwner) { + emitType(needEmitTermMember()) + } + } else scanner.nextToken() + loop(indent, isAfterNewline = false, currRegion, newExpectIgnoreBody) case IMPORT | EXPORT => // skip imports because they might have `given` kw acceptToStatSep() @@ -463,6 +478,24 @@ class ScalaToplevelMtags( currRegion.changeCaseClassState(true), nextExpectTemplate ) + case EXTENDS => + val (overridden, maybeNewIndent) = findOverridden(List.empty) + expectTemplate.map(tmpl => + withOwner(tmpl.owner) { + addOverridden( + overridden.reverse + .map(_.name) + .distinct + .map(UnresolvedOverriddenSymbol(_)) + ) + } + ) + loop( + maybeNewIndent.getOrElse(indent), + isAfterNewline = maybeNewIndent.isDefined, + currRegion, + expectTemplate + ) case IDENTIFIER if currRegion.emitIdentifier && includeMembers => withOwner(currRegion.owner) { term( @@ -480,17 +513,14 @@ class ScalaToplevelMtags( ) case CASE => val nextIsNewLine = nextIsNL() - val (shouldCreateClassTemplate, isAfterNewline) = + val isAfterNewline = emitEnumCases(region, nextIsNewLine) - val nextExpectTemplate = - if (shouldCreateClassTemplate) newExpectClassTemplate() - else expectTemplate.filter(!_.isPackageBody) loop( indent, isAfterNewline, currRegion, - if (scanner.curr.token == CLASS) newExpectCaseClassTemplate - else nextExpectTemplate + if (scanner.curr.token == CLASS) newExpectCaseClassTemplate() + else newExpectClassTemplate() ) case IMPLICIT => scanner.nextToken() @@ -555,6 +585,57 @@ class ScalaToplevelMtags( buf.result() } + @tailrec + private def acceptAllAfterOverriddenIdentifier(): Option[Int] = { + val maybeNewIndent = acceptTrivia() + scanner.curr.token match { + case LPAREN => + acceptBalancedDelimeters(LPAREN, RPAREN) + acceptAllAfterOverriddenIdentifier() + case LBRACKET => + acceptBalancedDelimeters(LBRACKET, RBRACKET) + acceptAllAfterOverriddenIdentifier() + case _ => maybeNewIndent + } + + } + + @tailrec + private def findOverridden( + acc0: List[Identifier] + ): (List[Identifier], Option[Int]) = { + val maybeNewIndent0 = acceptTrivia() + scanner.curr.token match { + case IDENTIFIER => + @tailrec + def getIdentifier(): (Option[Identifier], Option[Int]) = { + val currentIdentifier = newIdentifier + val maybeNewIndent = acceptAllAfterOverriddenIdentifier() + scanner.curr.token match { + case DOT => + scanner.nextToken() + getIdentifier() + case _ => (currentIdentifier, maybeNewIndent) + } + } + val (identifier, maybeNewIndent) = getIdentifier() + val acc = identifier.toList ++ acc0 + scanner.curr.token match { + case WITH => findOverridden(acc) + case COMMA => findOverridden(acc) + case _ => (acc, maybeNewIndent) + } + case LBRACE => + acceptBalancedDelimeters(LBRACE, RBRACE) + val maybeNewIndent = acceptTrivia() + scanner.curr.token match { + case WITH => findOverridden(acc0) + case _ => (acc0, maybeNewIndent) + } + case _ => (acc0, maybeNewIndent0) + } + } + /** * Enters a toplevel symbol such as class, trait or object */ @@ -581,8 +662,59 @@ class ScalaToplevelMtags( scanner.nextToken() } + def emitType(emitTermMember: Boolean): Option[Unit] = { + acceptTrivia() + newIdentifier + .map { ident => + val typeSymbol = symbol(Descriptor.Type(ident.name)) + if (emitTermMember) { + tpe(ident.name, ident.pos, Kind.TYPE, 0) + } + nextIsNL() + @tailrec + def loop( + name: Option[String], + isAfterEq: Boolean = false + ): Option[String] = { + scanner.curr.token match { + case SEMI => name + case _ if isNewline | isDone => name + case EQUALS => + scanner.nextToken() + loop(name, isAfterEq = true) + case TYPELAMBDAARROW | WHITESPACE => + scanner.nextToken() + loop(name, isAfterEq) + case LBRACKET => + acceptBalancedDelimeters(LBRACKET, RBRACKET) + scanner.nextToken() + loop(name, isAfterEq) + case LBRACE => + acceptBalancedDelimeters(LBRACE, RBRACE) + scanner.nextToken() + loop(name, isAfterEq) + case IDENTIFIER + if isAfterEq && scanner.curr.name != "|" && scanner.curr.name != "&" => + val optName = selectName() + loop(optName, isAfterEq) + case _ if isAfterEq => None + case _ => + scanner.nextToken() + loop(name) + } + } + + loop(name = None).foreach { rhsName => + overridden += (( + typeSymbol, + List(UnresolvedOverriddenSymbol(rhsName)) + )) + } + } + } + /** - * Enters a global element (def/val/var/type) + * Enters a global element (def/val/var/given) */ def emitTerm(region: Region): Unit = { val kind = scanner.curr.token @@ -608,10 +740,6 @@ class ScalaToplevelMtags( ) resetRegion(region) }) - case TYPE => - newIdentifier.foreach { name => - tpe(name.name, name.pos, Kind.TYPE, 0) - } case DEF => methodIdentifier.foreach(name => method( @@ -638,7 +766,7 @@ class ScalaToplevelMtags( private def emitEnumCases( region: Region, nextIsNewLine: Boolean - ): (Boolean, Boolean) = { + ): Boolean = { def ownerCompanionObject = if (currentOwner.endsWith("#")) s"${currentOwner.stripSuffix("#")}." @@ -648,19 +776,22 @@ class ScalaToplevelMtags( val pos = newPosition val name = scanner.curr.name def emitEnumCaseObject() = { - withOwner(ownerCompanionObject) { - term( - name, - pos, - Kind.METHOD, - SymbolInformation.Property.VAL.value - ) - } + currentOwner = ownerCompanionObject + term( + name, + pos, + Kind.METHOD, + SymbolInformation.Property.VAL.value + ) } + def emitOverridden() = addOverridden( + List(ResolvedOverriddenSymbol(region.owner)) + ) val nextIsNewLine0 = nextIsNL() scanner.curr.token match { case COMMA => emitEnumCaseObject() + emitOverridden() resetRegion(region) val nextIsNewLine1 = nextIsNL() emitEnumCases(region, nextIsNewLine1) @@ -672,12 +803,15 @@ class ScalaToplevelMtags( Kind.CLASS, SymbolInformation.Property.VAL.value ) - (true, false) - case _ => + false + case tok => emitEnumCaseObject() - (false, nextIsNewLine0) + if (tok != EXTENDS) { + emitOverridden() + } + nextIsNewLine0 } - case _ => (false, nextIsNewLine) + case _ => nextIsNewLine } } @@ -738,7 +872,9 @@ class ScalaToplevelMtags( } } - private def acceptTrivia(): Unit = { + private def acceptTrivia(): Option[Int] = { + var includedNewline = false + var indent = 0 scanner.nextToken() while ( !isDone && @@ -747,8 +883,15 @@ class ScalaToplevelMtags( case _ => false }) ) { + if (isNewline) { + includedNewline = true + indent = 0 + } else if (scanner.curr.token == WHITESPACE) { + indent += 1 + } scanner.nextToken() } + if (includedNewline) Some(indent) else None } private def nextIsNL(): Boolean = { @@ -776,7 +919,24 @@ class ScalaToplevelMtags( reportError("identifier") None } + } + def selectName(): Option[String] = { + @tailrec + def loop(last: Option[String]): Option[String] = { + scanner.curr.token match { + case IDENTIFIER => + val name = scanner.curr.name + scanner.nextToken() + loop(Some(name)) + case DOT => + scanner.nextToken() + loop(last) + case _ => + last + } + } + loop(last = None) } /** diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/SymbolIndexBucket.scala b/mtags/src/main/scala/scala/meta/internal/mtags/SymbolIndexBucket.scala index 6e3d5cee52a..ffac79cce78 100644 --- a/mtags/src/main/scala/scala/meta/internal/mtags/SymbolIndexBucket.scala +++ b/mtags/src/main/scala/scala/meta/internal/mtags/SymbolIndexBucket.scala @@ -47,29 +47,36 @@ class SymbolIndexBucket( def close(): Unit = sourceJars.close() - def addSourceDirectory(dir: AbsolutePath): List[(String, AbsolutePath)] = { + def addSourceDirectory( + dir: AbsolutePath + ): List[IndexingResult] = { if (sourceJars.addEntry(dir.toNIO)) { dir.listRecursive.toList.flatMap { case source if source.isScala => - addSourceFile(source, Some(dir)).map(sym => (sym, source)) + addSourceFile(source, Some(dir)) case _ => - List.empty + None } - } else - List.empty + } else List.empty } - def addSourceJar(jar: AbsolutePath): List[(String, AbsolutePath)] = { + def addSourceJar( + jar: AbsolutePath + ): List[IndexingResult] = { if (sourceJars.addEntry(jar.toNIO)) { FileIO.withJarFileSystem(jar, create = false) { root => try { root.listRecursive.toList.flatMap { case source if source.isScala => - addSourceFile(source, None).map(sym => (sym, source)) + addSourceFile(source, None) case source if source.isJava => - addJavaSourceFile(source).map(sym => (sym, source)) + addJavaSourceFile(source) match { + case Nil => None + case topLevels => + Some(IndexingResult(source, topLevels, overrides = Nil)) + } case _ => - List.empty + None } } catch { // this happens in broken jars since file from FileWalker should exists @@ -124,28 +131,34 @@ class SymbolIndexBucket( def addSourceFile( source: AbsolutePath, sourceDirectory: Option[AbsolutePath] - ): List[String] = { - val symbols = indexSource(source, dialect, sourceDirectory) - symbols.foreach { symbol => + ): Option[IndexingResult] = { + val IndexingResult(path, topLevels, overrides) = + indexSource(source, dialect, sourceDirectory) + topLevels.foreach { symbol => toplevels.updateWith(symbol) { case Some(acc) => Some(acc + source) case None => Some(Set(source)) } } - symbols + Some(IndexingResult(path, topLevels, overrides)) } private def indexSource( source: AbsolutePath, dialect: Dialect, sourceDirectory: Option[AbsolutePath] - ): List[String] = { + ): IndexingResult = { val uri = source.toIdeallyRelativeURI(sourceDirectory) - val sourceToplevels = mtags.topLevelSymbols(source, dialect) - if (source.isAmmoniteScript) - sourceToplevels - else - sourceToplevels.filter(sym => !isTrivialToplevelSymbol(uri, sym)) + val (doc, overrides) = mtags.indexWithOverrides(source, dialect) + val sourceTopLevels = + doc.occurrences.iterator + .filterNot(_.symbol.isPackage) + .map(_.symbol) + val topLevels = + if (source.isAmmoniteScript) sourceTopLevels.toList + else + sourceTopLevels.filter(sym => !isTrivialToplevelSymbol(uri, sym)).toList + IndexingResult(source, topLevels, overrides) } // Returns true if symbol is com/foo/Bar# and path is /com/foo/Bar.scala diff --git a/tests/mtest/src/main/scala/tests/DelegatingGlobalSymbolIndex.scala b/tests/mtest/src/main/scala/tests/DelegatingGlobalSymbolIndex.scala index 6d439242638..782503ac6ce 100644 --- a/tests/mtest/src/main/scala/tests/DelegatingGlobalSymbolIndex.scala +++ b/tests/mtest/src/main/scala/tests/DelegatingGlobalSymbolIndex.scala @@ -26,19 +26,19 @@ class DelegatingGlobalSymbolIndex( file: AbsolutePath, sourceDirectory: Option[AbsolutePath], dialect: Dialect - ): List[String] = { + ): Option[mtags.IndexingResult] = { underlying.addSourceFile(file, sourceDirectory, dialect) } def addSourceJar( jar: AbsolutePath, dialect: Dialect - ): List[(String, AbsolutePath)] = { + ): List[mtags.IndexingResult] = { underlying.addSourceJar(jar, dialect) } def addSourceDirectory( dir: AbsolutePath, dialect: Dialect - ): List[(String, AbsolutePath)] = { + ): List[mtags.IndexingResult] = { underlying.addSourceDirectory(dir, dialect) } } diff --git a/tests/slow/src/test/scala/tests/feature/ImplementationCrossLspSuite.scala b/tests/slow/src/test/scala/tests/feature/ImplementationCrossLspSuite.scala new file mode 100644 index 00000000000..8953492c41c --- /dev/null +++ b/tests/slow/src/test/scala/tests/feature/ImplementationCrossLspSuite.scala @@ -0,0 +1,88 @@ +package tests.feature + +import scala.meta.internal.metals.{BuildInfo => V} + +import tests.BaseImplementationSuite + +class ImplementationCrossLspSuite + extends BaseImplementationSuite("implementation-cross") { + + checkSymbols( + "seqFactory", + """package a + |import scala.collection.SeqFactory + |import scala.collection.mutable + | + |object A extends Seq@@Factory[List] { + | + | override def from[A](source: IterableOnce[A]): List[A] = ??? + | + | override def empty[A]: List[A] = ??? + | + | override def newBuilder[A]: mutable.Builder[A, List[A]] = ??? + | + | + |} + |""".stripMargin, + """|a/A. + |scala/collection/ClassTagSeqFactory.AnySeqDelegate# + |scala/collection/IndexedSeq. + |scala/collection/LinearSeq. + |scala/collection/Seq. + |scala/collection/SeqFactory.Delegate# + |scala/collection/StrictOptimizedSeqFactory# + |scala/collection/immutable/IndexedSeq. + |scala/collection/immutable/LazyList. + |scala/collection/immutable/LinearSeq. + |scala/collection/immutable/List. + |scala/collection/immutable/Queue. + |scala/collection/immutable/Seq. + |scala/collection/immutable/Stream. + |scala/collection/immutable/Vector. + |scala/collection/mutable/ArrayBuffer. + |scala/collection/mutable/ArrayDeque. + |scala/collection/mutable/Buffer. + |scala/collection/mutable/IndexedBuffer. + |scala/collection/mutable/IndexedSeq. + |scala/collection/mutable/ListBuffer. + |scala/collection/mutable/Queue. + |scala/collection/mutable/Seq. + |scala/collection/mutable/Stack. + |scala/jdk/AnyAccumulator. + |""".stripMargin, + scalaVersion = V.scala3, + ) + + check( + "basic-method-params", + """|/a/src/main/scala/a/Main.scala + |package a + |trait LivingBeing{ + | def sound: Int + | def s@@ound(times : Int): Int = 1 + | def sound(start : Long): Int = 1 + |} + |abstract class Animal extends LivingBeing{} + |class Dog extends Animal{ + | def sound = 1 + | override def sound(times : Long) = 1 + | override def <>(times : Int) = 1 + |} + |class Cat extends Animal{ + | override def <>(times : Int) = 1 + | override def sound = 1 + |} + |""".stripMargin, + scalaVersion = Some(V.scala3), + ) + + check( + "empty-pkg", + """|/a/src/main/scala/a/Main.scala + |trait A@@A + |class <> extends A@@A + |""".stripMargin, + scalaVersion = Some(V.scala3), + ) + +} diff --git a/tests/unit/src/main/scala/tests/BaseImplementationSuite.scala b/tests/unit/src/main/scala/tests/BaseImplementationSuite.scala new file mode 100644 index 00000000000..3626bbcb383 --- /dev/null +++ b/tests/unit/src/main/scala/tests/BaseImplementationSuite.scala @@ -0,0 +1,60 @@ +package tests + +import scala.concurrent.Future + +import scala.meta.internal.metals.MetalsEnrichments._ + +abstract class BaseImplementationSuite(name: String) + extends BaseRangesSuite(name) { + + override def assertCheck( + filename: String, + edit: String, + expected: Map[String, String], + base: Map[String, String], + ): Future[Unit] = { + server.assertImplementation( + filename, + edit, + expected.toMap, + base.toMap, + ) + } + + def checkSymbols( + name: String, + fileContents: String, + expectedSymbols: String, + scalaVersion: String = BuildInfo.scalaVersion, + ): Unit = + test(name) { + val fileName = "a/src/main/scala/a/Main.scala" + cleanWorkspace() + for { + _ <- initialize( + s"""/metals.json + |{"a": + | { + | "scalaVersion" : "$scalaVersion" + | } + |} + |/$fileName + |${fileContents.replace("@@", "")} + """.stripMargin + ) + _ <- server.didOpen(fileName) + locations <- server.implementation(fileName, fileContents) + definitions <- + Future.sequence( + locations.map(location => + server.server.definitionResult( + location.toTextDocumentPositionParams + ) + ) + ) + symbols = definitions.map(_.symbol).sorted + _ = assertNoDiff(symbols.mkString("\n"), expectedSymbols) + _ <- server.shutdown() + } yield () + } +} diff --git a/tests/unit/src/main/scala/tests/BaseRangesSuite.scala b/tests/unit/src/main/scala/tests/BaseRangesSuite.scala index 8304e59eabb..7b96b99f25f 100644 --- a/tests/unit/src/main/scala/tests/BaseRangesSuite.scala +++ b/tests/unit/src/main/scala/tests/BaseRangesSuite.scala @@ -20,6 +20,8 @@ abstract class BaseRangesSuite(name: String) extends BaseLspSuite(name) { name: TestOptions, input: String, scalaVersion: Option[String] = None, + additionalLibraryDependencies: List[String] = Nil, + scalacOptions: List[String] = Nil, )(implicit loc: Location ): Unit = { @@ -51,7 +53,8 @@ abstract class BaseRangesSuite(name: String) extends BaseLspSuite(name) { |{"a": | { | "scalaVersion" : "$actualScalaVersion", - | "libraryDependencies": ${toJsonArray(libraryDependencies)} + | "libraryDependencies": ${toJsonArray(libraryDependencies ++ additionalLibraryDependencies)}, + | "scalacOptions": ${toJsonArray(scalacOptions)} | } |} |${input diff --git a/tests/unit/src/main/scala/tests/TestingServer.scala b/tests/unit/src/main/scala/tests/TestingServer.scala index e18df026ba1..29fadfd13f3 100644 --- a/tests/unit/src/main/scala/tests/TestingServer.scala +++ b/tests/unit/src/main/scala/tests/TestingServer.scala @@ -1610,12 +1610,19 @@ final case class TestingServer( base: Map[String, String], ): Future[Map[String, String]] = { Debug.printEnclosing() + implementation(filename, query).map( + TestRanges.renderLocationsAsString(base, _) + ) + } + + def implementation( + filename: String, + query: String, + ): Future[List[Location]] = { for { (_, params) <- offsetParams(filename, query, workspace) implementations <- fullServer.implementation(params).asScala - } yield { - TestRanges.renderLocationsAsString(base, implementations.asScala.toList) - } + } yield implementations.asScala.toList } def getReferenceLocations( diff --git a/tests/unit/src/test/scala/tests/ImplementationLspSuite.scala b/tests/unit/src/test/scala/tests/ImplementationLspSuite.scala index 771698e3461..0411b7c22e2 100644 --- a/tests/unit/src/test/scala/tests/ImplementationLspSuite.scala +++ b/tests/unit/src/test/scala/tests/ImplementationLspSuite.scala @@ -1,7 +1,6 @@ package tests -import scala.concurrent.Future -class ImplementationLspSuite extends BaseRangesSuite("implementation") { +class ImplementationLspSuite extends BaseImplementationSuite("implementation") { check( "basic", @@ -255,8 +254,9 @@ class ImplementationLspSuite extends BaseRangesSuite("implementation") { |""".stripMargin, ) + // we currently don't collect information about overridden symbols in JavaMtags check( - "java-classes", + "java-classes".ignore, """|/a/src/main/scala/a/Main.scala |package a |class <> extends Exce@@ption @@ -521,6 +521,9 @@ class ImplementationLspSuite extends BaseRangesSuite("implementation") { | case object <> extends Animal |} |""".stripMargin, + additionalLibraryDependencies = + List("io.circe::circe-generic-extras:0.14.0"), + scalacOptions = List("-Ymacro-annotations"), ) check( @@ -536,6 +539,23 @@ class ImplementationLspSuite extends BaseRangesSuite("implementation") { |} |""".stripMargin, ) + + check( + "local-methods", + """|/a/src/main/scala/a/Main.scala + |object Test { + | def main { + | trait A { + | def f@@oo(): Int + | } + | class B extends A { + | def <>(): Int = 1 + | } + | } + |} + |""".stripMargin, + ) + check( "type-implementation", """|/a/src/main/scala/a/Main.scala @@ -553,7 +573,9 @@ class ImplementationLspSuite extends BaseRangesSuite("implementation") { check( "java-implementation", - """|/a/src/main/scala/a/Main.java + """|/a/src/main/scala/a/Main.scala + |// empty scala file, so Scala pc is loaded + |/a/src/main/scala/a/Main.java |package a; |public class Main { | abstract class A { @@ -585,20 +607,59 @@ class ImplementationLspSuite extends BaseRangesSuite("implementation") { |""".stripMargin, ) + checkSymbols( + "set", + """|package a + |class MySet[A] extends S@@et[A] { + | override def iterator: Iterator[A] = ??? + | override def contains(elem: A): Boolean = ??? + | override def incl(elem: A): Set[A] = ??? + | override def excl(elem: A): Set[A] = ??? + |} + |""".stripMargin, + """|a/MySet# + |scala/Enumeration#ValueSet# + |scala/collection/immutable/AbstractSet# + |scala/collection/immutable/BitSet# + |scala/collection/immutable/BitSet.BitSet1# + |scala/collection/immutable/BitSet.BitSet2# + |scala/collection/immutable/BitSet.BitSetN# + |scala/collection/immutable/HashMap#HashKeySet# + |scala/collection/immutable/HashSet# + |scala/collection/immutable/ListSet# + |scala/collection/immutable/ListSet#Node# + |scala/collection/immutable/ListSet.EmptyListSet. + |scala/collection/immutable/MapOps#ImmutableKeySet# + |scala/collection/immutable/Set.EmptySet. + |scala/collection/immutable/Set.Set1# + |scala/collection/immutable/Set.Set2# + |scala/collection/immutable/Set.Set3# + |scala/collection/immutable/Set.Set4# + |scala/collection/immutable/SortedMapOps#ImmutableKeySortedSet# + |scala/collection/immutable/SortedSet# + |scala/collection/immutable/TreeSet# + |""".stripMargin, + ) + + checkSymbols( + "exception", + """package a + |class MyException extends Excep@@tion + |""".stripMargin, + """|a/MyException# + |scala/ScalaReflectionException# + |scala/reflect/internal/FatalError# + |scala/reflect/internal/MissingRequirementError# + |scala/reflect/internal/Positions#ValidateException# + |scala/reflect/macros/Enclosures#EnclosureException# + |scala/reflect/macros/ParseException# + |scala/reflect/macros/ReificationException# + |scala/reflect/macros/TypecheckException# + |scala/reflect/macros/UnexpectedReificationException# + |""".stripMargin, + ) + override protected def libraryDependencies: List[String] = List("org.scalatest::scalatest:3.2.16", "io.circe::circe-generic:0.12.0") - override def assertCheck( - filename: String, - edit: String, - expected: Map[String, String], - base: Map[String, String], - ): Future[Unit] = { - server.assertImplementation( - filename, - edit, - expected.toMap, - base.toMap, - ) - } } diff --git a/tests/unit/src/test/scala/tests/JarTopLevelsSuite.scala b/tests/unit/src/test/scala/tests/JarTopLevelsSuite.scala index d67afa3d4ed..c6444aee90d 100644 --- a/tests/unit/src/test/scala/tests/JarTopLevelsSuite.scala +++ b/tests/unit/src/test/scala/tests/JarTopLevelsSuite.scala @@ -2,10 +2,13 @@ package tests import java.nio.file.Files import java.nio.file.Path +import java.sql.PreparedStatement +import java.sql.Statement import scala.meta.internal.io.FileIO import scala.meta.internal.io.PlatformFileIO import scala.meta.internal.metals.JarTopLevels +import scala.meta.internal.mtags.UnresolvedOverriddenSymbol import scala.meta.io.AbsolutePath class JarTopLevelsSuite extends BaseTablesSuite { @@ -19,7 +22,7 @@ class JarTopLevelsSuite extends BaseTablesSuite { FileIO.withJarFileSystem(zip, create = true, close = true) { root => FileLayout.fromString( """|/foo.scala - |object Hello { + |case class Hello(i: Int) extends AnyVal { |}""".stripMargin, root, ) @@ -31,13 +34,20 @@ class JarTopLevelsSuite extends BaseTablesSuite { val fs = PlatformFileIO.newJarFileSystem(jar1, create = false) val filePath = AbsolutePath(fs.getPath("/foo.scala")) val toplevels = List("foo" -> filePath) - jarSymbols.putTopLevels(jar1, toplevels) + val overrides = + List((filePath, "foo/Hello#", UnresolvedOverriddenSymbol("AnyVal"))) + jarSymbols.putJarIndexingInfo(jar1, toplevels, overrides) val resultOption = jarSymbols.getTopLevels(jar1) assert(resultOption.isDefined) val result = resultOption.get assert(toplevels == result) + val resultOption1 = jarSymbols.getTypeHierarchy(jar1) + assert(resultOption1.isDefined) + val result1 = resultOption1.get + assert(overrides == result1) val noOption = jarSymbols.getTopLevels(jar2) assert(noOption.isEmpty) + assert(jarSymbols.getTypeHierarchy(jar2).isEmpty) } test("deleteNotUsed") { @@ -45,16 +55,58 @@ class JarTopLevelsSuite extends BaseTablesSuite { val fs = PlatformFileIO.newJarFileSystem(jar, create = false) val filePath = AbsolutePath(fs.getPath("/foo.scala")) val toplevels = List("foo" -> filePath) - jarSymbols.putTopLevels(jar, toplevels) + val overrides = + List((filePath, "foo/Hello#", UnresolvedOverriddenSymbol("AnyVal"))) + jarSymbols.putJarIndexingInfo(jar, toplevels, overrides) } jarSymbols.deleteNotUsedTopLevels(Array(jar1, jar1)) assert(jarSymbols.getTopLevels(jar1).isDefined) + assert(jarSymbols.getTypeHierarchy(jar1).isDefined) assert(jarSymbols.getTopLevels(jar2).isEmpty) + assert(jarSymbols.getTypeHierarchy(jar2).isEmpty) } test("noSymbols") { - jarSymbols.putTopLevels(jar1, List.empty) + jarSymbols.putJarIndexingInfo(jar1, List.empty, List.empty) val result = jarSymbols.getTopLevels(jar1) assert(result.isEmpty) } + + test("addTypeHierarchy") { + val fs = PlatformFileIO.newJarFileSystem(jar1, create = false) + val filePath = AbsolutePath(fs.getPath("/foo.scala")) + val toplevels = List("foo" -> filePath) + val overrides = + List((filePath, "foo/Hello#", UnresolvedOverriddenSymbol("AnyVal"))) + + var jarStmt: PreparedStatement = null + val jar = + try { + jarStmt = tables + .connect() + .prepareStatement( + s"insert into indexed_jar (md5) values (?)", + Statement.RETURN_GENERATED_KEYS, + ) + jarStmt.setString(1, tables.jarSymbols.getMD5Digest(jar1)) + jarStmt.executeUpdate() + val rs = jarStmt.getGeneratedKeys + rs.next() + rs.getInt("id") + } finally { + if (jarStmt != null) jarStmt.close() + } + tables.jarSymbols.putToplevels(jar, toplevels) + + assert(jarSymbols.getTopLevels(jar1).nonEmpty) + assert(jarSymbols.getTypeHierarchy(jar1).isEmpty) + + jarSymbols.addTypeHierarchyInfo(jar1, overrides) + val obtainedTopLevels = jarSymbols.getTopLevels(jar1) + assert(obtainedTopLevels.nonEmpty) + assert(obtainedTopLevels.get == toplevels) + val obtainedTypeHierarchy = jarSymbols.getTypeHierarchy(jar1) + assert(obtainedTypeHierarchy.nonEmpty) + assert(obtainedTypeHierarchy.get == overrides) + } } diff --git a/tests/unit/src/test/scala/tests/ScalaToplevelSuite.scala b/tests/unit/src/test/scala/tests/ScalaToplevelSuite.scala index 169f1de83fc..90bf52ca242 100644 --- a/tests/unit/src/test/scala/tests/ScalaToplevelSuite.scala +++ b/tests/unit/src/test/scala/tests/ScalaToplevelSuite.scala @@ -4,9 +4,10 @@ import java.nio.file.Files import scala.meta.Dialect import scala.meta.dialects -import scala.meta.inputs.Input import scala.meta.internal.metals.MetalsEnrichments._ import scala.meta.internal.mtags.Mtags +import scala.meta.internal.mtags.ResolvedOverriddenSymbol +import scala.meta.internal.mtags.UnresolvedOverriddenSymbol import scala.meta.io.AbsolutePath import munit.TestOptions @@ -55,7 +56,7 @@ class ScalaToplevelSuite extends BaseSuite { List( "_empty_/A.", "_empty_/A.foo().", "_empty_/A.Z#", "_empty_/B#", "_empty_/B#X#", "_empty_/B#foo().", "_empty_/B#v.", "_empty_/C#", - "_empty_/C#i.", "_empty_/D#", "_empty_/D.Da.", "_empty_/D.Db.", + "_empty_/C#i.", "_empty_/D#", "_empty_/D.Da. -> D", "_empty_/D.Db. -> D", "_empty_/D#getI().", "_empty_/D#i.", ), mode = All, @@ -103,7 +104,7 @@ class ScalaToplevelSuite extends BaseSuite { List( "_empty_/A.", "_empty_/A.foo().", "_empty_/A.Z#", "_empty_/B#", "_empty_/B#X#", "_empty_/B#foo().", "_empty_/C#", "_empty_/D#", - "_empty_/D.Da.", "_empty_/D.Db.", + "_empty_/D.Da. -> _empty_/D#", "_empty_/D.Db. -> _empty_/D#", ), mode = All, ) @@ -475,9 +476,10 @@ class ScalaToplevelSuite extends BaseSuite { | |enum NotPlanets{ case Vase } |""".stripMargin, - List("a/", "a/Planets#", "a/Planets.Earth.", "a/Planets.Mercury.", - "a/Planets#num.", "a/Planets.Venus.", "a/NotPlanets#", - "a/NotPlanets.Vase."), + List("a/", "a/Planets#", "a/Planets.Earth. -> Planets", + "a/Planets.Mercury. -> Planets", "a/Planets#num.", + "a/Planets.Venus. -> Planets", "a/NotPlanets#", + "a/NotPlanets.Vase. -> a/NotPlanets#"), dialect = dialects.Scala3, mode = All, ) @@ -498,9 +500,10 @@ class ScalaToplevelSuite extends BaseSuite { |enum NotPlanets: | case Vase |""".stripMargin, - List("a/", "a/Planets#", "a/Planets.Earth.", "a/Planets.Mercury.", - "a/Planets#num.", "a/Planets.Venus.", "a/NotPlanets#", - "a/NotPlanets.Vase."), + List("a/", "a/Planets#", "a/Planets.Earth. -> Planets", + "a/Planets.Mercury. -> Planets", "a/Planets#num.", + "a/Planets.Venus. -> Planets", "a/NotPlanets#", + "a/NotPlanets.Vase. -> a/NotPlanets#"), dialect = dialects.Scala3, mode = All, ) @@ -517,9 +520,10 @@ class ScalaToplevelSuite extends BaseSuite { |enum NotPlanets: | case Vase |""".stripMargin, - List("a/", "a/Planets#", "a/Planets#mmm().", "a/Planets.Earth#", - "a/Planets.Earth#v.", "a/Planets.Mercury#", "a/Planets#num.", - "a/Planets.Venus#", "a/NotPlanets#", "a/NotPlanets.Vase."), + List("a/", "a/Planets#", "a/Planets#mmm().", "a/Planets.Earth# -> Planets", + "a/Planets.Earth#v.", "a/Planets.Mercury# -> Planets", "a/Planets#num.", + "a/Planets.Venus# -> Planets", "a/NotPlanets#", + "a/NotPlanets.Vase. -> a/NotPlanets#"), dialect = dialects.Scala3, mode = All, ) @@ -534,7 +538,7 @@ class ScalaToplevelSuite extends BaseSuite { | } |} |""".stripMargin, - List("a/", "a/TypeProxy#"), + List("a/", "a/TypeProxy# -> Type"), dialect = dialects.Scala3, mode = ToplevelWithInner, ) @@ -595,12 +599,62 @@ class ScalaToplevelSuite extends BaseSuite { // It is easier to work around this inconstancy in `SemanticdbSymbols.inverseSemanticdbSymbol` // than to change symbols emitted by `ScalaTopLevelMtags`, // since the object could be placed before type definition. - List("s/", "s/Test$package.", "s/Test$package.Cow#", "s/Cow.", + List("s/", "s/Test$package.", "s/Test$package.Cow# -> Long", "s/Cow.", "s/Cow.apply()."), dialect = dialects.Scala3, mode = All, ) + check( + "overridden", + """|package a + |case class A[T](v: Int)(using Context) extends B[Int](2) with C: + | object O extends H + |class M(ctx: Context) extends W(1)(ctx) + |""".stripMargin, + List("a/", "a/A# -> B, C", "a/A#v.", "a/A#O. -> H", "a/M# -> W"), + dialect = dialects.Scala3, + mode = All, + ) + + check( + "overridden2", + """|package a + |class A extends b.B + |""".stripMargin, + List("a/", "a/A# -> B"), + mode = All, + ) + + check( + "overridden3", + """|package a + |class A extends B, C + |""".stripMargin, + List("a/", "a/A# -> B, C"), + mode = All, + ) + + check( + "overridden-type-alias", + """|package a + |object O { + | type A[X] = Set[X] + | type W[X] = mutable.Set[X] + | type H = [X] =>> List[X] + | type R = Set[Int] { def a: Int } + | opaque type L <: mutable.List[Int] = mutable.List[Int] + | type Elem[X] = X match + | case String => Char + | case Array[t] => t + | case Iterable[t] => t + |} + |""".stripMargin, + List("a/", "a/O.", "a/O.A# -> Set", "a/O.H# -> List", "a/O.W# -> Set", + "a/O.R# -> Set", "a/O.L# -> List", "a/O.Elem#"), + mode = All, + ) + def check( options: TestOptions, code: String, @@ -609,25 +663,35 @@ class ScalaToplevelSuite extends BaseSuite { dialect: Dialect = dialects.Scala3, )(implicit location: munit.Location): Unit = { test(options) { + val dir = AbsolutePath(Files.createTempDirectory("mtags")) + val input = dir.resolve("Test.scala") + input.writeText(code) val obtained = mode match { case All | ToplevelWithInner => - val input = Input.VirtualFile("Test.scala", code) val includeMembers = mode == All - Mtags - .allToplevels(input, dialect, includeMembers) - .occurrences - .map(_.symbol) - .toList - case Toplevel => - val dir = AbsolutePath(Files.createTempDirectory("mtags")) - val input = dir.resolve("Test.scala") - input.writeText(code) - val obtained = Mtags.topLevelSymbols(input, dialect) - input.delete() - dir.delete() - obtained + val (doc, overrides) = + Mtags.indexWithOverrides(input, dialect, includeMembers) + val symbols = doc.occurrences.map(_.symbol).toList + val overriddenMap = overrides.toMap + symbols.map { symbol => + overriddenMap.get(symbol) match { + case None => symbol + case Some(symbols) => + val overridden = + symbols + .map { + case ResolvedOverriddenSymbol(symbol) => symbol + case UnresolvedOverriddenSymbol(name) => name + } + .mkString(", ") + s"$symbol -> $overridden" + } + } + case Toplevel => Mtags.topLevelSymbols(input, dialect) } + input.delete() + dir.delete() assertNoDiff( obtained.sorted.mkString("\n"), expected.sorted.mkString("\n"),