Skip to content

Commit

Permalink
HAMT: try to optimize getting the hash by using <:
Browse files Browse the repository at this point in the history
  • Loading branch information
durban committed May 12, 2024
1 parent d23fb1b commit e3b8e88
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ private[mcas] final class LogMap2[A] private (
final def definitelyReadOnly: Boolean =
this.isBlueSubtree

protected final override def hashOf(k: MemoryLocation[A]): Long =
k.id

protected final override def keyOf(a: LogEntry[A]): MemoryLocation[A] =
a.address

protected final override def isBlue(a: LogEntry[A]): Boolean =
a.readOnly

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ package mcas

import scala.util.hashing.MurmurHash3

sealed abstract class WdLike[A] {
sealed abstract class WdLike[A] extends Hamt.HasKey[MemoryLocation[A]] {
val address: MemoryLocation[A]
val ov: A
val nv: A
val oldVersion: Long

final override def key: MemoryLocation[A] =
this.address
}

final class LogEntry[A] private ( // formerly called HWD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ private[mcas] final class LogMap2[A] private[mcas] (
final def definitelyReadOnly: Boolean =
this.isBlueSubtree

protected final override def hashOf(k: MemoryLocation[A]): Long =
k.id

protected final override def keyOf(a: LogEntry[A]): MemoryLocation[A] =
a.address

protected final override def isBlue(a: LogEntry[A]): Boolean =
a.readOnly

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ private[mcas] final class LogMapMut[A] private (
this.forAll(ctx)
}

protected final override def keyOf(a: LogEntry[A]): MemoryLocation[A] =
a.address

protected final override def hashOf(k: MemoryLocation[A]): Long =
k.id

protected final override def isBlue(a: LogEntry[A]): Boolean =
a.readOnly

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ package mcas

import scala.util.hashing.MurmurHash3

sealed trait WdLike[A] {
sealed trait WdLike[A] extends Hamt.HasKey[MemoryLocation[A]] {
val address: MemoryLocation[A]
def ov: A
def nv: A
val oldVersion: Long
def cleanForGc(wasSuccessful: Boolean, sentinel: A): Unit

final override def key: MemoryLocation[A] =
this.address
}

// TODO: this is duplicated on JS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@ package mcas

import scala.util.hashing.MurmurHash3

private[mcas] abstract class AbstractHamt[K, V, E, T1, T2, H <: AbstractHamt[K, V, E, T1, T2, H]] protected[mcas] () { this: H =>

protected def keyOf(a: V): K

protected def hashOf(k: K): Long
private[mcas] abstract class AbstractHamt[K <: Hamt.HasHash, V <: Hamt.HasKey[K], E, T1, T2, H <: AbstractHamt[K, V, E, T1, T2, H]] protected[mcas] () { this: H =>

protected def newArray(size: Int): Array[E]

Expand Down Expand Up @@ -189,7 +185,7 @@ private[mcas] abstract class AbstractHamt[K, V, E, T1, T2, H <: AbstractHamt[K,
case node: AbstractHamt[_, _, _, _, _, _] =>
curr = node.hashCodeInternal(curr)
case a =>
curr = MurmurHash3.mix(curr, (hashOf(keyOf(a.asInstanceOf[V])) >>> 32).toInt)
curr = MurmurHash3.mix(curr, (a.asInstanceOf[V].key.hash >>> 32).toInt)
curr = MurmurHash3.mix(curr, a.##)
}
i += 1
Expand Down
34 changes: 21 additions & 13 deletions mcas/shared/src/main/scala/dev/tauri/choam/internal/mcas/Hamt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ import java.util.Arrays
* Public methods are the "external" API. We take care never to call them
* on a node in lower levels (they assume they're called on the root).
*/
private[mcas] abstract class Hamt[K, V, E, T1, T2, H <: Hamt[K, V, E, T1, T2, H]] protected[mcas] (
private[mcas] abstract class Hamt[K <: Hamt.HasHash, V <: Hamt.HasKey[K], E, T1, T2, H <: Hamt[K, V, E, T1, T2, H]] protected[mcas] (

private val sizeAndBlue: Int,

Expand Down Expand Up @@ -139,15 +139,15 @@ private[mcas] abstract class Hamt[K, V, E, T1, T2, H <: Hamt[K, V, E, T1, T2, H]

/** Must already contain the key of `a` */
final def updated(a: V): H = {
this.insertOrOverwrite(hashOf(keyOf(a)), a, 0, OP_UPDATE) match {
this.insertOrOverwrite(a.key.hash, a, 0, OP_UPDATE) match {
case null => this
case newRoot => newRoot
}
}

/** Mustn't already contain the key of `a` */
final def inserted(a: V): H = {
val newRoot = this.insertOrOverwrite(hashOf(keyOf(a)), a, 0, OP_INSERT)
val newRoot = this.insertOrOverwrite(a.key.hash, a, 0, OP_INSERT)
assert(newRoot ne null)
newRoot
}
Expand All @@ -158,21 +158,21 @@ private[mcas] abstract class Hamt[K, V, E, T1, T2, H <: Hamt[K, V, E, T1, T2, H]

/** May or may not already contain the key of `a` */
final def upserted(a: V): H = {
this.insertOrOverwrite(hashOf(keyOf(a)), a, 0, OP_UPSERT) match {
this.insertOrOverwrite(a.key.hash, a, 0, OP_UPSERT) match {
case null => this
case newRoot => newRoot
}
}

final def computeIfAbsent[T](k: K, tok: T, visitor: Hamt.EntryVisitor[K, V, T]): H = {
this.visit(k, hashOf(k), tok, visitor, modify = false, shift = 0) match {
this.visit(k, k.hash, tok, visitor, modify = false, shift = 0) match {
case null => this
case newRoot => newRoot
}
}

final def computeOrModify[T](k: K, tok: T, visitor: Hamt.EntryVisitor[K, V, T]): H = {
this.visit(k, hashOf(k), tok, visitor, modify = true, shift = 0) match {
this.visit(k, k.hash, tok, visitor, modify = true, shift = 0) match {
case null => this
case newRoot => newRoot
}
Expand Down Expand Up @@ -219,7 +219,7 @@ private[mcas] abstract class Hamt[K, V, E, T1, T2, H <: Hamt[K, V, E, T1, T2, H]
node.lookupOrNull(hash, shift + W).asInstanceOf[V]
case value =>
val a = value.asInstanceOf[V]
val hashA = hashOf(keyOf(a))
val hashA = a.key.hash
if (hash == hashA) {
a
} else {
Expand Down Expand Up @@ -253,7 +253,7 @@ private[mcas] abstract class Hamt[K, V, E, T1, T2, H <: Hamt[K, V, E, T1, T2, H]
case null =>
nullOf[H]
case newVal =>
assert(hashOf(keyOf(newVal)) == hash)
assert(newVal.key.hash == hash)
// TODO: this will compute physIdx again:
this.insertOrOverwrite(hash, newVal, shift, op = OP_INSERT)
}
Expand All @@ -272,14 +272,14 @@ private[mcas] abstract class Hamt[K, V, E, T1, T2, H <: Hamt[K, V, E, T1, T2, H]
}
case value =>
val a = value.asInstanceOf[V]
val hashA = hashOf(keyOf(a))
val hashA = a.key.hash
if (hash == hashA) {
val newEntry = visitor.entryPresent(k, a, tok)
if (modify) {
if (equ(newEntry, a)) {
nullOf[H]
} else {
assert(hashOf(keyOf(newEntry)) == hashA)
assert(newEntry.key.hash == hashA)
this.insertOrOverwrite(hashA, newEntry, shift, op = OP_UPDATE)
}
} else {
Expand All @@ -291,7 +291,7 @@ private[mcas] abstract class Hamt[K, V, E, T1, T2, H <: Hamt[K, V, E, T1, T2, H]
case null =>
nullOf[H]
case newVal =>
assert(hashOf(keyOf(newVal)) == hash)
assert(newVal.key.hash == hash)
// TODO: this will compute physIdx again:
this.insertOrOverwrite(hash, newVal, shift, op = OP_INSERT)
}
Expand All @@ -316,7 +316,7 @@ private[mcas] abstract class Hamt[K, V, E, T1, T2, H <: Hamt[K, V, E, T1, T2, H]
this.withNode(this.size + (newNode.size - node.size), bitmap, newNode, physIdx)
}
case ov =>
val oh = hashOf(keyOf(ov.asInstanceOf[V]))
val oh = ov.asInstanceOf[V].key.hash
if (hash == oh) {
if (op == OP_INSERT) {
throw new IllegalArgumentException
Expand Down Expand Up @@ -433,7 +433,15 @@ private[mcas] abstract class Hamt[K, V, E, T1, T2, H <: Hamt[K, V, E, T1, T2, H]

private[choam] object Hamt {

trait EntryVisitor[K, V, T] { // TODO: maybe move this to AbstractHamt?
trait HasKey[K <: HasHash] {
def key: K
}

trait HasHash {
def hash: Long
}

trait EntryVisitor[K, V, T] {
def entryPresent(k: K, v: V, tok: T): V
def entryAbsent(k: K, tok: T): V
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ import cats.kernel.Order
* directly. Instead, use MCAS, or an even higher
* level abstraction.
*/
trait MemoryLocation[A] {
trait MemoryLocation[A] extends Hamt.HasHash {

// contents:

Expand Down Expand Up @@ -100,6 +100,9 @@ trait MemoryLocation[A] {

def id: Long

final override def hash: Long =
this.id

// private utilities:

private[mcas] final def cast[B]: MemoryLocation[B] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ package mcas
/**
* Mutable HAMT; not thread safe; `null` values are forbidden.
*/
private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I], H <: MutHamt[K, V, E, T1, T2, I, H]] protected[mcas] (
private[mcas] abstract class MutHamt[K <: Hamt.HasHash, V <: Hamt.HasKey[K], E, T1, T2, I <: Hamt[_, _, _, _, _, I], H <: MutHamt[K, V, E, T1, T2, I, H]] protected[mcas] (
// NB: the root doesn't have a logical idx, so we're abusing this field to store the tree size (and also the blue bit)
private var logIdx: Int,
private var contents: Array[AnyRef],
Expand Down Expand Up @@ -86,14 +86,14 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
}

final def update(a: V): Unit = {
val sdb = this.insertOrOverwrite(hashOf(keyOf(a)), a, shift = 0, op = OP_UPDATE)
val sdb = this.insertOrOverwrite(a.key.hash, a, shift = 0, op = OP_UPDATE)
val sizeDiff = unpackSizeDiff(sdb)
assert(sizeDiff == 0)
this.isBlueTree &= unpackIsBlue(sdb)
}

final def insert(a: V): Unit = {
val sdb = this.insertOrOverwrite(hashOf(keyOf(a)), a, shift = 0, op = OP_INSERT)
val sdb = this.insertOrOverwrite(a.key.hash, a, shift = 0, op = OP_INSERT)
val sizeDiff = unpackSizeDiff(sdb)
assert(sizeDiff == 1)
this.addToSize(1)
Expand All @@ -105,23 +105,23 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
}

final def upsert(a: V): Unit = {
val sdb = this.insertOrOverwrite(hashOf(keyOf(a)), a, shift = 0, op = OP_UPSERT)
val sdb = this.insertOrOverwrite(a.key.hash, a, shift = 0, op = OP_UPSERT)
val sizeDiff = unpackSizeDiff(sdb)
assert((sizeDiff == 0) || (sizeDiff == 1))
this.addToSize(sizeDiff)
this.isBlueTree &= unpackIsBlue(sdb)
}

final def computeIfAbsent[T](k: K, tok: T, visitor: Hamt.EntryVisitor[K, V, T]): Unit = {
val sdb = this.visit(k, hashOf(k), tok, visitor, newValue = nullOf[V], modify = false, shift = 0)
val sdb = this.visit(k, k.hash, tok, visitor, newValue = nullOf[V], modify = false, shift = 0)
val sizeDiff = unpackSizeDiff(sdb)
assert((sizeDiff == 0) || (sizeDiff == 1))
this.addToSize(sizeDiff)
this.isBlueTree &= unpackIsBlue(sdb)
}

final def computeOrModify[T](k: K, tok: T, visitor: Hamt.EntryVisitor[K, V, T]): Unit = {
val sdb = this.visit(k, hashOf(k), tok, visitor, newValue = nullOf[V], modify = true, shift = 0)
val sdb = this.visit(k, k.hash, tok, visitor, newValue = nullOf[V], modify = true, shift = 0)
val sizeDiff = unpackSizeDiff(sdb)
assert((sizeDiff == 0) || (sizeDiff == 1))
this.addToSize(sizeDiff)
Expand Down Expand Up @@ -166,7 +166,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
node.lookupOrNull(hash, shift + W).asInstanceOf[V]
case value =>
val a = value.asInstanceOf[V]
val hashA = hashOf(keyOf(a))
val hashA = a.key.hash
if (hash == hashA) {
a
} else {
Expand Down Expand Up @@ -204,7 +204,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
case null =>
packSizeDiffAndBlue(0, true)
case newVal =>
assert(hashOf(keyOf(newVal)) == hash)
assert(newVal.key.hash == hash)
// TODO: this will compute physIdx again:
this.insertOrOverwrite(hash, newVal, shift, op = OP_INSERT)
}
Expand All @@ -231,7 +231,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
}
case value =>
val a = value.asInstanceOf[V]
val hashA = hashOf(keyOf(a))
val hashA = a.key.hash
if (hash == hashA) {
val newVal = if (isNull(newValue)) {
visitor.entryPresent(k, a, tok)
Expand All @@ -242,7 +242,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
if (equ(newValue, a)) {
packSizeDiffAndBlue(0, isBlue(a))
} else {
assert(hashOf(keyOf(newVal)) == hashA)
assert(newVal.key.hash == hashA)
this.insertOrOverwrite(hash, newVal, shift, op = OP_UPDATE)
}
} else {
Expand All @@ -263,7 +263,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
case null =>
packSizeDiffAndBlue(0, true)
case newVal =>
assert(hashOf(keyOf(newVal)) == hash)
assert(newVal.key.hash == hash)
this.insertOrOverwrite(hash, newVal, shift, op = OP_INSERT)
}
} else {
Expand Down Expand Up @@ -316,7 +316,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
}
}
case ov =>
val oh = hashOf(keyOf(ov.asInstanceOf[V]))
val oh = ov.asInstanceOf[V].key.hash
if (hash == oh) {
if (op == OP_INSERT) {
throw new IllegalArgumentException
Expand Down Expand Up @@ -373,7 +373,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
val newPhysIdx = physicalIdx(logIdx, newSize, shift = shift)
newContents(newPhysIdx) = node
case value =>
val logIdx = logicalIdx(hashOf(keyOf(value.asInstanceOf[V])), shift)
val logIdx = logicalIdx(value.asInstanceOf[V].key.hash, shift)
val newPhysIdx = physicalIdx(logIdx, newSize, shift = shift)
newContents(newPhysIdx) = value
}
Expand Down Expand Up @@ -412,7 +412,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
arr(arity) = box(child)
arity += 1
case value =>
bitmap |= (1L << logicalIdx(hashOf(keyOf(value.asInstanceOf[V])), shift = shift))
bitmap |= (1L << logicalIdx(value.asInstanceOf[V].key.hash, shift = shift))
size += 1
isBlueSubtree &= isBlue(value.asInstanceOf[V])
arr(arity) = value
Expand Down
Loading

0 comments on commit e3b8e88

Please sign in to comment.