Skip to content

Commit

Permalink
WIP: work towards a proper STM
Browse files Browse the repository at this point in the history
  • Loading branch information
durban committed Nov 9, 2024
1 parent 1fb2aab commit a2008b7
Show file tree
Hide file tree
Showing 6 changed files with 353 additions and 33 deletions.
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ lazy val core = crossProject(JVMPlatform, JSPlatform)
// "eu.timepit" %%% "refined" % "0.11.1",
),
mimaBinaryIssueFilters ++= Seq(
ProblemFilters.exclude[MissingClassProblem]("dev.tauri.choam.core.Rxn$Suspend") // private
ProblemFilters.exclude[MissingClassProblem]("dev.tauri.choam.core.Rxn$Suspend"), // private
ProblemFilters.exclude[DirectMissingMethodProblem]("dev.tauri.choam.core.Rxn#InterpreterState.this"), // private
),
)

Expand Down
218 changes: 190 additions & 28 deletions core/shared/src/main/scala/dev/tauri/choam/core/Rxn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ sealed abstract class Rxn[-A, +B] // short for 'reaction'
x = a,
mcas = mcas,
strategy = strategy,
isStm = false,
).interpretSync()
}

Expand All @@ -267,6 +268,7 @@ sealed abstract class Rxn[-A, +B] // short for 'reaction'
a,
mcas = mcas,
strategy = strategy,
isStm = false,
).interpretAsync(F)
}

Expand All @@ -287,6 +289,7 @@ sealed abstract class Rxn[-A, +B] // short for 'reaction'
a,
ctx.impl,
strategy = str,
isStm = false,
).interpretSyncWithContext(ctx)
}

Expand All @@ -298,18 +301,35 @@ sealed abstract class Rxn[-A, +B] // short for 'reaction'
this.flatMapF { b => f(b).impl }
}

final override def orElse[Y >: B](that: Txn[Rxn.Anything, Y]): Txn[Rxn.Anything, Y] = {
this + that.impl // TODO: orElse/+ semantics
}

private[choam] final def castF[F[_]]: Txn[F, B] =
this.asInstanceOf[Txn[F, B]]

private[core] final override def impl: Axn[B] =
this.asInstanceOf[Rxn[Any, B]] // Note: this is unsafe in general, we must take care to only use it on Txns

private[choam] final def performStm[F[_], X >: B](
a: A,
mcas: Mcas,
)(implicit F: Async[F]): F[X] = {
new InterpreterState[A, X](
this,
a,
mcas = mcas,
strategy = RetryStrategy.sleep(),
isStm = true,
).interpretAsync(F)
}

// /STM
}

object Rxn extends RxnInstances0 {

type Anything[_]
final abstract class Anything[A]

private[this] final val interruptCheckPeriod =
16384
Expand Down Expand Up @@ -486,6 +506,11 @@ object Rxn extends RxnInstances0 {
): Rxn[D, Unit] = new FinishExchange(hole, restOtherContK, lenSelfContT)
}

private[choam] final object StmImpl {
private[choam] final def retryWhenChanged[A]: Axn[A] =
new RetryWhenChanged[A]
}

// Representation:

/** Only the interpreter can use this! */
Expand All @@ -494,6 +519,7 @@ object Rxn extends RxnInstances0 {
final override def toString: String = "Commit()"
}

