Skip to content

Commit

Permalink
improvement: use pc for finding references of local symbols and when …
Browse files Browse the repository at this point in the history
…semanticdb is missing (#5940)

* improvement: use pc for finding references of local symbols

* delete unused `referencedPackages`

* use pc for references as fallback when missing semanticdb

* move collecting identifiers to mtags

* fixes

* small review changes

* add back `compileAndLookForNewReferences`

* benchmarks

* get file content on demand

* small fix

* scalafix

* fixes after rebase

* add test for rename with un-compiled build target

* post rebase fixes

* refactor: drop using buffers in pc

* review fixes

* filter if empty locations

* post rebase fixes
  • Loading branch information
kasiaMarek authored May 22, 2024
1 parent df933f2 commit d12e21e
Show file tree
Hide file tree
Showing 50 changed files with 2,109 additions and 1,093 deletions.
16 changes: 11 additions & 5 deletions metals-bench/src/main/scala/bench/Inflated.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,22 @@ import scala.meta.internal.io.FileIO
import scala.meta.io.AbsolutePath
import scala.meta.io.Classpath

case class Inflated(inputs: List[Input.VirtualFile], linesOfCode: Long) {
case class Inflated(
inputs: List[(Input.VirtualFile, AbsolutePath)],
linesOfCode: Long,
) {
def filter(f: Input.VirtualFile => Boolean): Inflated = {
val newInputs = inputs.filter(input => f(input))
val newInputs = inputs.filter { case (input, _) => f(input) }
val newLinesOfCode = newInputs.foldLeft(0) { case (accum, input) =>
accum + input.text.linesIterator.length
accum + input._1.text.linesIterator.length
}
Inflated(newInputs, newLinesOfCode)
}
def +(other: Inflated): Inflated =
Inflated(other.inputs ++ inputs, other.linesOfCode + linesOfCode)

def foreach(f: Input.VirtualFile => Unit): Unit =
inputs.foreach { case (file, _) => f(file) }
}

object Inflated {
Expand All @@ -33,12 +39,12 @@ object Inflated {
close = true,
) { root =>
var lines = 0L
val buf = List.newBuilder[Input.VirtualFile]
val buf = List.newBuilder[(Input.VirtualFile, AbsolutePath)]
FileIO.listAllFilesRecursively(root).foreach { file =>
val path = file.toURI.toString()
val text = FileIO.slurp(file, StandardCharsets.UTF_8)
lines += text.linesIterator.length
buf += Input.VirtualFile(path, text)
buf += ((Input.VirtualFile(path, text), file))
}
val inputs = buf.result()
Inflated(inputs, lines)
Expand Down
72 changes: 61 additions & 11 deletions metals-bench/src/main/scala/bench/MetalsBench.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import scala.tools.nsc.interactive.Global

import scala.meta.dialects
import scala.meta.interactive.InteractiveSemanticdb
import scala.meta.internal.metals.EmptyReportContext
import scala.meta.internal.metals.IdentifierIndex
import scala.meta.internal.metals.JdkSources
import scala.meta.internal.metals.LoggerReportContext
import scala.meta.internal.metals.ReportContext
Expand All @@ -24,6 +26,7 @@ import scala.meta.internal.tokenizers.LegacyToken
import scala.meta.io.AbsolutePath
import scala.meta.io.Classpath

import ch.epfl.scala.bsp4j.BuildTargetIdentifier
import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.BenchmarkMode
import org.openjdk.jmh.annotations.Mode
Expand Down Expand Up @@ -78,18 +81,19 @@ class MetalsBench {
.flatMap(_.sources.entries)
.filter(_.toNIO.getFileName.toString.endsWith(".jar"))
)

@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
def mtagsScalaIndex(): Unit = {
scalaDependencySources.inputs.foreach { input =>
scalaDependencySources.foreach { input =>
ScalaMtags.index(input, dialects.Scala213).index()
}
}

@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
def toplevelsScalaIndex(): Unit = {
scalaDependencySources.inputs.foreach { input =>
scalaDependencySources.inputs.foreach { case (input, _) =>
implicit val rc: ReportContext = LoggerReportContext
new ScalaToplevelMtags(
input,
Expand All @@ -103,7 +107,7 @@ class MetalsBench {
@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
def typeHierarchyIndex(): Unit = {
scalaDependencySources.inputs.foreach { input =>
scalaDependencySources.inputs.foreach { case (input, _) =>
implicit val rc: ReportContext = LoggerReportContext
new ScalaToplevelMtags(
input,
Expand All @@ -117,7 +121,7 @@ class MetalsBench {
@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
def scalaTokenize(): Unit = {
scalaDependencySources.inputs.foreach { input =>
scalaDependencySources.foreach { input =>
val scanner = new LegacyScanner(input, Trees.defaultTokenizerDialect)
var i = 0
scanner.foreach(_ => i += 1)
Expand All @@ -128,7 +132,7 @@ class MetalsBench {
@BenchmarkMode(Array(Mode.SingleShotTime))
def scalacTokenize(): Unit = {
val g = global
scalaDependencySources.inputs.foreach { input =>
scalaDependencySources.foreach { input =>
val unit = new g.CompilationUnit(
new BatchSourceFile(new VirtualFile(input.path), input.chars)
)
Expand All @@ -143,7 +147,7 @@ class MetalsBench {
@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
def scalametaParse(): Unit = {
scalaDependencySources.inputs.foreach { input =>
scalaDependencySources.foreach { input =>
import scala.meta._
Trees.defaultTokenizerDialect(input).parse[Source].get
}
Expand All @@ -155,7 +159,7 @@ class MetalsBench {
@BenchmarkMode(Array(Mode.SingleShotTime))
def scalacParse(): Unit = {
val g = global
scalaDependencySources.inputs.foreach { input =>
scalaDependencySources.foreach { input =>
val unit = new g.CompilationUnit(
new BatchSourceFile(new VirtualFile(input.path), input.chars)
)
Expand All @@ -173,7 +177,7 @@ class MetalsBench {
@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
def mtagsJavaParse(): Unit = {
javaDependencySources.inputs.foreach { input =>
javaDependencySources.foreach { input =>
JavaMtags
.index(input, includeMembers = true)
.index()
Expand All @@ -183,7 +187,7 @@ class MetalsBench {
@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
def toplevelJavaMtags(): Unit = {
javaDependencySources.inputs.foreach { input =>
javaDependencySources.inputs.foreach { case (input, _) =>
new JavaToplevelMtags(input, includeInnerClasses = true)(
LoggerReportContext
).index()
Expand All @@ -202,8 +206,54 @@ class MetalsBench {
@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
def alltoplevelsScalaIndex(): Unit = {
scalaDependencySources.inputs.foreach { input =>
Mtags.allToplevels(input, dialects.Scala3)
scalaDependencySources.foreach { input =>
Mtags.allToplevels(input, dialects.Scala213)
}
}

@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
def alltoplevelsScalaIndexWithCollectIdents(): Unit = {
scalaDependencySources.foreach { input =>
new ScalaToplevelMtags(
input,
includeInnerClasses = true,
includeMembers = true,
dialects.Scala213,
collectIdentifiers = true,
)(EmptyReportContext).indexRoot()
}
}

@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
def alltoplevelsScalaIndexWithBuildIdentifierIndex(): Unit = {
val buildTargetIdent = List(
new BuildTargetIdentifier("id1"),
new BuildTargetIdentifier("id2"),
new BuildTargetIdentifier("id3"),
)
var btIndex = 0
val index = new IdentifierIndex
scalaDependencySources.inputs.foreach { case (input, path) =>
val mtags = new ScalaToplevelMtags(
input,
includeInnerClasses = true,
includeMembers = true,
dialects.Scala213,
collectIdentifiers = true,
)(EmptyReportContext)

mtags.indexRoot()

val identifiers = mtags.allIdentifiers
if (identifiers.nonEmpty)
index.addIdentifiers(
path,
buildTargetIdent(btIndex),
identifiers,
)
btIndex = (btIndex + 1) % 3
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package scala.meta.internal.metals
import java.{util => ju}

import scala.meta.internal.metals.MetalsEnrichments._
import scala.meta.pc
import scala.meta.pc.AutoImportsResult
import scala.meta.pc.HoverSignature

Expand Down Expand Up @@ -50,6 +51,17 @@ trait AdjustLspData {
diag
}

def adjustLocation(location: Location): Location =
new Location(location.getUri(), adjustRange(location.getRange()))

def adjustReferencesResult(
referencesResult: pc.ReferencesResult
): ReferencesResult =
new ReferencesResult(
referencesResult.symbol,
referencesResult.locations().asScala.map(adjustLocation).toList,
)

def adjustLocations(
locations: java.util.List[Location]
): ju.List[Location]
Expand Down
41 changes: 41 additions & 0 deletions metals/src/main/scala/scala/meta/internal/metals/Compilers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import scala.util.control.NonFatal

import scala.meta.inputs.Input
import scala.meta.inputs.Position
import scala.meta.internal
import scala.meta.internal.builds.SbtBuildTool
import scala.meta.internal.metals.CompilerOffsetParamsUtils
import scala.meta.internal.metals.CompilerRangeParamsUtils
Expand Down Expand Up @@ -48,6 +49,7 @@ import org.eclipse.lsp4j.InitializeParams
import org.eclipse.lsp4j.InlayHint
import org.eclipse.lsp4j.InlayHintKind
import org.eclipse.lsp4j.InlayHintParams
import org.eclipse.lsp4j.ReferenceParams
import org.eclipse.lsp4j.RenameParams
import org.eclipse.lsp4j.SelectionRange
import org.eclipse.lsp4j.SelectionRangeParams
Expand All @@ -57,6 +59,7 @@ import org.eclipse.lsp4j.SignatureHelp
import org.eclipse.lsp4j.TextDocumentIdentifier
import org.eclipse.lsp4j.TextDocumentPositionParams
import org.eclipse.lsp4j.TextEdit
import org.eclipse.lsp4j.jsonrpc.messages.{Either => JEither}
import org.eclipse.lsp4j.{Position => LspPosition}
import org.eclipse.lsp4j.{Range => LspRange}
import org.eclipse.lsp4j.{debug => d}
Expand Down Expand Up @@ -727,6 +730,44 @@ class Compilers(
}
}.getOrElse(Future.successful(Nil.asJava))

def references(
params: ReferenceParams,
token: CancelToken,
): Future[List[ReferencesResult]] = {
withPCAndAdjustLsp(params) { case (pc, pos, adjust) =>
val requestParams = new internal.pc.PcReferencesRequest(
CompilerOffsetParamsUtils.fromPos(pos, token),
params.getContext().isIncludeDeclaration(),
JEither.forLeft(pos.start),
)
pc.references(requestParams)
.asScala
.map(_.asScala.map(adjust.adjustReferencesResult).toList)
}
}.getOrElse(Future.successful(Nil))

def references(
searchFile: AbsolutePath,
includeDefinition: Boolean,
symbol: String,
): Future[List[ReferencesResult]] =
loadCompiler(searchFile)
.map { compiler =>
val uri = searchFile.toURI
val (input, _, adjust) =
sourceAdjustments(uri.toString(), compiler.scalaVersion())
val requestParams = new internal.pc.PcReferencesRequest(
CompilerVirtualFileParams(uri, input.text),
includeDefinition,
JEither.forRight(symbol),
)
compiler
.references(requestParams)
.asScala
.map(_.asScala.map(adjust.adjustReferencesResult).toList)
}
.getOrElse(Future.successful(Nil))

def extractMethod(
doc: TextDocumentIdentifier,
range: LspRange,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package scala.meta.internal.metals

import java.nio.charset.StandardCharsets
import java.nio.file.Path

import scala.collection.concurrent.TrieMap
import scala.util.control.NonFatal

import scala.meta.Dialect
import scala.meta.inputs.Input
import scala.meta.internal.tokenizers.LegacyScanner
import scala.meta.internal.tokenizers.LegacyToken._
import scala.meta.io.AbsolutePath

import ch.epfl.scala.bsp4j.BuildTargetIdentifier
import com.google.common.hash.BloomFilter
import com.google.common.hash.Funnels

class IdentifierIndex {
val index: TrieMap[Path, IdentifierIndex.IndexEntry] = TrieMap.empty

def addIdentifiers(
file: AbsolutePath,
id: BuildTargetIdentifier,
set: Iterable[String],
): Unit = {
val bloom = BloomFilter.create(
Funnels.stringFunnel(StandardCharsets.UTF_8),
Integer.valueOf(set.size * 2),
0.01,
)

val entry = IdentifierIndex.IndexEntry(id, bloom)
index(file.toNIO) = entry
set.foreach(bloom.put)
}

def collectIdentifiers(
text: String,
dialect: Dialect,
): Iterable[String] = {
val identifiers = Set.newBuilder[String]

try {
new LegacyScanner(Input.String(text), dialect).foreach {
case ident if ident.token == IDENTIFIER => identifiers += ident.name
case _ =>
}
} catch {
case NonFatal(_) =>
}

identifiers.result()
}
}

object IdentifierIndex {
case class IndexEntry(
id: BuildTargetIdentifier,
bloom: BloomFilter[CharSequence],
)
}
Loading

0 comments on commit d12e21e

Please sign in to comment.