Skip to content

Commit

Permalink
Support concurrent waiting for AsyncCompleter (#298)
Browse files Browse the repository at this point in the history
  • Loading branch information
hiroshihorie authored Jan 9, 2024
1 parent da703ae commit ffcec8f
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 75 deletions.
2 changes: 1 addition & 1 deletion Sources/LiveKit/Core/DataChannelPairActor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ actor DataChannelPairActor: NSObject, Loggable {

// MARK: - Public

public let openCompleter = AsyncCompleter<Void>(label: "Data channel open", timeOut: .defaultPublisherDataChannelOpen)
public let openCompleter = AsyncCompleter<Void>(label: "Data channel open", defaultTimeOut: .defaultPublisherDataChannelOpen)

public var isOpen: Bool {
guard let reliable = _reliableChannel, let lossy = _lossyChannel else { return false }
Expand Down
4 changes: 2 additions & 2 deletions Sources/LiveKit/Core/Engine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class Engine: MulticastDelegate<EngineDelegate> {
var hasPublished: Bool = false
}

let primaryTransportConnectedCompleter = AsyncCompleter<Void>(label: "Primary transport connect", timeOut: .defaultTransportState)
let publisherTransportConnectedCompleter = AsyncCompleter<Void>(label: "Publisher transport connect", timeOut: .defaultTransportState)
let primaryTransportConnectedCompleter = AsyncCompleter<Void>(label: "Primary transport connect", defaultTimeOut: .defaultTransportState)
let publisherTransportConnectedCompleter = AsyncCompleter<Void>(label: "Publisher transport connect", defaultTimeOut: .defaultTransportState)

public var _state: StateSync<State>

Expand Down
4 changes: 2 additions & 2 deletions Sources/LiveKit/Core/SignalClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ class SignalClient: MulticastDelegate<SignalClientDelegate> {
private var _messageLoopTask: Task<Void, Never>?
private var latestJoinResponse: Livekit_JoinResponse?

private let _connectResponseCompleter = AsyncCompleter<ConnectResponse>(label: "Join response", timeOut: .defaultJoinResponse)
private let _addTrackCompleters = CompleterMapActor<Livekit_TrackInfo>(label: "Completers for add track", timeOut: .defaultPublish)
private let _connectResponseCompleter = AsyncCompleter<ConnectResponse>(label: "Join response", defaultTimeOut: .defaultJoinResponse)
private let _addTrackCompleters = CompleterMapActor<Livekit_TrackInfo>(label: "Completers for add track", defaultTimeOut: .defaultPublish)

private var _pingIntervalTimer: DispatchQueueTimer?
private var _pingTimeoutTimer: DispatchQueueTimer?
Expand Down
144 changes: 81 additions & 63 deletions Sources/LiveKit/Support/AsyncCompleter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ actor CompleterMapActor<T> {

// MARK: - Private

private let _timeOut: DispatchTimeInterval
private let _defaultTimeOut: DispatchTimeInterval
private var _completerMap = [String: AsyncCompleter<T>]()

public init(label: String, timeOut: DispatchTimeInterval) {
public init(label: String, defaultTimeOut: DispatchTimeInterval) {
self.label = label
_timeOut = timeOut
_defaultTimeOut = defaultTimeOut
}

public func completer(for key: String) -> AsyncCompleter<T> {
Expand All @@ -38,7 +38,7 @@ actor CompleterMapActor<T> {
return element
}

let newCompleter = AsyncCompleter<T>(label: label, timeOut: _timeOut)
let newCompleter = AsyncCompleter<T>(label: label, defaultTimeOut: _defaultTimeOut)
_completerMap[key] = newCompleter
return newCompleter
}
Expand All @@ -60,110 +60,128 @@ actor CompleterMapActor<T> {
}

class AsyncCompleter<T>: Loggable {
//
struct WaitEntry {
let continuation: UnsafeContinuation<T, any Error>
let timeOutBlock: DispatchWorkItem

func cancel() {
continuation.resume(throwing: LiveKitError(.cancelled))
timeOutBlock.cancel()
}

func timeOut() {
continuation.resume(throwing: LiveKitError(.timedOut))
timeOutBlock.cancel()
}

func resume(with result: Result<T, Error>) {
continuation.resume(with: result)
timeOutBlock.cancel()
}
}

public let label: String

private let _timeOut: DispatchTimeInterval
private let _queue = DispatchQueue(label: "LiveKitSDK.AsyncCompleter", qos: .background)
// Internal states
private var _continuation: CheckedContinuation<T, any Error>?
private var _timeOutBlock: DispatchWorkItem?
private let _defaultTimeOut: DispatchTimeInterval
private let _timerQueue = DispatchQueue(label: "LiveKitSDK.AsyncCompleter", qos: .background)

private var _returningValue: T?
private var _throwingError: Error?
// Internal states
private var _entries: [UUID: WaitEntry] = [:]
private var _result: Result<T, Error>?

private let _lock = UnfairLock()

public init(label: String, timeOut: DispatchTimeInterval) {
public init(label: String, defaultTimeOut: DispatchTimeInterval) {
self.label = label
_timeOut = timeOut
_defaultTimeOut = defaultTimeOut
}

deinit {
reset()
}

private func _cancelTimer() {
// Make sure time-out blocked doesn't fire
_timeOutBlock?.cancel()
_timeOutBlock = nil
public func reset() {
_lock.sync {
for entry in _entries.values {
entry.cancel()
}
_entries.removeAll()
_result = nil
}
}

public func reset() {
public func resume(with result: Result<T, Error>) {
_lock.sync {
_cancelTimer()
if let continuation = _continuation {
log("\(label) Cancelled")
continuation.resume(throwing: LiveKitError(.cancelled))
for entry in _entries.values {
entry.resume(with: result)
}
_continuation = nil
_returningValue = nil
_throwingError = nil
_entries.removeAll()
_result = result
}
}

public func resume(returning value: T) {
log("\(label)")
_lock.sync {
_cancelTimer()
_returningValue = value
_continuation?.resume(returning: value)
_continuation = nil
}
resume(with: .success(value))
}

public func resume(throwing error: Error) {
log("\(label)")
_lock.sync {
_cancelTimer()
_throwingError = error
_continuation?.resume(throwing: error)
_continuation = nil
}
resume(with: .failure(error))
}

public func wait() async throws -> T {
// resume(returning:) already called
if let returningValue = _lock.sync({ _returningValue }) {
log("\(label) returning value...")
return returningValue
}

// resume(throwing:) already called
if let throwingError = _lock.sync({ _throwingError }) {
log("\(label) throwing error...")
throw throwingError
public func wait(timeOut: DispatchTimeInterval? = nil) async throws -> T {
// Read value
if let result = _lock.sync({ _result }) {
// Already resolved...
if case let .success(value) = result {
// resume(returning:) already called
log("\(label) returning value...")
return value
} else if case let .failure(error) = result {
// resume(throwing:) already called
log("\(label) throwing error...")
throw error
}
}

log("\(label) waiting...")
// Create ids for continuation & timeOutBlock
let entryId = UUID()

// Cancel any previous waits
reset()
log("\(label) waiting with id: \(entryId)")

// Create a cancel-aware timed continuation
return try await withTaskCancellationHandler {
try await withCheckedThrowingContinuation { continuation in
try await withUnsafeThrowingContinuation { continuation in

// Create time-out block
let timeOutBlock = DispatchWorkItem { [weak self] in
guard let self else { return }
self.log("\(self.label) timedOut")
self.log("Wait \(entryId) timedOut")
self._lock.sync {
self._continuation?.resume(throwing: LiveKitError(.timedOut, message: "\(self.label) AsyncCompleter timed out"))
self._continuation = nil
if let entry = self._entries[entryId] {
entry.timeOut()
}
self._entries.removeValue(forKey: entryId)
}
self.reset()
}

_lock.sync {
// Schedule time-out block
_queue.asyncAfter(deadline: .now() + _timeOut, execute: timeOutBlock)
// Store reference to continuation
_continuation = continuation
// Store reference to time-out block
_timeOutBlock = timeOutBlock
_timerQueue.asyncAfter(deadline: .now() + (timeOut ?? _defaultTimeOut), execute: timeOutBlock)
// Store entry
_entries[entryId] = WaitEntry(continuation: continuation, timeOutBlock: timeOutBlock)
}
}
} onCancel: {
// Cancel completer when Task gets cancelled
reset()
// Cancel only this completer when Task gets cancelled
_lock.sync {
if let entry = self._entries[entryId] {
entry.cancel()
}
self._entries.removeValue(forKey: entryId)
}
}
}
}
2 changes: 1 addition & 1 deletion Sources/LiveKit/Track/Capturers/VideoCapturer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public class VideoCapturer: NSObject, Loggable, VideoCapturerProtocol {

weak var delegate: LKRTCVideoCapturerDelegate?

let dimensionsCompleter = AsyncCompleter<Dimensions>(label: "Dimensions", timeOut: .defaultCaptureStart)
let dimensionsCompleter = AsyncCompleter<Dimensions>(label: "Dimensions", defaultTimeOut: .defaultCaptureStart)

struct State: Equatable {
// Counts calls to start/stopCapturer so multiple Tracks can use the same VideoCapturer.
Expand Down
55 changes: 49 additions & 6 deletions Tests/LiveKitTests/CompleterTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class CompleterTests: XCTestCase {
override func tearDown() async throws {}

func testCompleterReuse() async throws {
let completer = AsyncCompleter<Void>(label: "Test01", timeOut: .seconds(1))
let completer = AsyncCompleter<Void>(label: "Test01", defaultTimeOut: .seconds(1))
do {
try await completer.wait()
} catch let error as LiveKitError where error.type == .timedOut {
Expand All @@ -38,7 +38,7 @@ class CompleterTests: XCTestCase {
}

func testCompleterCancel() async throws {
let completer = AsyncCompleter<Void>(label: "cancel-test", timeOut: .never)
let completer = AsyncCompleter<Void>(label: "cancel-test", defaultTimeOut: .never)
do {
// Run Tasks in parallel
try await withThrowingTaskGroup(of: Void.self) { group in
Expand All @@ -49,10 +49,10 @@ class CompleterTests: XCTestCase {
}

group.addTask {
print("Task 2: Started...")
// Cancel after 1 second
try await Task.sleep(until: .now + .seconds(1), clock: .continuous)
print("Task 2: Cancelling completer...")
print("Timer task: Started...")
// Cancel after 3 seconds
try await Task.sleep(until: .now + .seconds(3), clock: .continuous)
print("Timer task: Cancelling...")
completer.reset()
}

Expand All @@ -66,4 +66,47 @@ class CompleterTests: XCTestCase {
print("Unknown error: \(error)")
}
}

func testCompleterConcurrentWait() async throws {
let completer = AsyncCompleter<Void>(label: "cancel-test", defaultTimeOut: .never)
do {
// Run Tasks in parallel
try await withThrowingTaskGroup(of: Void.self) { group in

group.addTask {
print("Task 1: Waiting...")
try await completer.wait()
print("Task 1: Completed")
}

group.addTask {
print("Task 2: Waiting...")
try await completer.wait()
print("Task 2: Completed")
}

group.addTask {
print("Task 3: Waiting...")
try await completer.wait()
print("Task 3: Completed")
}

group.addTask {
print("Timer task: Started...")
// Cancel after 3 seconds
try await Task.sleep(until: .now + .seconds(3), clock: .continuous)
print("Timer task: Completing...")
completer.resume(returning: ())
}

try await group.waitForAll()
}
} catch let error as LiveKitError where error.type == .timedOut {
print("Completer timed out")
} catch let error as LiveKitError where error.type == .cancelled {
print("Completer cancelled")
} catch {
print("Unknown error: \(error)")
}
}
}

0 comments on commit ffcec8f

Please sign in to comment.