diff --git a/mcas/shared/src/main/scala/dev/tauri/choam/internal/mcas/MutHamt.scala b/mcas/shared/src/main/scala/dev/tauri/choam/internal/mcas/MutHamt.scala index bddefd04..016a61a5 100644 --- a/mcas/shared/src/main/scala/dev/tauri/choam/internal/mcas/MutHamt.scala +++ b/mcas/shared/src/main/scala/dev/tauri/choam/internal/mcas/MutHamt.scala @@ -28,6 +28,11 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I private var contents: Array[AnyRef], ) extends AbstractHamt[K, V, E, T1, T2, H] { this: H => + // TODO: We're often recomputing things like physIdx + // TODO: or logIdxWidth; figure out if we can be + // TODO: faster by not repeating these computations + // TODO: (by inlining methods if necessary). + require(contents.length > 0) private[this] final val START_MASK = 0xFC00000000000000L @@ -175,7 +180,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I val contents = this.contents val logIdx = logicalIdx(hash, shift) val size = contents.length // always a power of 2 - val physIdx = physicalIdx(logIdx, size) + val physIdx = physicalIdx(logIdx, size, shift = shift) contents(physIdx) } @@ -217,7 +222,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I case null => packSizeDiffAndBlue(0, true) case newVal => - this.growLevel(newSize = necessarySize(logIdx, nodeLogIdx), shift = shift) + this.growLevel(newSize = necessarySize(logIdx, nodeLogIdx, shift = shift), shift = shift) // now we can insert the new value: this.visit(k, hash, tok, visitor, newValue = newVal, modify = modify, shift = shift) } @@ -268,7 +273,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I case null => packSizeDiffAndBlue(0, true) case newVal => - this.growLevel(newSize = necessarySize(logIdx, logIdxA), shift = shift) + this.growLevel(newSize = necessarySize(logIdx, logIdxA, shift = shift), shift = shift) this.visit(k, hash, tok, visitor, newValue = newVal, modify = modify, shift = shift) } } else { @@ -284,7 +289,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I val contents = this.contents val logIdx = logicalIdx(hash, shift) val size = contents.length // always a power of 2 - val physIdx = physicalIdx(logIdx, size) + val physIdx = physicalIdx(logIdx, size, shift = shift) contents(physIdx) match { case null => if (op == OP_UPDATE) { @@ -306,7 +311,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I throw new IllegalArgumentException } else { // we need to grow this level: - this.growLevel(newSize = necessarySize(logIdx, nodeLogIdx), shift = shift) + this.growLevel(newSize = necessarySize(logIdx, nodeLogIdx, shift = shift), shift = shift) // now we'll suceed for sure: this.insertOrOverwrite(hash, value, shift, op) } @@ -327,12 +332,13 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I if (logIdx == oLogIdx) { // hash collision at this level, // so we go down 1 level + val childShift = shift + W val childNode = { val cArr = new Array[AnyRef](2) // NB: 2 instead of 1 - cArr(physicalIdx(logicalIdx(oh, shift + W), size = 2)) = ov + cArr(physicalIdx(logicalIdx(oh, childShift), size = 2, shift = childShift)) = ov this.newNode(logIdx, cArr) } - val r = childNode.insertOrOverwrite(hash, value, shift = shift + W, op = op) + val r = childNode.insertOrOverwrite(hash, value, shift = childShift, op = op) contents(physIdx) = childNode assert(unpackSizeDiff(r) == 1) r @@ -343,7 +349,7 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I throw new IllegalArgumentException } else { // grow this level: - this.growLevel(newSize = necessarySize(logIdx, oLogIdx), shift = shift) + this.growLevel(newSize = necessarySize(logIdx, oLogIdx, shift = shift), shift = shift) // now we'll suceed for sure: this.insertOrOverwrite(hash, value, shift, op) } @@ -365,11 +371,11 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I () case node: MutHamt[_, _, _, _, _, _, _] => val logIdx = node.logIdx - val newPhysIdx = physicalIdx(logIdx, newSize) + val newPhysIdx = physicalIdx(logIdx, newSize, shift = shift) newContents(newPhysIdx) = node case value => val logIdx = logicalIdx(hashOf(keyOf(value.asInstanceOf[V])), shift) - val newPhysIdx = physicalIdx(logIdx, newSize) + val newPhysIdx = physicalIdx(logIdx, newSize, shift = shift) newContents(newPhysIdx) = value } idx += 1 @@ -423,16 +429,16 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I this.newImmutableNode(sizeAndBlue = packSizeAndBlue(size, isBlueSubtree), bitmap = bitmap, contents = arr) } - private[this] final def necessarySize(logIdx1: Int, logIdx2: Int): Int = { + private[this] final def necessarySize(logIdx1: Int, logIdx2: Int, shift: Int): Int = { assert(logIdx1 != logIdx2) val diff = logIdx1 ^ logIdx2 - val necessaryBits = java.lang.Integer.numberOfLeadingZeros(diff) - (32 - W) + 1 + val necessaryBits = java.lang.Integer.numberOfLeadingZeros(diff) - (32 - logIdxWidth(shift)) + 1 assert(necessaryBits <= W) 1 << necessaryBits } - private[mcas] final def necessarySize_public(logIdx1: Int, logIdx2: Int): Int = { - this.necessarySize(logIdx1, logIdx2) + private[mcas] final def necessarySize_public(logIdx1: Int, logIdx2: Int, shift: Int): Int = { + this.necessarySize(logIdx1, logIdx2, shift = shift) } /** Index into the imaginary 64-element sparse array */ @@ -442,17 +448,26 @@ private[mcas] abstract class MutHamt[K, V, E, T1, T2, I <: Hamt[_, _, _, _, _, I ((hash & mask) >>> sh).toInt } - private[this] final def physicalIdx(logIdx: Int, size: Int): Int = { + private[mcas] final def logicalIdx_public(hash: Long, shift: Int): Int = { + logicalIdx(hash, shift) + } + + @inline + private[this] final def logIdxWidth(shift: Int): Int = { + java.lang.Long.bitCount(START_MASK >>> shift) + } + + private[this] final def physicalIdx(logIdx: Int, size: Int, shift: Int): Int = { assert((logIdx >= 0) && (logIdx < 64)) - // size is always a power of 2 - val width = java.lang.Integer.numberOfTrailingZeros(size) - assert(width <= W) - val sh = W - width + val width = java.lang.Integer.numberOfTrailingZeros(size) // size is always a power of 2 + val logIdxW = logIdxWidth(shift) + assert((width <= logIdxW) && (logIdxW <= W)) + val sh = logIdxW - width logIdx >>> sh } - private[mcas] final def physicalIdx_public(logIdx: Int, size: Int): Int = { - physicalIdx(logIdx, size) + private[mcas] final def physicalIdx_public(logIdx: Int, size: Int, shift: Int): Int = { + physicalIdx(logIdx, size, shift = shift) } private[this] final def packSizeDiffAndBlue(sizeDiff: Int, isBlue: Boolean): Int = { diff --git a/mcas/shared/src/test/scala/dev/tauri/choam/internal/mcas/MutHamtSpec.scala b/mcas/shared/src/test/scala/dev/tauri/choam/internal/mcas/MutHamtSpec.scala index c5d3f829..dd7fb68e 100644 --- a/mcas/shared/src/test/scala/dev/tauri/choam/internal/mcas/MutHamtSpec.scala +++ b/mcas/shared/src/test/scala/dev/tauri/choam/internal/mcas/MutHamtSpec.scala @@ -28,6 +28,7 @@ import cats.syntax.all._ import munit.ScalaCheckSuite +import org.scalacheck.Gen import org.scalacheck.Prop.forAll final class MutHamtSpec extends ScalaCheckSuite with MUnitUtils with PropertyHelpers { @@ -40,34 +41,77 @@ final class MutHamtSpec extends ScalaCheckSuite with MUnitUtils with PropertyHel p.withMaxSize(p.maxSize * (if (isJvm()) 32 else 2)) } - // TODO: "HAMT logicalIdx" + property("HAMT logicalIdx") { + val h = LongMutHamt.newEmpty() + + def testLogicalIdx(n: Long): Unit = { + assertEquals(h.logicalIdx_public(n, shift = 0), ( n >>> 58 ).toInt) + assertEquals(h.logicalIdx_public(n, shift = 6), ((n >>> 52) & 63L).toInt) + assertEquals(h.logicalIdx_public(n, shift = 12), ((n >>> 46) & 63L).toInt) + assertEquals(h.logicalIdx_public(n, shift = 18), ((n >>> 40) & 63L).toInt) + assertEquals(h.logicalIdx_public(n, shift = 24), ((n >>> 34) & 63L).toInt) + assertEquals(h.logicalIdx_public(n, shift = 30), ((n >>> 28) & 63L).toInt) + assertEquals(h.logicalIdx_public(n, shift = 36), ((n >>> 22) & 63L).toInt) + assertEquals(h.logicalIdx_public(n, shift = 42), ((n >>> 16) & 63L).toInt) + assertEquals(h.logicalIdx_public(n, shift = 48), ((n >>> 10) & 63L).toInt) + assertEquals(h.logicalIdx_public(n, shift = 54), ((n >>> 4) & 63L).toInt) + assertEquals(h.logicalIdx_public(n, shift = 60), ( n & 15L).toInt) // this is the tricky one + } + + val prop1 = forAll(Gen.choose(Long.MinValue, Long.MaxValue)) { (n: Long) => + testLogicalIdx(n) + } + + val prop2 = forAll { (n: Long) => + testLogicalIdx(n) + } + + prop1 && prop2 + } test("necessarySize") { val h = LongMutHamt.newEmpty() - assertEquals(h.necessarySize_public(32, 0), 2) - assertEquals(h.necessarySize_public(63, 31), 2) - assertEquals(h.necessarySize_public(16, 0), 4) - assertEquals(h.necessarySize_public(31, 15), 4) - assertEquals(h.necessarySize_public(0, 1), 64) - assertEquals(h.necessarySize_public(54, 55), 64) + assertEquals(h.necessarySize_public(32, 0, shift = 0), 2) + assertEquals(h.necessarySize_public(63, 31, shift = 0), 2) + assertEquals(h.necessarySize_public(16, 0, shift = 0), 4) + assertEquals(h.necessarySize_public(31, 15, shift = 0), 4) + assertEquals(h.necessarySize_public(0, 1, shift = 0), 64) + assertEquals(h.necessarySize_public(0, 1, shift = 6), 64) + assertEquals(h.necessarySize_public(0, 1, shift = 12), 64) + assertEquals(h.necessarySize_public(0, 1, shift = 54), 64) + assertEquals(h.necessarySize_public(0, 1, shift = 60), 16) + assertEquals(h.necessarySize_public(54, 55, shift = 0), 64) } test("physicalIdx") { val h = LongMutHamt.newEmpty() - assertEquals(h.physicalIdx_public(0, size = 1), 0) - assertEquals(h.physicalIdx_public(63, size = 1), 0) - assertEquals(h.physicalIdx_public(0, size = 2), 0) - assertEquals(h.physicalIdx_public(31, size = 2), 0) - assertEquals(h.physicalIdx_public(32, size = 2), 1) - assertEquals(h.physicalIdx_public(63, size = 2), 1) - assertEquals(h.physicalIdx_public(0, size = 32), 0) - assertEquals(h.physicalIdx_public(1, size = 32), 0) - assertEquals(h.physicalIdx_public(2, size = 32), 1) - assertEquals(h.physicalIdx_public(3, size = 32), 1) - assertEquals(h.physicalIdx_public(62, size = 32), 31) - assertEquals(h.physicalIdx_public(63, size = 32), 31) + assertEquals(h.physicalIdx_public(0, size = 1, shift = 0), 0) + assertEquals(h.physicalIdx_public(63, size = 1, shift = 0), 0) + assertEquals(h.physicalIdx_public(0, size = 2, shift = 0), 0) + assertEquals(h.physicalIdx_public(31, size = 2, shift = 0), 0) + assertEquals(h.physicalIdx_public(32, size = 2, shift = 0), 1) + assertEquals(h.physicalIdx_public(63, size = 2, shift = 0), 1) + assertEquals(h.physicalIdx_public(0, size = 16, shift = 0), 0) + assertEquals(h.physicalIdx_public(1, size = 16, shift = 0), 0) + assertEquals(h.physicalIdx_public(4, size = 16, shift = 0), 1) + assertEquals(h.physicalIdx_public(0, size = 16, shift = 6), 0) + assertEquals(h.physicalIdx_public(1, size = 16, shift = 6), 0) + assertEquals(h.physicalIdx_public(4, size = 16, shift = 6), 1) + assertEquals(h.physicalIdx_public(0, size = 16, shift = 60), 0) + assertEquals(h.physicalIdx_public(1, size = 16, shift = 60), 1) + assertEquals(h.physicalIdx_public(4, size = 16, shift = 60), 4) + assertEquals(h.physicalIdx_public(16, size = 16, shift = 60), 16) + assertEquals(h.physicalIdx_public(0, size = 32, shift = 0), 0) + assertEquals(h.physicalIdx_public(1, size = 32, shift = 0), 0) + assertEquals(h.physicalIdx_public(2, size = 32, shift = 0), 1) + assertEquals(h.physicalIdx_public(3, size = 32, shift = 0), 1) + assertEquals(h.physicalIdx_public(62, size = 32, shift = 0), 31) + assertEquals(h.physicalIdx_public(63, size = 32, shift = 0), 31) for (logIdx <- 0 to 63) { - assertEquals(h.physicalIdx_public(logIdx, size = 64), logIdx) + assertEquals(h.physicalIdx_public(logIdx, size = 64, shift = 0), logIdx) + } + for (logIdx <- 0 to 15) { + assertEquals(h.physicalIdx_public(logIdx, size = 16, shift = 60), logIdx) } } @@ -161,6 +205,18 @@ final class MutHamtSpec extends ScalaCheckSuite with MUnitUtils with PropertyHel assertEquals(mutable.toArray.toList, immutable2.toArray.toList) } + test("HAMT examples (3)") { + val k0 = 0L + val k1 = 1L + val k2 = 2L + val mutable = LongMutHamt.newEmpty() + mutable.insert(Val(k0)) + mutable.insert(Val(k1)) + mutable.insert(Val(k2)) + val immutable = HamtSpec.LongHamt.empty.inserted(Val(k0)).inserted(Val(k1)).inserted(Val(k2)) + assertEquals(mutable.toArray.toList, immutable.toArray.toList) + } + property("HAMT lookup/upsert/toArray (default generator)") { forAll { (seed: Long, _nums: Set[Long]) => testBasics(seed, _nums)