Skip to content

Commit

Permalink
WIP: add some tests, fix some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
durban committed Nov 6, 2024
1 parent f7c866a commit 4c1d005
Show file tree
Hide file tree
Showing 14 changed files with 74 additions and 22 deletions.
9 changes: 8 additions & 1 deletion core/shared/src/main/scala/dev/tauri/choam/core/Txn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package dev.tauri.choam
package core

import internal.mcas.Mcas

private[choam] trait Txn[F[_], +B] { // TODO: sealed

def map[C](f: B => C): Txn[F, C]
Expand All @@ -36,5 +38,10 @@ private[choam] object Txn {
Rxn.pure(a).castF[F]

final def retry[F[_], A]: Txn[F, A] =
Rxn.unsafe.retry[Any, A].castF[F]
Rxn.unsafe.retry[Any, A].castF[F] // TODO: retry when changed

private[choam] final object unsafe {
private[choam] final def delayContext[F[_], A](uf: Mcas.ThreadContext => A): Txn[F, A] =
Rxn.unsafe.delayContext(uf).castF[F]
}
}
11 changes: 11 additions & 0 deletions core/shared/src/test/scala/dev/tauri/choam/RefSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import cats.kernel.{ Order, Hash }
import cats.effect.IO
import cats.effect.kernel.{ Ref => CatsRef }

import internal.mcas.MemoryLocation

final class RefSpec_Real_ThreadConfinedMcas_IO
extends BaseSpecIO
with SpecThreadConfinedMcas
Expand Down Expand Up @@ -264,4 +266,13 @@ trait RefLikeSpec[F[_]] extends BaseSpecAsyncF[F] { this: McasImplSpec =>
_ <- testCatsRef[Axn](r, initial = "a", run = this.rF)
} yield ()
}

