From 29d61f1b84c556f11c80ee6590420a9bd9d029e1 Mon Sep 17 00:00:00 2001 From: Hiroshi Horie <548776+hiroshihorie@users.noreply.github.com> Date: Thu, 9 Nov 2023 20:57:20 +0800 Subject: [PATCH] Make `Room.connect` cancellable (#273) * engine connect * connect flow * cancellable completer * cancellable WebSocket * completer cancel test * comment * check cancel for queue actor --- Package.swift | 1 - .../Broadcast/BroadcastScreenCapturer.swift | 12 ++--- Sources/LiveKit/Core/Engine.swift | 15 +++++++ Sources/LiveKit/Core/SignalClient.swift | 14 +++--- Sources/LiveKit/Core/Transport.swift | 2 +- Sources/LiveKit/Errors.swift | 2 + Sources/LiveKit/Support/AsyncCompleter.swift | 42 ++++++++++------- Sources/LiveKit/Support/AsyncQueueActor.swift | 6 ++- Sources/LiveKit/Support/WebSocket.swift | 23 +++++----- Tests/LiveKitTests/CompleterTests.swift | 45 ++++++++++++++++++- Tests/LiveKitTests/TimerTests.swift | 9 ++-- Tests/LiveKitTests/WebSocketTests.swift | 11 +---- 12 files changed, 121 insertions(+), 61 deletions(-) diff --git a/Package.swift b/Package.swift index abb706dd3..e98293bdd 100644 --- a/Package.swift +++ b/Package.swift @@ -10,7 +10,6 @@ let package = Package( .macOS(.v10_15), ], products: [ - // Products define the executables and libraries a package produces, and make them visible to other packages. .library( name: "LiveKit", targets: ["LiveKit"] diff --git a/Sources/LiveKit/Broadcast/BroadcastScreenCapturer.swift b/Sources/LiveKit/Broadcast/BroadcastScreenCapturer.swift index f660071e2..09d75abe9 100644 --- a/Sources/LiveKit/Broadcast/BroadcastScreenCapturer.swift +++ b/Sources/LiveKit/Broadcast/BroadcastScreenCapturer.swift @@ -14,15 +14,15 @@ * limitations under the License. */ -import Foundation +#if os(iOS) -#if canImport(UIKit) - import UIKit -#endif + import Foundation -@_implementationOnly import WebRTC + #if canImport(UIKit) + import UIKit + #endif -#if os(iOS) + @_implementationOnly import WebRTC class BroadcastScreenCapturer: BufferCapturer { static let kRTCScreensharingSocketFD = "rtc_SSFD" diff --git a/Sources/LiveKit/Core/Engine.swift b/Sources/LiveKit/Core/Engine.swift index a59683b6b..3536f6668 100644 --- a/Sources/LiveKit/Core/Engine.swift +++ b/Sources/LiveKit/Core/Engine.swift @@ -145,6 +145,7 @@ class Engine: MulticastDelegate { } try await cleanUp() + try Task.checkCancellation() _state.mutate { $0.connectionState = .connecting } @@ -154,6 +155,9 @@ class Engine: MulticastDelegate { // Connect sequence successful log("Connect sequence completed") + // Final check if cancelled, don't fire connected events + try Task.checkCancellation() + // update internal vars (only if connect succeeded) _state.mutate { $0.url = url @@ -161,6 +165,9 @@ class Engine: MulticastDelegate { $0.connectionState = .connected } + } catch is CancellationError { + // Cancelled by .user + try await cleanUp(reason: .user) } catch { try await cleanUp(reason: .networkError(error)) } @@ -344,10 +351,18 @@ extension Engine { connectOptions: _state.connectOptions, reconnectMode: _state.reconnectMode, adaptiveStream: room._state.options.adaptiveStream) + // Check cancellation after WebSocket connected + try Task.checkCancellation() let jr = try await signalClient.joinResponseCompleter.wait() + // Check cancellation after received join response + try Task.checkCancellation() + _state.mutate { $0.connectStopwatch.split(label: "signal") } try await configureTransports(joinResponse: jr) + // Check cancellation after configuring transports + try Task.checkCancellation() + try await signalClient.resumeResponseQueue() try await primaryTransportConnectedCompleter.wait() _state.mutate { $0.connectStopwatch.split(label: "engine") } diff --git a/Sources/LiveKit/Core/SignalClient.swift b/Sources/LiveKit/Core/SignalClient.swift index e71749db2..a9cf1a7b7 100644 --- a/Sources/LiveKit/Core/SignalClient.swift +++ b/Sources/LiveKit/Core/SignalClient.swift @@ -103,10 +103,8 @@ class SignalClient: MulticastDelegate { $0.connectionState = .connecting } - let socket = WebSocket(url: url) - do { - try await socket.connect() + let socket = try await WebSocket(url: url) _webSocket = socket _state.mutate { $0.connectionState = .connected } @@ -156,10 +154,8 @@ class SignalClient: MulticastDelegate { pingIntervalTimer = nil pingTimeoutTimer = nil - if let socket = _webSocket { - socket.reset() - _webSocket = nil - } + _webSocket?.close() + _webSocket = nil latestJoinResponse = nil @@ -311,7 +307,7 @@ private extension SignalClient { extension SignalClient { func resumeResponseQueue() async throws { - await _responseQueue.resume { response in + try await _responseQueue.resume { response in await processSignalResponse(response) } } @@ -321,7 +317,7 @@ extension SignalClient { extension SignalClient { func sendQueuedRequests() async throws { - await _requestQueue.resume { element in + try await _requestQueue.resume { element in do { try await sendRequest(element, enqueueIfReconnecting: false) } catch { diff --git a/Sources/LiveKit/Core/Transport.swift b/Sources/LiveKit/Core/Transport.swift index 237d87370..9f442d4c6 100644 --- a/Sources/LiveKit/Core/Transport.swift +++ b/Sources/LiveKit/Core/Transport.swift @@ -112,7 +112,7 @@ class Transport: MulticastDelegate { func set(remoteDescription sd: LKRTCSessionDescription) async throws { try await _pc.setRemoteDescription(sd) - await _pendingCandidatesQueue.resume { candidate in + try await _pendingCandidatesQueue.resume { candidate in do { try await add(iceCandidate: candidate) } catch { diff --git a/Sources/LiveKit/Errors.swift b/Sources/LiveKit/Errors.swift index 20441b23f..2361911cf 100644 --- a/Sources/LiveKit/Errors.swift +++ b/Sources/LiveKit/Errors.swift @@ -96,6 +96,7 @@ public enum TrackError: LiveKitError { } public enum SignalClientError: LiveKitError { + case cancelled case state(message: String? = nil) case socketError(rawError: Error?) case close(message: String? = nil) @@ -105,6 +106,7 @@ public enum SignalClientError: LiveKitError { public var description: String { switch self { + case .cancelled: return buildDescription("cancelled") case let .state(message): return buildDescription("state", message) case let .socketError(rawError): return buildDescription("socketError", rawError: rawError) case let .close(message): return buildDescription("close", message) diff --git a/Sources/LiveKit/Support/AsyncCompleter.swift b/Sources/LiveKit/Support/AsyncCompleter.swift index 36d9da7cd..fc7453451 100644 --- a/Sources/LiveKit/Support/AsyncCompleter.swift +++ b/Sources/LiveKit/Support/AsyncCompleter.swift @@ -96,6 +96,9 @@ class AsyncCompleter: Loggable { public func cancel() { _cancelTimer() + if _continuation != nil { + log("\(label) cancelled") + } _continuation?.resume(throwing: AsyncCompleterError.cancelled) _continuation = nil _returningValue = nil @@ -140,24 +143,29 @@ class AsyncCompleter: Loggable { // Cancel any previous waits cancel() - // Create a timed continuation - return try await withCheckedThrowingContinuation { continuation in - // Store reference to continuation - _continuation = continuation - - // Create time-out block - let timeOutBlock = DispatchWorkItem { [weak self] in - guard let self else { return } - self.log("\(self.label) timedOut") - self._continuation?.resume(throwing: AsyncCompleterError.timedOut) - self._continuation = nil - self.cancel() + // Create a cancel-aware timed continuation + return try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { continuation in + // Store reference to continuation + _continuation = continuation + + // Create time-out block + let timeOutBlock = DispatchWorkItem { [weak self] in + guard let self else { return } + self.log("\(self.label) timedOut") + self._continuation?.resume(throwing: AsyncCompleterError.timedOut) + self._continuation = nil + self.cancel() + } + + // Schedule time-out block + _queue.asyncAfter(deadline: .now() + _timeOut, execute: timeOutBlock) + // Store reference to time-out block + _timeOutBlock = timeOutBlock } - - // Schedule time-out block - _queue.asyncAfter(deadline: .now() + _timeOut, execute: timeOutBlock) - // Store reference to time-out block - _timeOutBlock = timeOutBlock + } onCancel: { + // Cancel completer when Task gets cancelled + cancel() } } } diff --git a/Sources/LiveKit/Support/AsyncQueueActor.swift b/Sources/LiveKit/Support/AsyncQueueActor.swift index 14a930dc1..2d7b32566 100644 --- a/Sources/LiveKit/Support/AsyncQueueActor.swift +++ b/Sources/LiveKit/Support/AsyncQueueActor.swift @@ -49,11 +49,13 @@ actor AsyncQueueActor { } /// Mark as `.resumed` and process each element with an async `block`. - func resume(_ block: (T) async -> Void) async { + func resume(_ block: (T) async throws -> Void) async throws { state = .resumed if queue.isEmpty { return } for element in queue { - await block(element) + // Check cancellation before processing next block... + try Task.checkCancellation() + try await block(element) } queue.removeAll() } diff --git a/Sources/LiveKit/Support/WebSocket.swift b/Sources/LiveKit/Support/WebSocket.swift index df19d76cb..856dabb30 100644 --- a/Sources/LiveKit/Support/WebSocket.swift +++ b/Sources/LiveKit/Support/WebSocket.swift @@ -42,24 +42,27 @@ class WebSocket: NSObject, Loggable, AsyncSequence, URLSessionWebSocketDelegate waitForNextValue() } - init(url: URL) { + init(url: URL) async throws { request = URLRequest(url: url, cachePolicy: .useProtocolCachePolicy, timeoutInterval: .defaultSocketConnect) + super.init() + try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { continuation in + connectContinuation = continuation + task.resume() + } + } onCancel: { + // Cancel(reset) when Task gets cancelled + close() + } } deinit { - reset() - } - - public func connect() async throws { - try await withCheckedThrowingContinuation { continuation in - connectContinuation = continuation - task.resume() - } + close() } - func reset() { + func close() { task.cancel(with: .goingAway, reason: nil) connectContinuation?.resume(throwing: SignalClientError.socketError(rawError: nil)) connectContinuation = nil diff --git a/Tests/LiveKitTests/CompleterTests.swift b/Tests/LiveKitTests/CompleterTests.swift index 22d778875..5f7ee7cc6 100644 --- a/Tests/LiveKitTests/CompleterTests.swift +++ b/Tests/LiveKitTests/CompleterTests.swift @@ -22,5 +22,48 @@ class CompleterTests: XCTestCase { override func tearDown() async throws {} - func testCompleter() async throws {} + func testCompleterReuse() async throws { + let completer = AsyncCompleter(label: "Test01", timeOut: .seconds(1)) + do { + try await completer.wait() + } catch AsyncCompleterError.timedOut { + print("Timed out 1") + } + // Re-use + do { + try await completer.wait() + } catch AsyncCompleterError.timedOut { + print("Timed out 2") + } + } + + func testCompleterCancel() async throws { + let completer = AsyncCompleter(label: "cancel-test", timeOut: .never) + do { + // Run Tasks in parallel + try await withThrowingTaskGroup(of: Void.self) { group in + + group.addTask { + print("Task 1: Waiting...") + try await completer.wait() + } + + group.addTask { + print("Task 2: Started...") + // Cancel after 1 second + try await Task.sleep(until: .now + .seconds(1), clock: .continuous) + print("Task 2: Cancelling completer...") + completer.cancel() + } + + try await group.waitForAll() + } + } catch let error as AsyncCompleterError where error == .timedOut { + print("Completer timed out") + } catch let error as AsyncCompleterError where error == .cancelled { + print("Completer cancelled") + } catch { + print("Unknown error: \(error)") + } + } } diff --git a/Tests/LiveKitTests/TimerTests.swift b/Tests/LiveKitTests/TimerTests.swift index 5c057b9b4..9b036b168 100644 --- a/Tests/LiveKitTests/TimerTests.swift +++ b/Tests/LiveKitTests/TimerTests.swift @@ -15,7 +15,6 @@ */ @testable import LiveKit -import Promises import XCTest class TimerTests: XCTestCase { @@ -35,10 +34,10 @@ class TimerTests: XCTestCase { if self.counter == 3 { print("suspending timer for 3s...") self.timer.suspend() - Promise(()).delay(3).then { - print("restarting timer...") - self.timer.restart() - } +// Promise(()).delay(3).then { +// print("restarting timer...") +// self.timer.restart() +// } } if self.counter == 5 { diff --git a/Tests/LiveKitTests/WebSocketTests.swift b/Tests/LiveKitTests/WebSocketTests.swift index d8bbe7200..87795a6b9 100644 --- a/Tests/LiveKitTests/WebSocketTests.swift +++ b/Tests/LiveKitTests/WebSocketTests.swift @@ -18,20 +18,13 @@ import XCTest class WebSocketTests: XCTestCase { - lazy var socket: WebSocket = { - let url = URL(string: "wss://socketsbay.com/wss/v2/1/demo/")! - return WebSocket(url: url) - }() - override func setUpWithError() throws {} override func tearDown() async throws {} - func testCompleter1() async throws { - // Read messages - + func testWebSocket01() async throws { print("Connecting...") - try await socket.connect() + let socket = try await WebSocket(url: URL(string: "wss://socketsbay.com/wss/v2/1/demo/")!) print("Connected. Waiting for messages...") do {