diff --git a/Modules/Package.swift b/Modules/Package.swift index 4b5494e85591..e96cdcbaea08 100644 --- a/Modules/Package.swift +++ b/Modules/Package.swift @@ -13,6 +13,7 @@ let package = Package( .library(name: "WordPressFlux", targets: ["WordPressFlux"]), .library(name: "WordPressShared", targets: ["WordPressShared"]), .library(name: "WordPressUI", targets: ["WordPressUI"]), + .library(name: "AsyncCombine", targets: ["AsyncCombine"]), ], dependencies: [ .package(url: "https://github.com/airbnb/lottie-ios", from: "4.4.0"), @@ -49,6 +50,7 @@ let package = Package( .package(url: "https://github.com/Automattic/wordpress-rs", revision: "alpha-swift-20240813"), .package(url: "https://github.com/wordpress-mobile/GutenbergKit", revision: "6cc307e7fc24910697be5f71b7d70f465a9c0f63"), .package(url: "https://github.com/Automattic/color-studio", branch: "trunk"), + .package(url: "https://github.com/apple/swift-async-algorithms", from: "1.0.0"), ], targets: XcodeSupport.targets + [ .target(name: "JetpackStatsWidgetsCore"), @@ -61,12 +63,14 @@ let package = Package( .target(name: "WordPressSharedObjC", resources: [.process("Resources")]), .target(name: "WordPressShared", dependencies: [.target(name: "WordPressSharedObjC")], resources: [.process("Resources")]), .target(name: "WordPressUI", dependencies: [.target(name: "WordPressShared")], resources: [.process("Resources")]), + .target(name: "AsyncCombine"), .testTarget(name: "JetpackStatsWidgetsCoreTests", dependencies: [.target(name: "JetpackStatsWidgetsCore")]), .testTarget(name: "DesignSystemTests", dependencies: [.target(name: "DesignSystem")]), .testTarget(name: "WordPressFluxTests", dependencies: ["WordPressFlux"]), .testTarget(name: "WordPressSharedTests", dependencies: [.target(name: "WordPressShared")]), .testTarget(name: "WordPressSharedObjCTests", dependencies: [.target(name: "WordPressShared")], resources: [.process("Resources")]), .testTarget(name: "WordPressUITests", dependencies: [.target(name: "WordPressUI")]), + .testTarget(name: "AsyncCombineTests", dependencies: [.target(name: "AsyncCombine"), .product(name: "AsyncAlgorithms", package: "swift-async-algorithms")]), ] ) @@ -134,6 +138,7 @@ enum XcodeSupport { "WordPressFlux", "WordPressShared", "WordPressUI", + "AsyncCombine", .product(name: "Alamofire", package: "Alamofire"), .product(name: "AlamofireImage", package: "AlamofireImage"), .product(name: "AutomatticAbout", package: "AutomatticAbout-swift"), diff --git a/Modules/Sources/AsyncCombine/Just.swift b/Modules/Sources/AsyncCombine/Just.swift new file mode 100644 index 000000000000..77bc47ab3f7e --- /dev/null +++ b/Modules/Sources/AsyncCombine/Just.swift @@ -0,0 +1,59 @@ +import Foundation + +public struct JustAsyncSequence: AsyncSequence { + var producer: () async -> Element + + public init(_ output: Element) { + self.init({ output }) + } + + public init(_ producer: @escaping () async -> Element) { + self.producer = producer + } + + public func makeAsyncIterator() -> Iterator { + Iterator(producer: producer) + } + + public struct Iterator: AsyncIteratorProtocol { + var started = false + let producer: () async -> Element + + public mutating func next() async -> Element? { + guard !started else { return nil } + + started = true + let result = await producer() + return Task.isCancelled ? nil : result + } + } +} + +public struct JustThrowingAsyncSequence: AsyncSequence { + var producer: () async throws -> Element + + public init(_ error: Error) { + self.init({ throw error }) + } + + public init(_ producer: @escaping () async throws -> Element) { + self.producer = producer + } + + public func makeAsyncIterator() -> Iterator { + Iterator(producer: producer) + } + + public struct Iterator: AsyncIteratorProtocol { + var started = false + let producer: () async throws -> Element + + public mutating func next() async throws -> Element? { + guard !started else { return nil } + + started = true + let result = try await producer() + return Task.isCancelled ? nil : result + } + } +} diff --git a/Modules/Sources/AsyncCombine/Publisher.swift b/Modules/Sources/AsyncCombine/Publisher.swift new file mode 100644 index 000000000000..d586969047f5 --- /dev/null +++ b/Modules/Sources/AsyncCombine/Publisher.swift @@ -0,0 +1,87 @@ +import Foundation +import Combine + +public extension AsyncSequence { + var publisher: AnyPublisher { + precondition(!(self is AsyncStream) && !(self is AsyncThrowingStream), "Use sharedPublisher for AsyncStream and AsyncThrowingStream") + + return StreamPublisher(sequence: self).eraseToAnyPublisher() + } + + var sharedPublisher: Publishers.Share> { + StreamPublisher(sequence: self).eraseToAnyPublisher().share() + } +} + +public extension AsyncStream { + @available(*, deprecated, message: "Use sharedPublisher for AsyncStream and AsyncThrowingStream") + var publisher: AnyPublisher { + fatalError("Use sharedPublisher for AsyncStream and AsyncThrowingStream") + } +} + +public extension AsyncThrowingStream { + @available(*, deprecated, message: "Use sharedPublisher for AsyncStream and AsyncThrowingStream") + var publisher: AnyPublisher { + fatalError("Use sharedPublisher for AsyncStream and AsyncThrowingStream") + } +} + +class StreamPublisher: Publisher { + + typealias Output = Sequence.Element + typealias Failure = Error + + let sequence: Sequence + + init(sequence: Sequence) { + self.sequence = sequence + } + + func receive(subscriber: S) where S.Input == Output, S.Failure == Failure { + let subscription = Subscription(sequence: sequence, subscriber: subscriber) + subscriber.receive(subscription: subscription) + } + + class Subscription: Combine.Subscription where S.Input == Output, S.Failure == Failure { + let sequence: Sequence + + var subscriber: S? + var task: Task? + var outputSent: Int = 0 + + init(sequence: Sequence, subscriber: S) { + self.sequence = sequence + self.subscriber = subscriber + } + + func request(_ demand: Subscribers.Demand) { + task = Task { + do { + if let max = demand.max { + for try await element in sequence.prefix(max) { + try Task.checkCancellation() + + _ = subscriber?.receive(element) + } + } else { + for try await element in sequence { + try Task.checkCancellation() + + _ = subscriber?.receive(element) + } + } + subscriber?.receive(completion: .finished) + } catch { + subscriber?.receive(completion: .failure(error)) + } + } + } + + func cancel() { + task?.cancel() + task = nil + subscriber = nil + } + } +} diff --git a/Modules/Sources/AsyncCombine/Task.swift b/Modules/Sources/AsyncCombine/Task.swift new file mode 100644 index 000000000000..e39b624a8949 --- /dev/null +++ b/Modules/Sources/AsyncCombine/Task.swift @@ -0,0 +1,38 @@ +import Foundation +import Combine + +extension Task where Failure == Never { + var stream: AsyncStream { + AsyncStream(unfolding: { await self.value }, onCancel: cancel) + } + + public var publisher: Publishers.Share> { + stream.sharedPublisher + } +} + +extension Task where Failure == Error { + var stream: AsyncThrowingStream { + let builder: (AsyncThrowingStream.Continuation) -> Void = { continuation in + Task { + do { + let output = try await self.value + continuation.yield(output) + continuation.finish() + } catch { + continuation.finish(throwing: error) + } + } + continuation.onTermination = { + if case .cancelled = $0 { + self.cancel() + } + } + } + return AsyncThrowingStream(Success.self, builder) + } + + public var publisher: Publishers.Share> { + stream.sharedPublisher + } +} diff --git a/Modules/Tests/AsyncCombineTests/FetchOnce.swift b/Modules/Tests/AsyncCombineTests/FetchOnce.swift new file mode 100644 index 000000000000..7528509af016 --- /dev/null +++ b/Modules/Tests/AsyncCombineTests/FetchOnce.swift @@ -0,0 +1,294 @@ +import Foundation +import Combine +import XCTest + +@testable import AsyncCombine + +class FetchOnce: XCTestCase { + + var tracks: Tracks! + + override func setUp() { + super.setUp() + tracks = Tracks() + } + + func testMultipleFetchesBeforeTaskCompletes() async throws { + let expectation = self.expectation(description: "All fetches are completed") + expectation.expectedFulfillmentCount = 2 + + let imageDownloader = ImageDownloader(tracks: tracks) + + for _ in 1...expectation.expectedFulfillmentCount { + try await Task.sleep(for: .milliseconds(100)) + Task.detached { + do { + let image = try await imageDownloader.fetch() + XCTAssertEqual(image, "IMAGE") + } catch { + XCTFail("Unexpected error: \(error)") + } + + expectation.fulfill() + } + } + + await fulfillment(of: [expectation], timeout: 1) + + let events = await tracks.events + XCTAssertEqual(events.count(where: { $0 == .HTTPGetCalled }), 1) + } + + func testMultipleFetchesBeforeSecondTaskCompletes() async throws { + let expectation = self.expectation(description: "All fetches are completed") + expectation.expectedFulfillmentCount = 6 + + let imageDownloader = ImageDownloader(tracks: tracks) + + for _ in 1...expectation.expectedFulfillmentCount { + try await Task.sleep(for: .milliseconds(80)) + Task.detached { + do { + let image = try await imageDownloader.fetch() + XCTAssertEqual(image, "IMAGE") + } catch { + XCTFail("Unexpected error: \(error)") + } + + expectation.fulfill() + } + } + + await fulfillment(of: [expectation], timeout: 1) + + let events = await tracks.events + XCTAssertEqual(events.count(where: { $0 == .HTTPGetCalled }), 2) + } + + func testCancellation_Reference() async throws { + let cancelled = expectation(description: "Cancelled") + let timer = Just(42) + .delay(for: .milliseconds(300), scheduler: DispatchQueue.main) + .handleEvents(receiveCancel: cancelled.fulfill) + + let expectation = self.expectation(description: "Cancelled") + let task = Task { + let result = await timer.values.first { _ in true } + if result == nil { + expectation.fulfill() + } + } + + try await Task.sleep(for: .milliseconds(100)) + task.cancel() + + await fulfillment(of: [cancelled, expectation], timeout: 1) + } + + func testCancellation() async throws { + let expectation = expectation(description: "Cancelled") + let imageDownloader = ImageDownloader(tracks: tracks) + let task = Task { + do { + let _ = try await imageDownloader.fetch() + XCTFail("Unexpected success") + } catch { + expectation.fulfill() + } + } + + try await Task.sleep(for: .milliseconds(100)) + task.cancel() + + await fulfillment(of: [expectation], timeout: 1) + + let events = await tracks.events + XCTAssertEqual(events.count(where: { $0 == .HTTPGetCalled }), 1) + XCTAssertFalse(events.contains(.HTTPGetCompleted)) + } + + func testCancelMultipleFetches() async throws { + let expectation = expectation(description: "Cancelled") + expectation.expectedFulfillmentCount = 2 + + let imageDownloader = ImageDownloader(tracks: tracks) + var tasks: [Task] = [] + for _ in 1...2 { + let task = Task { + do { + let _ = try await imageDownloader.fetch() + XCTFail("Unexpected success") + } catch { + expectation.fulfill() + } + } + tasks.append(task) + try await Task.sleep(for: .milliseconds(100)) + } + + for task in tasks { + task.cancel() + } + + await fulfillment(of: [expectation], timeout: 1) + + let events = await tracks.events + XCTAssertEqual(events.count(where: { $0 == .HTTPGetCalled }), 1) + XCTAssertFalse(events.contains(.HTTPGetCompleted)) + } + + func testCancelSomeButNotAllFetches() async throws { + let success = expectation(description: "Successful result") + success.expectedFulfillmentCount = 3 + + let imageDownloader = ImageDownloader(tracks: tracks) + for _ in 1...3 { + Task.detached { + do { + let _ = try await imageDownloader.fetch() + success.fulfill() + } catch { + XCTFail("Unexpected error") + } + } + try await Task.sleep(for: .milliseconds(10)) + } + + let cancelled = expectation(description: "Cancelled") + cancelled.expectedFulfillmentCount = 3 + var tasksToCancel: [Task] = [] + for _ in 1...3 { + let task = Task.detached { + do { + let _ = try await imageDownloader.fetch() + XCTFail("Unexpected success") + } catch { + cancelled.fulfill() + } + } + tasksToCancel.append(task) + try await Task.sleep(for: .milliseconds(10)) + } + + for task in tasksToCancel { + task.cancel() + } + + await fulfillment(of: [success, cancelled], timeout: 1) + } + + func testTaskIsCanceled() async throws { + let taskCancelled = expectation(description: "Task cancelled") + let publisher = Task { + do { + try await Task.sleep(for: .seconds(1)) + } catch is CancellationError { + taskCancelled.fulfill() + } catch { + XCTFail("Unexpected error: \(error)") + } + } + .publisher + + let tasks = [1...3].map { _ in + Task.detached { + let _ = try await publisher.values.reduce(into: []) { $0.append($1) } + } + } + try await Task.sleep(for: .milliseconds(10)) + for task in tasks { + task.cancel() + } + + await fulfillment(of: [taskCancelled], timeout: 0.3) + } + +} + +actor ImageDownloader { + let tracks: Tracks + let http: HTTP + var publisher: AnyPublisher? + var task: Task? + + init(tracks: Tracks) { + self.tracks = tracks + self.http = HTTP(tracks: tracks) + } + + func fetch() async throws -> String { + if publisher == nil { + publisher = Task { try await self.http.get() } + .stream + .sharedPublisher + .eraseToAnyPublisher() + + Task.detached { [tracks] in + await tracks.log(.fetchPublisherCreated) + } + } + Task.detached { [tracks] in + await tracks.log(.fetchWaitForResult) + } + + defer { + publisher = nil + Task.detached { [tracks] in + await tracks.log(.fetchPublisherDestoryed) + } + } + + if let output = try await publisher!.values.first(where: { _ in true }) { + return output + } + + throw CancellationError() + } +} + +class HTTP { + let tracks: Tracks + + init(tracks: Tracks) { + self.tracks = tracks + } + + func get() async throws -> String { + await tracks.log(.HTTPGetCalled) + try await Task.sleep(for: .milliseconds(300)) + await tracks.log(.HTTPGetCompleted) + return "IMAGE" + } +} + +actor Tracks { + enum Event: Equatable { + case fetchPublisherCreated + case fetchWaitForResult + case fetchPublisherDestoryed + case HTTPGetCalled + case HTTPGetCompleted + } + + var raw: [(Event, Date)] = [] + + var events: [Event] { + raw.map { event, _ in event } + } + + func log(_ event: Event) { + raw.append((event, Date())) + } + + func print() { + guard !events.isEmpty else { return } + + let startTime = raw.first!.1 + + for (event, time) in raw { + let duration = time.timeIntervalSince(startTime) + let formattedDuration = String(format: "%.3f", duration) + Swift.print("[\(formattedDuration)] \(event)") + } + } +} diff --git a/Modules/Tests/AsyncCombineTests/Others/AsyncPublisherTests.swift b/Modules/Tests/AsyncCombineTests/Others/AsyncPublisherTests.swift new file mode 100644 index 000000000000..66f45b569726 --- /dev/null +++ b/Modules/Tests/AsyncCombineTests/Others/AsyncPublisherTests.swift @@ -0,0 +1,83 @@ +import Foundation +import Combine +import XCTest + +class AsyncPublisherTests: XCTestCase { + + func testCollectValues() async throws { + let publisher = Array(1...10).publisher + let values = await publisher.values.reduce(into: []) { $0.append($1) } + XCTAssertEqual(values, Array(1...10)) + } + + func testCollectError() async throws { + struct TestError: Error {} + + let publisher = Fail(outputType: Int.self, failure: TestError()) + do { + let _: [Int] = try await publisher.values.reduce(into: []) { $0.append($1) } + XCTFail("Unexpected success") + } catch { + XCTAssertTrue(error is TestError) + } + } + + func testEmitValuesAndError() async throws { + struct TestError: Error {} + + let publisher = Record(output: [1, 2, 3], completion: .failure(TestError())) + + var values = [Int]() + do { + for try await value in publisher.values { + values.append(value) + } + XCTFail("Unexpected success") + } catch { + XCTAssertTrue(error is TestError) + XCTAssertEqual(values, [1, 2, 3]) + } + } + + func testCancellation() async throws { + let publisherCancelled = expectation(description: "Cancelled") + let publisher = Just(42) + .setFailureType(to: Error.self) + .delay(for: .milliseconds(300), scheduler: DispatchQueue.main) + .handleEvents(receiveCancel: publisherCancelled.fulfill) + + let collected = self.expectation(description: "No output because task is cancelled") + let task = Task { + let values = try await publisher.values.reduce(into: []) { $0.append($1) } + XCTAssertEqual(values, []) + collected.fulfill() + } + + try await Task.sleep(for: .milliseconds(100)) + task.cancel() + + await fulfillment(of: [publisherCancelled, collected], timeout: 0.5) + } + + func testCancelErrorPublisher() async throws { + struct TestError: Error {} + + let publisherCancelled = expectation(description: "Cancelled") + let publisher = Fail(outputType: Int.self, failure: TestError()) + .delay(for: .milliseconds(300), scheduler: DispatchQueue.main) + .handleEvents(receiveCancel: publisherCancelled.fulfill) + + let collected = self.expectation(description: "No output because task is cancelled") + let task = Task { + let values = try await publisher.values.reduce(into: []) { $0.append($1) } + XCTAssertEqual(values, []) + collected.fulfill() + } + + try await Task.sleep(for: .milliseconds(100)) + task.cancel() + + await fulfillment(of: [publisherCancelled, collected], timeout: 0.5) + } + +} diff --git a/Modules/Tests/AsyncCombineTests/Others/StandardStreamTests.swift b/Modules/Tests/AsyncCombineTests/Others/StandardStreamTests.swift new file mode 100644 index 000000000000..1910aab14f7a --- /dev/null +++ b/Modules/Tests/AsyncCombineTests/Others/StandardStreamTests.swift @@ -0,0 +1,15 @@ +import Foundation +import XCTest +import AsyncAlgorithms + +class StandardStreamTests: XCTestCase { + + func testColdSequence() async throws { + let sequence = Array(1...50).async.prefix(10) + let consumer1: [Int] = await sequence.reduce(into: []) { $0.append($1) } + let consumer2: [Int] = await sequence.reduce(into: []) { $0.append($1) } + XCTAssertEqual(consumer1, Array(1...10)) + XCTAssertEqual(consumer2, Array(1...10)) + } + +} diff --git a/Modules/Tests/AsyncCombineTests/SharedPublisherTests.swift b/Modules/Tests/AsyncCombineTests/SharedPublisherTests.swift new file mode 100644 index 000000000000..81a11d0d50ba --- /dev/null +++ b/Modules/Tests/AsyncCombineTests/SharedPublisherTests.swift @@ -0,0 +1,180 @@ +import XCTest +import Combine + +@testable import AsyncCombine + +class SharedPublisherTests: XCTestCase { + var cancellables: Set = [] + + override func tearDown() { + super.tearDown() + cancellables.removeAll() + } + + /// Create publishers that emit elements in [0, 1, 2, 3, 4], with 0.1 seconds interval. + /// + /// The first one is publisher implemented by this library, and the second one is the Timer publisher from Foundation. + /// + /// The Timer publisher is a `shared` publisher, where one upstream publisher broadcasts outputs to all subscribers. + /// This publisher instance is created to match `AsyncStream`'s behaviour where one stream instance emits elements to + /// all its `await`ers. + func createComparisonPublishers() -> (stream: AnyPublisher, timer: AnyPublisher) { + let stream = Counter(start: 0, end: 4, interval: .milliseconds(100)) + .sharedPublisher + .eraseToAnyPublisher() + + let timer = Timer.publish(every: 0.1, on: .main, in: .default) + .autoconnect() + .prefix(5) + .scan(-1) { counter, _ in counter + 1 } + .share() // IMPORTANT + .eraseToAnyPublisher() + + return (stream, timer) + } + + func subscribAtTheSameTime(publisher: AnyPublisher, line: UInt = #line) { + let expectation = XCTestExpectation(description: "Subscribers complete") + expectation.expectedFulfillmentCount = 2 + var first: [Int] = [] + var second: [Int] = [] + + publisher.sink( + receiveCompletion: { _ in expectation.fulfill() }, + receiveValue: { value in + first.append(value) + } + ).store(in: &cancellables) + + publisher.sink( + receiveCompletion: { _ in expectation.fulfill() }, + receiveValue: { value in + second.append(value) + } + ).store(in: &cancellables) + + wait(for: [expectation], timeout: 1) + + XCTAssertEqual(first, [0, 1, 2, 3, 4], "The first subscriber receives the full sequence", line: line) + XCTAssertEqual(second, [0, 1, 2, 3, 4], "The second subscriber receives the full sequence", line: line) + } + + func testSubscribAtTheSameTime() { + let (stream, _) = createComparisonPublishers() + subscribAtTheSameTime(publisher: stream) + } + + func testSubscribAtTheSameTime_Reference() { + let (_, timer) = createComparisonPublishers() + subscribAtTheSameTime(publisher: timer) + } + + func singleSubscriptionWithDelay(publisher: AnyPublisher, line: UInt = #line) { + let expectation = XCTestExpectation(description: "Subscribers complete") + var received: [Int] = [] + + DispatchQueue.main.asyncAfter(deadline: .now() + 0.2) { + publisher.sink( + receiveCompletion: { _ in expectation.fulfill() }, + receiveValue: { value in + received.append(value) + } + ).store(in: &self.cancellables) + } + + wait(for: [expectation], timeout: 1) + + XCTAssertEqual(received, [0, 1, 2, 3, 4], "Receives the full sequence", line: line) + } + + func testSingleSubscriptionWithDelay() { + let (stream, _) = createComparisonPublishers() + singleSubscriptionWithDelay(publisher: stream) + } + + func testSingleSubscriptionWithDelay_Reference() { + let (_, timer) = createComparisonPublishers() + singleSubscriptionWithDelay(publisher: timer) + } + + func multiSubscriptionWithOneDelay(publisher: AnyPublisher, line: UInt = #line) { + let expectation = XCTestExpectation(description: "Subscribers complete") + expectation.expectedFulfillmentCount = 2 + var first: [Int] = [] + var second: [Int] = [] + + publisher.sink( + receiveCompletion: { _ in expectation.fulfill() }, + receiveValue: { value in + first.append(value) + } + ).store(in: &cancellables) + + DispatchQueue.main.asyncAfter(deadline: .now() + 0.2) { + publisher.sink( + receiveCompletion: { _ in expectation.fulfill() }, + receiveValue: { value in + second.append(value) + } + ).store(in: &self.cancellables) + } + + wait(for: [expectation], timeout: 1) + + XCTAssertEqual(first, [0, 1, 2, 3, 4], "The first subscriber receives the full sequence", line: line) + XCTAssertGreaterThan(first.count, second.count, "The second subscriber only recieves a subset of the full sequence", line: line) + XCTAssert(first.ends(with: second), "The second subscriber receives the tail of the full sequence", line: line) + } + + func testMultiSubscriptionWithOneDelay() { + let (stream, _) = createComparisonPublishers() + multiSubscriptionWithOneDelay(publisher: stream) + } + + func testMultiSubscriptionWithOneDelay_Reference() { + let (_, timer) = createComparisonPublishers() + multiSubscriptionWithOneDelay(publisher: timer) + } + + func multiSubscriptionWithMultiDelay(publisher: AnyPublisher, line: UInt = #line) { + let expectation = XCTestExpectation(description: "Subscribers complete") + expectation.expectedFulfillmentCount = 2 + var first: [Int] = [] + var second: [Int] = [] + + DispatchQueue.main.asyncAfter(deadline: .now() + 0.2) { + publisher.sink( + receiveCompletion: { _ in expectation.fulfill() }, + receiveValue: { value in + first.append(value) + } + ).store(in: &self.cancellables) + } + + DispatchQueue.main.asyncAfter(deadline: .now() + 0.4) { + publisher.sink( + receiveCompletion: { _ in expectation.fulfill() }, + receiveValue: { value in + second.append(value) + } + ).store(in: &self.cancellables) + } + + wait(for: [expectation], timeout: 1) + + XCTAssertEqual(first, [0, 1, 2, 3, 4], "The first subscriber receives the full sequence", line: line) + XCTAssertGreaterThan(first.count, second.count, "The second subscriber only recieves a subset of the full sequence", line: line) + XCTAssert(first.ends(with: second), "The second subscriber receives the tail of the full sequence", line: line) + } + + func testMultiSubscriptionWithMultiDelay() { + let (stream, _) = createComparisonPublishers() + multiSubscriptionWithMultiDelay(publisher: stream) + } + + func testMultiSubscriptionWithMultiDelay_Reference() { + let (_, timer) = createComparisonPublishers() + multiSubscriptionWithMultiDelay(publisher: timer) + } + +} diff --git a/Modules/Tests/AsyncCombineTests/StandardPublisherTests.swift b/Modules/Tests/AsyncCombineTests/StandardPublisherTests.swift new file mode 100644 index 000000000000..19c9f6b2b4b8 --- /dev/null +++ b/Modules/Tests/AsyncCombineTests/StandardPublisherTests.swift @@ -0,0 +1,133 @@ +import XCTest +import Combine + +@testable import AsyncCombine + +class StandardPublisherTests: XCTestCase { + var cancellables: Set = [] + + override func tearDown() { + super.tearDown() + cancellables.removeAll() + } + + func testEmitsValues() { + let expectation = XCTestExpectation(description: "Publisher emits values") + var receivedValues: [Int] = [] + + let publisher = JustAsyncSequence(42).publisher + publisher.sink( + receiveCompletion: { _ in }, + receiveValue: { value in + receivedValues.append(value) + expectation.fulfill() + } + ).store(in: &cancellables) + + wait(for: [expectation], timeout: 0.3) + XCTAssertEqual(receivedValues, [42]) + } + + func testCompletes() { + let expectation = XCTestExpectation(description: "Publisher completes") + + let publisher = JustAsyncSequence(42).publisher + publisher.sink( + receiveCompletion: { completion in + if case .finished = completion { + expectation.fulfill() + } + }, + receiveValue: { _ in } + ).store(in: &cancellables) + + wait(for: [expectation], timeout: 0.3) + } + + func testPropagatesError() { + let expectation = XCTestExpectation(description: "Publisher propagates error") + + struct TestError: Error {} + let publisher = JustThrowingAsyncSequence(TestError()).publisher + publisher.sink( + receiveCompletion: { completion in + if case .failure(let error) = completion { + XCTAssertTrue(error is TestError) + expectation.fulfill() + } + }, + receiveValue: { _ in XCTFail("Should not emit values") } + ).store(in: &cancellables) + + wait(for: [expectation], timeout: 0.3) + } + + func testHandlesCancellation() { + let expectation = XCTestExpectation(description: "Publisher is cancelled") + expectation.isInverted = true + + let cancelled = XCTestExpectation(description: "Task is cancelled") + cancelled.isInverted = true + + let publisher = JustThrowingAsyncSequence { + try await Task.sleep(for: .microseconds(300)) + cancelled.fulfill() + return 42 + } + .publisher + let cancellable = publisher.sink( + receiveCompletion: { _ in XCTFail("Should not complete") }, + receiveValue: { _ in XCTFail("Should not emit values") } + ) + cancellable.cancel() + + wait(for: [expectation], timeout: 0.5) + } + + func testFirstOperator() { + let expectation = XCTestExpectation(description: "Publisher emits first value and completes") + var receivedValues: [Int] = [] + + let publisher = Counter(start: 0, end: 9, interval: .milliseconds(100)) + .publisher + let firstPublisher = publisher.first() + + firstPublisher.sink( + receiveCompletion: { completion in + if case .finished = completion { + expectation.fulfill() + } + }, + receiveValue: { value in + receivedValues.append(value) + } + ).store(in: &cancellables) + + wait(for: [expectation], timeout: 1.0) + XCTAssertEqual(receivedValues, [0]) + } + + func testPrefixOperator() { + let expectation = XCTestExpectation(description: "Publisher emits first N values and completes") + var receivedValues: [Int] = [] + + let publisher = Counter(start: 0, end: 9, interval: .milliseconds(100)) + .publisher + let prefixedPublisher = publisher.prefix(3) + + prefixedPublisher.sink( + receiveCompletion: { completion in + if case .finished = completion { + expectation.fulfill() + } + }, + receiveValue: { value in + receivedValues.append(value) + } + ).store(in: &cancellables) + + wait(for: [expectation], timeout: 1.0) + XCTAssertEqual(receivedValues, [0, 1, 2]) + } + +} diff --git a/Modules/Tests/AsyncCombineTests/Support.swift b/Modules/Tests/AsyncCombineTests/Support.swift new file mode 100644 index 000000000000..8e0d63fe7d14 --- /dev/null +++ b/Modules/Tests/AsyncCombineTests/Support.swift @@ -0,0 +1,58 @@ +import Foundation +import XCTest + +struct Counter: AsyncSequence { + typealias Element = Int + let start: Int + let end: Int + let interval: Duration + + struct AsyncIterator: AsyncIteratorProtocol { + let start: Int + let end: Int + let interval: Duration + var current: Int + + init(start: Int, end: Int, interval: Duration) { + self.start = start + self.end = end + self.interval = interval + self.current = start + } + + mutating func next() async throws -> Int? { + try await Task.sleep(for: interval) + + guard current <= end else { + return nil + } + + let result = current + current += 1 + return result + } + } + + func makeAsyncIterator() -> AsyncIterator { + return AsyncIterator(start: start, end: end, interval: interval) + } +} + +extension Sequence where Element: Equatable { + func ends(with other: S) -> Bool where Element == S.Element { + reversed().starts(with: other.reversed()) + } +} + +class SupportTests: XCTestCase { + + func testEndsWith() { + XCTAssertFalse([1, 2, 3].ends(with: [1])) + XCTAssertFalse([1, 2, 3].ends(with: [1, 3])) + + XCTAssertTrue([1, 2, 3].ends(with: [3])) + XCTAssertTrue([1, 2, 3].ends(with: [2, 3])) + XCTAssertTrue([1, 2, 3].ends(with: [1, 2, 3])) + } + +} diff --git a/WordPress.xcworkspace/xcshareddata/swiftpm/Package.resolved b/WordPress.xcworkspace/xcshareddata/swiftpm/Package.resolved index 6e5903f96a77..93543e3e693e 100644 --- a/WordPress.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/WordPress.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -313,6 +313,24 @@ "version" : "2.3.1" } }, + { + "identity" : "swift-async-algorithms", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-async-algorithms", + "state" : { + "revision" : "5c8bd186f48c16af0775972700626f0b74588278", + "version" : "1.0.2" + } + }, + { + "identity" : "swift-collections", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-collections.git", + "state" : { + "revision" : "671108c96644956dddcd89dd59c203dcdb36cec7", + "version" : "1.1.4" + } + }, { "identity" : "swift-log", "kind" : "remoteSourceControl", diff --git a/WordPress/Classes/Utility/Media/ImageDownloader.swift b/WordPress/Classes/Utility/Media/ImageDownloader.swift index f81deb20962d..7e169768e43a 100644 --- a/WordPress/Classes/Utility/Media/ImageDownloader.swift +++ b/WordPress/Classes/Utility/Media/ImageDownloader.swift @@ -1,4 +1,6 @@ import UIKit +import Combine +import AsyncCombine struct ImageRequestOptions { /// Resize the thumbnail to the given size. By default, `nil`. @@ -30,7 +32,7 @@ actor ImageDownloader { ) } - private var tasks: [String: ImageDataTask] = [:] + private var tasks: [String: AnyPublisher] = [:] init(cache: MemoryCacheProtocol = MemoryCache.shared) { self.cache = cache @@ -115,32 +117,21 @@ actor ImageDownloader { private func data(for request: URLRequest, options: ImageRequestOptions) async throws -> Data { let requestKey = request.urlRequest?.url?.absoluteString ?? "" - let task = tasks[requestKey] ?? ImageDataTask(key: requestKey, Task { - try await self._data(for: request, options: options, key: requestKey) - }) - task.downloader = self + let task = tasks[requestKey] ?? { + Task { + try await self._data(for: request, options: options, key: requestKey) + } + .publisher + .eraseToAnyPublisher() + }() - let subscriptionID = UUID() - task.subscriptions.insert(subscriptionID) tasks[requestKey] = task - return try await task.getData(subscriptionID: subscriptionID) - } - - fileprivate nonisolated func unsubscribe(_ subscriptionID: UUID, key: String) { - Task { - await _unsubscribe(subscriptionID, key: key) - } - } - - private func _unsubscribe(_ subscriptionID: UUID, key: String) { - guard let task = tasks[key], - task.subscriptions.remove(subscriptionID) != nil, - task.subscriptions.isEmpty else { - return + let result: [Data] = try await task.values.reduce(into: []) { $0.append($1) } + if result.count != 1 { + throw CancellationError() } - task.task.cancel() - tasks[key] = nil + return result[0] } private func _data(for request: URLRequest, options: ImageRequestOptions, key: String) async throws -> Data { @@ -161,27 +152,6 @@ actor ImageDownloader { } } -private final class ImageDataTask { - let key: String - var subscriptions = Set() - let task: Task - weak var downloader: ImageDownloader? - - init(key: String, _ task: Task) { - self.key = key - self.task = task - } - - func getData(subscriptionID: UUID) async throws -> Data { - try await withTaskCancellationHandler { - try await task.value - } onCancel: { [weak self] in - guard let self else { return } - self.downloader?.unsubscribe(subscriptionID, key: self.key) - } - } -} - // MARK: - ImageDownloader (Closures) extension ImageDownloader { diff --git a/WordPress/WordPressTest/ImageDownloaderTests.swift b/WordPress/WordPressTest/ImageDownloaderTests.swift index 12786eeb8ec5..f2bf78c96b69 100644 --- a/WordPress/WordPressTest/ImageDownloaderTests.swift +++ b/WordPress/WordPressTest/ImageDownloaderTests.swift @@ -64,10 +64,56 @@ class ImageDownloaderTests: CoreDataTestCase { let _ = try await task.value XCTFail() } catch { - XCTAssertEqual((error as? URLError)?.code, .cancelled) +// XCTAssertEqual((error as? URLError)?.code, .cancelled) + XCTAssertTrue(error is CancellationError) } } + func testCancelOneOfManySubscribers() async throws { + // GIVEN + let httpRequestReceived = self.expectation(description: "HTTP request received") + let imageURL = try XCTUnwrap(URL(string: "https://example.files.wordpress.com/2023/09/image.jpg")) + stub(condition: { _ in true }, response: { _ in + httpRequestReceived.fulfill() + + guard let sourceURL = try? XCTUnwrap(Bundle.test.url(forResource: "test-image", withExtension: "jpg")), + let data = try? Data(contentsOf: sourceURL) else { + return HTTPStubsResponse(error: URLError(.unknown)) + } + + return HTTPStubsResponse(data: data, statusCode: 200, headers: nil) + .responseTime(0.3) + }) + + // WHEN there are concurrent calls to download the same image and one of those downloads is cancelled + let taskCompleted = self.expectation(description: "Image downloaded") + taskCompleted.expectedFulfillmentCount = 3 + for _ in 1...3 { + try await Task.sleep(for: .milliseconds(50)) + Task.detached { + do { + _ = try await self.sut.image(from: imageURL) + } catch { + XCTFail("Unexpected error: \(error)") + } + taskCompleted.fulfill() + } + } + + let taskCancelled = expectation(description: "Task is cancelled") + let taskToBeCancelled = Task.detached { + do { + _ = try await self.sut.image(from: imageURL) + XCTFail("Unexpected successful result.") + } catch { + taskCancelled.fulfill() + } + } + taskToBeCancelled.cancel() + + await fulfillment(of: [httpRequestReceived, taskCompleted, taskCancelled], timeout: 0.5) + } + func testMemoryCache() async throws { // GIVEN let imageURL = try XCTUnwrap(URL(string: "https://example.files.wordpress.com/2023/09/image.jpg"))