Skip to content

Commit

Permalink
MutHamt: make sure the last level has array sizes <= 16
Browse files Browse the repository at this point in the history
  • Loading branch information
durban committed May 11, 2024
1 parent e48f1de commit 9fdd3c2
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
private var contents: Array[AnyRef],
) extends AbstractHamt[K, V, E, T1, T2, H] { this: H =>

// TODO: We're often recomputing things like physIdx
// TODO: or logIdxWidth; figure out if we can be
// TODO: faster by not repeating these computations
// TODO: (by inlining methods if necessary).

require(contents.length > 0)

private[this] final val START_MASK = 0xFC00000000000000L
Expand Down Expand Up @@ -175,7 +180,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
val contents = this.contents
val logIdx = logicalIdx(hash, shift)
val size = contents.length // always a power of 2
val physIdx = physicalIdx(logIdx, size)
val physIdx = physicalIdx(logIdx, size, shift = shift)
contents(physIdx)
}

Expand Down Expand Up @@ -217,7 +222,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
case null =>
packSizeDiffAndBlue(0, true)
case newVal =>
this.growLevel(newSize = necessarySize(logIdx, nodeLogIdx), shift = shift)
this.growLevel(newSize = necessarySize(logIdx, nodeLogIdx, shift = shift), shift = shift)
// now we can insert the new value:
this.visit(k, hash, tok, visitor, newValue = newVal, modify = modify, shift = shift)
}
Expand Down Expand Up @@ -268,7 +273,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
case null =>
packSizeDiffAndBlue(0, true)
case newVal =>
this.growLevel(newSize = necessarySize(logIdx, logIdxA), shift = shift)
this.growLevel(newSize = necessarySize(logIdx, logIdxA, shift = shift), shift = shift)
this.visit(k, hash, tok, visitor, newValue = newVal, modify = modify, shift = shift)
}
} else {
Expand All @@ -284,7 +289,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
val contents = this.contents
val logIdx = logicalIdx(hash, shift)
val size = contents.length // always a power of 2
val physIdx = physicalIdx(logIdx, size)
val physIdx = physicalIdx(logIdx, size, shift = shift)
contents(physIdx) match {
case null =>
if (op == OP_UPDATE) {
Expand All @@ -306,7 +311,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
throw new IllegalArgumentException
} else {
// we need to grow this level:
this.growLevel(newSize = necessarySize(logIdx, nodeLogIdx), shift = shift)
this.growLevel(newSize = necessarySize(logIdx, nodeLogIdx, shift = shift), shift = shift)
// now we'll suceed for sure:
this.insertOrOverwrite(hash, value, shift, op)
}
Expand All @@ -327,12 +332,13 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
if (logIdx == oLogIdx) {
// hash collision at this level,
// so we go down 1 level
val childShift = shift + W
val childNode = {
val cArr = new Array[AnyRef](2) // NB: 2 instead of 1
cArr(physicalIdx(logicalIdx(oh, shift + W), size = 2)) = ov
cArr(physicalIdx(logicalIdx(oh, childShift), size = 2, shift = childShift)) = ov
this.newNode(logIdx, cArr)
}
val r = childNode.insertOrOverwrite(hash, value, shift = shift + W, op = op)
val r = childNode.insertOrOverwrite(hash, value, shift = childShift, op = op)
contents(physIdx) = childNode
assert(unpackSizeDiff(r) == 1)
r
Expand All @@ -343,7 +349,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
throw new IllegalArgumentException
} else {
// grow this level:
this.growLevel(newSize = necessarySize(logIdx, oLogIdx), shift = shift)
this.growLevel(newSize = necessarySize(logIdx, oLogIdx, shift = shift), shift = shift)
// now we'll suceed for sure:
this.insertOrOverwrite(hash, value, shift, op)
}
Expand All @@ -365,11 +371,11 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
()
case node: MutHamt[_, _, _, _, _, _, _] =>
val logIdx = node.logIdx
val newPhysIdx = physicalIdx(logIdx, newSize)
val newPhysIdx = physicalIdx(logIdx, newSize, shift = shift)
newContents(newPhysIdx) = node
case value =>
val logIdx = logicalIdx(hashOf(keyOf(value.asInstanceOf[V])), shift)
val newPhysIdx = physicalIdx(logIdx, newSize)
val newPhysIdx = physicalIdx(logIdx, newSize, shift = shift)
newContents(newPhysIdx) = value
}
idx += 1
Expand Down Expand Up @@ -423,16 +429,16 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
this.newImmutableNode(sizeAndBlue = packSizeAndBlue(size, isBlueSubtree), bitmap = bitmap, contents = arr)
}

