From ffcec8f52858151fa70e2c0582c487034bd854ce Mon Sep 17 00:00:00 2001 From: Hiroshi Horie <548776+hiroshihorie@users.noreply.github.com> Date: Tue, 9 Jan 2024 22:21:08 +0900 Subject: [PATCH] Support concurrent waiting for `AsyncCompleter` (#298) --- .../LiveKit/Core/DataChannelPairActor.swift | 2 +- Sources/LiveKit/Core/Engine.swift | 4 +- Sources/LiveKit/Core/SignalClient.swift | 4 +- Sources/LiveKit/Support/AsyncCompleter.swift | 144 ++++++++++-------- .../Track/Capturers/VideoCapturer.swift | 2 +- Tests/LiveKitTests/CompleterTests.swift | 55 ++++++- 6 files changed, 136 insertions(+), 75 deletions(-) diff --git a/Sources/LiveKit/Core/DataChannelPairActor.swift b/Sources/LiveKit/Core/DataChannelPairActor.swift index 4ef174fa0..c41cbf953 100644 --- a/Sources/LiveKit/Core/DataChannelPairActor.swift +++ b/Sources/LiveKit/Core/DataChannelPairActor.swift @@ -25,7 +25,7 @@ actor DataChannelPairActor: NSObject, Loggable { // MARK: - Public - public let openCompleter = AsyncCompleter(label: "Data channel open", timeOut: .defaultPublisherDataChannelOpen) + public let openCompleter = AsyncCompleter(label: "Data channel open", defaultTimeOut: .defaultPublisherDataChannelOpen) public var isOpen: Bool { guard let reliable = _reliableChannel, let lossy = _lossyChannel else { return false } diff --git a/Sources/LiveKit/Core/Engine.swift b/Sources/LiveKit/Core/Engine.swift index 38aa50a20..8e9754c30 100644 --- a/Sources/LiveKit/Core/Engine.swift +++ b/Sources/LiveKit/Core/Engine.swift @@ -40,8 +40,8 @@ class Engine: MulticastDelegate { var hasPublished: Bool = false } - let primaryTransportConnectedCompleter = AsyncCompleter(label: "Primary transport connect", timeOut: .defaultTransportState) - let publisherTransportConnectedCompleter = AsyncCompleter(label: "Publisher transport connect", timeOut: .defaultTransportState) + let primaryTransportConnectedCompleter = AsyncCompleter(label: "Primary transport connect", defaultTimeOut: .defaultTransportState) + let publisherTransportConnectedCompleter = AsyncCompleter(label: "Publisher transport connect", defaultTimeOut: .defaultTransportState) public var _state: StateSync diff --git a/Sources/LiveKit/Core/SignalClient.swift b/Sources/LiveKit/Core/SignalClient.swift index 08c0dc299..4c476fd7c 100644 --- a/Sources/LiveKit/Core/SignalClient.swift +++ b/Sources/LiveKit/Core/SignalClient.swift @@ -85,8 +85,8 @@ class SignalClient: MulticastDelegate { private var _messageLoopTask: Task? private var latestJoinResponse: Livekit_JoinResponse? - private let _connectResponseCompleter = AsyncCompleter(label: "Join response", timeOut: .defaultJoinResponse) - private let _addTrackCompleters = CompleterMapActor(label: "Completers for add track", timeOut: .defaultPublish) + private let _connectResponseCompleter = AsyncCompleter(label: "Join response", defaultTimeOut: .defaultJoinResponse) + private let _addTrackCompleters = CompleterMapActor(label: "Completers for add track", defaultTimeOut: .defaultPublish) private var _pingIntervalTimer: DispatchQueueTimer? private var _pingTimeoutTimer: DispatchQueueTimer? diff --git a/Sources/LiveKit/Support/AsyncCompleter.swift b/Sources/LiveKit/Support/AsyncCompleter.swift index 65e1155a6..51bf16d2b 100644 --- a/Sources/LiveKit/Support/AsyncCompleter.swift +++ b/Sources/LiveKit/Support/AsyncCompleter.swift @@ -24,12 +24,12 @@ actor CompleterMapActor { // MARK: - Private - private let _timeOut: DispatchTimeInterval + private let _defaultTimeOut: DispatchTimeInterval private var _completerMap = [String: AsyncCompleter]() - public init(label: String, timeOut: DispatchTimeInterval) { + public init(label: String, defaultTimeOut: DispatchTimeInterval) { self.label = label - _timeOut = timeOut + _defaultTimeOut = defaultTimeOut } public func completer(for key: String) -> AsyncCompleter { @@ -38,7 +38,7 @@ actor CompleterMapActor { return element } - let newCompleter = AsyncCompleter(label: label, timeOut: _timeOut) + let newCompleter = AsyncCompleter(label: label, defaultTimeOut: _defaultTimeOut) _completerMap[key] = newCompleter return newCompleter } @@ -60,110 +60,128 @@ actor CompleterMapActor { } class AsyncCompleter: Loggable { + // + struct WaitEntry { + let continuation: UnsafeContinuation + let timeOutBlock: DispatchWorkItem + + func cancel() { + continuation.resume(throwing: LiveKitError(.cancelled)) + timeOutBlock.cancel() + } + + func timeOut() { + continuation.resume(throwing: LiveKitError(.timedOut)) + timeOutBlock.cancel() + } + + func resume(with result: Result) { + continuation.resume(with: result) + timeOutBlock.cancel() + } + } + public let label: String - private let _timeOut: DispatchTimeInterval - private let _queue = DispatchQueue(label: "LiveKitSDK.AsyncCompleter", qos: .background) - // Internal states - private var _continuation: CheckedContinuation? - private var _timeOutBlock: DispatchWorkItem? + private let _defaultTimeOut: DispatchTimeInterval + private let _timerQueue = DispatchQueue(label: "LiveKitSDK.AsyncCompleter", qos: .background) - private var _returningValue: T? - private var _throwingError: Error? + // Internal states + private var _entries: [UUID: WaitEntry] = [:] + private var _result: Result? private let _lock = UnfairLock() - public init(label: String, timeOut: DispatchTimeInterval) { + public init(label: String, defaultTimeOut: DispatchTimeInterval) { self.label = label - _timeOut = timeOut + _defaultTimeOut = defaultTimeOut } deinit { reset() } - private func _cancelTimer() { - // Make sure time-out blocked doesn't fire - _timeOutBlock?.cancel() - _timeOutBlock = nil + public func reset() { + _lock.sync { + for entry in _entries.values { + entry.cancel() + } + _entries.removeAll() + _result = nil + } } - public func reset() { + public func resume(with result: Result) { _lock.sync { - _cancelTimer() - if let continuation = _continuation { - log("\(label) Cancelled") - continuation.resume(throwing: LiveKitError(.cancelled)) + for entry in _entries.values { + entry.resume(with: result) } - _continuation = nil - _returningValue = nil - _throwingError = nil + _entries.removeAll() + _result = result } } public func resume(returning value: T) { log("\(label)") - _lock.sync { - _cancelTimer() - _returningValue = value - _continuation?.resume(returning: value) - _continuation = nil - } + resume(with: .success(value)) } public func resume(throwing error: Error) { log("\(label)") - _lock.sync { - _cancelTimer() - _throwingError = error - _continuation?.resume(throwing: error) - _continuation = nil - } + resume(with: .failure(error)) } - public func wait() async throws -> T { - // resume(returning:) already called - if let returningValue = _lock.sync({ _returningValue }) { - log("\(label) returning value...") - return returningValue - } - - // resume(throwing:) already called - if let throwingError = _lock.sync({ _throwingError }) { - log("\(label) throwing error...") - throw throwingError + public func wait(timeOut: DispatchTimeInterval? = nil) async throws -> T { + // Read value + if let result = _lock.sync({ _result }) { + // Already resolved... + if case let .success(value) = result { + // resume(returning:) already called + log("\(label) returning value...") + return value + } else if case let .failure(error) = result { + // resume(throwing:) already called + log("\(label) throwing error...") + throw error + } } - log("\(label) waiting...") + // Create ids for continuation & timeOutBlock + let entryId = UUID() - // Cancel any previous waits - reset() + log("\(label) waiting with id: \(entryId)") // Create a cancel-aware timed continuation return try await withTaskCancellationHandler { - try await withCheckedThrowingContinuation { continuation in + try await withUnsafeThrowingContinuation { continuation in + // Create time-out block let timeOutBlock = DispatchWorkItem { [weak self] in guard let self else { return } - self.log("\(self.label) timedOut") + self.log("Wait \(entryId) timedOut") self._lock.sync { - self._continuation?.resume(throwing: LiveKitError(.timedOut, message: "\(self.label) AsyncCompleter timed out")) - self._continuation = nil + if let entry = self._entries[entryId] { + entry.timeOut() + } + self._entries.removeValue(forKey: entryId) } - self.reset() } + _lock.sync { // Schedule time-out block - _queue.asyncAfter(deadline: .now() + _timeOut, execute: timeOutBlock) - // Store reference to continuation - _continuation = continuation - // Store reference to time-out block - _timeOutBlock = timeOutBlock + _timerQueue.asyncAfter(deadline: .now() + (timeOut ?? _defaultTimeOut), execute: timeOutBlock) + // Store entry + _entries[entryId] = WaitEntry(continuation: continuation, timeOutBlock: timeOutBlock) } } } onCancel: { - // Cancel completer when Task gets cancelled - reset() + // Cancel only this completer when Task gets cancelled + _lock.sync { + if let entry = self._entries[entryId] { + entry.cancel() + } + self._entries.removeValue(forKey: entryId) + } } } } diff --git a/Sources/LiveKit/Track/Capturers/VideoCapturer.swift b/Sources/LiveKit/Track/Capturers/VideoCapturer.swift index af67b6f85..a9d30379c 100644 --- a/Sources/LiveKit/Track/Capturers/VideoCapturer.swift +++ b/Sources/LiveKit/Track/Capturers/VideoCapturer.swift @@ -65,7 +65,7 @@ public class VideoCapturer: NSObject, Loggable, VideoCapturerProtocol { weak var delegate: LKRTCVideoCapturerDelegate? - let dimensionsCompleter = AsyncCompleter(label: "Dimensions", timeOut: .defaultCaptureStart) + let dimensionsCompleter = AsyncCompleter(label: "Dimensions", defaultTimeOut: .defaultCaptureStart) struct State: Equatable { // Counts calls to start/stopCapturer so multiple Tracks can use the same VideoCapturer. diff --git a/Tests/LiveKitTests/CompleterTests.swift b/Tests/LiveKitTests/CompleterTests.swift index 07e4c1d25..051a0319e 100644 --- a/Tests/LiveKitTests/CompleterTests.swift +++ b/Tests/LiveKitTests/CompleterTests.swift @@ -23,7 +23,7 @@ class CompleterTests: XCTestCase { override func tearDown() async throws {} func testCompleterReuse() async throws { - let completer = AsyncCompleter(label: "Test01", timeOut: .seconds(1)) + let completer = AsyncCompleter(label: "Test01", defaultTimeOut: .seconds(1)) do { try await completer.wait() } catch let error as LiveKitError where error.type == .timedOut { @@ -38,7 +38,7 @@ class CompleterTests: XCTestCase { } func testCompleterCancel() async throws { - let completer = AsyncCompleter(label: "cancel-test", timeOut: .never) + let completer = AsyncCompleter(label: "cancel-test", defaultTimeOut: .never) do { // Run Tasks in parallel try await withThrowingTaskGroup(of: Void.self) { group in @@ -49,10 +49,10 @@ class CompleterTests: XCTestCase { } group.addTask { - print("Task 2: Started...") - // Cancel after 1 second - try await Task.sleep(until: .now + .seconds(1), clock: .continuous) - print("Task 2: Cancelling completer...") + print("Timer task: Started...") + // Cancel after 3 seconds + try await Task.sleep(until: .now + .seconds(3), clock: .continuous) + print("Timer task: Cancelling...") completer.reset() } @@ -66,4 +66,47 @@ class CompleterTests: XCTestCase { print("Unknown error: \(error)") } } + + func testCompleterConcurrentWait() async throws { + let completer = AsyncCompleter(label: "cancel-test", defaultTimeOut: .never) + do { + // Run Tasks in parallel + try await withThrowingTaskGroup(of: Void.self) { group in + + group.addTask { + print("Task 1: Waiting...") + try await completer.wait() + print("Task 1: Completed") + } + + group.addTask { + print("Task 2: Waiting...") + try await completer.wait() + print("Task 2: Completed") + } + + group.addTask { + print("Task 3: Waiting...") + try await completer.wait() + print("Task 3: Completed") + } + + group.addTask { + print("Timer task: Started...") + // Cancel after 3 seconds + try await Task.sleep(until: .now + .seconds(3), clock: .continuous) + print("Timer task: Completing...") + completer.resume(returning: ()) + } + + try await group.waitForAll() + } + } catch let error as LiveKitError where error.type == .timedOut { + print("Completer timed out") + } catch let error as LiveKitError where error.type == .cancelled { + print("Completer cancelled") + } catch { + print("Unknown error: \(error)") + } + } }