Skip to content

Commit e524ed3

Browse files
authored
Provider state machine (#25)
* [WIP] provider state machine * undoing changes * moving tests, removing unnecessary changes * moving tests back * more unnecessary changes * adding log * this is so dumb * Okay this might be the way to go * let's see this * this is less dumb * removing thread local * adding store.commit() * timer suite * taskCompletionListener -> taskFailureListener * adding assertion that version == readStore.version * fixing more tests * returning same store
1 parent 5fab410 commit e524ed3

13 files changed

+509
-91
lines changed

Diff for: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

+1
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ abstract class StatePartitionReaderBase(
126126
stateStoreColFamilySchema.keyStateEncoderSpec.get,
127127
useMultipleValuesPerKey = useMultipleValuesPerKey,
128128
isInternal = isInternal)
129+
store.commit()
129130
}
130131
provider
131132
}

Diff for: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala

+97-18
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ private[sql] class RocksDBStateStoreProvider
4343
with SupportsFineGrainedReplay {
4444
import RocksDBStateStoreProvider._
4545

46-
class RocksDBStateStore(lastVersion: Long) extends StateStore {
46+
class RocksDBStateStore(lastVersion: Long, val stamp: Long) extends StateStore {
4747
/** Trait and classes representing the internal state of the store */
4848
trait STATE
4949
case object UPDATING extends STATE
@@ -58,6 +58,10 @@ private[sql] class RocksDBStateStoreProvider
5858

5959
@volatile private var state: STATE = UPDATING
6060

61+
override def getReadStamp: Long = {
62+
stamp
63+
}
64+
6165
/**
6266
* Validates the expected state, throws exception if state is not as expected.
6367
* Returns the current state
@@ -81,6 +85,7 @@ private[sql] class RocksDBStateStoreProvider
8185
private def validateAndTransitionState(transition: TRANSITION): Unit = {
8286
val newState = transition match {
8387
case UPDATE =>
88+
stateMachine.verifyStamp(stamp)
8489
state match {
8590
case UPDATING => UPDATING
8691
case COMMITTED => throw StateStoreErrors.stateStoreOperationOutOfOrder(
@@ -90,14 +95,18 @@ private[sql] class RocksDBStateStoreProvider
9095
}
9196
case ABORT =>
9297
state match {
93-
case UPDATING => ABORTED
98+
case UPDATING =>
99+
stateMachine.verifyStamp(stamp)
100+
ABORTED
94101
case COMMITTED => throw StateStoreErrors.stateStoreOperationOutOfOrder(
95102
"Cannot abort after committed")
96103
case ABORTED => ABORTED
97104
}
98105
case COMMIT =>
99106
state match {
100-
case UPDATING => COMMITTED
107+
case UPDATING =>
108+
stateMachine.verifyStamp(stamp)
109+
COMMITTED
101110
case COMMITTED => throw StateStoreErrors.stateStoreOperationOutOfOrder(
102111
"Cannot commit after committed")
103112
case ABORTED => throw StateStoreErrors.stateStoreOperationOutOfOrder(
@@ -118,10 +127,14 @@ private[sql] class RocksDBStateStoreProvider
118127
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit] {
119128
_ =>
120129
try {
121-
abort()
130+
if (state == UPDATING) {
131+
abort()
132+
}
122133
} catch {
123134
case NonFatal(e) =>
124135
logWarning("Failed to abort state store", e)
136+
} finally {
137+
stateMachine.releaseStore(stamp, throwEx = false)
125138
}
126139
})
127140

@@ -318,15 +331,18 @@ private[sql] class RocksDBStateStoreProvider
318331
}
319332

320333
var checkpointInfo: Option[StateStoreCheckpointInfo] = None
334+
private var storedMetrics: Option[RocksDBMetrics] = None
335+
321336
override def commit(): Long = synchronized {
322337
validateState(List(UPDATING))
323-
324338
try {
325339
verify(state == UPDATING, "Cannot commit after already committed or aborted")
326340
val (newVersion, newCheckpointInfo) = rocksDB.commit()
327341
checkpointInfo = Some(newCheckpointInfo)
342+
storedMetrics = rocksDB.metricsOpt
328343
validateAndTransitionState(COMMIT)
329-
state = COMMITTED
344+
stateMachine.releaseStore(stamp)
345+
330346
logInfo(log"Committed ${MDC(VERSION_NUM, newVersion)} " +
331347
log"for ${MDC(STATE_STORE_ID, id)}")
332348
newVersion
@@ -342,6 +358,7 @@ private[sql] class RocksDBStateStoreProvider
342358
log"for ${MDC(STATE_STORE_ID, id)}")
343359
rocksDB.rollback()
344360
validateAndTransitionState(ABORT)
361+
stateMachine.releaseStore(stamp)
345362
}
346363
}
347364

@@ -541,15 +558,26 @@ private[sql] class RocksDBStateStoreProvider
541558

542559
override def stateStoreId: StateStoreId = stateStoreId_
543560

561+
private lazy val stateMachine: RocksDBStateStoreProviderStateMachine =
562+
new RocksDBStateStoreProviderStateMachine(stateStoreId, RocksDBConf(storeConf))
563+
544564
override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = {
545565
try {
546566
if (version < 0) {
547567
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
548568
}
549-
rocksDB.load(
550-
version,
551-
stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None)
552-
new RocksDBStateStore(version)
569+
val stamp = stateMachine.acquireStore()
570+
try {
571+
rocksDB.load(
572+
version,
573+
stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None,
574+
readOnly = false)
575+
new RocksDBStateStore(version, stamp)
576+
} catch {
577+
case e: Throwable =>
578+
stateMachine.releaseStore(stamp)
579+
throw e
580+
}
553581
}
554582
catch {
555583
case e: SparkException
@@ -564,16 +592,58 @@ private[sql] class RocksDBStateStoreProvider
564592
}
565593
}
566594

567-
override def getReadStore(version: Long, uniqueId: Option[String] = None): StateStore = {
595+
override def getWriteStore(
596+
readStore: ReadStateStore,
597+
version: Long,
598+
uniqueId: Option[String] = None): StateStore = {
568599
try {
569600
if (version < 0) {
570601
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
571602
}
572-
rocksDB.load(
573-
version,
574-
stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None,
575-
readOnly = true)
576-
new RocksDBStateStore(version)
603+
assert(version == readStore.version)
604+
try {
605+
rocksDB.load(
606+
version,
607+
stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None,
608+
readOnly = false)
609+
readStore match {
610+
case stateStore: RocksDBStateStore =>
611+
stateStore
612+
case _ =>
613+
throw new IllegalArgumentException
614+
}
615+
} catch {
616+
case e: Throwable =>
617+
stateMachine.releaseStore(readStore.getReadStamp)
618+
throw e
619+
}
620+
} catch {
621+
case e: SparkException
622+
if Option(e.getCondition).exists(_.contains("CANNOT_LOAD_STATE_STORE")) =>
623+
throw e
624+
case e: OutOfMemoryError =>
625+
throw QueryExecutionErrors.notEnoughMemoryToLoadStore(
626+
stateStoreId.toString,
627+
"ROCKSDB_STORE_PROVIDER",
628+
e)
629+
case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
630+
}
631+
}
632+
633+
override def getReadStore(version: Long, uniqueId: Option[String] = None): StateStore = {
634+
try {
635+
val stamp = stateMachine.acquireStore()
636+
try {
637+
rocksDB.load(
638+
version,
639+
stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None,
640+
readOnly = true)
641+
new RocksDBStateStore(version, stamp)
642+
} catch {
643+
case e: Throwable =>
644+
stateMachine.releaseStore(stamp)
645+
throw e
646+
}
577647
}
578648
catch {
579649
case e: SparkException
@@ -590,6 +660,7 @@ private[sql] class RocksDBStateStoreProvider
590660

591661
override def doMaintenance(): Unit = {
592662
try {
663+
stateMachine.maintenanceStore()
593664
rocksDB.doMaintenance()
594665
} catch {
595666
// SPARK-46547 - Swallow non-fatal exception in maintenance task to avoid deadlock between
@@ -601,6 +672,7 @@ private[sql] class RocksDBStateStoreProvider
601672
}
602673

603674
override def close(): Unit = {
675+
stateMachine.closeStore()
604676
rocksDB.close()
605677
}
606678

@@ -657,8 +729,15 @@ private[sql] class RocksDBStateStoreProvider
657729
if (endVersion < snapshotVersion) {
658730
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
659731
}
660-
rocksDB.loadFromSnapshot(snapshotVersion, endVersion)
661-
new RocksDBStateStore(endVersion)
732+
val stamp = stateMachine.acquireStore()
733+
try {
734+
rocksDB.loadFromSnapshot(snapshotVersion, endVersion)
735+
new RocksDBStateStore(endVersion, stamp)
736+
} catch {
737+
case e: Throwable =>
738+
stateMachine.releaseStore(stamp)
739+
throw e
740+
}
662741
}
663742
catch {
664743
case e: OutOfMemoryError =>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.streaming.state
19+
20+
import java.util.concurrent.TimeUnit
21+
import java.util.concurrent.atomic.AtomicLong
22+
import javax.annotation.concurrent.GuardedBy
23+
24+
import org.apache.spark.internal.Logging
25+
import org.apache.spark.sql.errors.QueryExecutionErrors
26+
27+
class RocksDBStateStoreProviderStateMachine(
28+
stateStoreId: StateStoreId,
29+
rocksDBConf: RocksDBConf) extends Logging {
30+
31+
private sealed trait STATE
32+
private case object RELEASED extends STATE
33+
private case object ACQUIRED extends STATE
34+
private case object CLOSED extends STATE
35+
36+
private sealed abstract class TRANSITION(name: String) {
37+
override def toString: String = name
38+
}
39+
private case object LOAD extends TRANSITION("load")
40+
private case object RELEASE extends TRANSITION("release")
41+
private case object CLOSE extends TRANSITION("close")
42+
private case object MAINTENANCE extends TRANSITION("maintenance")
43+
44+
private val instanceLock = new Object()
45+
@GuardedBy("instanceLock")
46+
private var state: STATE = RELEASED
47+
@GuardedBy("instanceLock")
48+
private var acquiredThreadInfo: AcquiredThreadInfo = _
49+
50+
// Can be read without holding any locks, but should only be updated when
51+
// instanceLock is held.
52+
// -1 indicates that the store is not locked.
53+
private[sql] val currentValidStamp = new AtomicLong(-1L)
54+
@GuardedBy("instanceLock")
55+
private var lastValidStamp: Long = 0L
56+
57+
// Instance lock must be held.
58+
private def incAndGetStamp: Long = {
59+
lastValidStamp += 1
60+
currentValidStamp.set(lastValidStamp)
61+
lastValidStamp
62+
}
63+
64+
// Instance lock must be held.
65+
private def awaitNotLocked(transition: TRANSITION): Unit = {
66+
val waitStartTime = System.nanoTime()
67+
def timeWaitedMs = {
68+
val elapsedNanos = System.nanoTime() - waitStartTime
69+
// Convert from nanoseconds to milliseconds
70+
TimeUnit.MILLISECONDS.convert(elapsedNanos, TimeUnit.NANOSECONDS)
71+
}
72+
while (state == ACQUIRED && timeWaitedMs < rocksDBConf.lockAcquireTimeoutMs) {
73+
instanceLock.wait(10)
74+
}
75+
if (state == ACQUIRED) {
76+
val newAcquiredThreadInfo = AcquiredThreadInfo()
77+
val stackTraceOutput = acquiredThreadInfo.threadRef.get.get.getStackTrace.mkString("\n")
78+
val loggingId = s"StateStoreId(opId=${stateStoreId.operatorId}," +
79+
s"partId=${stateStoreId.partitionId},name=${stateStoreId.storeName})"
80+
throw QueryExecutionErrors.unreleasedThreadError(loggingId, transition.toString,
81+
newAcquiredThreadInfo.toString(), acquiredThreadInfo.toString(), timeWaitedMs,
82+
stackTraceOutput)
83+
}
84+
}
85+
86+
/**
87+
* Returns oldState, newState.
88+
* Throws error if transition is illegal.
89+
* MUST be called for every StateStoreProvider method.
90+
* Caller MUST hold instance lock.
91+
*/
92+
private def validateAndTransitionState(transition: TRANSITION): (STATE, STATE) = {
93+
val oldState = state
94+
val newState = transition match {
95+
case LOAD =>
96+
oldState match {
97+
case RELEASED => ACQUIRED
98+
case ACQUIRED => throw new IllegalStateException("Cannot lock when state is LOCKED")
99+
case CLOSED => throw new IllegalStateException("Cannot lock when state is CLOSED")
100+
}
101+
case RELEASE =>
102+
oldState match {
103+
case RELEASED => throw new IllegalStateException("Cannot unlock when state is UNLOCKED")
104+
case ACQUIRED => RELEASED
105+
case CLOSED => throw new IllegalStateException("Cannot unlock when state is CLOSED")
106+
}
107+
case CLOSE =>
108+
oldState match {
109+
case RELEASED => CLOSED
110+
case ACQUIRED => throw new IllegalStateException("Cannot closed when state is LOCKED")
111+
case CLOSED => CLOSED
112+
}
113+
case MAINTENANCE =>
114+
oldState match {
115+
case RELEASED => RELEASED
116+
case ACQUIRED => ACQUIRED
117+
case CLOSED => throw new IllegalStateException("Cannot do maintenance when state is" +
118+
"CLOSED")
119+
}
120+
}
121+
state = newState
122+
if (newState == ACQUIRED) {
123+
acquiredThreadInfo = AcquiredThreadInfo()
124+
}
125+
(oldState, newState)
126+
}
127+
128+
def verifyStamp(stamp: Long): Unit = {
129+
if (stamp != currentValidStamp.get()) {
130+
throw new IllegalStateException(s"Invalid stamp $stamp, " +
131+
s"currentStamp: ${currentValidStamp.get()}")
132+
}
133+
}
134+
135+
// Returns whether store successfully released
136+
def releaseStore(stamp: Long, throwEx: Boolean = true): Boolean = instanceLock.synchronized {
137+
if (!currentValidStamp.compareAndSet(stamp, -1L)) {
138+
if (throwEx) {
139+
throw new IllegalStateException("Invalid stamp for release")
140+
} else {
141+
return false
142+
}
143+
}
144+
validateAndTransitionState(RELEASE)
145+
true
146+
}
147+
148+
def acquireStore(): Long = instanceLock.synchronized {
149+
awaitNotLocked(LOAD)
150+
validateAndTransitionState(LOAD)
151+
incAndGetStamp
152+
}
153+
154+
def maintenanceStore(): Unit = instanceLock.synchronized {
155+
validateAndTransitionState(MAINTENANCE)
156+
}
157+
158+
def closeStore(): Unit = instanceLock.synchronized {
159+
awaitNotLocked(CLOSE)
160+
validateAndTransitionState(CLOSE)
161+
}
162+
}

0 commit comments

Comments
 (0)