From 276a02c53a2e8b47972086985f837397b54daa89 Mon Sep 17 00:00:00 2001 From: Jakub Ciesluk <323892@uwr.edu.pl> Date: Wed, 13 Dec 2023 10:40:34 +0100 Subject: [PATCH] improvement: Add atomix update on TrieMap --- .../meta/internal/mtags/AtomicTrieMap.scala | 41 ++++++++++++++++ .../internal/mtags/SymbolIndexBucket.scala | 49 +++++++++++-------- 2 files changed, 69 insertions(+), 21 deletions(-) create mode 100644 mtags/src/main/scala/scala/meta/internal/mtags/AtomicTrieMap.scala diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/AtomicTrieMap.scala b/mtags/src/main/scala/scala/meta/internal/mtags/AtomicTrieMap.scala new file mode 100644 index 00000000000..526e08c3603 --- /dev/null +++ b/mtags/src/main/scala/scala/meta/internal/mtags/AtomicTrieMap.scala @@ -0,0 +1,41 @@ +package scala.meta.internal.mtags + +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.concurrent.TrieMap + +/** + * This class is a wrapper around TrieMap that provides atomic updateWith + */ +final class AtomicTrieMap[K, V] { + private val trieMap = new TrieMap[K, V]() + private val concurrentMap = new ConcurrentHashMap[K, V] + + def get(key: K): Option[V] = trieMap.get(key) + + def contains(key: K): Boolean = trieMap.contains(key) + + def updateWith(key: K)(remappingFunc: Option[V] => Option[V]): Unit = { + val computeFunction = new java.util.function.BiFunction[K, V, V] { + override def apply(k: K, v: V): V = { + trieMap.get(key) match { + case Some(value) => + remappingFunc(Some(value)) match { + case Some(newValue) => + trieMap.update(key, newValue) + case None => + trieMap.remove(key) + } + case None => + remappingFunc(None).foreach(trieMap.update(key, _)) + } + null.asInstanceOf[V] + } + } + concurrentMap.compute(key, computeFunction) + } +} + +object AtomicTrieMap { + def empty[K, V]: AtomicTrieMap[K, V] = new AtomicTrieMap[K, V] +} diff --git a/mtags/src/main/scala/scala/meta/internal/mtags/SymbolIndexBucket.scala b/mtags/src/main/scala/scala/meta/internal/mtags/SymbolIndexBucket.scala index b8299179d0b..244029810aa 100644 --- a/mtags/src/main/scala/scala/meta/internal/mtags/SymbolIndexBucket.scala +++ b/mtags/src/main/scala/scala/meta/internal/mtags/SymbolIndexBucket.scala @@ -5,7 +5,6 @@ import java.nio.CharBuffer import java.util.logging.Level import java.util.logging.Logger -import scala.collection.concurrent.TrieMap import scala.util.Properties import scala.util.control.NonFatal @@ -35,8 +34,8 @@ final case class SymbolLocation( * while definitions contains only symbols generated by ScalaMtags. */ class SymbolIndexBucket( - toplevels: TrieMap[String, Set[AbsolutePath]], - definitions: TrieMap[String, Set[SymbolLocation]], + toplevels: AtomicTrieMap[String, Set[AbsolutePath]], + definitions: AtomicTrieMap[String, Set[SymbolLocation]], sourceJars: OpenClassLoader, toIndexSource: AbsolutePath => AbsolutePath = identity, mtags: Mtags, @@ -84,8 +83,10 @@ class SymbolIndexBucket( ): Unit = { if (sourceJars.addEntry(jar.toNIO)) { symbols.foreach { case (sym, path) => - val acc = toplevels.getOrElse(sym, Set.empty) - toplevels(sym) = acc + path + toplevels.updateWith(sym) { + case Some(acc) => Some(acc + path) + case None => Some(Set(path)) + } } } } @@ -96,8 +97,10 @@ class SymbolIndexBucket( ): List[String] = { val symbols = indexSource(source, dialect, sourceDirectory) symbols.foreach { symbol => - val acc = toplevels.getOrElse(symbol, Set.empty) - toplevels(symbol) = acc + source + toplevels.updateWith(symbol) { + case Some(acc) => Some(acc + source) + case None => Some(Set(source)) + } } symbols } @@ -132,8 +135,10 @@ class SymbolIndexBucket( toplevel: String ): Unit = { if (source.isAmmoniteScript || !isTrivialToplevelSymbol(path, toplevel)) { - val acc = toplevels.getOrElse(toplevel, Set.empty) - toplevels(toplevel) = acc + source + toplevels.updateWith(toplevel) { + case Some(acc) => Some(acc + source) + case None => Some(Set(source)) + } } } @@ -220,20 +225,20 @@ class SymbolIndexBucket( .map(_.map(_.path)) .getOrElse(Set.empty)).filter(_.exists) - toplevels.get(symbol.value) match { - case None => () + toplevels.updateWith(symbol.value) { + case None => None case Some(acc) => val updated = acc.filter(exists(_)) - if (updated.isEmpty) toplevels.remove(symbol.value) - else toplevels(symbol.value) = updated + if (updated.isEmpty) None + else Some(updated) } - definitions.get(symbol.value) match { - case None => () + definitions.updateWith(symbol.value) { + case None => None case Some(acc) => val updated = acc.filter(loc => exists(loc.path)) - if (updated.isEmpty) definitions.remove(symbol.value) - else definitions(symbol.value) = updated + if (updated.isEmpty) None + else Some(updated) } } @@ -272,8 +277,10 @@ class SymbolIndexBucket( docs.documents.foreach { document => document.occurrences.foreach { occ => if (occ.symbol.isGlobal && occ.role.isDefinition) { - val acc = definitions.getOrElse(occ.symbol, Set.empty) - definitions.put(occ.symbol, acc + SymbolLocation(file, occ.range)) + definitions.updateWith(occ.symbol) { + case Some(acc) => Some(acc + SymbolLocation(file, occ.range)) + case None => Some(Set(SymbolLocation(file, occ.range))) + } } else { // do nothing, we only care about global symbol definitions. } @@ -339,8 +346,8 @@ object SymbolIndexBucket { toIndexSource: AbsolutePath => AbsolutePath ): SymbolIndexBucket = new SymbolIndexBucket( - TrieMap.empty, - TrieMap.empty, + AtomicTrieMap.empty, + AtomicTrieMap.empty, new OpenClassLoader, toIndexSource, mtags,