From eea9bb4d4fe83d59802c695b6f077351b34b6736 Mon Sep 17 00:00:00 2001 From: Daniel Urban Date: Mon, 28 Oct 2024 18:40:41 +0100 Subject: [PATCH] WIP: prepare for TRef having listeners --- .../choam/internal/mcas/MemoryLocation.scala | 10 ++++ .../scala/dev/tauri/choam/stm/TRefImpl.scala | 50 ++++++++++++++++++- 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/mcas/shared/src/main/scala/dev/tauri/choam/internal/mcas/MemoryLocation.scala b/mcas/shared/src/main/scala/dev/tauri/choam/internal/mcas/MemoryLocation.scala index 4ad62087..2f956ad8 100644 --- a/mcas/shared/src/main/scala/dev/tauri/choam/internal/mcas/MemoryLocation.scala +++ b/mcas/shared/src/main/scala/dev/tauri/choam/internal/mcas/MemoryLocation.scala @@ -103,6 +103,11 @@ trait MemoryLocation[A] extends Hamt.HasHash { final override def hash: Long = this.id + // listeners (for STM): + + private[choam] def withListeners: MemoryLocation.WithListeners[A] = + throw new UnsupportedOperationException + // private utilities: private[mcas] final def cast[B]: MemoryLocation[B] = @@ -111,6 +116,11 @@ trait MemoryLocation[A] extends Hamt.HasHash { object MemoryLocation extends MemoryLocationInstances0 { + private[choam] trait WithListeners[A] { + private[choam] def unsafeRegisterListener(listener: Null => Unit, lastSeenVersion: Long): Long + private[choam] def unsafeCancelListener(lid: Long): Unit + } + def unsafe[A](initial: A): MemoryLocation[A] = // TODO: remove this unsafeUnpadded[A](initial) diff --git a/stm/jvm/src/main/scala/dev/tauri/choam/stm/TRefImpl.scala b/stm/jvm/src/main/scala/dev/tauri/choam/stm/TRefImpl.scala index a169fe21..14801286 100644 --- a/stm/jvm/src/main/scala/dev/tauri/choam/stm/TRefImpl.scala +++ b/stm/jvm/src/main/scala/dev/tauri/choam/stm/TRefImpl.scala @@ -21,12 +21,14 @@ package stm import java.lang.ref.WeakReference import java.util.concurrent.atomic.{ AtomicReference, AtomicLong } +import scala.collection.immutable.LongMap + import internal.mcas.MemoryLocation private final class TRefImpl[F[_], A]( initial: A, final override val id: Long, -) extends MemoryLocation[A] with TRef.UnsealedTRef[F, A] { +) extends MemoryLocation[A] with MemoryLocation.WithListeners[A] with TRef.UnsealedTRef[F, A] { // TODO: use VarHandles @@ -39,6 +41,12 @@ private final class TRefImpl[F[_], A]( private[this] val marker = new AtomicReference[WeakReference[AnyRef]] + private[this] val listeners = + new AtomicReference[LongMap[Null => Unit]](LongMap.empty) + + private[this] val nextListenerId = + new AtomicLong(java.lang.Long.MIN_VALUE) + final override def unsafeGetV(): A = contents.get() @@ -90,4 +98,44 @@ private final class TRefImpl[F[_], A]( // identity) is fine for us. this.id.toInt } + + private[choam] final override def withListeners: this.type = + this + + private[choam] final override def unsafeRegisterListener(listener: Null => Unit, lastSeenVersion: Long): Long = { + val lid = nextListenerId.incrementAndGet() // could be opaque + assert(lid != java.lang.Long.MIN_VALUE) // detect overflow + + @tailrec + def go(ov: LongMap[Null => Unit]): Long = { + val nv = ov.updated(lid, listener) + val wit = listeners.compareAndExchange(ov, nv) + if (wit eq ov) { + lid + } else { + go(wit) + } + } + + go(listeners.get()) + + // TODO: double-check concurrent version change + // TODO: actually call listeners when needed + } + + private[choam] final override def unsafeCancelListener(lid: Long): Unit = { + + @tailrec + def go(ov: LongMap[Null => Unit]): Unit = { + val nv = ov.removed(lid) + if (nv ne ov) { + val wit = listeners.compareAndExchange(ov, nv) + if (wit ne ov) { + go(wit) + } // else: we're done + } // else: we're done + } + + go(listeners.get()) + } }