Skip to content

Commit

Permalink
add preferred dialect to workspace symbol search
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiaMarek committed Nov 21, 2023
1 parent 7062b8c commit 6c67b4b
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 13 deletions.
2 changes: 1 addition & 1 deletion metals-bench/src/main/scala/bench/ClasspathFuzzBench.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ClasspathFuzzBench {
@BenchmarkMode(Array(Mode.SingleShotTime))
@OutputTimeUnit(TimeUnit.MILLISECONDS)
def run(): Seq[SymbolInformation] = {
symbols.search(query)
symbols.search(query, None)
}

}
2 changes: 1 addition & 1 deletion metals-bench/src/main/scala/bench/WorkspaceFuzzBench.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class WorkspaceFuzzBench {
@BenchmarkMode(Array(Mode.SingleShotTime))
@OutputTimeUnit(TimeUnit.MILLISECONDS)
def upper(): Seq[SymbolInformation] = {
symbols.search(query)
symbols.search(query, None)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,9 @@ final class DefinitionProvider(
else true
}

val dialect = scalaVersionSelector.dialectFromBuildTarget(path)
val locs = workspaceSearch
.searchExactFrom(ident.value, path, token)
.searchExactFrom(ident.value, path, token, dialect)

val reducedGuesses =
if (locs.size > 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1677,7 +1677,8 @@ class MetalsLspService(
): Future[List[SymbolInformation]] =
indexingPromise.future.map { _ =>
val timer = new Timer(time)
val result = workspaceSymbols.search(params.getQuery, token).toList
val result =
workspaceSymbols.search(params.getQuery, token, currentDialect).toList
if (clientConfig.initialConfig.statistics.isWorkspaceSymbol) {
scribe.info(
s"time: found ${result.length} results for query '${params.getQuery}' in $timer"
Expand All @@ -1687,9 +1688,12 @@ class MetalsLspService(
}

def workspaceSymbol(query: String): Seq[SymbolInformation] = {
workspaceSymbols.search(query)
workspaceSymbols.search(query, currentDialect)
}

private def currentDialect =
focusedDocument().flatMap(scalaVersionSelector.dialectFromBuildTarget)

def indexSources(): Future[Unit] = Future {
indexer.indexWorkspaceSources(buildTargets.allWritableData)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import java.{util => ju}

import scala.collection.mutable

import scala.meta.Dialect
import scala.meta.dialects
import scala.meta.internal.metals.MetalsEnrichments._
import scala.meta.internal.mtags.GlobalSymbolIndex
import scala.meta.internal.mtags.Symbol
Expand Down Expand Up @@ -32,6 +34,7 @@ class WorkspaceSearchVisitor(
token: CancelChecker,
index: GlobalSymbolIndex,
saveClassFileToDisk: Boolean,
preferredDialect: Option[Dialect],
)(implicit rc: ReportContext)
extends SymbolSearchVisitor {
private val fromWorkspace = new ju.ArrayList[l.SymbolInformation]()
Expand Down Expand Up @@ -88,14 +91,21 @@ class WorkspaceSearchVisitor(
): Option[SymbolDefinition] = {
val nme = Classfile.name(filename)
val tpe = Symbol(Symbols.Global(pkg, Descriptor.Type(nme)))

val preferredDialects = preferredDialect match {
case Some(dialects.Scala213) =>
Set(dialects.Scala213, dialects.Scala213Source3)
case Some(dialects.Scala212) =>
Set(dialects.Scala212, dialects.Scala212Source3)
case opt => opt.toSet
}
val forTpe = index.definitions(tpe)
val defs = if (forTpe.isEmpty) {
val term = Symbol(Symbols.Global(pkg, Descriptor.Term(nme)))
index.definitions(term)
} else forTpe

defs.sortBy(_.path.toURI.toString).headOption
defs.sortBy { defn =>
(!preferredDialects(defn.dialect), defn.path.toURI.toString)
}.headOption
}
override def shouldVisitPackage(pkg: String): Boolean = true
override def visitWorkspaceSymbol(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import java.nio.file.Path
import scala.collection.concurrent.TrieMap
import scala.util.control.NonFatal

import scala.meta.Dialect
import scala.meta.internal.mtags.GlobalSymbolIndex
import scala.meta.internal.pc.InterruptException
import scala.meta.io.AbsolutePath
Expand Down Expand Up @@ -39,14 +40,21 @@ final class WorkspaceSymbolProvider(
var inDependencies: ClasspathSearch =
ClasspathSearch.empty

def search(query: String): Seq[l.SymbolInformation] = {
search(query, () => ())
def search(
query: String,
preferredDialect: Option[Dialect],
): Seq[l.SymbolInformation] = {
search(query, () => (), preferredDialect)
}

def search(query: String, token: CancelChecker): Seq[l.SymbolInformation] = {
def search(
query: String,
token: CancelChecker,
preferredDialect: Option[Dialect],
): Seq[l.SymbolInformation] = {
if (query.isEmpty) return Nil
try {
searchUnsafe(query, token)
searchUnsafe(query, token, preferredDialect)
} catch {
case InterruptException() =>
Nil
Expand All @@ -57,6 +65,7 @@ final class WorkspaceSymbolProvider(
queryString: String,
path: AbsolutePath,
token: CancelToken,
preferredDialect: Option[Dialect],
): Seq[l.SymbolInformation] = {
val query = WorkspaceSymbolQuery.exact(queryString)
val visistor =
Expand All @@ -66,6 +75,7 @@ final class WorkspaceSymbolProvider(
token,
index,
saveClassFileToDisk,
preferredDialect,
)
val targetId = buildTargets.inverseSources(path)
search(query, visistor, targetId)
Expand Down Expand Up @@ -205,6 +215,7 @@ final class WorkspaceSymbolProvider(
private def searchUnsafe(
textQuery: String,
token: CancelChecker,
preferredDialect: Option[Dialect],
): Seq[l.SymbolInformation] = {
val query = WorkspaceSymbolQuery.fromTextQuery(textQuery)
val visitor =
Expand All @@ -214,6 +225,7 @@ final class WorkspaceSymbolProvider(
token,
index,
saveClassFileToDisk,
preferredDialect,
)
search(query, visitor, None)
visitor.allResults()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ abstract class BaseWorkspaceSymbolSuite extends BaseSuite {
expected: String,
)(implicit loc: Location): Unit = {
test(query) {
val result = symbols.search(query)
val result = symbols.search(query, None)
val obtained =
if (result.length > 100) s"${result.length} results"
else {
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/src/main/scala/tests/debug/BaseStepDapSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ abstract class BaseStepDapSuite(
.at("a/src/main/scala/a/ScalaMain.scala", line = 5)(StepIn)
.at("a/src/main/java/a/JavaClass.java", line = 5)(StepOut)
.at("a/src/main/scala/a/ScalaMain.scala", line = 6)(Continue),
focusFile = "a/src/main/scala/a/ScalaMain.scala",
)

assertSteps("step-into-scala-lib", withoutVirtualDocs = true)(
Expand Down Expand Up @@ -162,12 +163,14 @@ abstract class BaseStepDapSuite(
steps
.at("a/src/main/scala/a/Main.scala", line = 6)(Continue)
.at("a/src/main/scala/a/Main.scala", line = 13)(Continue),
focusFile = "a/src/main/scala/a/Main.scala",
)

def assertSteps(name: TestOptions, withoutVirtualDocs: Boolean = false)(
sources: String,
main: String,
instrument: StepNavigator => StepNavigator,
focusFile: String = "a/src/main/scala/Main.scala",
)(implicit loc: Location): Unit = {
test(name, withoutVirtualDocs) {
cleanWorkspace()
Expand All @@ -176,6 +179,7 @@ abstract class BaseStepDapSuite(

for {
_ <- initialize(workspaceLayout)
_ <- server.didFocus(focusFile)
navigator = instrument(StepNavigator(workspace))
debugger <- debugMain("a", main, navigator)
_ <- debugger.initialize
Expand Down

0 comments on commit 6c67b4b

Please sign in to comment.