// TODO: this could be a singleton
private final class AlwaysRetry[A, B]() extends Rxn[A, B] {
private[core] final override def tag = 1
final override def toString: String = "AlwaysRetry()"
Expand All @@ -514,7 +540,11 @@ object Rxn extends RxnInstances0 {
final override def toString: String = "Computed(<function>)"
}

// tag = 5 (unused)
// TODO: this could be a singleton
private final class RetryWhenChanged[A]() extends Rxn[Any, A] { // STM
private[core] final override def tag = 5
final override def toString: String = "RetryWhenChanged()"
}

private final class Choice[A, B](val left: Rxn[A, B], val right: Rxn[A, B]) extends Rxn[A, B] {
private[core] final override def tag = 6
Expand Down Expand Up @@ -631,9 +661,110 @@ object Rxn extends RxnInstances0 {
}

/** Only the interpreter can use this! */
private final class SuspendUntil(val token: Long) extends Rxn[Any, Nothing] {
private sealed abstract class SuspendUntil extends Rxn[Any, Nothing] {

private[core] final override def tag = 27
final override def toString: String = s"SuspendUntil(${token.toHexString})"

def toF[F[_]](
mcasImpl: Mcas,
mcasCtx: Mcas.ThreadContext,
)(implicit F: Async[F]): F[Unit]
}

private final class SuspendUntilBackoff(val token: Long) extends SuspendUntil {

assert(!Backoff2.isPauseToken(token))

final override def toString: String =
s"SuspendUntilBackoff(${token.toHexString})"

final override def toF[F[_]](
mcasImpl: Mcas,
mcasCtx: Mcas.ThreadContext,
)(implicit F: Async[F]): F[Unit] =
Backoff2.tokenToF[F](token)
}

private final class SuspendUntilChanged(desc: AbstractDescriptor) extends SuspendUntil {

final override def toString: String =
s"SuspendUntilChanged(${desc})"

final override def toF[F[_]](
mcasImpl: Mcas,
mcasCtx: Mcas.ThreadContext,
)(implicit F: Async[F]): F[Unit] = {
if (desc ne null) {
F.asyncCheckAttempt[Unit] { cb =>
F.delay {
val rightUnit = Right(())
val cb2 = { (_: Null) =>
cb(rightUnit)
}
val (refs, cancelIds) = subscribe(mcasImpl, mcasCtx, cb2)
if (cancelIds eq null) {
// some ref already changed, we're done:
rightUnit
} else {
val cancelTsk = F.delay {
var idx = 0
while (idx < refs.length) {
refs(idx).unsafeCancelListener(cancelIds(idx))
idx += 1
}
}
Left(Some(cancelTsk))
// TODO: if one of the Refs wakes us, we still have to
// TODO: cancel all the other subscriptions (to not leak memory)
}
}
}
} else {
F.never
}
}

private[this] final def subscribe(
mcasImpl: Mcas,
mcasCtx: Mcas.ThreadContext,
cb: Null => Unit,
): (Array[MemoryLocation.WithListeners], Array[Long]) = {
val ctx = if (mcasImpl.isCurrentContext(mcasCtx)) {
mcasCtx
} else {
mcasImpl.currentContext()
}
val refs = new Array[MemoryLocation.WithListeners](desc.size)
val cancelIds = new Array[Long](desc.size)
val itr = desc.hwdIterator
var idx = 0
while (itr.hasNext) {
val hwd = itr.next()
val loc = hwd.address.withListeners
val cancelId = loc.unsafeRegisterListener(ctx, cb, hwd.oldVersion)
if (cancelId == Consts.InvalidListenerId) {
this.undoSubscribe(idx, refs, cancelIds)
return (null, null) // scalafix:ok
}
refs(idx) = loc
cancelIds(idx) = cancelId
idx += 1
}

(refs, cancelIds)
}

private[this] final def undoSubscribe(
count: Int,
refs: Array[MemoryLocation.WithListeners],
cancelIds: Array[Long],
): Unit = {
var idx = 0
while (idx < count) {
refs(idx).unsafeCancelListener(cancelIds(idx))
idx += 1
}
}
}

private final class TailRecM[X, A, B](val a: A, val f: A => Rxn[X, Either[A, B]]) extends Rxn[X, B] {
Expand Down Expand Up @@ -670,14 +801,18 @@ object Rxn extends RxnInstances0 {
x: X,
mcas: Mcas,
strategy: RetryStrategy,
isStm: Boolean,
) extends Hamt.EntryVisitor[MemoryLocation[Any], LogEntry[Any], Rxn[Any, Any]] {

private[this] val maxRetries: Int =
strategy.maxRetriesInt

private[this] val canSuspend: Boolean = {
val cs = strategy.canSuspend
assert((!cs) == strategy.isInstanceOf[RetryStrategy.Spin]) // just to be sure
assert( // just to be sure:
((!cs) == strategy.isInstanceOf[RetryStrategy.Spin]) &&
(cs || (!isStm))
)
cs
}

Expand Down Expand Up @@ -1033,9 +1168,12 @@ object Rxn extends RxnInstances0 {
}

private[this] final def retry(): Rxn[Any, Any] =
this.retry(this.canSuspend)
this.retry(canSuspend = this.canSuspend)

private[this] final def retry(canSuspend: Boolean): Rxn[Any, Any] =
this.retry(canSuspend = canSuspend, suspendUntilChanged = false)

private[this] final def retry(canSuspend: Boolean): Rxn[Any, Any] = {
private[this] final def retry(canSuspend: Boolean, suspendUntilChanged: Boolean): Rxn[Any, Any] = {
if (alts.nonEmpty) {
// we're not actually retrying,
// just going to the other side
Expand All @@ -1054,27 +1192,45 @@ object Rxn extends RxnInstances0 {
} else {
maybeCheckInterrupt(retriesNow)
}
// STM might still need these:
val d = if (this.isStm) this._desc else null
// TODO: we should also subscribe to refs we've read in *previous alts*
// restart everything:
clearDesc()
a = startA
resetConts()
pc.clear()
backoffAndNext(retriesNow, canSuspend)
backoffAndNext(
retriesNow,
canSuspend = canSuspend,
suspendUntilChanged = suspendUntilChanged,
desc = d,
)
}
}

private[this] final def backoffAndNext(retries: Int, canSuspend: Boolean): Rxn[Any, Any] = {
val token = Backoff2.backoffStrTok(
retries = retries,
strategy = this.strategy,
canSuspend = canSuspend,
)
if (Backoff2.spinIfPauseToken(token)) {
// ok, spinning done, restart:
this.startRxn
} else {
assert(canSuspend)
new SuspendUntil(token)
private[this] final def backoffAndNext(
retries: Int,
canSuspend: Boolean,
suspendUntilChanged: Boolean,
desc: AbstractDescriptor,
): Rxn[Any, Any] = {
if (!suspendUntilChanged) { // spin/cede/sleep
val token = Backoff2.backoffStrTok(
retries = retries,
strategy = this.strategy,
canSuspend = canSuspend,
)
if (Backoff2.spinIfPauseToken(token)) {
// ok, spinning done, restart:
this.startRxn
} else {
assert(canSuspend)
new SuspendUntilBackoff(token)
}
} else { // STM
assert(canSuspend && this.isStm)
new SuspendUntilChanged(desc)
}
}

Expand Down Expand Up @@ -1245,8 +1401,9 @@ object Rxn extends RxnInstances0 {
val nxt = c.f(a.asInstanceOf[A])
a = () : Any
loop(nxt)
case 5 => // (was DelayComputed)
impossible(s"Unknown tag 5 for ${curr}")
case 5 => // RetryWhenChanged (STM)
assert(this.canSuspend && this.isStm)
loop(retry(canSuspend = true, suspendUntilChanged = true))
case 6 => // Choice
val c = curr.asInstanceOf[Choice[A, B]]
saveAlt(c.right.asInstanceOf[Rxn[Any, R]])
Expand Down Expand Up @@ -1451,11 +1608,10 @@ object Rxn extends RxnInstances0 {
contK.push(a)
contK.push(c.f)
loop(c.rxn)
case 27 => // Suspend
case 27 => // SuspendUntil
assert(this.canSuspend)
assert(!Backoff2.isPauseToken(curr.asInstanceOf[SuspendUntil].token))
// user code can't access a `Suspend`, so
// we can abuse `R` and return `Suspend`:
// user code can't access a `SuspendUntil`, so
// we can abuse `R` and return `SuspendUntil`:
curr.asInstanceOf[R]
case 28 => // TailRecM
val c = curr.asInstanceOf[TailRecM[Any, Any, Any]]
Expand All @@ -1474,12 +1630,18 @@ object Rxn extends RxnInstances0 {
if (this.canSuspend) {
// cede or sleep strategy:
def step(poll: F ~> F): F[R] = F.defer {
this.ctx = mcas.currentContext()
// TODO: We could try passing ctx between
// TODO: steps, as it's likely that we remain
// TODO: on the same thread. But we have to
// TODO: always check.
val ctx = mcas.currentContext()
this.ctx = ctx
try {
loop(startRxn) match {
case s: SuspendUntil =>
assert(this._entryHolder eq null)
F.flatMap(poll(Backoff2.tokenToF[F](s.token))) { _ => step(poll) }
val sus: F[Unit] = s.toF[F](mcas, ctx)
F.flatMap(poll(sus)) { _ => step(poll) }
case r =>
assert(this._entryHolder eq null)
F.pure(r)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package core

import cats.effect.kernel.Async

private[choam] sealed trait Transactive[F[_]] extends Reactive[F] {
private[choam] sealed trait Transactive[F[_]]
extends Reactive[F] { // TODO: we probably shouldn't extend Reactive

def commit[B](txn: Txn[F, B]): F[B]
}

Expand All @@ -29,7 +31,7 @@ private[choam] object Transactive {
final def forAsync[F[_]](implicit F: Async[F]): Transactive[F] = {
new Reactive.SyncReactive[F](Rxn.DefaultMcas) with Transactive[F] {
final override def commit[B](txn: Txn[F, B]): F[B] = {
txn.impl.perform[F, B](null, this.mcasImpl, RetryStrategy.sleep())
txn.impl.performStm[F, B](null, this.mcasImpl)
}
}
}
Expand Down
Loading

0 comments on commit a2008b7

Please sign in to comment.