private[this] final def necessarySize(logIdx1: Int, logIdx2: Int): Int = {
private[this] final def necessarySize(logIdx1: Int, logIdx2: Int, shift: Int): Int = {
assert(logIdx1 != logIdx2)
val diff = logIdx1 ^ logIdx2
val necessaryBits = java.lang.Integer.numberOfLeadingZeros(diff) - (32 - W) + 1
val necessaryBits = java.lang.Integer.numberOfLeadingZeros(diff) - (32 - logIdxWidth(shift)) + 1
assert(necessaryBits <= W)
1 << necessaryBits
}

private[mcas] final def necessarySize_public(logIdx1: Int, logIdx2: Int): Int = {
this.necessarySize(logIdx1, logIdx2)
private[mcas] final def necessarySize_public(logIdx1: Int, logIdx2: Int, shift: Int): Int = {
this.necessarySize(logIdx1, logIdx2, shift = shift)
}

/** Index into the imaginary 64-element sparse array */
Expand All @@ -442,17 +448,26 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I
((hash & mask) >>> sh).toInt
}

private[this] final def physicalIdx(logIdx: Int, size: Int): Int = {
private[mcas] final def logicalIdx_public(hash: Long, shift: Int): Int = {
logicalIdx(hash, shift)
}

@inline
private[this] final def logIdxWidth(shift: Int): Int = {
java.lang.Long.bitCount(START_MASK >>> shift)
}

private[this] final def physicalIdx(logIdx: Int, size: Int, shift: Int): Int = {
assert((logIdx >= 0) && (logIdx < 64))
// size is always a power of 2
val width = java.lang.Integer.numberOfTrailingZeros(size)
assert(width <= W)
val sh = W - width
val width = java.lang.Integer.numberOfTrailingZeros(size) // size is always a power of 2
val logIdxW = logIdxWidth(shift)
assert((width <= logIdxW) && (logIdxW <= W))
val sh = logIdxW - width
logIdx >>> sh
}

private[mcas] final def physicalIdx_public(logIdx: Int, size: Int): Int = {
physicalIdx(logIdx, size)
private[mcas] final def physicalIdx_public(logIdx: Int, size: Int, shift: Int): Int = {
physicalIdx(logIdx, size, shift = shift)
}

private[this] final def packSizeDiffAndBlue(sizeDiff: Int, isBlue: Boolean): Int = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import cats.syntax.all._

import munit.ScalaCheckSuite

import org.scalacheck.Gen
import org.scalacheck.Prop.forAll

final class MutHamtSpec extends ScalaCheckSuite with MUnitUtils with PropertyHelpers {
Expand All @@ -40,34 +41,77 @@ final class MutHamtSpec extends ScalaCheckSuite with MUnitUtils with PropertyHel
p.withMaxSize(p.maxSize * (if (isJvm()) 32 else 2))
}

// TODO: "HAMT logicalIdx"
property("HAMT logicalIdx") {
val h = LongMutHamt.newEmpty()

def testLogicalIdx(n: Long): Unit = {
assertEquals(h.logicalIdx_public(n, shift = 0), ( n >>> 58 ).toInt)
assertEquals(h.logicalIdx_public(n, shift = 6), ((n >>> 52) & 63L).toInt)
assertEquals(h.logicalIdx_public(n, shift = 12), ((n >>> 46) & 63L).toInt)
assertEquals(h.logicalIdx_public(n, shift = 18), ((n >>> 40) & 63L).toInt)
assertEquals(h.logicalIdx_public(n, shift = 24), ((n >>> 34) & 63L).toInt)
assertEquals(h.logicalIdx_public(n, shift = 30), ((n >>> 28) & 63L).toInt)
assertEquals(h.logicalIdx_public(n, shift = 36), ((n >>> 22) & 63L).toInt)
assertEquals(h.logicalIdx_public(n, shift = 42), ((n >>> 16) & 63L).toInt)
assertEquals(h.logicalIdx_public(n, shift = 48), ((n >>> 10) & 63L).toInt)
assertEquals(h.logicalIdx_public(n, shift = 54), ((n >>> 4) & 63L).toInt)
assertEquals(h.logicalIdx_public(n, shift = 60), ( n & 15L).toInt) // this is the tricky one
}

val prop1 = forAll(Gen.choose(Long.MinValue, Long.MaxValue)) { (n: Long) =>
testLogicalIdx(n)
}

val prop2 = forAll { (n: Long) =>
testLogicalIdx(n)
}

prop1 && prop2
}

test("necessarySize") {
val h = LongMutHamt.newEmpty()
assertEquals(h.necessarySize_public(32, 0), 2)
assertEquals(h.necessarySize_public(63, 31), 2)
assertEquals(h.necessarySize_public(16, 0), 4)
assertEquals(h.necessarySize_public(31, 15), 4)
assertEquals(h.necessarySize_public(0, 1), 64)
assertEquals(h.necessarySize_public(54, 55), 64)
assertEquals(h.necessarySize_public(32, 0, shift = 0), 2)
assertEquals(h.necessarySize_public(63, 31, shift = 0), 2)
assertEquals(h.necessarySize_public(16, 0, shift = 0), 4)
assertEquals(h.necessarySize_public(31, 15, shift = 0), 4)
assertEquals(h.necessarySize_public(0, 1, shift = 0), 64)
assertEquals(h.necessarySize_public(0, 1, shift = 6), 64)
assertEquals(h.necessarySize_public(0, 1, shift = 12), 64)
assertEquals(h.necessarySize_public(0, 1, shift = 54), 64)
assertEquals(h.necessarySize_public(0, 1, shift = 60), 16)
assertEquals(h.necessarySize_public(54, 55, shift = 0), 64)
}

test("physicalIdx") {
val h = LongMutHamt.newEmpty()
assertEquals(h.physicalIdx_public(0, size = 1), 0)
assertEquals(h.physicalIdx_public(63, size = 1), 0)
assertEquals(h.physicalIdx_public(0, size = 2), 0)
assertEquals(h.physicalIdx_public(31, size = 2), 0)
assertEquals(h.physicalIdx_public(32, size = 2), 1)
assertEquals(h.physicalIdx_public(63, size = 2), 1)
assertEquals(h.physicalIdx_public(0, size = 32), 0)
assertEquals(h.physicalIdx_public(1, size = 32), 0)
assertEquals(h.physicalIdx_public(2, size = 32), 1)
assertEquals(h.physicalIdx_public(3, size = 32), 1)
assertEquals(h.physicalIdx_public(62, size = 32), 31)
assertEquals(h.physicalIdx_public(63, size = 32), 31)
assertEquals(h.physicalIdx_public(0, size = 1, shift = 0), 0)
assertEquals(h.physicalIdx_public(63, size = 1, shift = 0), 0)
assertEquals(h.physicalIdx_public(0, size = 2, shift = 0), 0)
assertEquals(h.physicalIdx_public(31, size = 2, shift = 0), 0)
assertEquals(h.physicalIdx_public(32, size = 2, shift = 0), 1)
assertEquals(h.physicalIdx_public(63, size = 2, shift = 0), 1)
assertEquals(h.physicalIdx_public(0, size = 16, shift = 0), 0)
assertEquals(h.physicalIdx_public(1, size = 16, shift = 0), 0)
assertEquals(h.physicalIdx_public(4, size = 16, shift = 0), 1)
assertEquals(h.physicalIdx_public(0, size = 16, shift = 6), 0)
assertEquals(h.physicalIdx_public(1, size = 16, shift = 6), 0)
assertEquals(h.physicalIdx_public(4, size = 16, shift = 6), 1)
assertEquals(h.physicalIdx_public(0, size = 16, shift = 60), 0)
assertEquals(h.physicalIdx_public(1, size = 16, shift = 60), 1)
assertEquals(h.physicalIdx_public(4, size = 16, shift = 60), 4)
assertEquals(h.physicalIdx_public(16, size = 16, shift = 60), 16)
assertEquals(h.physicalIdx_public(0, size = 32, shift = 0), 0)
assertEquals(h.physicalIdx_public(1, size = 32, shift = 0), 0)
assertEquals(h.physicalIdx_public(2, size = 32, shift = 0), 1)
assertEquals(h.physicalIdx_public(3, size = 32, shift = 0), 1)
assertEquals(h.physicalIdx_public(62, size = 32, shift = 0), 31)
assertEquals(h.physicalIdx_public(63, size = 32, shift = 0), 31)
for (logIdx <- 0 to 63) {
assertEquals(h.physicalIdx_public(logIdx, size = 64), logIdx)
assertEquals(h.physicalIdx_public(logIdx, size = 64, shift = 0), logIdx)
}
for (logIdx <- 0 to 15) {
assertEquals(h.physicalIdx_public(logIdx, size = 16, shift = 60), logIdx)
}
}

Expand Down Expand Up @@ -161,6 +205,18 @@ final class MutHamtSpec extends ScalaCheckSuite with MUnitUtils with PropertyHel
assertEquals(mutable.toArray.toList, immutable2.toArray.toList)
}

test("HAMT examples (3)") {
val k0 = 0L
val k1 = 1L
val k2 = 2L
val mutable = LongMutHamt.newEmpty()
mutable.insert(Val(k0))
mutable.insert(Val(k1))
mutable.insert(Val(k2))
val immutable = HamtSpec.LongHamt.empty.inserted(Val(k0)).inserted(Val(k1)).inserted(Val(k2))
assertEquals(mutable.toArray.toList, immutable.toArray.toList)
}

property("HAMT lookup/upsert/toArray (default generator)") {
forAll { (seed: Long, _nums: Set[Long]) =>
testBasics(seed, _nums)
Expand Down

0 comments on commit 9fdd3c2

Please sign in to comment.