From 218ac11953d99563534ff6ed5f005fc85cb2dd24 Mon Sep 17 00:00:00 2001 From: Daniel Urban Date: Mon, 4 Nov 2024 21:51:36 +0100 Subject: [PATCH] WIP: make sure to actually call the notify method --- .../internal/mcas/SimpleMemoryLocation.scala | 2 +- .../choam/internal/mcas/SpinLockMcas.scala | 1 + .../internal/mcas/ThreadConfinedMCAS.scala | 1 + .../tauri/choam/internal/mcas/mcasSpec.scala | 37 +++++++++++++++++++ 4 files changed, 40 insertions(+), 1 deletion(-) diff --git a/mcas/jvm/src/main/scala/dev/tauri/choam/internal/mcas/SimpleMemoryLocation.scala b/mcas/jvm/src/main/scala/dev/tauri/choam/internal/mcas/SimpleMemoryLocation.scala index 9f69719c..5d9bea4f 100644 --- a/mcas/jvm/src/main/scala/dev/tauri/choam/internal/mcas/SimpleMemoryLocation.scala +++ b/mcas/jvm/src/main/scala/dev/tauri/choam/internal/mcas/SimpleMemoryLocation.scala @@ -22,7 +22,7 @@ package mcas import java.lang.ref.WeakReference import java.util.concurrent.atomic.{ AtomicReference, AtomicLong } -private final class SimpleMemoryLocation[A](initial: A)( +private class SimpleMemoryLocation[A](initial: A)( override val id: Long, ) extends AtomicReference[A](initial) with MemoryLocation[A] { diff --git a/mcas/jvm/src/main/scala/dev/tauri/choam/internal/mcas/SpinLockMcas.scala b/mcas/jvm/src/main/scala/dev/tauri/choam/internal/mcas/SpinLockMcas.scala index 43138f6a..4a9ddc61 100644 --- a/mcas/jvm/src/main/scala/dev/tauri/choam/internal/mcas/SpinLockMcas.scala +++ b/mcas/jvm/src/main/scala/dev/tauri/choam/internal/mcas/SpinLockMcas.scala @@ -171,6 +171,7 @@ private object SpinLockMcas extends Mcas.UnsealedMcas { self => val wit = head.address.unsafeCmpxchgVersionV(ov, newVersion) assert(wit == ov) head.address.unsafeSetV(head.nv) + head.address.unsafeNotifyListeners() commit(tail, newVersion) } } diff --git a/mcas/shared/src/main/scala/dev/tauri/choam/internal/mcas/ThreadConfinedMCAS.scala b/mcas/shared/src/main/scala/dev/tauri/choam/internal/mcas/ThreadConfinedMCAS.scala index 899512c4..4c17320e 100644 --- a/mcas/shared/src/main/scala/dev/tauri/choam/internal/mcas/ThreadConfinedMCAS.scala +++ b/mcas/shared/src/main/scala/dev/tauri/choam/internal/mcas/ThreadConfinedMCAS.scala @@ -77,6 +77,7 @@ private object ThreadConfinedMCAS extends ThreadConfinedMCASPlatform { self => val ov = wd.address.unsafeGetVersionV() val wit = wd.address.unsafeCmpxchgVersionV(ov, newVersion) assert(wit == ov) + wd.address.unsafeNotifyListeners() execute(it, newVersion) } } diff --git a/mcas/shared/src/test/scala/dev/tauri/choam/internal/mcas/mcasSpec.scala b/mcas/shared/src/test/scala/dev/tauri/choam/internal/mcas/mcasSpec.scala index f737d39c..212e7b01 100644 --- a/mcas/shared/src/test/scala/dev/tauri/choam/internal/mcas/mcasSpec.scala +++ b/mcas/shared/src/test/scala/dev/tauri/choam/internal/mcas/mcasSpec.scala @@ -19,6 +19,8 @@ package dev.tauri.choam package internal package mcas +import java.util.concurrent.atomic.AtomicInteger + final class McasSpecThreadConfinedMcas extends McasSpec with SpecThreadConfinedMcas @@ -579,4 +581,39 @@ abstract class McasSpec extends BaseSpec { this: McasImplSpec => } assertEquals(ctx.readVersion(ref), ver) } + + test("subscriber notification should be called on success") { + val ctx = this.mcasImpl.currentContext() + val ctr = new AtomicInteger(0) + val ref = new SimpleMemoryLocation[String]("A")(ctx.refIdGen.nextId()) { + private[choam] final override def unsafeNotifyListeners(): Unit = { + ctr.getAndIncrement() + () + } + } + val d0 = ctx.start() + val Some((ov, d1)) = ctx.readMaybeFromLog(ref, d0) : @unchecked + val d2 = d1.overwrite(d1.getOrElseNull(ref).withNv(("B"))) + assert(ctx.tryPerformOk(d2)) + assertEquals(ctr.get(), 1) + assertEquals(ctx.readDirect(ref), "B") + } + + test("subscriber notification should NOT be called on failure") { + val ctx = this.mcasImpl.currentContext() + val ctr = new AtomicInteger(0) + val ref = new SimpleMemoryLocation[String]("A")(ctx.refIdGen.nextId()) { + private[choam] final override def unsafeNotifyListeners(): Unit = { + ctr.getAndIncrement() + () + } + } + val d0 = ctx.start() + val Some((ov, d1)) = ctx.readMaybeFromLog(ref, d0) : @unchecked + val e = d1.getOrElseNull(ref) + val d2 = d1.overwrite(LogEntry(e.address, "X", "B", e.version)) + assert(!ctx.tryPerformOk(d2)) + assertEquals(ctr.get(), 0) + assertEquals(ctx.readDirect(ref), "A") + } }