test("Regular Ref shouldn't have .withListeners") {
for {
r <- newRef("a")
loc <- F.delay(r.asInstanceOf[MemoryLocation[String]])
e = Either.catchOnly[UnsupportedOperationException] { loc.withListeners }
_ <- assertF(e.isLeft)
} yield ()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ private object SpinLockMcas extends Mcas.UnsealedMcas { self =>
go(ref.unsafeGetVersionV())
}

protected[mcas] final override def readVersion[A](ref: MemoryLocation[A]): Long = {
protected[choam] final override def readVersion[A](ref: MemoryLocation[A]): Long = {
@tailrec
def go(ver1: Long): Long = {
val _ = readOne(ref)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ private[mcas] final class Emcas extends GlobalContext { global =>
// help:
if (forMCAS) {
if (helpMCASforMCAS(parent, ctx = ctx, seen = seen, instRo = instRo)) {
// Note: `forMCAS` is true here, so we can return a reserved version
EmcasStatus.CycleDetected
} else {
// retry:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ private final class EmcasThreadContext(
hwd
}

protected[mcas] final override def readVersion[A](ref: MemoryLocation[A]): Long =
protected[choam] final override def readVersion[A](ref: MemoryLocation[A]): Long =
impl.readVersion(ref, this)

final override def start(): MutDescriptor =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ object FlakyEMCAS extends Mcas.UnsealedMcas { self =>
final override def readIntoHwd[A](ref: MemoryLocation[A]): LogEntry[A] =
emcasCtx.readIntoHwd(ref)

protected[mcas] final override def readVersion[A](ref: MemoryLocation[A]): Long =
protected[choam] final override def readVersion[A](ref: MemoryLocation[A]): Long =
emcasCtx.readVersion(ref)

final override def tryPerformInternal(desc: AbstractDescriptor, optimism: Long): Long =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ object Mcas extends McasCompanionPlatform { self =>
* @return the current version of `ref`, as if
* read by `readIntoHwd(ref).version`.
*/
private[mcas] def readVersion[A](ref: MemoryLocation[A]): Long
private[choam] def readVersion[A](ref: MemoryLocation[A]): Long

def validateAndTryExtend(
desc: AbstractDescriptor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,13 @@ trait MemoryLocation[A] extends Hamt.HasHash {
object MemoryLocation extends MemoryLocationInstances0 {

private[choam] trait WithListeners {
private[choam] def unsafeRegisterListener(listener: Null => Unit, lastSeenVersion: Long): Long

private[choam] def unsafeRegisterListener(
ctx: Mcas.ThreadContext,
listener: Null => Unit,
lastSeenVersion: Long,
): Long

private[choam] def unsafeCancelListener(lid: Long): Unit
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ object NullMcas extends Mcas.UnsealedMcas { self =>
final override def readIntoHwd[A](ref: MemoryLocation[A]): LogEntry[A] =
throw new UnsupportedOperationException

private[mcas] final override def readVersion[A](ref: MemoryLocation[A]): Long =
private[choam] final override def readVersion[A](ref: MemoryLocation[A]): Long =
throw new UnsupportedOperationException

final override def validateAndTryExtend(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ private object ThreadConfinedMCAS extends ThreadConfinedMCASPlatform { self =>
LogEntry(ref, ov = v, nv = v, version = ref.unsafeGetVersionV())
}

protected[mcas] final override def readVersion[A](ref: MemoryLocation[A]): Long =
protected[choam] final override def readVersion[A](ref: MemoryLocation[A]): Long =
ref.unsafeGetVersionV()

final override def tryPerformInternal(desc: AbstractDescriptor, optimism: Long): Long = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ final class LogMap2Spec extends ScalaCheckSuite { self =>
self.fail("not implemented")
def readIntoHwd[A](ref: MemoryLocation[A]): LogEntry[A] =
self.fail("not implemented")
private[mcas] def readVersion[A](ref: MemoryLocation[A]): Long =
private[choam] def readVersion[A](ref: MemoryLocation[A]): Long =
if (ref eq r1) r1Version else 42L // we simulate one of the refs changing version
def refIdGen: RefIdGen =
RefIdGen.global
Expand Down
14 changes: 8 additions & 6 deletions stm/js/src/main/scala/dev/tauri/choam/stm/TRefImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ import java.lang.ref.WeakReference

import scala.collection.mutable.{ LongMap => MutLongMap }

import internal.mcas.MemoryLocation
import internal.mcas.Consts
import internal.mcas.{ Mcas, MemoryLocation, Consts }

private final class TRefImpl[F[_], A](
initial: A,
Expand Down Expand Up @@ -114,17 +113,20 @@ private final class TRefImpl[F[_], A](
private[choam] final override def withListeners: this.type =
this

private[choam] final override def unsafeRegisterListener(listener: Null => Unit, lastSeenVersion: Long): Long = {
private[choam] final override def unsafeRegisterListener(
ctx: Mcas.ThreadContext,
listener: Null => Unit,
lastSeenVersion: Long,
): Long = {
val lid = previousListenerId + 1L
previousListenerId = lid
assert(lid != Consts.InvalidListenerId) // detect overflow

listeners.put(lid, listener) : Unit
val currVer = this.unsafeGetVersionV()
val currVer = ctx.readVersion(this)
if (currVer != lastSeenVersion) {
listeners.remove(lid) : Unit
Consts.InvalidListenerId
} else {
listeners.put(lid, listener) : Unit
lid
}
}
Expand Down
20 changes: 15 additions & 5 deletions stm/jvm/src/main/scala/dev/tauri/choam/stm/TRefImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ import java.util.concurrent.atomic.{ AtomicReference, AtomicLong }

import scala.collection.immutable.LongMap

import internal.mcas.MemoryLocation
import internal.mcas.Consts
import internal.mcas.{ Mcas, MemoryLocation, Consts }

private final class TRefImpl[F[_], A](
initial: A,
Expand Down Expand Up @@ -103,7 +102,11 @@ private final class TRefImpl[F[_], A](
private[choam] final override def withListeners: this.type =
this

private[choam] final override def unsafeRegisterListener(listener: Null => Unit, lastSeenVersion: Long): Long = {
private[choam] final override def unsafeRegisterListener(
ctx: Mcas.ThreadContext,
listener: Null => Unit,
lastSeenVersion: Long,
): Long = {
val lid = previousListenerId.incrementAndGet() // could be opaque
assert(lid != Consts.InvalidListenerId) // detect overflow

Expand All @@ -117,9 +120,16 @@ private final class TRefImpl[F[_], A](
}

go(listeners.get())
val currVer = this.unsafeGetVersionV()
val currVer = ctx.readVersion(this)
if (currVer != lastSeenVersion) {
// TODO: remove the listener we inserted (to not to leak memory)
// already changed since our caller last seen it
// (it is possible that the callback will be called
// anyway, since there is a race between double-
// checking the version and a possible notification;
// it is the responsibility of the caller to check
// the return value of this method, and ignore calls
// to the callback if we return `InvalidListenerId`)
unsafeCancelListener(lid)
Consts.InvalidListenerId
} else {
lid
Expand Down
19 changes: 17 additions & 2 deletions stm/shared/src/test/scala/dev/tauri/choam/stm/TxnSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,36 @@ trait TxnSpec[F[_]] extends TxnBaseSpec[F] { this: McasImplSpec =>
test("TRef should have .withListeners") {
def incr(ref: TRef[F, Int]): F[Unit] =
ref.get.flatMap { ov => ref.set(ov + 1) }.commit
def getVersion(loc: MemoryLocation[Int]): F[Long] =
Txn.unsafe.delayContext { ctx => ctx.readIntoHwd(loc).version }.commit
def regListener(wl: MemoryLocation.WithListeners, cb: Null => Unit, lastSeenVersion: Long): F[Long] =
Txn.unsafe.delayContext { ctx => wl.unsafeRegisterListener(ctx, cb, lastSeenVersion) }.commit
def check(ref: TRef[F, Int]): F[Unit] = for {
loc <- F.delay(ref.asInstanceOf[MemoryLocation[Int]])
wl <- F.delay(loc.withListeners)
ctr <- F.delay(new AtomicInteger(0))
lid <- F.delay(wl.unsafeRegisterListener({ _ => ctr.getAndIncrement(); () }, loc.unsafeGetVersionV()))
firstVersion <- getVersion(loc)
// registered listener should be called:
lid <- regListener(wl, { _ => ctr.getAndIncrement(); () }, firstVersion)
_ <- assertNotEqualsF(lid, Consts.InvalidListenerId)
_ <- incr(ref)
_ <- F.delay(assertEquals(ctr.get(), 1))
// after it was called once, it shouldn't any more:
_ <- incr(ref)
_ <- F.delay(assertEquals(ctr.get(), 1))
lid <- F.delay(wl.unsafeRegisterListener({ _ => ctr.getAndIncrement(); () }, loc.unsafeGetVersionV()))
// registered, but then cancelled listener shouldn't be called:
otherVersion <- getVersion(loc)
_ <- assertF(firstVersion < otherVersion)
lid <- regListener(wl, { _ => ctr.getAndIncrement(); () }, otherVersion)
_ <- assertNotEqualsF(lid, Consts.InvalidListenerId)
_ <- F.delay(wl.unsafeCancelListener(lid))
_ <- incr(ref)
_ <- F.delay(assertEquals(ctr.get(), 1))
// failed registration due to outdated `lastSeenVersion`:
lid <- regListener(wl, { _ => ctr.getAndIncrement(); () }, firstVersion)
_ <- assertEqualsF(lid, Consts.InvalidListenerId)
_ <- incr(ref)
_ <- F.delay(assertEquals(ctr.get(), 1))
} yield ()

for {
Expand Down

0 comments on commit 4c1d005

Please sign in to comment.